-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
144 lines (108 loc) · 5.38 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import os
import matplotlib.pyplot as plt
import torch
import torchaudio
from sklearn.model_selection import train_test_split
from torchaudio.transforms import MFCC, Resample, MelSpectrogram
def print_melspec(melspec):
plt.imshow(melspec)
plt.show()
def split_data(recording_ids, cycles, labels, prevent_leakage=False):
if prevent_leakage:
unique_recording_ids = torch.unique(recording_ids)
train_ids, test_ids = train_test_split(unique_recording_ids, test_size=0.2, stratify=labels, random_state=42)
# Create masks for selecting data
train_mask = torch.isin(recording_ids, train_ids)
test_mask = torch.isin(recording_ids, test_ids)
# Separate data based on recording_id
X_train, y_train = cycles[train_mask], labels[train_mask]
X_test, y_test = cycles[test_mask], labels[test_mask]
else:
# Random splitting 80/20
X_train, X_test, y_train, y_test = train_test_split(cycles, labels, test_size=0.2, stratify=labels, random_state=42)
# Further splitting of train set into validation
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train, random_state=21)
X_train, y_train = X_train[y_train == 0], y_train[y_train == 0]
return X_train, X_val, X_test, y_train, y_val, y_test
def wrap_padding(tensor, target_length):
# Truncate if the tensor is longer than the target length
if len(tensor) > target_length:
return tensor[:target_length]
# If the tensor is shorter than the target length, apply wrap padding
if len(tensor) < target_length:
# Calculate how many times the tensor needs to be concatenated to exceed the target length
num_repeats = (target_length + len(tensor) - 1) // len(tensor)
# Tile the tensor
tiled_tensor = tensor.repeat(num_repeats)
# Trim the tensor to the target length
return tiled_tensor[:target_length]
return tensor
def zero_padding(tensor, target_length):
tensor_length = len(tensor)
# Truncate if the tensor is longer than the target length
if tensor_length > target_length:
return tensor[:target_length]
# Zero padding if the tensor is shorter than the target length
if tensor_length < target_length:
padding_size = target_length - tensor_length
zero_padding = torch.zeros(padding_size, dtype=tensor.dtype, device=tensor.device)
return torch.cat([tensor, zero_padding])
return tensor
def process_cycle(waveform, sr):
# Extract mel spectrogram
# mel_spec = MelSpectrogram(sample_rate=sr, n_mels=64, n_fft=256, hop_length=256 // 2, f_max=2000)(waveform)
mel_spec = MFCC(sample_rate=sr, n_mfcc=13, melkwargs={'n_mels': 64, 'n_fft': 256, 'hop_length': 256 // 2, 'f_max': 2000})(waveform)
# Convert to db scale
mel_spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
# Normalize mel spectrogram
normalized_mel = (mel_spec_db - torch.mean(mel_spec_db)) / torch.std(mel_spec_db)
return normalized_mel
def extract_cycles(dataset):
wav_files = [f for f in os.listdir(dataset) if f.endswith('.wav')]
cycles = []
labels = []
recording_ids = []
for idx, wav_file in enumerate(wav_files):
txt_file = wav_file.replace('.wav', '.txt')
# Load audio file
waveform, sr = torchaudio.load(os.path.join(dataset, wav_file))
# Remove the channel dimension (mono sound)
waveform = waveform.squeeze()
# Read annotation
with open(os.path.join(dataset, txt_file), 'r') as f:
annotations = f.readlines()
for line in annotations:
start, end, crackles, wheezes = line.strip().split('\t')
# Convert time to sample number
start, end = float(start) * sr, float(end) * sr
label = 1 if int(crackles) or int(wheezes) else 0
# Extract and resample the cycle
cycle = waveform[int(start):int(end)]
resampler = Resample(orig_freq=sr, new_freq=4000)
cycle = resampler(cycle)
# Padding or truncating to 3 seconds
excerpt_length = int(5 * 4000)
cycle = zero_padding(cycle, excerpt_length)
cycle = process_cycle(cycle, 4000)
cycle = cycle[:, :128]
recording_ids.append(idx)
cycles.append(cycle)
labels.append(label)
return torch.tensor(recording_ids), torch.stack(cycles), torch.tensor(labels)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocessing of ICBHI 2017 dataset for anomaly detection")
parser.add_argument("--dataset", default="/home/lukas/thesis/dataset", type=str, help="Directory where the original dataset is stored")
parser.add_argument("--target", default="dataset.pt", type=str, help="Output path to store processed data")
parser.add_argument("--recording_level", default=False, type=bool, help="Whether or not to split at recording level")
args = parser.parse_args()
recording_ids, cycles, labels = extract_cycles(args.dataset)
X_train, X_val, X_test, y_train, y_val, y_test = split_data(recording_ids, cycles, labels, prevent_leakage=args.recording_level)
torch.save({
'X_train': X_train,
'X_val': X_val,
'X_test': X_test,
'y_train': y_train,
'y_val': y_val,
'y_test': y_test
}, args.target)