Created
March 2, 2022 14:11
-
-
Save VanDavv/8446bc9727e6939d821eff82b68b5c37 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from aiohttp import web | |
from aiortc import RTCPeerConnection, RTCSessionDescription | |
import traceback | |
import numpy as np | |
from aiortc import VideoStreamTrack | |
import cv2 | |
from av import VideoFrame | |
import depthai as dai | |
import blobconverter | |
from depthai_sdk import PipelineManager, NNetManager, PreviewManager, Previews | |
class VideoTransformTrack(VideoStreamTrack): | |
def __init__(self): | |
super().__init__() # don't forget this! | |
self.dummy = False | |
self.pm = PipelineManager() | |
self.pm.createColorCam(xout=True, previewSize=(300, 300)) | |
self._nnManager = NNetManager(inputSize=(300, 300), nnFamily="mobilenet") | |
self.pm.setNnManager(self._nnManager) | |
self._nn = self._nnManager.createNN(pipeline=self.pm.pipeline, nodes=self.pm.nodes, source=Previews.color.name, blobPath=blobconverter.from_zoo("mobilenet-ssd", shaves=6)) | |
self.pm.addNn(nn=self._nn) | |
self.device = dai.Device(self.pm.pipeline) | |
self.device.startPipeline() | |
self.pv = PreviewManager(display=[Previews.color.name], createWindows=False) | |
self.pv.createQueues(self.device) | |
self._nnManager.createQueues(self.device) | |
async def get_frame(self): | |
self.pv.prepareFrames() | |
self._nnData, inNn = self._nnManager.parse() | |
if self._nnData is not None: | |
self._nnManager.draw(self.pv, self._nnData) | |
return self.pv.get(Previews.color.name) | |
async def return_frame(self, frame): | |
pts, time_base = await self.next_timestamp() | |
new_frame = VideoFrame.from_ndarray(frame, format="bgr24") | |
new_frame.pts = pts | |
new_frame.time_base = time_base | |
return new_frame | |
async def dummy_recv(self): | |
frame = np.zeros((300, 300, 3), np.uint8) | |
y, x = frame.shape[0] / 2, frame.shape[1] / 2 | |
left, top, right, bottom = int(x - 50), int(y - 30), int(x + 50), int(y + 30) | |
cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), cv2.FILLED) | |
cv2.putText(frame, "ERROR", (left, int((bottom + top) / 2 + 10)), cv2.FONT_HERSHEY_DUPLEX, 1.0, | |
(255, 255, 255), 1) | |
return await self.return_frame(frame) | |
async def recv(self): | |
if self.dummy: | |
return await self.dummy_recv() | |
try: | |
frame = await self.get_frame() | |
return await self.return_frame(frame) | |
except: | |
print(traceback.format_exc()) | |
print('Switching to dummy mode...') | |
self.dummy = True | |
return await self.dummy_recv() | |
async def initWebRTC(raw_offer): | |
rtc_offer = RTCSessionDescription(sdp=raw_offer["sdp"], type=raw_offer["type"]) | |
pc = RTCPeerConnection() | |
# handle offer | |
await pc.setRemoteDescription(rtc_offer) | |
for t in pc.getTransceivers(): | |
if t.kind == "video": | |
pc.addTrack(VideoTransformTrack()) | |
@pc.on("iceconnectionstatechange") | |
async def on_iceconnectionstatechange(): | |
if pc.iceConnectionState == "failed": | |
await pc.close() | |
await pc.setLocalDescription(await pc.createAnswer()) | |
return { | |
"sdp": pc.localDescription.sdp, | |
"type": pc.localDescription.type | |
} | |
async def offer(request): | |
offer = await request.json() | |
answer = await initWebRTC(offer) | |
return web.Response( | |
content_type="application/json", | |
text=json.dumps(answer), | |
) | |
if __name__ == "__main__": | |
app = web.Application() | |
app.router.add_post("/offer", offer) | |
web.run_app(app, access_log=None, port=8080) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment