-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorchvision_reader.py
65 lines (53 loc) · 2.57 KB
/
torchvision_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Literal
from pathlib import Path
import torch
from torchvision.io import read_video, read_video_timestamps
from src.video_io.abstract_reader import AbstractVideoReader
class TorchvisionVideoReader(AbstractVideoReader):
"""Videoreader using PyTorch's torchvision.io.read_video.
Args:
video_path (str | Path): Path to the input video file.
mode (Literal["seek", "stream"], optional): Reading mode: "seek" -
find each frame individually, "stream" - decode all frames from
the range of requested indeces and subsample.
Defaults to "stream".
output_format (Literal["THWC", "TCHW"], optional): Data format:
channel last or first. Defaults to "THWC".
device (str, optional): Device to send the resulted tensor to.
Defaults to "cuda:0".
"""
def __init__(self, video_path: str | Path,
mode: Literal["seek", "stream"] = "stream",
output_format: Literal["THWC", "TCHW"] = "THWC",
device: str = "cuda:0"):
super().__init__(video_path, mode=mode, output_format=output_format,
device=device)
def _initialize_reader(self) -> None:
timestamps, fps = read_video_timestamps(self.video_path,
pts_unit="sec")
self.timestamps = timestamps
self.num_frames = len(timestamps)
self.fps = fps
def _to_tensor(self, frames: torch.Tensor) -> torch.Tensor:
return frames.to(self.device)
def seek_read(self, frame_indices: list[int]) -> list[torch.Tensor]:
frame_timestamps = [self.timestamps[fid] for fid in frame_indices]
frames = []
for ts in frame_timestamps:
frame, _, _ = read_video(
self.video_path, start_pts=ts, end_pts=ts,
pts_unit="sec", output_format=self.output_format)
frames.append(self._process_frame(frame))
return torch.cat(frames, dim=0)
def stream_read(self, frame_indices: list[int]) -> torch.Tensor:
frame_timestamps = [self.timestamps[fid] for fid in frame_indices]
frames, _, _ = read_video(
self.video_path, start_pts=min(frame_timestamps),
end_pts=max(frame_timestamps) + (1 / self.fps),
pts_unit="sec", output_format=self.output_format)
frame_indices_sample = [fid - min(frame_indices)
for fid in frame_indices]
frames = frames[frame_indices_sample]
return frames
def release(self) -> None:
pass