Skip to content

Commit

Permalink
Updated with audio conversion to and from salt audio format to normal…
Browse files Browse the repository at this point in the history
… audio
  • Loading branch information
akatz-ai committed Aug 22, 2024
1 parent 5a40206 commit ca8d846
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 2 deletions.
6 changes: 5 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from .ak_normalize_mask_color import AK_NormalizeMaskColor
from .ak_audioreactive_dilation_mask import AK_AudioreactiveDilationMask
from .ak_audioreactive_dynamic_dilation_mask import AK_AudioreactiveDynamicDilationMask
from .ak_convert_audio_to_salt_audio import AK_ConvertAudioToSaltAudio
from .ak_convert_salt_audio_to_audio import AK_ConvertSaltAudioToAudio

NODE_CONFIG = {
"AK_AnimatedDilationMaskLinear": {"class": AK_AnimatedDilationMaskLinear, "name": "AK Dilate Mask Linear"},
"AK_IPAdapterCustomWeights": {"class": AK_IPAdapterCustomWeights, "name": "AK IPAdapter Custom Weights"},
"AK_NormalizeMaskImage": {"class": AK_NormalizeMaskColor, "name": "AK Normalize Mask Color"},
"AK_AudioreactiveDilationMask": {"class": AK_AudioreactiveDilationMask, "name": "AK Audioreactive Dilate Mask"},
"AK_AudioreactiveDynamicDilationMask": {"class": AK_AudioreactiveDynamicDilationMask, "name": "AK Audioreactive Dynamic Dilate Mask"}
"AK_AudioreactiveDynamicDilationMask": {"class": AK_AudioreactiveDynamicDilationMask, "name": "AK Audioreactive Dynamic Dilate Mask"},
"AK_ConvertAudioToSaltAudio": {"class": AK_ConvertAudioToSaltAudio, "name": "AK Convert Audio To Salt Audio"},
"AK_ConvertSaltAudioToAudio": {"class": AK_ConvertSaltAudioToAudio, "name": "AK Convert Salt Audio To Audio"}
}

def generate_node_mappings(node_config):
Expand Down
70 changes: 70 additions & 0 deletions ak_convert_audio_to_salt_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

import torch
import io
import wave
import numpy as np
from collections.abc import Mapping

class AK_ConvertAudioToSaltAudio:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"audio": ("AUDIO",),
},
}

CATEGORY = "💜Akatz Nodes"
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "convert_tensor_to_audio_bytes"
DESCRIPTION = """
# Converts a PyTorch tensor representing audio data into raw audio bytes in WAV format.
Parameters:
- audio: LazyAudioMap-like object containing the waveform and sample rate
Returns:
- audio_bytes: Raw audio bytes in WAV format
"""

def convert_tensor_to_audio_bytes(self, audio, num_channels=2):
"""
Converts a PyTorch tensor representing audio data into raw audio bytes in WAV format.
Parameters:
- audio: PyTorch tensor with shape (1, num_channels, num_samples) or (num_channels, num_samples)
- num_channels: Number of audio channels (default: 2)
Returns:
- audio_bytes: Raw audio bytes in WAV format
"""
audio_tensor = audio['waveform']
sample_rate = audio['sample_rate']
# Ensure the tensor is in the correct shape (num_channels, num_samples)
if audio_tensor.dim() == 3:
audio_tensor = audio_tensor.squeeze(0)

# Convert tensor to numpy array with shape (num_samples, num_channels)
audio_np = audio_tensor.transpose(0, 1).numpy()

# Create a byte buffer to write the WAV file into
byte_io = io.BytesIO()

# Write the WAV file
with wave.open(byte_io, 'wb') as wave_file:
wave_file.setnchannels(num_channels)
wave_file.setsampwidth(2) # 2 bytes per sample (16-bit PCM)
wave_file.setframerate(sample_rate)

# Convert the numpy array to 16-bit PCM format
audio_int16 = (audio_np * 32767.0).astype('int16')
wave_file.writeframes(audio_int16.tobytes())

# Get the byte content
audio_bytes = byte_io.getvalue()

return (audio_bytes,)
80 changes: 80 additions & 0 deletions ak_convert_salt_audio_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@

import torch
import io
import wave
import numpy as np
from collections.abc import Mapping

class AK_ConvertSaltAudioToAudio:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"audio": ("AUDIO",),
},
}

CATEGORY = "💜Akatz Nodes"
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "convert_audio_bytes_to_lazy_audio_map"
DESCRIPTION = """
# Converts raw audio bytes (in WAV format) into a LazyAudioMap-like format.
Parameters:
- audio_bytes: Raw audio bytes in WAV format
Returns:
- lazy_audio_map: A LazyAudioMap-like object containing the waveform and sample rate
"""

def convert_audio_bytes_to_lazy_audio_map(self, audio):
"""
Converts raw audio bytes (in WAV format) into a LazyAudioMap-like format.
Parameters:
- audio: Raw audio bytes in WAV format
Returns:
- lazy_audio_map: A LazyAudioMap-like object containing the waveform and sample rate
"""
# Open the audio bytes as a WAV file
byte_io = io.BytesIO(audio)
with wave.open(byte_io, 'rb') as wave_file:
num_channels = wave_file.getnchannels()
sample_rate = wave_file.getframerate()
num_frames = wave_file.getnframes()

# Read the frames as raw bytes
audio_frames = wave_file.readframes(num_frames)

# Convert the bytes to a numpy array
audio_np = np.frombuffer(audio_frames, dtype='int16').reshape(-1, num_channels)

# Normalize the audio to the range [-1, 1] and convert to a PyTorch tensor
audio_tensor = torch.tensor(audio_np, dtype=torch.float32) / 32767.0

# Reshape to (num_channels, num_samples) and add the batch dimension
audio_tensor = audio_tensor.transpose(0, 1).unsqueeze(0)

# Create a LazyAudioMap-like object
class LazyAudioMap(Mapping):
def __init__(self, waveform, sample_rate):
self._dict = {
'waveform': waveform,
'sample_rate': sample_rate
}

def __getitem__(self, key):
return self._dict[key]

def __iter__(self):
return iter(self._dict)

def __len__(self):
return len(self._dict)

return (LazyAudioMap(audio_tensor, sample_rate),)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-akatz-nodes"
description = "Simple custom node pack for nodes I use in my workflows. Includes Dilate Mask Linear for animating masks."
version = "1.3.1"
version = "1.4.1"
license = {file = "LICENSE"}

[project.urls]
Expand Down

0 comments on commit ca8d846

Please sign in to comment.