From 624da2ea835d5b013ff3dada778131ea5afa3411 Mon Sep 17 00:00:00 2001
From: akatz-ai <katzfeyalex+akatzgh@gmail.com>
Date: Tue, 5 Nov 2024 15:49:44 -0800
Subject: [PATCH] Updated audioreactive dynamic dilation mask node with
 quality_factor

---
 .gitignore                                    |   3 +-
 modules/__pycache__/easing.cpython-310.pyc    | Bin 7754 -> 7754 bytes
 src/ak_audioreactive_dynamic_dilation_mask.py |  40 +++++++++++++-----
 3 files changed, 31 insertions(+), 12 deletions(-)

diff --git a/.gitignore b/.gitignore
index e2da6af..efd697e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,4 +2,5 @@
 /models/*.bin
 /models/*.safetensors
 .directory
-src/__pycache__/
\ No newline at end of file
+modules/__pycache__/
+src/__pycache__/
diff --git a/modules/__pycache__/easing.cpython-310.pyc b/modules/__pycache__/easing.cpython-310.pyc
index b45baa8927a9034a374f2adf0046c3d068c56d03..854455341754c7c53daed22e7371b97b1699577c 100644
GIT binary patch
delta 49
zcmX?QbIOJ%pO=@50SMUFX{B%EVHRb)HCaG(Es)$N8Vn?b#OhdXF>2h}JVk69I{;vZ
B52^qF

delta 49
zcmX?QbIOJ%pO=@50SIj0Jxtrk!z{}9YqEgoS|GVkG#E$<iPf?EV$}Gxd5YLJb^wyh
B5q1Co

diff --git a/src/ak_audioreactive_dynamic_dilation_mask.py b/src/ak_audioreactive_dynamic_dilation_mask.py
index e5cffa3..42a48b5 100644
--- a/src/ak_audioreactive_dynamic_dilation_mask.py
+++ b/src/ak_audioreactive_dynamic_dilation_mask.py
@@ -8,7 +8,7 @@
 
 class AK_AudioreactiveDynamicDilationMask:
     def __init__(self):
-            pass
+        pass
 
     @classmethod
     def INPUT_TYPES(s):
@@ -23,6 +23,13 @@ def INPUT_TYPES(s):
                 "min_radius": ("INT",{
                     "default": 0
                 }),
+                "quality_factor": ("FLOAT", {
+                    "default": 0.25,
+                    "min": 0.0,
+                    "max": 1.0,
+                    "step": 0.01,
+                    "display": "number",
+                }),
             },
         }
         
@@ -44,6 +51,7 @@ def VALIDATE_INPUTS(cls, input_types):
     - shape: The shape of the dilation
     - max_radius: The maximum radius of the dilation
     - min_radius: The minimum radius of the dilation
+    - quality_factor: The quality factor of the dilation
     """
     
     def create_circular_kernel(self, radius):
@@ -54,30 +62,40 @@ def create_circular_kernel(self, radius):
         kernel[mask] = 1
         return kernel
 
-    def dilate_mask_with_amplitude(self, mask, normalized_amp, shape="circle", max_radius=25, min_radius=0):
+    def dilate_mask_with_amplitude(self, mask, normalized_amp, shape="circle", max_radius=25, min_radius=0, quality_factor=0.25):
         dup = copy.deepcopy(mask.cpu().numpy())
         
         # Convert normalize_amp into a float list from numpy array if it is not already a list
         if not isinstance(normalized_amp, list):
             normalized_amp = normalized_amp.tolist()
 
-        # Pre-compute circular kernels if shape is "circle"
-        if shape == "circle":
-            circular_kernels = [self.create_circular_kernel(r) for r in range(max_radius+1)]
+        epsilon = 1e-6
+        if quality_factor < epsilon:
+            shape = "square"
 
-        for index, (mask, amp) in enumerate(zip(dup, normalized_amp)):
+        for index, (mask_frame, amp) in enumerate(zip(dup, normalized_amp)):
             # Scale the amplitude to fluctuate between min_radius and max_radius
-            current_radius = min_radius + amp * (max_radius - min_radius)
+            radius = min_radius + amp * (max_radius - min_radius)
 
-            if current_radius <= 0:
+            if radius <= 0:
                 continue
 
+            s = abs(int(radius * quality_factor if shape == "circle" else radius))
+            d = s * 2 + 1
+
             if shape == "circle":
-                k = circular_kernels[int(current_radius)]
+                k = np.zeros((d, d), np.uint8)
+                k = cv2.circle(k, (s, s), s, 1, -1)
             else:
-                d = 2 * int(current_radius) + 1
                 k = np.ones((d, d), np.uint8)
 
-            dup[index] = cv2.dilate(mask, k, iterations=1)
+            iterations = int(1 / quality_factor if quality_factor >= epsilon else 1)
+
+            if radius > 0:
+                dilated_mask = cv2.dilate(mask_frame, k, iterations=iterations)
+            else:
+                dilated_mask = cv2.erode(mask_frame, k, iterations=iterations)
+
+            dup[index] = dilated_mask
         
         return (torch.from_numpy(dup),)
\ No newline at end of file