Last active
June 25, 2020 10:33
-
-
Save bilalbayasut/a2e8874eb90d43a739ebc89d263ccd66 to your computer and use it in GitHub Desktop.
video.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import multiprocessing | |
import typing as T | |
import av | |
import numpy as np | |
import tqdm | |
def get_video_frame_at_timestamp(video_path: str, offset_s: float = 0) -> np.ndarray: | |
av_frame = read_video_frame(video_path, offset_s) | |
return color_image_from_frame(av_frame) | |
def color_image_from_frame(frame: av.VideoFrame): | |
""" Converts a frame to color numpy image """ | |
return frame.to_rgb().to_ndarray() | |
def greyscale_image_from_frame(frame: av.VideoFrame): | |
""" Converts a frame to grayscale numpy image """ | |
plane = frame.planes[0] | |
img_np = np.frombuffer(plane, np.uint8) | |
try: | |
img_np.shape = frame.height, frame.width | |
except ValueError: | |
img_np = img_np.reshape(-1, plane.line_size) | |
img_np = np.ascontiguousarray(img_np[:, : frame.width]) | |
return img_np | |
def reformat_frame(frame: av.VideoFrame, width: int = None, height: int = None): | |
width = width or frame.width | |
height = height or frame.height | |
return frame.reformat(width=width, height=height) | |
def read_video_frame( | |
video_path: str, timestamp: float = 0, width: int = None, height: int = None | |
) -> T.Iterator[av.VideoFrame]: | |
""" | |
Returns frame from a video | |
:param video_path: path or url of a video | |
:param timestamp: timestamp of frame from video start (in seconds) | |
""" | |
container = av.open(video_path) | |
timestamp_us = int(timestamp * 1e6) # TODO get container time base | |
frames = container.seek(timestamp_us) | |
frames = container.decode(video=0) | |
for frame in frames: | |
if frame.time < timestamp: | |
continue | |
if frame: | |
container.close() | |
return reformat_frame(frame, height=height, width=width) | |
break | |
container.close() | |
def read_video_frames( | |
video_path: str, start: float = 0, end: float = None | |
) -> T.Iterator[av.VideoFrame]: | |
""" | |
Yields frames from a video | |
:param video_path: path or url of a video | |
:param start: time to start from (in seconds) | |
:param end: time to end (in seconds) | |
""" | |
container = av.open(video_path) | |
if start: | |
start_ms = int(start * 1e6) # TODO get container time base | |
frames = container.seek(start_ms) | |
frames = container.decode(video=0) | |
for frame in frames: | |
if frame.time < start: | |
continue | |
if end is not None and frame.time > end: | |
break | |
if not frame: | |
break | |
yield frame | |
container.close() | |
def split_into_chunks(start: float, end: float, chunks: int): | |
""" | |
Split `start` to `end` into `chunks` amount of equally sized chunks | |
>>> split_into_chunks(10, 30, 4) | |
[(10, 15), (15, 20), (20, 25), (25, 30)] | |
""" | |
total = end - start | |
segment_size = total / chunks | |
result = [] | |
for i in range(chunks): | |
segment_start = start + i * segment_size | |
segment_end = segment_start + segment_size | |
result.append((segment_start, segment_end)) | |
return result | |
class VideoProcessor(object): | |
def __init__( | |
self, | |
video_path: str, | |
start_time: float, | |
end_time: float, | |
process_function: T.Callable[[av.VideoFrame], None], | |
result_function: T.Callable = None, | |
show_progress: bool = False, | |
return_result: bool = False, | |
n_workers: int = multiprocessing.cpu_count(), | |
logger=None, | |
): | |
""" | |
:param video_path: path to video | |
:param function: function to run on each frame | |
:param start_time: time at which to start the processing | |
:param end_time: time at which to start the processing | |
:param out_queue: queue where to put the results | |
""" | |
super().__init__() | |
self.id = id | |
self.video_path = video_path | |
self.process_function = process_function | |
self.result_function = result_function | |
self.start_time = start_time | |
self.end_time = end_time | |
self.queue = multiprocessing.Queue() | |
self.show_progress = show_progress | |
self.return_result = return_result | |
self.n_workers = n_workers | |
self.logger = logger or logging.getLogger(__name__) | |
self.status_dict = multiprocessing.Manager().dict() | |
container = av.open(video_path) | |
container_duration_s = container.duration / av.time_base | |
if self.start_time < 0: | |
self.start_time = 0 | |
if self.end_time is None or self.end_time > container_duration_s: | |
self.end_time = container_duration_s + 1 | |
duration_s = self.end_time - self.start_time | |
self.estimated_frames = int( | |
container.streams.video[0].average_rate * duration_s | |
) | |
self.section_offsets = split_into_chunks( | |
self.start_time, self.end_time, self.n_workers | |
) | |
self.logger.debug(f"number of workers: {n_workers}") | |
print(f"number of workers: {n_workers}") | |
self.workers = [] | |
self.pipe_list = [] | |
for i, (start, end) in enumerate(self.section_offsets): | |
recv_end, send_end = None, None | |
if return_result: | |
recv_end, send_end = multiprocessing.Pipe(False) | |
self.workers.append(self._make_worker(start, end, send_end)) | |
self.pipe_list.append(recv_end) | |
self.logger.debug(f"created worker {i} to work ({start} - {end})") | |
print(f"created worker {i} to work ({start} - {end})") | |
def _make_worker(self, start, end, send_end): | |
worker = VideoProcessorWorker( | |
id=len(self.workers) + 1, | |
video_path=self.video_path, | |
start_time=start, | |
end_time=end, | |
function=self.process_function, | |
out_queue=self.queue, | |
send_end=send_end, | |
logger=self.logger, | |
status_dict=self.status_dict, | |
) | |
return worker | |
def start(self): | |
def consumer(queue): | |
if self.show_progress: | |
progress = tqdm.tqdm(total=self.estimated_frames) | |
try: | |
for item in iter(queue.get, None): | |
if self.result_function: | |
self.result_function(item) | |
if self.show_progress: | |
progress.update() | |
except BrokenPipeError: | |
for worker in self.workers: | |
worker.terminate() | |
for worker in self.workers: | |
worker.daemon = True | |
worker.start() | |
consumer_process = multiprocessing.Process(target=consumer, args=(self.queue,)) | |
consumer_process.daemon = True | |
consumer_process.start() | |
for worker in self.workers: | |
worker.join() | |
self.logger.debug(f"joined worker-{worker}") | |
print(f"joined worker-{worker}") | |
self.queue.put(None) | |
consumer_process.join() | |
self.logger.info(f"status_dict: {self.status_dict}") | |
return ( | |
self.status_dict.get("status", "ERROR") != "ERROR", | |
[x.recv() for x in self.pipe_list], | |
) | |
class VideoProcessorWorker(multiprocessing.Process): | |
""" A class for processing video """ | |
def __init__( | |
self, | |
id: int, | |
video_path: str, | |
function: T.Callable[[av.VideoFrame], None], | |
start_time: float, | |
end_time: float, | |
out_queue: multiprocessing.Queue, | |
send_end: multiprocessing.Pipe, | |
logger=None, | |
status_dict={}, | |
): | |
""" | |
:param video_path: path to video | |
:param function: function to run on each frame | |
:param start_time: time at which to start the processing | |
:param end_time: time at which to start the processing | |
:param out_queue: queue where to put the results | |
""" | |
super().__init__() | |
self.id = id | |
self.video_path = video_path | |
self.function = function | |
self.start_time = start_time | |
self.end_time = end_time | |
self.out_queue = out_queue | |
self.send_end = send_end | |
self.logger = logger or logging.getLogger(f"worker-{self.id}") | |
self.status_dict = status_dict | |
self.exit = multiprocessing.Event() | |
def run(self): | |
while not self.exit.is_set(): | |
try: | |
self.logger.debug( | |
f"processing video start={self.start_time:.5f} end={self.end_time:.5f}" | |
) | |
print( | |
f"processing video start={self.start_time:.5f} end={self.end_time:.5f}" | |
) | |
n_frames = 0 | |
first_frame = last_frame = None | |
results_list = [] | |
for frame in read_video_frames( | |
self.video_path, self.start_time, self.end_time | |
): | |
if first_frame is None: | |
first_frame = frame | |
n_frames += 1 | |
result = self.function(frame) | |
self.out_queue.put(result) | |
if self.send_end: | |
results_list.append(result) | |
last_frame = frame | |
if self.send_end: | |
self.send_end.send(results_list) | |
self.logger.debug( | |
f"completed frames={n_frames}" | |
f" return {len(results_list)} apriltag detections" | |
) | |
print( | |
f"completed frames={n_frames}" | |
f" return {len(results_list)} apriltag detections" | |
) | |
self.status_dict["status"] = "SUCCEEDED" | |
self.shutdown() | |
except Exception as error: | |
print(error) | |
self.status_dict["status"] = "ERROR" | |
self.shutdown() | |
print(f"worker:{self.id} exited") | |
def shutdown(self): | |
self.logger.info("Shutdown initiated") | |
print("Shutdown initiated") | |
self.exit.set() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment