Skip to content

Commit

Permalink
perf(bio): significant refactor of video processing pipeline into mor…
Browse files Browse the repository at this point in the history
…e modular design with callbacks and moved most data to GPU where possible for speed-up.
  • Loading branch information
danellecline committed Dec 17, 2024
1 parent 7f630de commit 1b3d972
Show file tree
Hide file tree
Showing 21 changed files with 899 additions and 1,281 deletions.
6 changes: 3 additions & 3 deletions aipipeline/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
logger.addHandler(handler)


def get_version_id(api: TatorApi, project: Project, version: str) -> int:
def get_version_id(api: TatorApi, project_id: int, version: str) -> int:
"""
Get the version ID for the given project
:param api: :class:`TatorApi` object
:param project: project object
:param project_id: Project is
:param version: version name
:return: version ID
"""
versions = api.get_version_list(project=project.id)
versions = api.get_version_list(project=project_id)
logger.debug(versions)

# Find the version by name
Expand Down
8 changes: 4 additions & 4 deletions aipipeline/projects/bio/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ For example, to run detections
- and store the results in the "test" section of the database

```shell
python scripts/run_strided_inference.py \
--config ./aipipeline/projects/bio/config/config.yaml \
cd
python process.py \
--config config/config.yaml \
--video ./data/ctenophora_sp_A_aug/CTENOPHORA_SP_A_AUG_00001.mp4 \
--stride 2 \
--endpoint http://fasta-fasta-1d0o3gwgv046e-143598223.us-west-2.elb.amazonaws.com/predict \
--section "test"
--endpoint http://fasta-fasta-1d0o3gwgv046e-143598223.us-west-2.elb.amazonaws.com/predict
```
Empty file.
2 changes: 1 addition & 1 deletion aipipeline/projects/bio/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ redis:

data:
processed_path: "/mnt/ML_SCRATCH/901103-biodiversity"
version: "Baseline"
version: "mega-vits-track-gcam"
labels: "all"
download_args: ["--verified --min-saliency 1000"]

Expand Down
Empty file.
45 changes: 45 additions & 0 deletions aipipeline/projects/bio/core/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# aipipeline, Apache-2.0 license
# Filename: projects/bio/core/args.py
# Description: Argument parser for bio projects

import argparse
from pathlib import Path
from textwrap import dedent

DEFAULT_CONFIG_YAML = Path(__file__).resolve().parent / "config" / "config.yml"
DEFAULT_VIDEO = Path(__file__).resolve().parent / "data" / "V4361_20211006T163256Z_h265_1min.mp4"

