-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_inference.py
39 lines (27 loc) · 1.15 KB
/
multi_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from multi_utils import criterion
from torcheval.metrics.functional import binary_auroc
from tqdm import tqdm
import gc
@torch.inference_mode()
def inference(model, optimizer, dataloader, epoch, local_rank):
model.eval()
dataset_size = 0
running_loss = 0.0
running_auroc = 0.0
bar = tqdm(enumerate(dataloader), total=len(dataloader))
for step, data in bar:
images = data['image'].to(local_rank, dtype=torch.float32)
targets = data['target'].to(local_rank, dtype=torch.float32)
batch_size = images.size(0)
outputs = model(images).squeeze()
loss = criterion(outputs, targets)
auroc = binary_auroc(input=outputs.squeeze(), target=targets).item()
running_loss += (loss.item() * batch_size)
running_auroc += (auroc * batch_size)
dataset_size += batch_size
epoch_loss = running_loss / dataset_size
epoch_auroc = running_auroc / dataset_size
bar.set_postfix(epoch=epoch, val_loss=epoch_loss, val_auroc=epoch_auroc,LR=optimizer.param_groups[0]['lr'])
gc.collect() # garbarge collector
return epoch_loss, epoch_auroc