Skip to content

Commit

Permalink
Init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
yakovlev committed Jun 11, 2024
0 parents commit 9c1763b
Show file tree
Hide file tree
Showing 4 changed files with 1,035 additions and 0 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ReDimNet

This is oficial implementation for neural network architecture presented in paper [Reshape Dimensions Network for Speaker Recognition]().

## Update
* 2024.07.15 Adding model builder and pretrained weights for: `b0`, `b2`, `b3`, `b5`, `b6` model sizes.

## Introduction

We introduce Reshape Dimensions Network (ReDimNet), a novel neural network architecture for spectrogram (audio) processing, specifically for extracting utterance-level speaker representations. ReDimNet reshapes dimensionality between 2D feature maps and 1D signal representations, enabling the integration of 1D and 2D blocks within a single model. This architecture maintains the volume of channel-timestep-frequency outputs across both 1D and 2D blocks, ensuring efficient aggregation of residual feature maps. ReDimNet scales across various model sizes, from 1 to 15 million parameters and 0.5 to 20 GMACs. Our experiments show that ReDimNet achieves state-of-the-art performance in speaker recognition while reducing computational complexity and model size compared to existing systems.

<p align="center">
<img src="redimnet_scheme.png" alt="Sample" width="500">
<p align="center">
<em>ReDimNet architecture</em>
</p>
</p>

## Usage

### Requirement
PyTorch>=2.0
### Examples
```
import torch
# To load pretrained on vox2 model without Large-Margin finetuning
model = torch.hub.load('IDRnD/ReDimNet', 'b0', pretrained=True, finetuned=False)
# To load pretrained on vox2 model with Large-Margin finetuning:
model = torch.hub.load('IDRnD/ReDimNet', 'b0', pretrained=False, finetuned=True)
```
38 changes: 38 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json
import torch
from redimnet import ReDimNetWrap
dependencies = ['torch','torchaudio']

URL_TEMPLATE = "https://github.com/IDRnD/ReDimNet/releases/weights/download/{model_name}"

def load_custom(size, pretrained=False, finetuned=True):
model_prefix = 'lm' if finetuned else 'ptn'
assert size in [f'b{i}' for i in range(7)]

model_name = f'redimnet-{size}-vox2-{model_prefix}.pt'
url = URL_TEMPLATE.format(model_name = model_name)

full_state_dict = torch.hub.load_state_dict_from_url(url, progress=True)

model_config = full_state_dict['model_config']
state_dict = full_state_dict['state_dict']
model = ReDimNetWrap(**model_config)
if pretrained or finetuned:
load_res = model.load_state_dict(state_dict)
print(f"load_res : {load_res}")
return model

def b0(pretrained=False, finetuned=True):
return load_custom('b0', pretrained=pretrained, finetuned=finetuned)

def b2(pretrained=False, finetuned=True):
return load_custom('b2', pretrained=pretrained, finetuned=finetuned)

def b3(pretrained=False, finetuned=True):
return load_custom('b3', pretrained=pretrained, finetuned=finetuned)

def b5(pretrained=False, finetuned=True):
return load_custom('b5', pretrained=pretrained, finetuned=finetuned)

def b6(pretrained=False, finetuned=True):
return load_custom('b6', pretrained=pretrained, finetuned=finetuned)
Loading

0 comments on commit 9c1763b

Please sign in to comment.