Skip to content

Commit

Permalink
perf(bio): add vits model to strided video track
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 22, 2024
1 parent 8964d86 commit 4c9027a
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions aipipeline/projects/bio/run_strided_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import argparse
import ast
import uuid
from typing import List

from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch
import dotenv
import json
import logging
Expand All @@ -19,7 +20,7 @@
import cv2
import pandas as pd
import redis
import requests
import torch.nn.functional as F

from aipipeline.config_setup import setup_config
from aipipeline.db_utils import init_api_project, get_version_id
Expand Down Expand Up @@ -66,11 +67,12 @@ def display_tracks(frames, tracks, frame_num, out_video):
if pt is not None:
# Offset to the right by 10 pixels for better visibility on small objects
center = (int(pt[0]) + 10, int(pt[1]))
color = (255, 255, 255)
# Create a unique color for box in python based on hash of the filename
color = hash(label) % 256, hash(label) % 256, hash(label) % 256
thickness = 1
font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 1
# Draw the track track_id with the label, e.g. 1:Unknown
# Draw the track track_id with the label, e.g. 1:marine organism
frame = cv2.putText(frame, f"{track.id}:{label}", center, font, fontScale, color, thickness,
cv2.LINE_AA)

Expand Down Expand Up @@ -99,6 +101,7 @@ def run_inference_track(
video_file: str,
stride_fps: int,
endpoint_url: str,
model: str,
allowed_class_names: [str] = None,
remapped_class_names: dict = None,
version_id: int = 0,
Expand All @@ -125,11 +128,16 @@ def run_inference_track(
logger.info(f"Error: {e}")
return

# Detector
yv5 = FastAPIYV5(endpoint_url)
vss = FastAPIVSS(base_url=config_dict['vss']['url'],
project= config_dict["vss"]["project"],
threshold=float(config_dict["vss"]["threshold"]))
vss_top_k = int(config_dict["vss"]["top_k"])

# Classifier
model_name = Path(model).name
vit_model = AutoModelForImageClassification.from_pretrained(model)
vit_model.to("cpu")
processor = AutoImageProcessor.from_pretrained(model)

# Video summarization and output
cap = cv2.VideoCapture(video_path.as_posix())
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
Expand All @@ -147,6 +155,7 @@ def run_inference_track(
crop_path.mkdir(parents=True, exist_ok=True)
dive_path.mkdir(parents=True, exist_ok=True)
clean(output_path)
clean(dive_path)
out_video = cv2.VideoWriter(out_video_path.as_posix(), fourcc, 1, (frame_width, frame_height))

if not skip_load:
Expand Down Expand Up @@ -203,13 +212,13 @@ def run_inference_track(
detections.append(t_loc)

# Run the tracker and clean up the localizations
tracks = tracker.update_batch((start_frame, end_frame), frame_stack, detections=detections, max_frames=900, max_cost=10)
tracks = tracker.update_batch((start_frame, end_frame), frame_stack, detections=detections, max_empty_frames=10, max_frames=300, max_cost=10)
display_tracks(frame_stack_full, tracks, start_frame, out_video)
clean(output_path)
all_loc = []

# Check if any tracks are closed and queue the localizations in REDIS
closed_tracks = [t for t in tracks if t.is_closed(end_frame)]
closed_tracks = [t for t in tracks if t.is_closed()]
if len(closed_tracks) > 0:
for track in closed_tracks:
logger.info(f"Closed track {track.id} at frame {frame_idx}")
Expand Down Expand Up @@ -299,22 +308,33 @@ def run_inference_track(
logger.info(f"{video_path.name}: No valid localizations in frame at {current_time_secs} seconds")
continue

# Run the VSS model on the cropped images in batch
images = [read_image(loc["crop_path"]) for loc in data]
# Run the vits model on the cropped images in batch
images = [Image.open(loc["crop_path"]).convert("RGB") for loc in data]
inputs = processor(images=images, return_tensors="pt").to("cpu")
try:
file_paths, best_predictions, best_scores = vss.predict(images, top_k=vss_top_k)
with torch.no_grad():
outputs = vit_model(**inputs)
logits = outputs.logits
# Get the top 3 classes and scores
top_scores, top_classes = torch.topk(logits, 3)
top_classes = top_classes.cpu().numpy()
top_scores = F.softmax(top_scores, dim=-1).cpu().numpy()
best_predictions = [[vit_model.config.id2label[class_idx] for class_idx in class_list]
for class_list in top_classes]
best_scores = [[str(score) for score in score_list] for score_list in top_scores]
except Exception as e:
logger.error(f"Error processing VSS model: {e}")
logger.error(f"Error processing {model_name}: {e}")
return

for loc, best_prediction, best_score in zip(data, best_predictions, best_scores):

if allowed_class_names and best_prediction[0] not in allowed_class_names:
logger.info(f"{video_path.name}: VSS model prediction {best_predictions[0]} not in {allowed_class_names}. Skipping this detection.")
logger.info(f"{video_path.name}: {model_name} prediction {best_predictions[0]} not in {allowed_class_names}. Skipping this detection.")
continue

if best_prediction:
loc["confidence"] = best_score
loc["class_name"] = best_prediction
loc["confidence"] = best_score[0]
loc["class_name"] = best_prediction[0]
if remapped_class_names:
loc["class_name"] = remapped_class_names[loc["class_name"]]
else:
Expand Down Expand Up @@ -360,12 +380,13 @@ def run_inference_track(
return

logger.info(f"Finished processing video {video_path}")
# out_video.release()
out_video.release()

def process_videos(
video_files: [str],
stride_fps: int,
endpoint_url: str,
model: str,
allowed_class_names: [str] = None,
remapped_class_names: dict = None,
version_id: int = 0,
Expand All @@ -378,7 +399,7 @@ def process_videos(
pool = multiprocessing.Pool(processes=num_cpus)
pool.starmap(
run_inference_track,
[(v, stride_fps, endpoint_url, allowed_class_names, remapped_class_names, version_id,
[(v, stride_fps, endpoint_url, model, allowed_class_names, remapped_class_names, version_id,
skip_load, min_confidence, min_depth, max_secs) for v in
video_files],
)
Expand All @@ -400,6 +421,7 @@ def parse_args():
parser.add_argument("--video", help="Video file or directory.", required=False, type=str)
parser.add_argument("--max-seconds", help="Maximum number of seconds to process.", required=False, type=int)
parser.add_argument("--version", help="Version name", required=False, type=str)
parser.add_argument("--model", help="Model location", required=False, type=str, default="/mnt/DeepSea-AI/models/m3midwater-vit-b-16/")
parser.add_argument("--skip-load", help="Skip loading the video reference into Tator.", action="store_true")
parser.add_argument(
"--tsv",
Expand Down Expand Up @@ -507,6 +529,7 @@ def parse_args():
video_path.as_posix(),
args.stride_fps,
args.endpoint_url,
args.model,
args.allowed_classes,
args.class_remap,
version_id,
Expand All @@ -523,6 +546,7 @@ def parse_args():
video_files,
args.stride_fps,
args.endpoint_url,
args.model,
args.allowed_classes,
args.class_remap,
version_id,
Expand All @@ -548,6 +572,7 @@ def parse_args():
video_files,
args.stride_fps,
args.endpoint_url,
args.model,
args.allowed_classes,
args.class_remap,
version_id,
Expand Down

0 comments on commit 4c9027a

Please sign in to comment.