-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
yakovlev
committed
Jun 11, 2024
0 parents
commit 9c1763b
Showing
4 changed files
with
1,035 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# ReDimNet | ||
|
||
This is oficial implementation for neural network architecture presented in paper [Reshape Dimensions Network for Speaker Recognition](). | ||
|
||
## Update | ||
* 2024.07.15 Adding model builder and pretrained weights for: `b0`, `b2`, `b3`, `b5`, `b6` model sizes. | ||
|
||
## Introduction | ||
|
||
We introduce Reshape Dimensions Network (ReDimNet), a novel neural network architecture for spectrogram (audio) processing, specifically for extracting utterance-level speaker representations. ReDimNet reshapes dimensionality between 2D feature maps and 1D signal representations, enabling the integration of 1D and 2D blocks within a single model. This architecture maintains the volume of channel-timestep-frequency outputs across both 1D and 2D blocks, ensuring efficient aggregation of residual feature maps. ReDimNet scales across various model sizes, from 1 to 15 million parameters and 0.5 to 20 GMACs. Our experiments show that ReDimNet achieves state-of-the-art performance in speaker recognition while reducing computational complexity and model size compared to existing systems. | ||
|
||
<p align="center"> | ||
<img src="redimnet_scheme.png" alt="Sample" width="500"> | ||
<p align="center"> | ||
<em>ReDimNet architecture</em> | ||
</p> | ||
</p> | ||
|
||
## Usage | ||
|
||
### Requirement | ||
PyTorch>=2.0 | ||
### Examples | ||
``` | ||
import torch | ||
# To load pretrained on vox2 model without Large-Margin finetuning | ||
model = torch.hub.load('IDRnD/ReDimNet', 'b0', pretrained=True, finetuned=False) | ||
# To load pretrained on vox2 model with Large-Margin finetuning: | ||
model = torch.hub.load('IDRnD/ReDimNet', 'b0', pretrained=False, finetuned=True) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import json | ||
import torch | ||
from redimnet import ReDimNetWrap | ||
dependencies = ['torch','torchaudio'] | ||
|
||
URL_TEMPLATE = "https://github.com/IDRnD/ReDimNet/releases/weights/download/{model_name}" | ||
|
||
def load_custom(size, pretrained=False, finetuned=True): | ||
model_prefix = 'lm' if finetuned else 'ptn' | ||
assert size in [f'b{i}' for i in range(7)] | ||
|
||
model_name = f'redimnet-{size}-vox2-{model_prefix}.pt' | ||
url = URL_TEMPLATE.format(model_name = model_name) | ||
|
||
full_state_dict = torch.hub.load_state_dict_from_url(url, progress=True) | ||
|
||
model_config = full_state_dict['model_config'] | ||
state_dict = full_state_dict['state_dict'] | ||
model = ReDimNetWrap(**model_config) | ||
if pretrained or finetuned: | ||
load_res = model.load_state_dict(state_dict) | ||
print(f"load_res : {load_res}") | ||
return model | ||
|
||
def b0(pretrained=False, finetuned=True): | ||
return load_custom('b0', pretrained=pretrained, finetuned=finetuned) | ||
|
||
def b2(pretrained=False, finetuned=True): | ||
return load_custom('b2', pretrained=pretrained, finetuned=finetuned) | ||
|
||
def b3(pretrained=False, finetuned=True): | ||
return load_custom('b3', pretrained=pretrained, finetuned=finetuned) | ||
|
||
def b5(pretrained=False, finetuned=True): | ||
return load_custom('b5', pretrained=pretrained, finetuned=finetuned) | ||
|
||
def b6(pretrained=False, finetuned=True): | ||
return load_custom('b6', pretrained=pretrained, finetuned=finetuned) |
Oops, something went wrong.