def parse_args():
parser = argparse.ArgumentParser(
description=dedent("""\
Run strided video track pipeline with REDIS queue based load.
Example:
python3 predict.py /path/to/video.mp4
"""),
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--config", default=DEFAULT_CONFIG_YAML, required=False, help=f"Configuration files. Default: {DEFAULT_CONFIG_YAML}", type=str)
parser.add_argument("--video", help="Video file", required=False, type=str, default=DEFAULT_VIDEO)
parser.add_argument("--max-seconds", help="Maximum number of seconds to process.", required=False, type=int, default=-1)
parser.add_argument("--min-frames", help="Minimum number of frames a track must have.", required=False, type=int, default=5)
parser.add_argument("--min-score-track", help="Minimum score for a track to be valid.", required=False, type=float, default=0.1)
parser.add_argument("--min-score-det", help="Minimum score for a detection to be valid.", required=False, type=float, default=0.1)
parser.add_argument("--max-frames-tracked", help="Maximum number of frames a track can have before closing it.", required=False, type=int, default=300)
parser.add_argument("--version", help="Version name", required=False, type=str)
parser.add_argument("--gpu-id", help="GPU ID to use for inference.", required=False, type=int, default=0)
parser.add_argument("--vits-model", help="ViTS vits_model location", required=False, type=str, default="/mnt/DeepSea-AI/models/m3midwater-vit-b-16/")
parser.add_argument("--skip-load", help="Skip loading the video reference into Tator.", action="store_true")
parser.add_argument("--imshow", help="Display the video frames.", action="store_true")
parser.add_argument("--tsv", help="TSV file with video paths per Haddock output", required=False, type=str)
parser.add_argument("--stride-fps", help="Frames per second to run detection, e.g. 1 is 1 frame every second, "
"5 is 5 frames a second.", default=3, type=int)
parser.add_argument("--class_name", help="Class name to target inference.", default="Ctenophora sp. A", type=str)
parser.add_argument( "--endpoint-url", help="URL of the inference endpoint.", required=False, type=str,)
parser.add_argument("--det-model",help="Object detection model path.",required=False, type=str,)
parser.add_argument("--batch-size", help="Batch size", default=1, type=int)
parser.add_argument("--min-depth", help="Minimum depth for detections. Any video shallower than this at the beginning of the video will be skipped", default=0, type=int)
parser.add_argument("--flush", help="Flush the REDIS database.", action="store_true")
parser.add_argument('--allowed-classes',type=str,nargs='+',help='List of allowed classes.')
parser.add_argument('--class-remap',type=str,help='Dictionary of class remapping, formatted as a string.')
return parser.parse_args()
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
# aipipeline, Apache-2.0 license
# Filename: projects/bio/core/bioutils.py
# Description: General utility functions for bio projects

import io
import json
import logging
import os
import random
import subprocess
import uuid
from datetime import datetime
from pathlib import Path
from typing import List

import cv2
import torch
import requests
from aipipeline.docker.utils import run_docker
from aipipeline.prediction.utils import crop_square_image

logger = logging.getLogger(__name__)


def get_ancillary_data(dive: str, config_dict: dict, iso_datetime: datetime) -> dict:
def get_ancillary_data(dive: str, config_dict: dict, iso_datetime: any) -> dict:
try:
# Create a random index for the container name
index = random.randint(0, 1000)
platform = dive.split(' ')[:-1] # remove the last element which is the dive number
platform = ''.join(platform)
if isinstance(iso_datetime, str):
iso_datetime = datetime.fromisoformat(iso_datetime)
else:
iso_datetime = iso_datetime
container = run_docker(
image=config_dict["docker"]["expd"],
name=f"expd-{platform}-{iso_datetime:%Y%m%dT%H%M%S%f}-{index}",
Expand Down Expand Up @@ -132,4 +145,74 @@ def seconds_to_timestamp(seconds):
timestamp = f"{seconds:.3f}"
else:
timestamp = f"{int(hours):02}:{int(minutes):02}:{int(seconds_):02}"
return timestamp
return timestamp


def show_boxes(batch, predictions):
scale_w, scale_h = 1280, 1280
for batch_idx, img in enumerate(batch):
# Convert the Tensor to a numpy array
img = img.cpu().numpy().transpose(1, 2, 0)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

for pred in predictions:
if pred['batch_idx'] != batch_idx:
continue
x1, y1, x2, y2 = pred['x'], pred['y'], pred['x'] + pred['w'], pred['y'] + pred['h']
# Scale the bounding box
x1, y1, x2, y2 = int(x1 * scale_w), int(y1 * scale_h), int(x2 * scale_w), int(y2 * scale_h)
# class_name = pred['class_name']
# conf = pred['confidence']
img = cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
# img = cv2.putText(img, f'{self.model.class_names[int(cls)]} {conf:.2f}', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
cv2.imshow('frame', img)
if cv2.waitKey(1) & 0xFF == ord('q'):
break


def filter_blur_pred(images: torch.Tensor,
predictions: List[dict],
crop_path: Path,
image_width: int,
image_height: int):
filtered_pred = []
# TODO: get image width and height from the images tensor
for p in predictions:
loc = {
"x": p["x"],
"y": p["y"],
"xx": p["x"] + p['w'],
"xy": p["y"] + p['h'],
"w": p['w'],
"h": p['h'],
"frame": p['frame'],
"image_width": image_width,
"image_height": image_height,
"confidence": p['confidence'],
"batch_idx": p['batch_idx'],
"crop_path": (crop_path / f"{uuid.uuid5(uuid.NAMESPACE_DNS, str(p['x']) + str(p['y']) + str(p['w']) + str(p['h']))}.jpg").as_posix()
}

crop_square_image(images, loc, 224)

# Return true the crop if it has a blurriness score of less than the threshold
def detect_blur(image_path, threshold):
image = cv2.imread(image_path)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary_image = cv2.threshold(gray, 180, 255, cv2.THRESH_BINARY)
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
_, max_val, _, _ = cv2.minMaxLoc(gray)
laplacian_variance = laplacian.var()
print(f"laplacian_variance: {laplacian_variance} for {image_path}")
if laplacian_variance < threshold:
return True
return False

if detect_blur(loc["crop_path"], 2.0):
print(f"Detected blur in {loc['crop_path']}")
os.remove(loc["crop_path"])
else:
filtered_pred.append(loc)

return filtered_pred

131 changes: 131 additions & 0 deletions aipipeline/projects/bio/core/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# aipipeline, Apache-2.0 license
# Filename: projects/bio/core/callback.py
# Description: Custom callback for bio projects
import json
import logging
from datetime import datetime, timedelta

from projects.bio.core.bioutils import get_ancillary_data, get_video_metadata

logger = logging.getLogger(__name__)

global redis_queue

class Callback:

"""Base class for callbacks."""
def on_predict_batch_start(self, batch):
"""Called at the start of each prediction batch."""
pass

def on_predict_batch_end(self, predictions):
"""Called at the end of each prediction batch."""
pass

def on_predict_start(self, redis_queue, predictor, video_name):
"""Called at the start of prediction for a video."""
pass

class AncillaryCallback(Callback):

"""Custom callback to fetch ancillary data for bio projects."""
def on_predict_start(self, redis_queue, predictor, video_name):
print(f"Getting metadata for video: {video_name}")
try:
md = get_video_metadata(video_name)
if md is None:
logger.error(f"Failed to get video metadata for {video_name}")
else:
predictor.md = md
video_ref_uuid = md["video_reference_uuid"]
iso_start = md["start_timestamp"]
video_url = md["uri"]
# http://mantis.shore.mbari.org/M3/mezzanine/Ventana/2022/09/4432/V4432_20220914T210637Z_h264.mp4
# https://m3.shore.mbari.org/videos/M3/mezzanine/Ventana/2022/09/4432/V4432_20220914T210637Z_h264.mp4
# Replace m3.shore.mbari.org/videos with mantis.shore.mbari.org/M3
video_url = video_url.replace("https://m3.shore.mbari.org/videos", "http://mantis.shore.mbari.org")
logger.info(f"video_ref_uuid: {video_ref_uuid}")
redis_queue.hset(
f"video_refs_start:{video_ref_uuid}",
"start_timestamp",
iso_start,
)
redis_queue.hset(
f"video_refs_load:{video_ref_uuid}",
"video_uri",
video_url,
)
except Exception as e:
logger.info(f"Error: {e}")
if predictor.md is None:
predictor.md = {}
else:
# Remove the video reference from the queue
video_ref_uuid = predictor.md["video_reference_uuid"]
redis_queue.delete(f"video_refs_start:{video_ref_uuid}")
redis_queue.delete(f"video_refs_load:{video_ref_uuid}")



class ExportCallback(Callback):

num_loaded = 0
def on_predict_start(self, redis_queue, predictor, video_name):
output_path = predictor.output_path
logger.info(f"Removing {output_path}")
for output_path in output_path.rglob("*.jpg"):
output_path.unlink()
for output_path in output_path.rglob("*.json"):
output_path.unlink()

def on_predict_batch_end(self, batch):
""" Check if any tracks are closed and queue the localizations in REDIS"""
skip_load, redis_queue, version_id, config_dict, predictor, tracks = batch
if skip_load:
return
closed_tracks = [t for t in tracks if t.is_closed()]

if len(closed_tracks) > 0:
start_datetime = datetime.fromisoformat(predictor.md["start_timestamp"])
config_dict = config_dict

for track in closed_tracks:
logger.info(f"Closed track {track.id}")
best_frame, best_pt, best_label, best_box, best_score = track.get_best(False)
best_time_secs = float(best_frame * predictor.frame_stride / predictor.source.frame_rate)
logger.info(f"Best track {track.id} is {best_pt},{best_box},{best_label},{best_score} in frame {best_frame}")

loc_datetime = start_datetime + timedelta(seconds=best_time_secs)
ancillary_data = get_ancillary_data(predictor.md['dive'], config_dict, loc_datetime)

if ancillary_data is None or "depthMeters" not in ancillary_data:
logger.error(f"Failed to get ancillary data for {predictor.md['dive']} {start_datetime}")
continue

new_loc = {
"x1": float(max(best_box[0], 0.0)),
"y1": float(max(best_box[1], 0.0)),
"x2": float(best_box[2]),
"y2": float(best_box[3]),
"width": int(predictor.source.width),
"height": int(predictor.source.height),
"frame": int(best_frame * predictor.frame_stride),
"version_id": int(version_id),
"score": float(best_score[0]),
"score_s": float(best_score[1]),
"cluster": "-1",
"label": best_label[0],
"label_s": best_label[1],
"dive": predictor.md["dive"],
"depth": ancillary_data["depthMeters"],
"iso_datetime": loc_datetime.strftime("%Y-%m-%dT%H:%M:%SZ"),
"latitude": ancillary_data["latitude"],
"longitude": ancillary_data["longitude"],
"temperature": ancillary_data["temperature"],
"oxygen": ancillary_data["oxygen"],
}
logger.info(f"queuing loc: {new_loc} {predictor.md['dive']} {loc_datetime}")
json.dumps(new_loc)
redis_queue.hset(f"locs:{predictor.md['video_reference_uuid']}", str(self.num_loaded), json.dumps(new_loc))
logger.info(f"{predictor.source.name} found total possible {self.num_loaded} localizations")
self.num_loaded += 1
Loading

0 comments on commit 1b3d972

Please sign in to comment.