-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from akatz-ai/alex-dev
Updated with flex feature convert nodes, as well as fade between batc…
- Loading branch information
Showing
6 changed files
with
249 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
|
||
class AK_FlexFeatureToFloatList: | ||
def __init__(self): | ||
pass | ||
|
||
@classmethod | ||
def INPUT_TYPES(s): | ||
return { | ||
"required": { | ||
"feature": ("FEATURE", {"forceInput": True}), # 'forceInput' used for custom types like FEATURE | ||
}, | ||
} | ||
|
||
RETURN_TYPES = ("FLOAT",) # Define the output as a list of floats | ||
FUNCTION = "convert_feature_to_float_list" | ||
CATEGORY = "💜Akatz Nodes/Utils" | ||
|
||
DESCRIPTION = """ | ||
AK_FlexFeatureToFloatList: | ||
This node converts a FEATURE type input into a list of float values, one per frame. | ||
- feature: A custom Feature object with attributes: | ||
- feature.get_value_at_frame(i): Method that returns a float value for frame `i`. | ||
- feature.frame_count: The total number of frames in the feature. | ||
""" | ||
|
||
def convert_feature_to_float_list(self, feature): | ||
""" | ||
Convert the values from a Feature object to a list of floats. | ||
Args: | ||
- feature: The Feature object which contains frame values. | ||
Returns: | ||
- tuple: A tuple containing a list of floats, where each float is the value | ||
extracted from the feature for each frame. | ||
""" | ||
# Initialize an empty list to store float values | ||
float_list = [] | ||
|
||
# Iterate over each frame in the feature and extract the normalized value | ||
for i in range(feature.frame_count): | ||
value = feature.get_value_at_frame(i) | ||
float_list.append(value) | ||
|
||
# Return the list as a tuple (since ComfyUI expects tuples) | ||
return (float_list,) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
|
||
class AK_FloatListToFlexFeature: | ||
def __init__(self): | ||
pass | ||
|
||
@classmethod | ||
def INPUT_TYPES(s): | ||
return { | ||
"required": { | ||
"float_list": ("FLOAT", {"forceInput": True}), # 'forceInput' used for custom types like FEATURE | ||
"original_feature": ("FEATURE", {"forceInput": True}), # The original FEATURE object to combine with the float list for the output | ||
}, | ||
} | ||
|
||
RETURN_TYPES = ("FEATURE",) # Define the output as a FEATURE object | ||
FUNCTION = "convert_float_list_to_feature" | ||
CATEGORY = "💜Akatz Nodes/Utils" | ||
|
||
DESCRIPTION = """ | ||
AK_FloatListToFlexFeature: | ||
This node converts a list of float values into a FEATURE type input. | ||
- original_feature: A custom Feature object with attributes: | ||
- original_feature.get_value_at_frame(i): Method that returns a float value for frame `i`. | ||
- original_feature.frame_count: The total number of frames in the feature. | ||
""" | ||
|
||
def convert_float_list_to_feature(self, float_list, original_feature): | ||
""" | ||
Convert the values from a list of floats to a FEATURE object. | ||
Args: | ||
- float_list: A list of floats. | ||
- original_feature: The Feature object which contains frame values. | ||
Returns: | ||
- original_feature: A FEATURE object with the same frame values as the float list. | ||
""" | ||
|
||
# Convert the float list to a numpy array | ||
array = np.array(float_list) | ||
|
||
# Get frame count from the number of elements in the array | ||
frame_count = array.shape[0] | ||
|
||
original_feature.frame_count = frame_count | ||
original_feature.data = array | ||
|
||
# Return the feature as a tuple (since ComfyUI expects tuples) | ||
return (original_feature,) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
class AK_FadeBetweenBatches: | ||
@classmethod | ||
def INPUT_TYPES(cls): | ||
return { | ||
"required": { | ||
"image1": ("IMAGE",), # First image batch | ||
"image2": ("IMAGE",), # Second image batch | ||
"overlap_frames": ("INT", {"default": 10, "min": 1}), # Number of frames to overlap/fade | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("IMAGE",) | ||
FUNCTION = "fade_batches" | ||
CATEGORY = "💜Akatz Nodes/Utils" | ||
DESCRIPTION = """ | ||
# AK Fade Between Batches | ||
This node takes two image batches and blends them together by transitioning | ||
with overlapping frames. The output is a single image batch where: | ||
- image1 starts, | ||
- a number of overlap frames fade between image1 and image2, | ||
- image2 finishes. | ||
""" | ||
|
||
def fade_batches(self, image1, image2, overlap_frames): | ||
# Check if image1 or image2 is None or empty | ||
if image1 is None or image1.numel() == 0: | ||
if image2 is None or image2.numel() == 0: | ||
raise ValueError("Both image1 and image2 are None or empty.") | ||
return (image2,) | ||
|
||
if image2 is None or image2.numel() == 0: | ||
return (image1,) | ||
|
||
# Ensure image2 has the same height, width, and channels as image1 | ||
if image1.shape[1:] != image2.shape[1:]: | ||
# Resize image2 to match image1 dimensions (using bilinear interpolation) | ||
image2 = F.interpolate(image2.permute(0, 3, 1, 2), size=image1.shape[1:3], mode='bilinear', align_corners=False) | ||
image2 = image2.permute(0, 2, 3, 1) # Re-permute to get back to [B, H, W, C] | ||
|
||
batch_size1 = image1.shape[0] | ||
batch_size2 = image2.shape[0] | ||
|
||
# Calculate the number of frames outside of the overlap | ||
non_overlap1 = batch_size1 - overlap_frames | ||
non_overlap2 = batch_size2 - overlap_frames | ||
|
||
if non_overlap1 < 0 or non_overlap2 < 0: | ||
raise ValueError("Overlap frames exceed the batch sizes of the input images.") | ||
|
||
# Create the transition frames by fading between image1 and image2 | ||
transition_frames = [] | ||
for i in range(overlap_frames): | ||
alpha_step = 1 / ((overlap_frames + 2) - 1) # Alpha will go from 0 to 1 | ||
alpha = alpha_step * (i + 1) | ||
fade_frame = (1 - alpha) * image1[non_overlap1 + i] + alpha * image2[i] | ||
transition_frames.append(fade_frame) | ||
|
||
# Stack the frames into a single tensor (batch) | ||
result_batch = torch.cat( | ||
(image1[:non_overlap1], # First part from image1 | ||
torch.stack(transition_frames), # Transition frames | ||
image2[overlap_frames:]), # Remaining frames from image2 | ||
dim=0 | ||
) | ||
return (result_batch,) | ||
|
||
# image1 = inputs['0_image'] | ||
# image2 = inputs['1_image'] | ||
# overlap_frames = inputs['2_int'] | ||
|
||
# instance = AK_FadeBetweenBatches() | ||
# outputs[0] = instance.fade_batches(image1, image2, overlap_frames)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
|
||
class AK_SplitImageBatch: | ||
@classmethod | ||
def INPUT_TYPES(cls): | ||
return { | ||
"required": { | ||
"image_batch": ("IMAGE",), # The input image batch to be split | ||
"split_index": ("INT", {"default": 1, "min": 1}), # The index at which to split the batch | ||
"split_batch_index": ("INT", {"default": 0, "min": 0, "max": 1}), # Whether to return the first or second batch | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("IMAGE",) | ||
FUNCTION = "split_image_batch" | ||
CATEGORY = "💜Akatz Nodes/Utils" | ||
DESCRIPTION = """ | ||
# AK Split Image Batch | ||
This node splits an input image batch into two at a given index, and returns one of the partitions. | ||
- split_index: The index at which to split the image batch. | ||
- split_batch_index: 0 to return the first partition, 1 to return the second partition. | ||
""" | ||
|
||
def split_image_batch(self, image_batch, split_index, split_batch_index): | ||
# Validate the split index to be within range | ||
batch_size = image_batch.shape[0] | ||
if split_index < 1 or split_index >= batch_size: | ||
raise ValueError(f"split_index must be between 1 and {batch_size - 1}.") | ||
|
||
# Split the image batch into two parts | ||
batch1 = image_batch[:split_index] # First part up to split_index | ||
batch2 = image_batch[split_index:] # Second part from split_index to the end | ||
|
||
# Select which batch to return based on split_batch_index (0 or 1) | ||
if split_batch_index == 0: | ||
return (batch1,) | ||
elif split_batch_index == 1: | ||
return (batch2,) | ||
else: | ||
raise ValueError("split_batch_index must be either 0 or 1.") |