diff --git a/.gitignore b/.gitignore index e2da6af..efd697e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ /models/*.bin /models/*.safetensors .directory -src/__pycache__/ \ No newline at end of file +modules/__pycache__/ +src/__pycache__/ diff --git a/modules/__pycache__/easing.cpython-310.pyc b/modules/__pycache__/easing.cpython-310.pyc index b45baa8..8544553 100644 Binary files a/modules/__pycache__/easing.cpython-310.pyc and b/modules/__pycache__/easing.cpython-310.pyc differ diff --git a/src/ak_audioreactive_dynamic_dilation_mask.py b/src/ak_audioreactive_dynamic_dilation_mask.py index e5cffa3..42a48b5 100644 --- a/src/ak_audioreactive_dynamic_dilation_mask.py +++ b/src/ak_audioreactive_dynamic_dilation_mask.py @@ -8,7 +8,7 @@ class AK_AudioreactiveDynamicDilationMask: def __init__(self): - pass + pass @classmethod def INPUT_TYPES(s): @@ -23,6 +23,13 @@ def INPUT_TYPES(s): "min_radius": ("INT",{ "default": 0 }), + "quality_factor": ("FLOAT", { + "default": 0.25, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "display": "number", + }), }, } @@ -44,6 +51,7 @@ def VALIDATE_INPUTS(cls, input_types): - shape: The shape of the dilation - max_radius: The maximum radius of the dilation - min_radius: The minimum radius of the dilation + - quality_factor: The quality factor of the dilation """ def create_circular_kernel(self, radius): @@ -54,30 +62,40 @@ def create_circular_kernel(self, radius): kernel[mask] = 1 return kernel - def dilate_mask_with_amplitude(self, mask, normalized_amp, shape="circle", max_radius=25, min_radius=0): + def dilate_mask_with_amplitude(self, mask, normalized_amp, shape="circle", max_radius=25, min_radius=0, quality_factor=0.25): dup = copy.deepcopy(mask.cpu().numpy()) # Convert normalize_amp into a float list from numpy array if it is not already a list if not isinstance(normalized_amp, list): normalized_amp = normalized_amp.tolist() - # Pre-compute circular kernels if shape is "circle" - if shape == "circle": - circular_kernels = [self.create_circular_kernel(r) for r in range(max_radius+1)] + epsilon = 1e-6 + if quality_factor < epsilon: + shape = "square" - for index, (mask, amp) in enumerate(zip(dup, normalized_amp)): + for index, (mask_frame, amp) in enumerate(zip(dup, normalized_amp)): # Scale the amplitude to fluctuate between min_radius and max_radius - current_radius = min_radius + amp * (max_radius - min_radius) + radius = min_radius + amp * (max_radius - min_radius) - if current_radius <= 0: + if radius <= 0: continue + s = abs(int(radius * quality_factor if shape == "circle" else radius)) + d = s * 2 + 1 + if shape == "circle": - k = circular_kernels[int(current_radius)] + k = np.zeros((d, d), np.uint8) + k = cv2.circle(k, (s, s), s, 1, -1) else: - d = 2 * int(current_radius) + 1 k = np.ones((d, d), np.uint8) - dup[index] = cv2.dilate(mask, k, iterations=1) + iterations = int(1 / quality_factor if quality_factor >= epsilon else 1) + + if radius > 0: + dilated_mask = cv2.dilate(mask_frame, k, iterations=iterations) + else: + dilated_mask = cv2.erode(mask_frame, k, iterations=iterations) + + dup[index] = dilated_mask return (torch.from_numpy(dup),) \ No newline at end of file