Skip to content

Commit

Permalink
black format
Browse files Browse the repository at this point in the history
  • Loading branch information
parsakhaz committed Feb 20, 2025
1 parent a7fb0e9 commit 9ef4b05
Show file tree
Hide file tree
Showing 6 changed files with 581 additions and 311 deletions.
199 changes: 149 additions & 50 deletions recipes/promptable-content-moderation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@
# Initialize Moondream model globally for reuse (will be loaded on first use)
model, tokenizer = None, None


# Uncomment for Hugging Face Spaces
# @spaces.GPU(duration=120)
def process_video_file(
video_file, target_object, box_style, ffmpeg_preset, grid_rows, grid_cols, test_mode, test_duration
video_file,
target_object,
box_style,
ffmpeg_preset,
grid_rows,
grid_cols,
test_mode,
test_duration,
):
"""Process a video file through the Gradio interface."""
try:
Expand Down Expand Up @@ -67,7 +75,9 @@ def process_video_file(

# Get the corresponding JSON path
base_name = os.path.splitext(os.path.basename(video_filename))[0]
json_path = os.path.join(outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json")
json_path = os.path.join(
outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json"
)

# Verify output exists and is readable
if not output_path or not os.path.exists(output_path):
Expand Down Expand Up @@ -109,6 +119,7 @@ def process_video_file(
print(f"Error in process_video_file: {str(e)}")
raise gr.Error(f"Error processing video: {str(e)}")


def create_visualization_plots(json_path):
"""Create visualization plots and return them as images."""
try:
Expand All @@ -123,52 +134,65 @@ def create_visualization_plots(json_path):
frame = frame_data["frame"]
timestamp = frame_data["timestamp"]
for obj in frame_data["objects"]:
rows.append({
"frame": frame,
"timestamp": timestamp,
"keyword": obj["keyword"],
"x1": obj["bbox"][0],
"y1": obj["bbox"][1],
"x2": obj["bbox"][2],
"y2": obj["bbox"][3],
"area": (obj["bbox"][2] - obj["bbox"][0]) * (obj["bbox"][3] - obj["bbox"][1]),
"center_x": (obj["bbox"][0] + obj["bbox"][2]) / 2,
"center_y": (obj["bbox"][1] + obj["bbox"][3]) / 2
})
rows.append(
{
"frame": frame,
"timestamp": timestamp,
"keyword": obj["keyword"],
"x1": obj["bbox"][0],
"y1": obj["bbox"][1],
"x2": obj["bbox"][2],
"y2": obj["bbox"][3],
"area": (obj["bbox"][2] - obj["bbox"][0])
* (obj["bbox"][3] - obj["bbox"][1]),
"center_x": (obj["bbox"][0] + obj["bbox"][2]) / 2,
"center_y": (obj["bbox"][1] + obj["bbox"][3]) / 2,
}
)

if not rows:
return None, None, None, None, None, None, None, None, "No detections found in the data"
return (
None,
None,
None,
None,
None,
None,
None,
None,
"No detections found in the data",
)

df = pd.DataFrame(rows)
plots = []

# Create each plot and convert to image
for plot_num in range(8): # Increased to 8 plots
plt.figure(figsize=(8, 6))

if plot_num == 0:
# Plot 1: Number of detections per frame (Original)
detections_per_frame = df.groupby("frame").size()
plt.plot(detections_per_frame.index, detections_per_frame.values)
plt.xlabel("Frame")
plt.ylabel("Number of Detections")
plt.title("Detections Per Frame")

elif plot_num == 1:
# Plot 2: Distribution of detection areas (Original)
df["area"].hist(bins=30)
plt.xlabel("Detection Area (normalized)")
plt.ylabel("Count")
plt.title("Distribution of Detection Areas")

elif plot_num == 2:
# Plot 3: Average detection area over time (Original)
avg_area = df.groupby("frame")["area"].mean()
plt.plot(avg_area.index, avg_area.values)
plt.xlabel("Frame")
plt.ylabel("Average Detection Area")
plt.title("Average Detection Area Over Time")

elif plot_num == 3:
# Plot 4: Heatmap of detection centers (Original)
plt.hist2d(df["center_x"], df["center_y"], bins=30)
Expand All @@ -191,20 +215,39 @@ def create_visualization_plots(json_path):
# Plot 6: Screen Region Analysis
# Divide screen into 3x3 grid and show detection counts
try:
df["grid_x"] = pd.qcut(df["center_x"], q=3, labels=["Left", "Center", "Right"], duplicates='drop')
df["grid_y"] = pd.qcut(df["center_y"], q=3, labels=["Top", "Middle", "Bottom"], duplicates='drop')
region_counts = df.groupby(["grid_y", "grid_x"]).size().unstack(fill_value=0)
df["grid_x"] = pd.qcut(
df["center_x"],
q=3,
labels=["Left", "Center", "Right"],
duplicates="drop",
)
df["grid_y"] = pd.qcut(
df["center_y"],
q=3,
labels=["Top", "Middle", "Bottom"],
duplicates="drop",
)
region_counts = (
df.groupby(["grid_y", "grid_x"]).size().unstack(fill_value=0)
)
plt.imshow(region_counts, cmap="YlOrRd")
plt.colorbar(label="Detection Count")
for i in range(3):
for j in range(3):
plt.text(j, i, region_counts.iloc[i, j], ha="center", va="center")
plt.text(
j, i, region_counts.iloc[i, j], ha="center", va="center"
)
plt.xticks(range(3), ["Left", "Center", "Right"])
plt.yticks(range(3), ["Top", "Middle", "Bottom"])
plt.title("Screen Region Analysis")
except Exception as e:
plt.text(0.5, 0.5, "Insufficient variation in detection positions",
ha='center', va='center')
plt.text(
0.5,
0.5,
"Insufficient variation in detection positions",
ha="center",
va="center",
)
plt.title("Screen Region Analysis (Not Available)")

elif plot_num == 6:
Expand All @@ -215,25 +258,34 @@ def create_visualization_plots(json_path):
"Small (likely far/background)",
"Medium-small",
"Medium-large",
"Large (likely foreground/close)"
"Large (likely foreground/close)",
]

# Handle cases with limited unique values
unique_areas = df["area"].nunique()
if unique_areas >= 4:
df["size_category"] = pd.qcut(df["area"], q=4, labels=size_labels, duplicates='drop')
df["size_category"] = pd.qcut(
df["area"], q=4, labels=size_labels, duplicates="drop"
)
else:
# Alternative binning for limited unique values
df["size_category"] = pd.cut(df["area"],
bins=unique_areas,
labels=size_labels[:unique_areas])

df["size_category"] = pd.cut(
df["area"],
bins=unique_areas,
labels=size_labels[:unique_areas],
)

size_dist = df["size_category"].value_counts()
plt.pie(size_dist.values, labels=size_dist.index, autopct="%1.1f%%")
plt.title("Detection Size Distribution")
except Exception as e:
plt.text(0.5, 0.5, "Insufficient variation in detection sizes",
ha='center', va='center')
plt.text(
0.5,
0.5,
"Insufficient variation in detection sizes",
ha="center",
va="center",
)
plt.title("Detection Size Distribution (Not Available)")

elif plot_num == 7:
Expand All @@ -242,21 +294,32 @@ def create_visualization_plots(json_path):
try:
detection_gaps = df.sort_values("frame")["frame"].diff()
if len(detection_gaps.dropna().unique()) > 1:
plt.hist(detection_gaps.dropna(), bins=min(30, len(detection_gaps.dropna().unique())),
edgecolor="black")
plt.hist(
detection_gaps.dropna(),
bins=min(30, len(detection_gaps.dropna().unique())),
edgecolor="black",
)
plt.xlabel("Frames Between Detections")
plt.ylabel("Frequency")
plt.title("Detection Temporal Pattern Analysis")
else:
plt.text(0.5, 0.5, "Uniform detection intervals", ha='center', va='center')
plt.text(
0.5,
0.5,
"Uniform detection intervals",
ha="center",
va="center",
)
plt.title("Temporal Pattern Analysis (Uniform)")
except Exception as e:
plt.text(0.5, 0.5, "Insufficient temporal data", ha='center', va='center')
plt.text(
0.5, 0.5, "Insufficient temporal data", ha="center", va="center"
)
plt.title("Temporal Pattern Analysis (Not Available)")

# Save plot to bytes
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
plots.append(Image.open(buf))
plt.close()
Expand All @@ -278,13 +341,35 @@ def create_visualization_plots(json_path):
for key, value in data["video_metadata"].items():
summary += f"{key}: {value}\n"

return plots[0], plots[1], plots[2], plots[3], plots[4], plots[5], plots[6], plots[7], summary
return (
plots[0],
plots[1],
plots[2],
plots[3],
plots[4],
plots[5],
plots[6],
plots[7],
summary,
)

except Exception as e:
print(f"Error creating visualization: {str(e)}")
import traceback

traceback.print_exc()
return None, None, None, None, None, None, None, None, f"Error creating visualization: {str(e)}"
return (
None,
None,
None,
None,
None,
None,
None,
None,
f"Error creating visualization: {str(e)}",
)


# Create the Gradio interface
with gr.Blocks(title="Promptable Content Moderation") as app:
Expand Down Expand Up @@ -315,7 +400,17 @@ def create_visualization_plots(json_path):

with gr.Accordion("Advanced Settings", open=False):
box_style_input = gr.Radio(
choices=["censor", "bounding-box", "hitmarker", "sam", "sam-fast", "fuzzy-blur", "pixelated-blur", "intense-pixelated-blur", "obfuscated-pixel"],
choices=[
"censor",
"bounding-box",
"hitmarker",
"sam",
"sam-fast",
"fuzzy-blur",
"pixelated-blur",
"intense-pixelated-blur",
"obfuscated-pixel",
],
value="obfuscated-pixel",
label="Visualization Style",
info="Choose how to display moderations: censor (black boxes), bounding-box (red boxes with labels), hitmarker (COD-style markers), sam (precise segmentation), sam-fast (faster but less precise segmentation), fuzzy-blur (Gaussian blur), pixelated-blur (pixelated with blur), obfuscated-pixel (advanced pixelation with neighborhood averaging)",
Expand All @@ -340,9 +435,13 @@ def create_visualization_plots(json_path):
minimum=1, maximum=4, value=1, step=1, label="Grid Rows"
)
cols_input = gr.Slider(
minimum=1, maximum=4, value=1, step=1, label="Grid Columns"
minimum=1,
maximum=4,
value=1,
step=1,
label="Grid Columns",
)

test_mode_input = gr.Checkbox(
label="Test Mode (Process first 3 seconds only)",
value=True,
Expand All @@ -355,7 +454,7 @@ def create_visualization_plots(json_path):
value=3,
step=1,
label="Test Mode Duration (seconds)",
info="Number of seconds to process in test mode"
info="Number of seconds to process in test mode",
)

gr.Markdown(
Expand Down Expand Up @@ -401,7 +500,7 @@ def create_visualization_plots(json_path):
- Detection density patterns
"""
)

with gr.Row():
json_input = gr.File(
label="Upload Detection Data (JSON)",
Expand All @@ -423,7 +522,7 @@ def create_visualization_plots(json_path):
plot6 = gr.Image(
label="Screen Region Analysis",
)

with gr.Column():
plot3 = gr.Image(
label="Average Detection Area Over Time",
Expand All @@ -437,13 +536,13 @@ def create_visualization_plots(json_path):
plot8 = gr.Image(
label="Temporal Pattern Analysis",
)

stats_output = gr.Textbox(
label="Statistics",
info="Summary of key metrics and patterns found in the detection data.",
lines=12,
max_lines=15,
interactive=False
interactive=False,
)

# with gr.Tab("Video Visualizations"):
Expand All @@ -455,7 +554,7 @@ def create_visualization_plots(json_path):
# - Gauge: Simple yes/no indicator for current frame detections
# """
# )

# with gr.Row():
# json_input_realtime = gr.File(
# label="Upload Detection Data (JSON)",
Expand Down
Loading

0 comments on commit 9ef4b05

Please sign in to comment.