-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspeaker_embed_redimnet.py
81 lines (62 loc) · 2.76 KB
/
speaker_embed_redimnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torchaudio
import glob
from pathlib import Path
import numpy as np
from tqdm import tqdm
import pickle
# (기존에 사용하던 ReDimNetWrap import는 더 이상 사용하지 않으므로 주석처리할 수 있습니다.)
# from redimnet.model import ReDimNetWrap
def load_audio(audio_path, target_sample_rate=16000):
"""오디오 파일을 로드하고 필요한 경우 리샘플링합니다."""
waveform, sample_rate = torchaudio.load(audio_path)
# 모노로 변환 (스테레오인 경우)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 리샘플링
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
waveform = resampler(waveform)
return waveform
def main():
# CUDA 사용 가능 여부 확인
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 모델 로드 - hubconf.py의 load_custom() 사용하여 pretrained 파라미터 불러오기
from hubconf import load_custom
model_name = 'S'
train_type = 'ft_mix'
dataset= 'vb2+vox2+cnc'
model = load_custom(model_name=model_name, train_type=train_type, dataset=dataset).to(device)
# 평가 모드로 전환
model.eval()
# 오디오 파일 경로 가져오기 (모든 파일 처리)
audio_files = glob.glob('./audio/*.wav')
print(f"Found {len(audio_files)} audio files")
# 임베딩을 저장할 딕셔너리
embeddings_dict = {}
# 각 오디오 파일에 대해 추론 수행
with torch.no_grad():
for audio_path in tqdm(audio_files):
try:
# 오디오 로드
waveform = load_audio(audio_path)
# 배치 차원 추가
waveform = waveform.unsqueeze(0).to(device)
# 추론
embeddings = model(waveform)
# CPU로 이동하고 numpy 배열로 변환
embeddings_np = embeddings.cpu().numpy()
# 파일 이름을 키로 사용하여 딕셔너리에 저장
file_name = Path(audio_path).stem
embeddings_dict[file_name] = embeddings_np
except Exception as e:
print(f"Error processing {audio_path}: {str(e)}")
# 임베딩을 pickle 파일로 저장
output_path = f'speaker_embed_{model_name}_{train_type}.pkl'
with open(output_path, 'wb') as f:
pickle.dump(embeddings_dict, f)
print(f"\nEmbeddings saved to {output_path}")
print(f"Total processed files: {len(embeddings_dict)}")
if __name__ == "__main__":
main()