-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_melfilter.py
72 lines (59 loc) · 2.12 KB
/
extract_melfilter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# mean-normalized Mel filter bank log-energies
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
class AudioFeatureExtractor:
def __init__(self,
sample_rate=16000,
n_fft=512,
win_length=400, # 25ms
hop_length=240, # 15ms
n_mels=72,
f_min=20,
f_max=7600,
pre_emphasis=0.97,
ref_level_db=20,
min_level_db=-100,
window_type='hann',
mel_scale='htk',
norm_type='per_feature',
eps=1e-6):
self.pre_emphasis = pre_emphasis
self.ref_level_db = ref_level_db
self.min_level_db = min_level_db
self.eps = eps
# Mel 변환기 설정
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
window_fn=getattr(torch, f'{window_type}_window'),
mel_scale=mel_scale,
normalized=True
)
self.norm_type = norm_type
def pre_emphasis_filter(self, x):
return torch.cat((x[:, 0:1], x[:, 1:] - self.pre_emphasis * x[:, :-1]), dim=1)
def normalize(self, x):
if self.norm_type == 'per_feature':
mean = torch.mean(x, dim=-1, keepdim=True)
std = torch.std(x, dim=-1, keepdim=True)
else: # all_features
mean = torch.mean(x)
std = torch.std(x)
return (x - mean) / (std + self.eps)
def extract_features(self, waveform):
# 프리엠파시스 적용
emphasized = self.pre_emphasis_filter(waveform)
# Mel spectrogram 계산
mel_spec = self.mel_transform(emphasized)
# Log-Mel 변환
log_mel = torch.log(mel_spec + self.eps)
# 정규화
normalized = self.normalize(log_mel)
return normalized