-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhubconf.py
29 lines (24 loc) · 1.15 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 이 파일은 ReDimNet 모델 파라미터를 받아오는 코드 입니다.
# 받을 수 있는 목록은 https://github.com/IDRnD/ReDimNet/releases 여기를 참고
import os
import sys
import json
import torch
sys.path.append(os.path.dirname(__file__))
print(os.path.dirname(__file__))
from redimnet.model import ReDimNetWrap
dependencies = ['torch','torchaudio']
URL_TEMPLATE = "https://github.com/IDRnD/ReDimNet/releases/download/latest/{model_name}"
def load_custom(model_name='M', train_type='ft_mix', dataset='vb2+vox2+cnc'):
model_name = f'{model_name}-{dataset}-{train_type}.pt'
url = URL_TEMPLATE.format(model_name = model_name)
full_state_dict = torch.hub.load_state_dict_from_url(url, progress=True)
model_config = full_state_dict['model_config']
state_dict = full_state_dict['state_dict']
model = ReDimNetWrap(**model_config)
if train_type is not None:
load_res = model.load_state_dict(state_dict)
print(f"load_res : {load_res}")
return model
def ReDimNet(model_name="S", train_type='ft_mix', dataset='vb2+vox2+cnc'):
return load_custom(model_name, train_type=train_type, dataset=dataset)