diff --git a/recipes/promptable-content-moderation/.gitignore b/recipes/promptable-content-moderation/.gitignore new file mode 100644 index 00000000..d5cb7282 --- /dev/null +++ b/recipes/promptable-content-moderation/.gitignore @@ -0,0 +1,52 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +*.dll + +# Virtual Environment +venv/ +env/ +ENV/ +.venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Project specific +inputs/* +outputs/* +!inputs/.gitkeep +!outputs/.gitkeep +inputs/ +outputs/ + +# Model files +*.pth +*.onnx +*.pt + +# Logs +*.log + +certificate.pem \ No newline at end of file diff --git a/recipes/promptable-content-moderation/README.md b/recipes/promptable-content-moderation/README.md new file mode 100644 index 00000000..872e5a1a --- /dev/null +++ b/recipes/promptable-content-moderation/README.md @@ -0,0 +1,159 @@ +# Promptable Content Moderation with Moondream + +Welcome to the future of content moderation with Moondream 2B, a powerful and lightweight vision-language model that enables detection and moderation of video content using natural language prompts. + +[Try it now.](https://huggingface.co/spaces/moondream/content-moderation) + +## Features + +- Content moderation through natural language prompts +- Multiple visualization styles +- Intelligent scene detection and tracking: + - DeepSORT tracking with scene-aware reset + - Persistent moderation across frames + - Smart tracker reset at scene boundaries +- Optional grid-based detection for improved accuracy on complex scenes +- Frame-by-frame processing with IoU-based merging +- Web-compatible output format +- Test mode (process only first X seconds) +- Advanced moderation analysis with multiple visualization plots + +## Examples + +| Prompt | Output | +|--------|-----------------| +| "white cigarette" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-cig.gif) | +| "gun" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-gu.gif) | +| "confederate flag" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-conflag.gif) | + +## Requirements + +### Python Dependencies + +For Windows users, before installing other requirements, first install PyTorch with CUDA support: + +```bash +pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121 +``` + +Then install the remaining dependencies: + +```bash +pip install -r requirements.txt +``` + +### System Requirements + +- FFmpeg (required for video processing) +- libvips (required for image processing) + +Installation by platform: + +- Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` +- macOS: `brew install ffmpeg libvips` +- Windows: + - Download FFmpeg from [ffmpeg.org](https://ffmpeg.org/download.html) + - Follow [libvips Windows installation guide](https://docs.moondream.ai/quick-start) + +## Installation + +1. Clone this repository and create a new virtual environment: + +```bash +git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction +python -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +``` + +2. Install Python dependencies: + +```bash +pip install -r requirements.txt +``` + +3. Install ffmpeg and libvips: + - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` + - On macOS: `brew install ffmpeg` + - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html) + +> Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start) + +## Usage + +The easiest way to use this tool is through its web interface, which provides a user-friendly experience for video content moderation. + +### Web Interface + +1. Start the web interface: + +```bash +python app.py +``` + +2. Open the provided URL in your browser (typically ) + +3. Use the interface to: + - Upload your video file + - Specify content to moderate (e.g., "face", "cigarette", "gun") + - Choose redaction style (default: obfuscated-pixel) + - OPTIONAL: Configure advanced settings + - Processing speed/quality + - Grid size for detection + - Test mode for quick validation (default: on, 3 seconds) + - Process the video and download results + - Analyze detection patterns with visualization tools + +## Output Files + +The tool generates two types of output files in the `outputs` directory: + +1. Processed Videos: + - Format: `[style]_[content_type]_[original_filename].mp4` + - Example: `censor_inappropriate_video.mp4` + +2. Detection Data: + - Format: `[style]_[content_type]_[original_filename]_detections.json` + - Contains frame-by-frame detection information + - Used for visualization and analysis + +## Technical Details + +### Scene Detection and Tracking + +The tool uses advanced scene detection and object tracking: + +1. Scene Detection: + - Powered by PySceneDetect's ContentDetector + - Automatically identifies scene changes in videos + - Configurable detection threshold (default: 30.0) + - Helps maintain tracking accuracy across scene boundaries + +2. Object Tracking: + - DeepSORT tracking for consistent object identification + - Automatic tracker reset at scene changes + - Maintains object identity within scenes + - Prevents tracking errors across scene boundaries + +3. Integration Benefits: + - More accurate object tracking + - Better handling of scene transitions + - Reduced false positives in tracking + - Improved tracking consistency + +## Best Practices + +- Use test mode for initial configuration +- Enable grid-based detection for complex scenes +- Choose appropriate redaction style based on content type: + - Censor: Complete content blocking + - Blur styles: Less intrusive moderation + - Bounding Box: Content review and analysis +- Monitor system resources during processing +- Use appropriate processing quality settings based on your needs + +## Notes + +- Processing time depends on video length, resolution, GPU availability, and chosen settings +- GPU is strongly recommended for faster processing +- Grid-based detection increases accuracy but requires more processing time (each grid cell is processed independently) +- Test mode processes only first X seconds (default: 3 seconds) for quick validation diff --git a/recipes/promptable-content-moderation/app.py b/recipes/promptable-content-moderation/app.py new file mode 100644 index 00000000..e33ef110 --- /dev/null +++ b/recipes/promptable-content-moderation/app.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +import gradio as gr +import os +from main import load_moondream, process_video, load_sam_model +import shutil +import torch +from visualization import visualize_detections +from persistence import load_detection_data +import matplotlib.pyplot as plt +import io +from PIL import Image +import pandas as pd +from video_visualization import create_video_visualization + +# Get absolute path to workspace root +WORKSPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) + +# Check CUDA availability +print(f"Is CUDA available: {torch.cuda.is_available()}") +# We want to get True +print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") +# GPU Name + +# Initialize Moondream model globally for reuse (will be loaded on first use) +model, tokenizer = None, None + + +def process_video_file( + video_file, + target_object, + box_style, + ffmpeg_preset, + grid_rows, + grid_cols, + test_mode, + test_duration, +): + """Process a video file through the Gradio interface.""" + try: + if not video_file: + raise gr.Error("Please upload a video file") + + # Load models if not already loaded + global model, tokenizer + if model is None or tokenizer is None: + model, tokenizer = load_moondream() + + # Ensure input/output directories exist using absolute paths + inputs_dir = os.path.join(WORKSPACE_ROOT, "inputs") + outputs_dir = os.path.join(WORKSPACE_ROOT, "outputs") + os.makedirs(inputs_dir, exist_ok=True) + os.makedirs(outputs_dir, exist_ok=True) + + # Copy uploaded video to inputs directory + video_filename = f"input_{os.path.basename(video_file)}" + input_video_path = os.path.join(inputs_dir, video_filename) + shutil.copy2(video_file, input_video_path) + + try: + # Process the video + output_path = process_video( + input_video_path, + target_object, + test_mode=test_mode, + test_duration=test_duration, + ffmpeg_preset=ffmpeg_preset, + grid_rows=grid_rows, + grid_cols=grid_cols, + box_style=box_style, + ) + + # Get the corresponding JSON path + base_name = os.path.splitext(os.path.basename(video_filename))[0] + json_path = os.path.join( + outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json" + ) + + # Verify output exists and is readable + if not output_path or not os.path.exists(output_path): + print(f"Warning: Output path {output_path} does not exist") + # Try to find the output based on expected naming convention + expected_output = os.path.join( + outputs_dir, f"{box_style}_{target_object}_{video_filename}" + ) + if os.path.exists(expected_output): + output_path = expected_output + else: + # Try searching in outputs directory for any matching file + matching_files = [ + f + for f in os.listdir(outputs_dir) + if f.startswith(f"{box_style}_{target_object}_") + ] + if matching_files: + output_path = os.path.join(outputs_dir, matching_files[0]) + else: + raise gr.Error("Failed to locate output video") + + # Convert output path to absolute path if it isn't already + if not os.path.isabs(output_path): + output_path = os.path.join(WORKSPACE_ROOT, output_path) + + print(f"Returning output path: {output_path}") + return output_path, json_path + + finally: + # Clean up input file + try: + if os.path.exists(input_video_path): + os.remove(input_video_path) + except: + pass + + except Exception as e: + print(f"Error in process_video_file: {str(e)}") + raise gr.Error(f"Error processing video: {str(e)}") + + +def create_visualization_plots(json_path): + """Create visualization plots and return them as images.""" + try: + # Load the data + data = load_detection_data(json_path) + if not data: + return None, None, None, None, None, None, None, None, "No data found" + + # Convert to DataFrame + rows = [] + for frame_data in data["frame_detections"]: + frame = frame_data["frame"] + timestamp = frame_data["timestamp"] + for obj in frame_data["objects"]: + rows.append( + { + "frame": frame, + "timestamp": timestamp, + "keyword": obj["keyword"], + "x1": obj["bbox"][0], + "y1": obj["bbox"][1], + "x2": obj["bbox"][2], + "y2": obj["bbox"][3], + "area": (obj["bbox"][2] - obj["bbox"][0]) + * (obj["bbox"][3] - obj["bbox"][1]), + "center_x": (obj["bbox"][0] + obj["bbox"][2]) / 2, + "center_y": (obj["bbox"][1] + obj["bbox"][3]) / 2, + } + ) + + if not rows: + return ( + None, + None, + None, + None, + None, + None, + None, + None, + "No detections found in the data", + ) + + df = pd.DataFrame(rows) + plots = [] + + # Create each plot and convert to image + for plot_num in range(8): # Increased to 8 plots + plt.figure(figsize=(8, 6)) + + if plot_num == 0: + # Plot 1: Number of detections per frame (Original) + detections_per_frame = df.groupby("frame").size() + plt.plot(detections_per_frame.index, detections_per_frame.values) + plt.xlabel("Frame") + plt.ylabel("Number of Detections") + plt.title("Detections Per Frame") + + elif plot_num == 1: + # Plot 2: Distribution of detection areas (Original) + df["area"].hist(bins=30) + plt.xlabel("Detection Area (normalized)") + plt.ylabel("Count") + plt.title("Distribution of Detection Areas") + + elif plot_num == 2: + # Plot 3: Average detection area over time (Original) + avg_area = df.groupby("frame")["area"].mean() + plt.plot(avg_area.index, avg_area.values) + plt.xlabel("Frame") + plt.ylabel("Average Detection Area") + plt.title("Average Detection Area Over Time") + + elif plot_num == 3: + # Plot 4: Heatmap of detection centers (Original) + plt.hist2d(df["center_x"], df["center_y"], bins=30) + plt.colorbar() + plt.xlabel("X Position") + plt.ylabel("Y Position") + plt.title("Detection Center Heatmap") + + elif plot_num == 4: + # Plot 5: Time-based Detection Density + # Shows when in the video most detections occur + df["time_bucket"] = pd.qcut(df["timestamp"], q=20, labels=False) + time_density = df.groupby("time_bucket").size() + plt.bar(time_density.index, time_density.values) + plt.xlabel("Video Timeline (20 segments)") + plt.ylabel("Number of Detections") + plt.title("Detection Density Over Video Duration") + + elif plot_num == 5: + # Plot 6: Screen Region Analysis + # Divide screen into 3x3 grid and show detection counts + try: + df["grid_x"] = pd.qcut( + df["center_x"], + q=3, + labels=["Left", "Center", "Right"], + duplicates="drop", + ) + df["grid_y"] = pd.qcut( + df["center_y"], + q=3, + labels=["Top", "Middle", "Bottom"], + duplicates="drop", + ) + region_counts = ( + df.groupby(["grid_y", "grid_x"]).size().unstack(fill_value=0) + ) + plt.imshow(region_counts, cmap="YlOrRd") + plt.colorbar(label="Detection Count") + for i in range(3): + for j in range(3): + plt.text( + j, i, region_counts.iloc[i, j], ha="center", va="center" + ) + plt.xticks(range(3), ["Left", "Center", "Right"]) + plt.yticks(range(3), ["Top", "Middle", "Bottom"]) + plt.title("Screen Region Analysis") + except Exception as e: + plt.text( + 0.5, + 0.5, + "Insufficient variation in detection positions", + ha="center", + va="center", + ) + plt.title("Screen Region Analysis (Not Available)") + + elif plot_num == 6: + # Plot 7: Detection Size Categories + # Categorize detections by size for content moderation + try: + size_labels = [ + "Small (likely far/background)", + "Medium-small", + "Medium-large", + "Large (likely foreground/close)", + ] + + # Handle cases with limited unique values + unique_areas = df["area"].nunique() + if unique_areas >= 4: + df["size_category"] = pd.qcut( + df["area"], q=4, labels=size_labels, duplicates="drop" + ) + else: + # Alternative binning for limited unique values + df["size_category"] = pd.cut( + df["area"], + bins=unique_areas, + labels=size_labels[:unique_areas], + ) + + size_dist = df["size_category"].value_counts() + plt.pie(size_dist.values, labels=size_dist.index, autopct="%1.1f%%") + plt.title("Detection Size Distribution") + except Exception as e: + plt.text( + 0.5, + 0.5, + "Insufficient variation in detection sizes", + ha="center", + va="center", + ) + plt.title("Detection Size Distribution (Not Available)") + + elif plot_num == 7: + # Plot 8: Temporal Pattern Analysis + # Show patterns of when detections occur in sequence + try: + detection_gaps = df.sort_values("frame")["frame"].diff() + if len(detection_gaps.dropna().unique()) > 1: + plt.hist( + detection_gaps.dropna(), + bins=min(30, len(detection_gaps.dropna().unique())), + edgecolor="black", + ) + plt.xlabel("Frames Between Detections") + plt.ylabel("Frequency") + plt.title("Detection Temporal Pattern Analysis") + else: + plt.text( + 0.5, + 0.5, + "Uniform detection intervals", + ha="center", + va="center", + ) + plt.title("Temporal Pattern Analysis (Uniform)") + except Exception as e: + plt.text( + 0.5, 0.5, "Insufficient temporal data", ha="center", va="center" + ) + plt.title("Temporal Pattern Analysis (Not Available)") + + # Save plot to bytes + buf = io.BytesIO() + plt.savefig(buf, format="png", bbox_inches="tight") + buf.seek(0) + plots.append(Image.open(buf)) + plt.close() + + # Enhanced summary text + summary = f"""Summary Statistics: +Total frames analyzed: {len(data['frame_detections'])} +Total detections: {len(df)} +Average detections per frame: {len(df) / len(data['frame_detections']):.2f} + +Detection Patterns: +- Peak detection count: {df.groupby('frame').size().max()} (in a single frame) +- Most common screen region: {df.groupby(['grid_y', 'grid_x']).size().idxmax()} +- Average detection size: {df['area'].mean():.3f} +- Median frames between detections: {detection_gaps.median():.1f} + +Video metadata: +""" + for key, value in data["video_metadata"].items(): + summary += f"{key}: {value}\n" + + return ( + plots[0], + plots[1], + plots[2], + plots[3], + plots[4], + plots[5], + plots[6], + plots[7], + summary, + ) + + except Exception as e: + print(f"Error creating visualization: {str(e)}") + import traceback + + traceback.print_exc() + return ( + None, + None, + None, + None, + None, + None, + None, + None, + f"Error creating visualization: {str(e)}", + ) + + +# Create the Gradio interface +with gr.Blocks(title="Promptable Content Moderation") as app: + with gr.Tabs(): + with gr.Tab("Process Video"): + gr.Markdown("# Promptable Content Moderation with Moondream") + gr.Markdown( + """ + Powered by [Moondream 2B](https://github.com/vikhyat/moondream). + + Upload a video and specify what to moderate. The app will process each frame and moderate any visual content that matches the prompt. For help, join the [Moondream Discord](https://discord.com/invite/tRUdpjDQfH). + """ + ) + + with gr.Row(): + with gr.Column(): + # Input components + video_input = gr.Video(label="Upload Video") + + detect_input = gr.Textbox( + label="What to Moderate", + placeholder="e.g. face, cigarette, gun, etc.", + value="face", + info="Moondream can moderate anything that you can describe in natural language", + ) + + process_btn = gr.Button("Process Video", variant="primary") + + with gr.Accordion("Advanced Settings", open=False): + box_style_input = gr.Radio( + choices=[ + "censor", + "bounding-box", + "hitmarker", + "sam", + "sam-fast", + "fuzzy-blur", + "pixelated-blur", + "intense-pixelated-blur", + "obfuscated-pixel", + ], + value="obfuscated-pixel", + label="Visualization Style", + info="Choose how to display moderations: censor (black boxes), bounding-box (red boxes with labels), hitmarker (COD-style markers), sam (precise segmentation), sam-fast (faster but less precise segmentation), fuzzy-blur (Gaussian blur), pixelated-blur (pixelated with blur), obfuscated-pixel (advanced pixelation with neighborhood averaging)", + ) + preset_input = gr.Dropdown( + choices=[ + "ultrafast", + "superfast", + "veryfast", + "faster", + "fast", + "medium", + "slow", + "slower", + "veryslow", + ], + value="medium", + label="Processing Speed (faster = lower quality)", + ) + with gr.Row(): + rows_input = gr.Slider( + minimum=1, maximum=4, value=1, step=1, label="Grid Rows" + ) + cols_input = gr.Slider( + minimum=1, + maximum=4, + value=1, + step=1, + label="Grid Columns", + ) + + test_mode_input = gr.Checkbox( + label="Test Mode (Process first 3 seconds only)", + value=True, + info="Enable to quickly test settings on a short clip before processing the full video (recommended). If using the data visualizations, disable.", + ) + + test_duration_input = gr.Slider( + minimum=1, + maximum=10, + value=3, + step=1, + label="Test Mode Duration (seconds)", + info="Number of seconds to process in test mode", + ) + + gr.Markdown( + """ + Note: Processing in test mode will only process the first 3 seconds of the video and is recommended for testing settings. + """ + ) + + gr.Markdown( + """ + We can get a rough estimate of how long the video will take to process by multiplying the videos framerate * seconds * the number of rows and columns and assuming 0.12 seconds processing time per detection. + For example, a 3 second video at 30fps with 2x2 grid, the estimated time is 3 * 30 * 2 * 2 * 0.12 = 43.2 seconds (tested on a 4090 GPU). + + Note: Using the SAM visualization style will increase processing time significantly as it performs additional segmentation for each detection. The sam-fast option uses a smaller model for faster processing at the cost of some accuracy. + """ + ) + + with gr.Column(): + # Output components + video_output = gr.Video(label="Processed Video") + json_output = gr.Text(label="Detection Data Path", visible=False) + + # About section under the video output + gr.Markdown( + """ + ### Links: + - [GitHub Repository](https://github.com/vikhyat/moondream) + - [Hugging Face](https://huggingface.co/vikhyatk/moondream2) + - [Quick Start](https://docs.moondream.ai/quick-start) + - [Moondream Recipes](https://docs.moondream.ai/recipes) + """ + ) + + with gr.Tab("Analyze Results"): + gr.Markdown("# Detection Analysis") + gr.Markdown( + """ + Analyze the detection results from processed videos. The analysis includes: + - Basic detection statistics and patterns + - Temporal and spatial distribution analysis + - Size-based categorization + - Screen region analysis + - Detection density patterns + """ + ) + + with gr.Row(): + json_input = gr.File( + label="Upload Detection Data (JSON)", + file_types=[".json"], + ) + analyze_btn = gr.Button("Analyze", variant="primary") + + with gr.Row(): + with gr.Column(): + plot1 = gr.Image( + label="Detections Per Frame", + ) + plot2 = gr.Image( + label="Detection Areas Distribution", + ) + plot5 = gr.Image( + label="Detection Density Timeline", + ) + plot6 = gr.Image( + label="Screen Region Analysis", + ) + + with gr.Column(): + plot3 = gr.Image( + label="Average Detection Area Over Time", + ) + plot4 = gr.Image( + label="Detection Center Heatmap", + ) + plot7 = gr.Image( + label="Detection Size Categories", + ) + plot8 = gr.Image( + label="Temporal Pattern Analysis", + ) + + stats_output = gr.Textbox( + label="Statistics", + info="Summary of key metrics and patterns found in the detection data.", + lines=12, + max_lines=15, + interactive=False, + ) + + # with gr.Tab("Video Visualizations"): + # gr.Markdown("# Real-time Detection Visualization") + # gr.Markdown( + # """ + # Watch the detection patterns unfold in real-time. Choose from: + # - Timeline: Shows number of detections over time + # - Gauge: Simple yes/no indicator for current frame detections + # """ + # ) + + # with gr.Row(): + # json_input_realtime = gr.File( + # label="Upload Detection Data (JSON)", + # file_types=[".json"], + # ) + # viz_style = gr.Radio( + # choices=["timeline", "gauge"], + # value="timeline", + # label="Visualization Style", + # info="Choose between timeline view or simple gauge indicator" + # ) + # visualize_btn = gr.Button("Visualize", variant="primary") + + # with gr.Row(): + # video_visualization = gr.Video( + # label="Detection Visualization", + # interactive=False + # ) + # stats_realtime = gr.Textbox( + # label="Video Statistics", + # lines=6, + # max_lines=8, + # interactive=False + # ) + + # Event handlers + process_outputs = process_btn.click( + fn=process_video_file, + inputs=[ + video_input, + detect_input, + box_style_input, + preset_input, + rows_input, + cols_input, + test_mode_input, + test_duration_input, + ], + outputs=[video_output, json_output], + ) + + # Auto-analyze after processing + process_outputs.then( + fn=create_visualization_plots, + inputs=[json_output], + outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output], + ) + + # Manual analysis button + analyze_btn.click( + fn=create_visualization_plots, + inputs=[json_input], + outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output], + ) + + # Video visualization button + # visualize_btn.click( + # fn=lambda json_file, style: create_video_visualization(json_file.name if json_file else None, style), + # inputs=[json_input_realtime, viz_style], + # outputs=[video_visualization, stats_realtime], + # ) + +if __name__ == "__main__": + app.launch(share=True) diff --git a/recipes/promptable-content-moderation/deep_sort_integration.py b/recipes/promptable-content-moderation/deep_sort_integration.py new file mode 100644 index 00000000..d1aecbb4 --- /dev/null +++ b/recipes/promptable-content-moderation/deep_sort_integration.py @@ -0,0 +1,74 @@ +import numpy as np +import torch +from deep_sort_realtime.deepsort_tracker import DeepSort +from datetime import datetime + + +class DeepSORTTracker: + def __init__(self, max_age=5): + """Initialize DeepSORT tracker.""" + self.max_age = max_age + self.tracker = self._create_tracker() + + def _create_tracker(self): + """Create a new instance of DeepSort tracker.""" + return DeepSort( + max_age=self.max_age, + embedder="mobilenet", # Using default MobileNetV2 embedder + today=datetime.now().date(), # For track naming and daily ID reset + ) + + def reset(self): + """Reset the tracker state by creating a new instance.""" + print("Resetting DeepSORT tracker...") + self.tracker = self._create_tracker() + + def update(self, frame, detections): + """Update tracking with new detections. + + Args: + frame: Current video frame (numpy array) + detections: List of (box, keyword) tuples where box is [x1, y1, x2, y2] normalized + + Returns: + List of (box, keyword, track_id) tuples + """ + if not detections: + return [] + + height, width = frame.shape[:2] + + # Convert normalized coordinates to absolute and format detections + detection_list = [] + for box, keyword in detections: + x1 = int(box[0] * width) + y1 = int(box[1] * height) + x2 = int(box[2] * width) + y2 = int(box[3] * height) + w = x2 - x1 + h = y2 - y1 + + # Format: ([left,top,w,h], confidence, detection_class) + detection_list.append(([x1, y1, w, h], 1.0, keyword)) + + # Update tracker + tracks = self.tracker.update_tracks(detection_list, frame=frame) + + # Convert back to normalized coordinates with track IDs + tracked_objects = [] + for track in tracks: + if not track.is_confirmed(): + continue + + ltrb = track.to_ltrb() # Get [left,top,right,bottom] format + x1, y1, x2, y2 = ltrb + + # Normalize coordinates + x1 = max(0.0, min(1.0, x1 / width)) + y1 = max(0.0, min(1.0, y1 / height)) + x2 = max(0.0, min(1.0, x2 / width)) + y2 = max(0.0, min(1.0, y2 / height)) + + tracked_objects.append(([x1, y1, x2, y2], track.det_class, track.track_id)) + + return tracked_objects diff --git a/recipes/promptable-content-moderation/main.py b/recipes/promptable-content-moderation/main.py new file mode 100644 index 00000000..454b10ac --- /dev/null +++ b/recipes/promptable-content-moderation/main.py @@ -0,0 +1,1326 @@ +#!/usr/bin/env python3 +import cv2, os, subprocess, argparse +from PIL import Image +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, SamModel, SamProcessor +from tqdm import tqdm +import numpy as np +from datetime import datetime +from deep_sort_integration import DeepSORTTracker +from scenedetect import detect, ContentDetector +from functools import lru_cache + +# Constants +DEFAULT_TEST_MODE_DURATION = 3 # Process only first 3 seconds in test mode by default +FFMPEG_PRESETS = [ + "ultrafast", + "superfast", + "veryfast", + "faster", + "fast", + "medium", + "slow", + "slower", + "veryslow", +] +FONT = cv2.FONT_HERSHEY_SIMPLEX # Font for bounding-box-style labels + +# Detection parameters +IOU_THRESHOLD = 0.5 # IoU threshold for considering boxes related + +# Hitmarker parameters +HITMARKER_SIZE = 20 # Size of the hitmarker in pixels +HITMARKER_GAP = 3 # Size of the empty space in the middle (reduced from 8) +HITMARKER_THICKNESS = 2 # Thickness of hitmarker lines +HITMARKER_COLOR = (255, 255, 255) # White color for hitmarker +HITMARKER_SHADOW_COLOR = (80, 80, 80) # Lighter gray for shadow effect +HITMARKER_SHADOW_OFFSET = 1 # Smaller shadow offset + +# SAM parameters +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Initialize model variables as None +sam_model = None +sam_processor = None +slimsam_model = None +slimsam_processor = None + + +@lru_cache(maxsize=2) # Cache both regular and slim SAM models +def get_sam_model(slim=False): + """Get cached SAM model and processor.""" + global sam_model, sam_processor, slimsam_model, slimsam_processor + + if slim: + if slimsam_model is None: + print("Loading SlimSAM model for the first time...") + slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to( + device + ) + slimsam_processor = SamProcessor.from_pretrained( + "nielsr/slimsam-50-uniform" + ) + return slimsam_model, slimsam_processor + else: + if sam_model is None: + print("Loading SAM model for the first time...") + sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) + sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + return sam_model, sam_processor + + +def load_sam_model(slim=False): + """Load SAM model and processor with caching.""" + return get_sam_model(slim=slim) + + +def generate_color_pair(): + """Generate a generic light blue and dark blue color pair for SAM visualization.""" + dark_rgb = [0, 0, 139] # Dark blue + light_rgb = [173, 216, 230] # Light blue + return dark_rgb, light_rgb + + +def create_mask_overlay(image, masks, points=None, labels=None): + """Create a mask overlay with contours for multiple SAM visualizations. + + Args: + image: PIL Image to overlay masks on + masks: List of binary masks or single mask + points: Optional list of (x,y) points for labels + labels: Optional list of label strings for each point + """ + # Convert single mask to list for uniform processing + if not isinstance(masks, list): + masks = [masks] + + # Create empty overlays + overlay = np.zeros((*image.size[::-1], 4), dtype=np.uint8) + outline = np.zeros((*image.size[::-1], 4), dtype=np.uint8) + + # Process each mask + for i, mask in enumerate(masks): + # Convert binary mask to uint8 + mask_uint8 = (mask > 0).astype(np.uint8) + + # Dilation to fill gaps + kernel = np.ones((5, 5), np.uint8) + mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1) + + # Find contours of the dilated mask + contours, _ = cv2.findContours( + mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + + # Generate random color pair for this segmentation + dark_color, light_color = generate_color_pair() + + # Add to the overlays + overlay[mask_dilated > 0] = [*light_color, 90] # Light color with 35% opacity + cv2.drawContours( + outline, contours, -1, (*dark_color, 255), 2 + ) # Dark color outline + + # Convert to PIL images + mask_overlay = Image.fromarray(overlay, "RGBA") + outline_overlay = Image.fromarray(outline, "RGBA") + + # Composite the layers + result = image.convert("RGBA") + result.paste(mask_overlay, (0, 0), mask_overlay) + result.paste(outline_overlay, (0, 0), outline_overlay) + + # Add labels if provided + if points and labels: + result_array = np.array(result) + for (x, y), label in zip(points, labels): + label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0] + cv2.putText( + result_array, + label, + (int(x - label_size[0] // 2), int(y - 20)), + FONT, + 0.5, + (255, 255, 255), + 1, + cv2.LINE_AA, + ) + result = Image.fromarray(result_array) + + return result + + +def process_sam_detection(image, center_x, center_y, slim=False): + """Process a single detection point with SAM. + + Returns: + tuple: (mask, result_pil) where mask is the binary mask and result_pil is the visualization + """ + if not isinstance(image, Image.Image): + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Get appropriate model from cache + model, processor = get_sam_model(slim) + + # Process the image with SAM + inputs = processor( + image, input_points=[[[center_x, center_y]]], return_tensors="pt" + ).to(device) + + with torch.no_grad(): + outputs = model(**inputs) + + mask = processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + )[0][0][0].numpy() + + # Create the visualization + result = create_mask_overlay(image, mask) + return mask, result + + +def load_moondream(): + """Load Moondream model and tokenizer.""" + model = AutoModelForCausalLM.from_pretrained( + "vikhyatk/moondream2", trust_remote_code=True, device_map={"": "cuda"} + ) + tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2") + return model, tokenizer + + +def get_video_properties(video_path): + """Get basic video properties.""" + video = cv2.VideoCapture(video_path) + fps = video.get(cv2.CAP_PROP_FPS) + frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + video.release() + return {"fps": fps, "frame_count": frame_count, "width": width, "height": height} + + +def is_valid_bounding_box(bounding_box): + """Check if bounding box coordinates are reasonable.""" + x1, y1, x2, y2 = bounding_box + width = x2 - x1 + height = y2 - y1 + + # Reject boxes that are too large (over 90% of frame in both dimensions) + if width > 0.9 and height > 0.9: + return False + + # Reject boxes that are too small (less than 1% of frame) + if width < 0.01 or height < 0.01: + return False + + return True + + +def split_frame_into_grid(frame, grid_rows, grid_cols): + """Split a frame into a grid of tiles.""" + height, width = frame.shape[:2] + tile_height = height // grid_rows + tile_width = width // grid_cols + tiles = [] + tile_positions = [] + + for i in range(grid_rows): + for j in range(grid_cols): + y1 = i * tile_height + y2 = (i + 1) * tile_height if i < grid_rows - 1 else height + x1 = j * tile_width + x2 = (j + 1) * tile_width if j < grid_cols - 1 else width + + tile = frame[y1:y2, x1:x2] + tiles.append(tile) + tile_positions.append((x1, y1, x2, y2)) + + return tiles, tile_positions + + +def convert_tile_coords_to_frame(box, tile_pos, frame_shape): + """Convert coordinates from tile space to frame space.""" + frame_height, frame_width = frame_shape[:2] + tile_x1, tile_y1, tile_x2, tile_y2 = tile_pos + tile_width = tile_x2 - tile_x1 + tile_height = tile_y2 - tile_y1 + + x1_tile_abs = box[0] * tile_width + y1_tile_abs = box[1] * tile_height + x2_tile_abs = box[2] * tile_width + y2_tile_abs = box[3] * tile_height + + x1_frame_abs = tile_x1 + x1_tile_abs + y1_frame_abs = tile_y1 + y1_tile_abs + x2_frame_abs = tile_x1 + x2_tile_abs + y2_frame_abs = tile_y1 + y2_tile_abs + + x1_norm = x1_frame_abs / frame_width + y1_norm = y1_frame_abs / frame_height + x2_norm = x2_frame_abs / frame_width + y2_norm = y2_frame_abs / frame_height + + x1_norm = max(0.0, min(1.0, x1_norm)) + y1_norm = max(0.0, min(1.0, y1_norm)) + x2_norm = max(0.0, min(1.0, x2_norm)) + y2_norm = max(0.0, min(1.0, y2_norm)) + + return [x1_norm, y1_norm, x2_norm, y2_norm] + + +def merge_tile_detections(tile_detections, iou_threshold=0.5): + """Merge detections from different tiles using NMS-like approach.""" + if not tile_detections: + return [] + + all_boxes = [] + all_keywords = [] + + # Collect all boxes and their keywords + for detections in tile_detections: + for box, keyword in detections: + all_boxes.append(box) + all_keywords.append(keyword) + + if not all_boxes: + return [] + + # Convert to numpy for easier processing + boxes = np.array(all_boxes) + + # Calculate areas + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + areas = (x2 - x1) * (y2 - y1) + + # Sort boxes by area + order = areas.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + + if order.size == 1: + break + + # Calculate IoU with rest of boxes + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + # Get indices of boxes with IoU less than threshold + inds = np.where(ovr <= iou_threshold)[0] + order = order[inds + 1] + + return [(all_boxes[i], all_keywords[i]) for i in keep] + + +def detect_objects_in_frame( + model, tokenizer, image, target_object, grid_rows=1, grid_cols=1 +): + """Detect specified objects in a frame using grid-based analysis.""" + if grid_rows == 1 and grid_cols == 1: + return detect_objects_in_frame_single(model, tokenizer, image, target_object) + + # Convert numpy array to PIL Image if needed + if not isinstance(image, Image.Image): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Split frame into tiles + tiles, tile_positions = split_frame_into_grid(image, grid_rows, grid_cols) + + # Process each tile + tile_detections = [] + for tile, tile_pos in zip(tiles, tile_positions): + # Convert tile to PIL Image + tile_pil = Image.fromarray(tile) + + # Detect objects in tile + response = model.detect(tile_pil, target_object) + + if response and "objects" in response and response["objects"]: + objects = response["objects"] + tile_objects = [] + + for obj in objects: + if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]): + box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]] + + if is_valid_bounding_box(box): + # Convert tile coordinates to frame coordinates + frame_box = convert_tile_coords_to_frame( + box, tile_pos, image.shape + ) + tile_objects.append((frame_box, target_object)) + + if tile_objects: # Only append if we found valid objects + tile_detections.append(tile_objects) + + # Merge detections from all tiles + merged_detections = merge_tile_detections(tile_detections) + return merged_detections + + +def detect_objects_in_frame_single(model, tokenizer, image, target_object): + """Single-frame detection function.""" + detected_objects = [] + + # Convert numpy array to PIL Image if needed + if not isinstance(image, Image.Image): + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Detect objects + response = model.detect(image, target_object) + + # Check if we have valid objects + if response and "objects" in response and response["objects"]: + objects = response["objects"] + + for obj in objects: + if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]): + box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]] + # If box is valid (not full-frame), add it + if is_valid_bounding_box(box): + detected_objects.append((box, target_object)) + + return detected_objects + + +def draw_hitmarker( + frame, center_x, center_y, size=HITMARKER_SIZE, color=HITMARKER_COLOR, shadow=True +): + """Draw a COD-style hitmarker cross with more space in the middle.""" + half_size = size // 2 + + # Draw shadow first if enabled + if shadow: + # Top-left to center shadow + cv2.line( + frame, + ( + center_x - half_size + HITMARKER_SHADOW_OFFSET, + center_y - half_size + HITMARKER_SHADOW_OFFSET, + ), + ( + center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + ), + HITMARKER_SHADOW_COLOR, + HITMARKER_THICKNESS, + ) + # Top-right to center shadow + cv2.line( + frame, + ( + center_x + half_size + HITMARKER_SHADOW_OFFSET, + center_y - half_size + HITMARKER_SHADOW_OFFSET, + ), + ( + center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + ), + HITMARKER_SHADOW_COLOR, + HITMARKER_THICKNESS, + ) + # Bottom-left to center shadow + cv2.line( + frame, + ( + center_x - half_size + HITMARKER_SHADOW_OFFSET, + center_y + half_size + HITMARKER_SHADOW_OFFSET, + ), + ( + center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + ), + HITMARKER_SHADOW_COLOR, + HITMARKER_THICKNESS, + ) + # Bottom-right to center shadow + cv2.line( + frame, + ( + center_x + half_size + HITMARKER_SHADOW_OFFSET, + center_y + half_size + HITMARKER_SHADOW_OFFSET, + ), + ( + center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, + ), + HITMARKER_SHADOW_COLOR, + HITMARKER_THICKNESS, + ) + + # Draw main hitmarker + # Top-left to center + cv2.line( + frame, + (center_x - half_size, center_y - half_size), + (center_x - HITMARKER_GAP, center_y - HITMARKER_GAP), + color, + HITMARKER_THICKNESS, + ) + # Top-right to center + cv2.line( + frame, + (center_x + half_size, center_y - half_size), + (center_x + HITMARKER_GAP, center_y - HITMARKER_GAP), + color, + HITMARKER_THICKNESS, + ) + # Bottom-left to center + cv2.line( + frame, + (center_x - half_size, center_y + half_size), + (center_x - HITMARKER_GAP, center_y + HITMARKER_GAP), + color, + HITMARKER_THICKNESS, + ) + # Bottom-right to center + cv2.line( + frame, + (center_x + half_size, center_y + half_size), + (center_x + HITMARKER_GAP, center_y + HITMARKER_GAP), + color, + HITMARKER_THICKNESS, + ) + + +def draw_ad_boxes(frame, detected_objects, detect_keyword, model, box_style="censor"): + height, width = frame.shape[:2] + + points = [] + # Only get points if we need them for hitmarker or SAM styles + if box_style in ["hitmarker", "sam", "sam-fast"]: + frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + try: + point_response = model.point(frame_pil, detect_keyword) + + if isinstance(point_response, dict) and "points" in point_response: + points = point_response["points"] + except Exception as e: + print(f"Error during point detection: {str(e)}") + points = [] + + # Only load SAM models and process points if we're using SAM styles and have points + if box_style in ["sam", "sam-fast"] and points: + # Start with the original PIL image + frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + + # Collect all masks and points + all_masks = [] + point_coords = [] + point_labels = [] + + for point in points: + try: + center_x = int(float(point["x"]) * width) + center_y = int(float(point["y"]) * height) + + # Get mask and visualization + mask, _ = process_sam_detection( + frame_pil, center_x, center_y, slim=(box_style == "sam-fast") + ) + + # Collect mask and point data + all_masks.append(mask) + point_coords.append((center_x, center_y)) + point_labels.append(detect_keyword) + + except Exception as e: + print(f"Error processing individual SAM point: {str(e)}") + print(f"Point data: {point}") + + if all_masks: + # Create final visualization with all masks + result_pil = create_mask_overlay( + frame_pil, all_masks, point_coords, point_labels + ) + frame = cv2.cvtColor(np.array(result_pil), cv2.COLOR_RGB2BGR) + + # Process other visualization styles + for detection in detected_objects: + try: + # Handle both tracked and untracked detections + if len(detection) == 3: # Tracked detection with ID + box, keyword, track_id = detection + else: # Regular detection without tracking + box, keyword = detection + track_id = None + + x1 = int(box[0] * width) + y1 = int(box[1] * height) + x2 = int(box[2] * width) + y2 = int(box[3] * height) + + x1 = max(0, min(x1, width - 1)) + y1 = max(0, min(y1, height - 1)) + x2 = max(0, min(x2, width - 1)) + y2 = max(0, min(y2, height - 1)) + + if x2 > x1 and y2 > y1: + if box_style == "censor": + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1) + elif box_style == "bounding-box": + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3) + + label = ( + f"{detect_keyword}" if track_id is not None else detect_keyword + ) + label_size = cv2.getTextSize(label, FONT, 0.7, 2)[0] + cv2.rectangle( + frame, (x1, y1 - 25), (x1 + label_size[0], y1), (0, 0, 255), -1 + ) + cv2.putText( + frame, + label, + (x1, y1 - 6), + FONT, + 0.7, + (255, 255, 255), + 2, + cv2.LINE_AA, + ) + elif box_style == "fuzzy-blur": + # Extract ROI + roi = frame[y1:y2, x1:x2] + # Apply Gaussian blur with much larger kernel for intense blur + blurred_roi = cv2.GaussianBlur(roi, (125, 125), 0) + # Replace original ROI with blurred version + frame[y1:y2, x1:x2] = blurred_roi + elif box_style == "pixelated-blur": + # Extract ROI + roi = frame[y1:y2, x1:x2] + # Pixelate by resizing down and up + h, w = roi.shape[:2] + temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR) + pixelated = cv2.resize( + temp, (w, h), interpolation=cv2.INTER_NEAREST + ) + # Mix up the pixelated frame slightly by adding random noise + noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8) + pixelated = cv2.add(pixelated, noise) + # Apply stronger Gaussian blur to smooth edges + blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0) + # Replace original ROI + frame[y1:y2, x1:x2] = blurred_pixelated + elif box_style == "obfuscated-pixel": + # Calculate expansion amount based on 10% of object dimensions + box_width = x2 - x1 + box_height = y2 - y1 + expand_x = int(box_width * 0.10) + expand_y = int(box_height * 0.10) + + # Expand the bounding box by 10% in all directions + x1_expanded = max(0, x1 - expand_x) + y1_expanded = max(0, y1 - expand_y) + x2_expanded = min(width - 1, x2 + expand_x) + y2_expanded = min(height - 1, y2 + expand_y) + + # Extract ROI with much larger padding for true background sampling + padding = 100 # Much larger padding to get true background + y1_pad = max(0, y1_expanded - padding) + y2_pad = min(height, y2_expanded + padding) + x1_pad = max(0, x1_expanded - padding) + x2_pad = min(width, x2_expanded + padding) + + # Get the padded region including background + padded_roi = frame[y1_pad:y2_pad, x1_pad:x2_pad] + + # Create mask that excludes a larger region around the detection + h, w = y2_expanded - y1_expanded, x2_expanded - x1_expanded + bg_mask = np.ones(padded_roi.shape[:2], dtype=bool) + + # Exclude a larger region around the detection from background sampling + exclusion_padding = 50 # Area to exclude around detection + exclude_y1 = padding - exclusion_padding + exclude_y2 = padding + h + exclusion_padding + exclude_x1 = padding - exclusion_padding + exclude_x2 = padding + w + exclusion_padding + + # Make sure exclusion coordinates are valid + exclude_y1 = max(0, exclude_y1) + exclude_y2 = min(padded_roi.shape[0], exclude_y2) + exclude_x1 = max(0, exclude_x1) + exclude_x2 = min(padded_roi.shape[1], exclude_x2) + + # Mark the exclusion zone in the mask + bg_mask[exclude_y1:exclude_y2, exclude_x1:exclude_x2] = False + + # If we have enough background pixels, calculate average color + if np.any(bg_mask): + bg_color = np.mean(padded_roi[bg_mask], axis=0).astype(np.uint8) + else: + # Fallback to edges if we couldn't get enough background + edge_samples = np.concatenate( + [ + padded_roi[0], # Top edge + padded_roi[-1], # Bottom edge + padded_roi[:, 0], # Left edge + padded_roi[:, -1], # Right edge + ] + ) + bg_color = np.mean(edge_samples, axis=0).astype(np.uint8) + + # Create base pixelated version (of the expanded region) + temp = cv2.resize( + frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded], + (6, 6), + interpolation=cv2.INTER_LINEAR, + ) + pixelated = cv2.resize( + temp, (w, h), interpolation=cv2.INTER_NEAREST + ) + + # Blend heavily towards background color + blend_factor = 0.9 # Much stronger blend with background + blended = cv2.addWeighted( + pixelated, + 1 - blend_factor, + np.full((h, w, 3), bg_color, dtype=np.uint8), + blend_factor, + 0, + ) + + # Replace original ROI with blended version (using expanded coordinates) + frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = blended + elif box_style == "intense-pixelated-blur": + # Expand the bounding box by pixels in all directions + x1_expanded = max(0, x1 - 15) + y1_expanded = max(0, y1 - 15) + x2_expanded = min(width - 1, x2 + 25) + y2_expanded = min(height - 1, y2 + 25) + + # Extract ROI + roi = frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] + # Pixelate by resizing down and up + h, w = roi.shape[:2] + temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR) + pixelated = cv2.resize( + temp, (w, h), interpolation=cv2.INTER_NEAREST + ) + # Mix up the pixelated frame slightly by adding random noise + noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8) + pixelated = cv2.add(pixelated, noise) + # Apply stronger Gaussian blur to smooth edges + blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0) + # Replace original ROI + frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = ( + blurred_pixelated + ) + elif box_style == "hitmarker": + if points: + for point in points: + try: + print(f"Processing point: {point}") + center_x = int(float(point["x"]) * width) + center_y = int(float(point["y"]) * height) + print( + f"Converted coordinates: ({center_x}, {center_y})" + ) + + draw_hitmarker(frame, center_x, center_y) + + label = ( + f"{detect_keyword}" + if track_id is not None + else detect_keyword + ) + label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0] + cv2.putText( + frame, + label, + ( + center_x - label_size[0] // 2, + center_y - HITMARKER_SIZE - 5, + ), + FONT, + 0.5, + HITMARKER_COLOR, + 1, + cv2.LINE_AA, + ) + except Exception as e: + print(f"Error processing individual point: {str(e)}") + print(f"Point data: {point}") + + except Exception as e: + print(f"Error drawing {box_style} style box: {str(e)}") + print(f"Box data: {box}") + print(f"Keyword: {keyword}") + + return frame + + +def filter_temporal_outliers(detections_dict): + """Filter out extremely large detections that take up most of the frame. + Only keeps detections that are reasonable in size. + + Args: + detections_dict: Dictionary of {frame_number: [(box, keyword, track_id), ...]} + """ + filtered_detections = {} + + for t, detections in detections_dict.items(): + # Only keep detections that aren't too large + valid_detections = [] + for detection in detections: + # Handle both tracked and untracked detections + if len(detection) == 3: # Tracked detection with ID + box, keyword, track_id = detection + else: # Regular detection without tracking + box, keyword = detection + track_id = None + + # Calculate box size as percentage of frame + width = box[2] - box[0] + height = box[3] - box[1] + area = width * height + + # If box is less than 90% of frame, keep it + if area < 0.9: + if track_id is not None: + valid_detections.append((box, keyword, track_id)) + else: + valid_detections.append((box, keyword)) + + if valid_detections: + filtered_detections[t] = valid_detections + + return filtered_detections + + +def describe_frames( + video_path, + model, + tokenizer, + detect_keyword, + test_mode=False, + test_duration=DEFAULT_TEST_MODE_DURATION, + grid_rows=1, + grid_cols=1, +): + """Extract and detect objects in frames.""" + props = get_video_properties(video_path) + fps = props["fps"] + + # Initialize DeepSORT tracker + tracker = DeepSORTTracker() + + # If in test mode, only process first N seconds + if test_mode: + frame_count = min(int(fps * test_duration), props["frame_count"]) + else: + frame_count = props["frame_count"] + + ad_detections = {} # Store detection results by frame number + + print("Extracting frames and detecting objects...") + video = cv2.VideoCapture(video_path) + + # Detect scenes first + scenes = detect(video_path, scene_detector) + scene_changes = set(end.get_frames() for _, end in scenes) + print(f"Detected {len(scenes)} scenes") + + frame_count_processed = 0 + with tqdm(total=frame_count) as pbar: + while frame_count_processed < frame_count: + ret, frame = video.read() + if not ret: + break + + # Check if current frame is a scene change + if frame_count_processed in scene_changes: + # Detect objects in the frame + detected_objects = detect_objects_in_frame( + model, + tokenizer, + frame, + detect_keyword, + grid_rows=grid_rows, + grid_cols=grid_cols, + ) + + # Update tracker with current detections + tracked_objects = tracker.update(frame, detected_objects) + + # Store results for every frame, even if empty + ad_detections[frame_count_processed] = tracked_objects + + frame_count_processed += 1 + pbar.update(1) + + video.release() + + if frame_count_processed == 0: + print("No frames could be read from video") + return {} + + return ad_detections + + +def create_detection_video( + video_path, + ad_detections, + detect_keyword, + model, + output_path=None, + ffmpeg_preset="medium", + test_mode=False, + test_duration=DEFAULT_TEST_MODE_DURATION, + box_style="censor", +): + """Create video with detection boxes while preserving audio.""" + if output_path is None: + # Create outputs directory if it doesn't exist + outputs_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "outputs" + ) + os.makedirs(outputs_dir, exist_ok=True) + + # Clean the detect_keyword for filename + safe_keyword = "".join( + x for x in detect_keyword if x.isalnum() or x in (" ", "_", "-") + ) + safe_keyword = safe_keyword.replace(" ", "_") + + # Create output filename + base_name = os.path.splitext(os.path.basename(video_path))[0] + output_path = os.path.join( + outputs_dir, f"{box_style}_{safe_keyword}_{base_name}.mp4" + ) + + print(f"Will save output to: {output_path}") + + props = get_video_properties(video_path) + fps, width, height = props["fps"], props["width"], props["height"] + + # If in test mode, only process first few seconds + if test_mode: + frame_count = min(int(fps * test_duration), props["frame_count"]) + print( + f"Test mode enabled: Processing first {test_duration} seconds ({frame_count} frames)" + ) + else: + frame_count = props["frame_count"] + print("Full video mode: Processing entire video") + + video = cv2.VideoCapture(video_path) + + # Create temp output path by adding _temp before the extension + base, ext = os.path.splitext(output_path) + temp_output = f"{base}_temp{ext}" + temp_audio = f"{base}_audio.aac" # Temporary audio file + + out = cv2.VideoWriter( + temp_output, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height) + ) + + print("Creating detection video...") + frame_count_processed = 0 + + with tqdm(total=frame_count) as pbar: + while frame_count_processed < frame_count: + ret, frame = video.read() + if not ret: + break + + # Get detections for this exact frame + if frame_count_processed in ad_detections: + current_detections = ad_detections[frame_count_processed] + if current_detections: + frame = draw_ad_boxes( + frame, + current_detections, + detect_keyword, + model, + box_style=box_style, + ) + + out.write(frame) + frame_count_processed += 1 + pbar.update(1) + + video.release() + out.release() + + # Extract audio from original video + try: + if test_mode: + # In test mode, extract only the required duration of audio + subprocess.run( + [ + "ffmpeg", + "-y", + "-i", + video_path, + "-t", + str(test_duration), + "-vn", # No video + "-acodec", + "copy", + temp_audio, + ], + check=True, + ) + else: + subprocess.run( + [ + "ffmpeg", + "-y", + "-i", + video_path, + "-vn", # No video + "-acodec", + "copy", + temp_audio, + ], + check=True, + ) + except subprocess.CalledProcessError as e: + print(f"Error extracting audio: {str(e)}") + if os.path.exists(temp_output): + os.remove(temp_output) + return None + + # Merge processed video with original audio + try: + # Base FFmpeg command + ffmpeg_cmd = [ + "ffmpeg", + "-y", + "-i", + temp_output, + "-i", + temp_audio, + "-c:v", + "libx264", + "-preset", + ffmpeg_preset, + "-crf", + "23", + "-c:a", + "aac", + "-b:a", + "192k", + "-movflags", + "+faststart", # Better web playback + ] + + if test_mode: + # In test mode, ensure output duration matches test_duration + ffmpeg_cmd.extend( + [ + "-t", + str(test_duration), + "-shortest", # Ensure output duration matches shortest input + ] + ) + + ffmpeg_cmd.extend(["-loglevel", "error", output_path]) + + subprocess.run(ffmpeg_cmd, check=True) + + # Clean up temporary files + os.remove(temp_output) + os.remove(temp_audio) + + if not os.path.exists(output_path): + print( + f"Warning: FFmpeg completed but output file not found at {output_path}" + ) + return None + + return output_path + + except subprocess.CalledProcessError as e: + print(f"Error merging audio with video: {str(e)}") + if os.path.exists(temp_output): + os.remove(temp_output) + if os.path.exists(temp_audio): + os.remove(temp_audio) + return None + + +def process_video( + video_path, + target_object, + test_mode=False, + test_duration=DEFAULT_TEST_MODE_DURATION, + ffmpeg_preset="medium", + grid_rows=1, + grid_cols=1, + box_style="censor", +): + """Process a video to detect and visualize specified objects.""" + try: + print(f"\nProcessing: {video_path}") + print(f"Looking for: {target_object}") + + # Load model + print("Loading Moondream model...") + model, tokenizer = load_moondream() + + # Get video properties + props = get_video_properties(video_path) + + # Initialize scene detector with ContentDetector + scene_detector = ContentDetector(threshold=30.0) # Adjust threshold as needed + + # Initialize DeepSORT tracker + tracker = DeepSORTTracker() + + # If in test mode, only process first N seconds + if test_mode: + frame_count = min(int(props["fps"] * test_duration), props["frame_count"]) + else: + frame_count = props["frame_count"] + + ad_detections = {} # Store detection results by frame number + + print("Extracting frames and detecting objects...") + video = cv2.VideoCapture(video_path) + + # Detect scenes first + scenes = detect(video_path, scene_detector) + scene_changes = set(end.get_frames() for _, end in scenes) + print(f"Detected {len(scenes)} scenes") + + frame_count_processed = 0 + with tqdm(total=frame_count) as pbar: + while frame_count_processed < frame_count: + ret, frame = video.read() + if not ret: + break + + # Check if current frame is a scene change + if frame_count_processed in scene_changes: + print( + f"Scene change detected at frame {frame_count_processed}. Resetting tracker." + ) + tracker.reset() + + # Detect objects in the frame + detected_objects = detect_objects_in_frame( + model, + tokenizer, + frame, + target_object, + grid_rows=grid_rows, + grid_cols=grid_cols, + ) + + # Update tracker with current detections + tracked_objects = tracker.update(frame, detected_objects) + + # Store results for every frame, even if empty + ad_detections[frame_count_processed] = tracked_objects + + frame_count_processed += 1 + pbar.update(1) + + video.release() + + if frame_count_processed == 0: + print("No frames could be read from video") + return {} + + # Apply filtering + filtered_ad_detections = filter_temporal_outliers(ad_detections) + + # Build detection data structure + detection_data = { + "video_metadata": { + "file_name": os.path.basename(video_path), + "fps": props["fps"], + "width": props["width"], + "height": props["height"], + "total_frames": props["frame_count"], + "duration_sec": props["frame_count"] / props["fps"], + "detect_keyword": target_object, + "test_mode": test_mode, + "grid_size": f"{grid_rows}x{grid_cols}", + "box_style": box_style, + "timestamp": datetime.now().isoformat(), + }, + "frame_detections": [ + { + "frame": frame_num, + "timestamp": frame_num / props["fps"], + "objects": [ + { + "keyword": kw, + "bbox": list(box), # Convert numpy array to list if needed + "track_id": track_id if len(detection) == 3 else None, + } + for detection in filtered_ad_detections.get(frame_num, []) + for box, kw, *track_id in [ + detection + ] # Unpack detection tuple, track_id will be empty list if not present + ], + } + for frame_num in range( + props["frame_count"] + if not test_mode + else min(int(props["fps"] * test_duration), props["frame_count"]) + ) + ], + } + + # Save filtered data + outputs_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "outputs" + ) + os.makedirs(outputs_dir, exist_ok=True) + base_name = os.path.splitext(os.path.basename(video_path))[0] + json_path = os.path.join( + outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json" + ) + + from persistence import save_detection_data + + if not save_detection_data(detection_data, json_path): + print("Warning: Failed to save detection data") + + # Create video with filtered data + output_path = create_detection_video( + video_path, + filtered_ad_detections, + target_object, + model, + ffmpeg_preset=ffmpeg_preset, + test_mode=test_mode, + test_duration=test_duration, + box_style=box_style, + ) + + if output_path is None: + print("\nError: Failed to create output video") + return None + + print(f"\nOutput saved to: {output_path}") + print(f"Detection data saved to: {json_path}") + return output_path + + except Exception as e: + print(f"Error processing video: {str(e)}") + import traceback + + traceback.print_exc() + return None + + +def main(): + """Process all videos in the inputs directory.""" + parser = argparse.ArgumentParser( + description="Detect objects in videos using Moondream2" + ) + parser.add_argument( + "--test", action="store_true", help="Process only first 3 seconds of each video" + ) + parser.add_argument( + "--test-duration", + type=int, + default=DEFAULT_TEST_MODE_DURATION, + help=f"Number of seconds to process in test mode (default: {DEFAULT_TEST_MODE_DURATION})", + ) + parser.add_argument( + "--preset", + choices=FFMPEG_PRESETS, + default="medium", + help="FFmpeg encoding preset (default: medium). Faster presets = lower quality", + ) + parser.add_argument( + "--detect", + type=str, + default="face", + help='Object to detect in the video (default: face, use --detect "thing to detect" to override)', + ) + parser.add_argument( + "--rows", + type=int, + default=1, + help="Number of rows to split each frame into (default: 1)", + ) + parser.add_argument( + "--cols", + type=int, + default=1, + help="Number of columns to split each frame into (default: 1)", + ) + parser.add_argument( + "--box-style", + choices=[ + "censor", + "bounding-box", + "hitmarker", + "sam", + "sam-fast", + "fuzzy-blur", + "pixelated-blur", + "intense-pixelated-blur", + "obfuscated-pixel", + ], + default="censor", + help="Style of detection visualization (default: censor)", + ) + args = parser.parse_args() + + input_dir = "inputs" + os.makedirs(input_dir, exist_ok=True) + os.makedirs("outputs", exist_ok=True) + + video_files = [ + f + for f in os.listdir(input_dir) + if f.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm")) + ] + + if not video_files: + print("No video files found in 'inputs' directory") + return + + print(f"Found {len(video_files)} videos to process") + print(f"Will detect: {args.detect}") + if args.test: + print("Running in test mode - processing only first 3 seconds of each video") + print(f"Using FFmpeg preset: {args.preset}") + print(f"Grid size: {args.rows}x{args.cols}") + print(f"Box style: {args.box_style}") + + success_count = 0 + for video_file in video_files: + video_path = os.path.join(input_dir, video_file) + output_path = process_video( + video_path, + args.detect, + test_mode=args.test, + test_duration=args.test_duration, + ffmpeg_preset=args.preset, + grid_rows=args.rows, + grid_cols=args.cols, + box_style=args.box_style, + ) + if output_path: + success_count += 1 + + print( + f"\nProcessing complete. Successfully processed {success_count} out of {len(video_files)} videos." + ) + + +if __name__ == "__main__": + main() diff --git a/recipes/promptable-content-moderation/packages.txt b/recipes/promptable-content-moderation/packages.txt new file mode 100644 index 00000000..8ba77f6e --- /dev/null +++ b/recipes/promptable-content-moderation/packages.txt @@ -0,0 +1,2 @@ +libvips +ffmpeg \ No newline at end of file diff --git a/recipes/promptable-content-moderation/persistence.py b/recipes/promptable-content-moderation/persistence.py new file mode 100644 index 00000000..0374ba3f --- /dev/null +++ b/recipes/promptable-content-moderation/persistence.py @@ -0,0 +1,41 @@ +import json +import os + + +def save_detection_data(data, output_file): + """ + Saves the detection data to a JSON file. + + Args: + data (dict): The complete detection data structure. + output_file (str): Path to the output JSON file. + """ + try: + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + with open(output_file, "w") as f: + json.dump(data, f, indent=4) + print(f"Detection data saved to {output_file}") + return True + except Exception as e: + print(f"Error saving data: {str(e)}") + return False + + +def load_detection_data(input_file): + """ + Loads the detection data from a JSON file. + + Args: + input_file (str): Path to the JSON file. + + Returns: + dict: The loaded detection data, or None if there was an error. + """ + try: + with open(input_file, "r") as f: + return json.load(f) + except Exception as e: + print(f"Error loading data: {str(e)}") + return None diff --git a/recipes/promptable-content-moderation/requirements.txt b/recipes/promptable-content-moderation/requirements.txt new file mode 100644 index 00000000..d9027ee9 --- /dev/null +++ b/recipes/promptable-content-moderation/requirements.txt @@ -0,0 +1,26 @@ +gradio>=4.0.0 +torch>=2.0.0 +# if on windows: pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121 +transformers>=4.36.0 +opencv-python>=4.8.0 +pillow>=10.0.0 +numpy>=1.24.0 +tqdm>=4.66.0 +ffmpeg-python +einops +pyvips-binary +pyvips +accelerate +# for spaces +--extra-index-url https://download.pytorch.org/whl/cu113 +spaces +# SAM dependencies +torchvision>=0.20.1 +matplotlib>=3.7.0 +pandas>=2.0.0 +plotly +# DeepSORT dependencies +deep-sort-realtime>=1.3.2 +scikit-learn # Required for deep-sort-realtime +# Scene detection dependencies (for intelligent scene-aware tracking) +scenedetect[opencv]>=0.6.2 # Provides scene change detection capabilities \ No newline at end of file diff --git a/recipes/promptable-content-moderation/video_visualization.py b/recipes/promptable-content-moderation/video_visualization.py new file mode 100644 index 00000000..e4e09cc8 --- /dev/null +++ b/recipes/promptable-content-moderation/video_visualization.py @@ -0,0 +1,398 @@ +import os +import tempfile +import subprocess +import matplotlib.pyplot as plt +import pandas as pd +import cv2 +import numpy as np +from tqdm import tqdm +from persistence import load_detection_data + + +def create_frame_data(json_path): + """Create frame-by-frame detection data for visualization.""" + try: + data = load_detection_data(json_path) + if not data: + print("No data loaded from JSON file") + return None + + if "video_metadata" not in data or "frame_detections" not in data: + print("Invalid JSON structure: missing required fields") + return None + + # Extract video metadata + metadata = data["video_metadata"] + if "fps" not in metadata or "total_frames" not in metadata: + print("Invalid metadata: missing fps or total_frames") + return None + + fps = metadata["fps"] + total_frames = metadata["total_frames"] + + # Create frame data + frame_counts = {} + for frame_data in data["frame_detections"]: + if "frame" not in frame_data or "objects" not in frame_data: + continue # Skip invalid frame data + frame_num = frame_data["frame"] + frame_counts[frame_num] = len(frame_data["objects"]) + + # Fill in missing frames with 0 detections + for frame in range(total_frames): + if frame not in frame_counts: + frame_counts[frame] = 0 + + if not frame_counts: + print("No valid frame data found") + return None + + # Convert to DataFrame + df = pd.DataFrame(list(frame_counts.items()), columns=["frame", "detections"]) + df["timestamp"] = df["frame"] / fps + + return df, metadata + + except Exception as e: + print(f"Error creating frame data: {str(e)}") + import traceback + + traceback.print_exc() + return None + + +def generate_frame_image(df, frame_num, temp_dir, max_y): + """Generate and save a single frame of the visualization.""" + # Set the style to dark background + plt.style.use("dark_background") + + # Set global font to monospace + plt.rcParams["font.family"] = "monospace" + plt.rcParams["font.monospace"] = ["DejaVu Sans Mono"] + + plt.figure(figsize=(10, 6)) + + # Plot data up to current frame + current_data = df[df["frame"] <= frame_num] + plt.plot( + df["frame"], df["detections"], color="#1a1a1a", alpha=0.5 + ) # Darker background line + plt.plot( + current_data["frame"], current_data["detections"], color="#00ff41" + ) # Matrix green + + # Add vertical line for current position + plt.axvline( + x=frame_num, color="#ff0000", linestyle="-", alpha=0.7 + ) # Keep red for position + + # Set consistent axes + plt.xlim(0, len(df) - 1) + plt.ylim(0, max_y * 1.1) # Add 10% padding + + # Add labels with Matrix green color + plt.title(f"FRAME {frame_num:04d} - DETECTIONS OVER TIME", color="#00ff41", pad=20) + plt.xlabel("FRAME NUMBER", color="#00ff41") + plt.ylabel("NUMBER OF DETECTIONS", color="#00ff41") + + # Add current stats in Matrix green with monospace formatting + current_detections = df[df["frame"] == frame_num]["detections"].iloc[0] + plt.text( + 0.02, + 0.98, + f"CURRENT DETECTIONS: {current_detections:02d}", + transform=plt.gca().transAxes, + verticalalignment="top", + color="#00ff41", + family="monospace", + ) + + # Style the grid and ticks + plt.grid(True, color="#1a1a1a", linestyle="-", alpha=0.3) + plt.tick_params(colors="#00ff41") + + # Save frame + frame_path = os.path.join(temp_dir, f"frame_{frame_num:05d}.png") + plt.savefig( + frame_path, bbox_inches="tight", dpi=100, facecolor="black", edgecolor="none" + ) + plt.close() + + return frame_path + + +def generate_gauge_frame(df, frame_num, temp_dir, detect_keyword="OBJECT"): + """Generate a modern square-style binary gauge visualization frame.""" + # Set the style to dark background + plt.style.use("dark_background") + + # Set global font to monospace + plt.rcParams["font.family"] = "monospace" + plt.rcParams["font.monospace"] = ["DejaVu Sans Mono"] + + # Create figure with 16:9 aspect ratio + plt.figure(figsize=(16, 9)) + + # Get current detection state + current_detections = df[df["frame"] == frame_num]["detections"].iloc[0] + has_detection = current_detections > 0 + + # Create a simple gauge visualization + plt.axis("off") + + # Set colors + if has_detection: + color = "#00ff41" # Matrix green for YES + status = "YES" + indicator_pos = 0.8 # Right position + else: + color = "#ff0000" # Red for NO + status = "NO" + indicator_pos = 0.2 # Left position + + # Draw background rectangle + background = plt.Rectangle( + (0.1, 0.3), 0.8, 0.2, facecolor="#1a1a1a", edgecolor="#333333", linewidth=2 + ) + plt.gca().add_patch(background) + + # Draw indicator + indicator_width = 0.05 + indicator = plt.Rectangle( + (indicator_pos - indicator_width / 2, 0.25), + indicator_width, + 0.3, + facecolor=color, + edgecolor=None, + ) + plt.gca().add_patch(indicator) + + # Add tick marks + tick_positions = [0.2, 0.5, 0.8] # NO, CENTER, YES + for x in tick_positions: + plt.plot([x, x], [0.3, 0.5], color="#444444", linewidth=2) + + # Add YES/NO labels + plt.text( + 0.8, + 0.2, + "YES", + color="#00ff41", + fontsize=14, + ha="center", + va="center", + family="monospace", + ) + plt.text( + 0.2, + 0.2, + "NO", + color="#ff0000", + fontsize=14, + ha="center", + va="center", + family="monospace", + ) + + # Add status box at top with detection keyword + plt.text( + 0.5, + 0.8, + f"{detect_keyword.upper()} DETECTED?", + color=color, + fontsize=16, + ha="center", + va="center", + family="monospace", + bbox=dict(facecolor="#1a1a1a", edgecolor=color, linewidth=2, pad=10), + ) + + # Add frame counter at bottom + plt.text( + 0.5, + 0.1, + f"FRAME: {frame_num:04d}", + color="#00ff41", + fontsize=14, + ha="center", + va="center", + family="monospace", + ) + + # Add subtle grid lines for depth + for x in np.linspace(0.2, 0.8, 7): + plt.plot([x, x], [0.3, 0.5], color="#222222", linewidth=1, zorder=0) + + # Add glow effect to indicator + for i in range(3): + glow = plt.Rectangle( + (indicator_pos - (indicator_width + i * 0.01) / 2, 0.25 - i * 0.01), + indicator_width + i * 0.01, + 0.3 + i * 0.02, + facecolor=color, + alpha=0.1 / (i + 1), + ) + plt.gca().add_patch(glow) + + # Set consistent plot limits + plt.xlim(0, 1) + plt.ylim(0, 1) + + # Save frame with 16:9 aspect ratio + frame_path = os.path.join(temp_dir, f"gauge_{frame_num:05d}.png") + plt.savefig( + frame_path, + bbox_inches="tight", + dpi=100, + facecolor="black", + edgecolor="none", + pad_inches=0, + ) + plt.close() + + return frame_path + + +def create_video_visualization(json_path, style="timeline"): + """Create a video visualization of the detection data.""" + try: + if not json_path: + return None, "No JSON file provided" + + if not os.path.exists(json_path): + return None, f"File not found: {json_path}" + + # Load and process data + result = create_frame_data(json_path) + if result is None: + return None, "Failed to load detection data from JSON file" + + frame_data, metadata = result + if len(frame_data) == 0: + return None, "No frame data found in JSON file" + + total_frames = metadata["total_frames"] + detect_keyword = metadata.get( + "detect_keyword", "OBJECT" + ) # Get the detection keyword + + # Create temporary directory for frames + with tempfile.TemporaryDirectory() as temp_dir: + max_y = frame_data["detections"].max() + + # Generate each frame + print("Generating frames...") + frame_paths = [] + with tqdm(total=total_frames, desc="Generating frames") as pbar: + for frame in range(total_frames): + try: + if style == "gauge": + frame_path = generate_gauge_frame( + frame_data, frame, temp_dir, detect_keyword + ) + else: # default to timeline + frame_path = generate_frame_image( + frame_data, frame, temp_dir, max_y + ) + if frame_path and os.path.exists(frame_path): + frame_paths.append(frame_path) + else: + print(f"Warning: Failed to generate frame {frame}") + pbar.update(1) + except Exception as e: + print(f"Error generating frame {frame}: {str(e)}") + continue + + if not frame_paths: + return None, "Failed to generate any frames" + + # Create output video path + output_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "outputs" + ) + os.makedirs(output_dir, exist_ok=True) + output_video = os.path.join( + output_dir, f"detection_visualization_{style}.mp4" + ) + + # Create temp output path + base, ext = os.path.splitext(output_video) + temp_output = f"{base}_temp{ext}" + + # First pass: Create video with OpenCV VideoWriter + print("Creating initial video...") + # Get frame size from first image + first_frame = cv2.imread(frame_paths[0]) + height, width = first_frame.shape[:2] + + out = cv2.VideoWriter( + temp_output, + cv2.VideoWriter_fourcc(*"mp4v"), + metadata["fps"], + (width, height), + ) + + with tqdm( + total=total_frames, desc="Creating video" + ) as pbar: # Use total_frames here too + for frame_path in frame_paths: + frame = cv2.imread(frame_path) + out.write(frame) + pbar.update(1) + + out.release() + + # Second pass: Convert to web-compatible format + print("Converting to web format...") + try: + subprocess.run( + [ + "ffmpeg", + "-y", + "-i", + temp_output, + "-c:v", + "libx264", + "-preset", + "medium", + "-crf", + "23", + "-movflags", + "+faststart", # Better web playback + "-loglevel", + "error", + output_video, + ], + check=True, + ) + + os.remove(temp_output) # Remove the temporary file + + if not os.path.exists(output_video): + print( + f"Warning: FFmpeg completed but output file not found at {output_video}" + ) + return None, "Failed to create video" + + # Return video path and stats + stats = f"""Video Stats: +FPS: {metadata['fps']} +Total Frames: {metadata['total_frames']} +Duration: {metadata['duration_sec']:.2f} seconds +Max Detections in a Frame: {frame_data['detections'].max()} +Average Detections per Frame: {frame_data['detections'].mean():.2f}""" + + return output_video, stats + + except subprocess.CalledProcessError as e: + print(f"Error running FFmpeg: {str(e)}") + if os.path.exists(temp_output): + os.remove(temp_output) + return None, f"Error creating visualization: {str(e)}" + + except Exception as e: + print(f"Error creating video visualization: {str(e)}") + import traceback + + traceback.print_exc() + return None, f"Error creating visualization: {str(e)}" diff --git a/recipes/promptable-content-moderation/visualization.py b/recipes/promptable-content-moderation/visualization.py new file mode 100644 index 00000000..82695f02 --- /dev/null +++ b/recipes/promptable-content-moderation/visualization.py @@ -0,0 +1,108 @@ +import pandas as pd +import matplotlib.pyplot as plt +from persistence import load_detection_data +import argparse + + +def visualize_detections(json_path): + """ + Visualize detection data from a JSON file. + + Args: + json_path (str): Path to the JSON file containing detection data. + """ + # Load the persisted JSON data + data = load_detection_data(json_path) + if not data: + return + + # Convert the frame detections to a DataFrame + rows = [] + for frame_data in data["frame_detections"]: + frame = frame_data["frame"] + timestamp = frame_data["timestamp"] + for obj in frame_data["objects"]: + rows.append( + { + "frame": frame, + "timestamp": timestamp, + "keyword": obj["keyword"], + "x1": obj["bbox"][0], + "y1": obj["bbox"][1], + "x2": obj["bbox"][2], + "y2": obj["bbox"][3], + "area": (obj["bbox"][2] - obj["bbox"][0]) + * (obj["bbox"][3] - obj["bbox"][1]), + } + ) + + if not rows: + print("No detections found in the data") + return + + df = pd.DataFrame(rows) + + # Create a figure with multiple subplots + fig = plt.figure(figsize=(15, 10)) + + # Plot 1: Number of detections per frame + plt.subplot(2, 2, 1) + detections_per_frame = df.groupby("frame").size() + plt.plot(detections_per_frame.index, detections_per_frame.values) + plt.xlabel("Frame") + plt.ylabel("Number of Detections") + plt.title("Detections Per Frame") + + # Plot 2: Distribution of detection areas + plt.subplot(2, 2, 2) + df["area"].hist(bins=30) + plt.xlabel("Detection Area (normalized)") + plt.ylabel("Count") + plt.title("Distribution of Detection Areas") + + # Plot 3: Average detection area over time + plt.subplot(2, 2, 3) + avg_area = df.groupby("frame")["area"].mean() + plt.plot(avg_area.index, avg_area.values) + plt.xlabel("Frame") + plt.ylabel("Average Detection Area") + plt.title("Average Detection Area Over Time") + + # Plot 4: Heatmap of detection centers + plt.subplot(2, 2, 4) + df["center_x"] = (df["x1"] + df["x2"]) / 2 + df["center_y"] = (df["y1"] + df["y2"]) / 2 + plt.hist2d(df["center_x"], df["center_y"], bins=30) + plt.colorbar() + plt.xlabel("X Position") + plt.ylabel("Y Position") + plt.title("Detection Center Heatmap") + + # Adjust layout and display + plt.tight_layout() + plt.show() + + # Print summary statistics + print("\nSummary Statistics:") + print(f"Total frames analyzed: {len(data['frame_detections'])}") + print(f"Total detections: {len(df)}") + print( + f"Average detections per frame: {len(df) / len(data['frame_detections']):.2f}" + ) + print(f"\nVideo metadata:") + for key, value in data["video_metadata"].items(): + print(f"{key}: {value}") + + +def main(): + parser = argparse.ArgumentParser(description="Visualize object detection data") + parser.add_argument( + "json_file", help="Path to the JSON file containing detection data" + ) + args = parser.parse_args() + + visualize_detections(args.json_file) + + +if __name__ == "__main__": + main()