-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
38 lines (29 loc) · 1.02 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch.utils.data import DataLoader
import os
import random
import torch
import h5py
import numpy as np
class TrainDataset(DataLoader):
def __init__(self):
hf = h5py.File("./data/train.h5", "r")
self.hr = hf["hr"]
self.lr = hf["lr"]
self.mask = hf["mask"]
def __getitem__(self, index):
idx, x, y, z = self.mask[index]
patch_hr = torch.from_numpy(self.hr[idx, x*2:x*2+2, y*2:y*2+2, z*2:z*2+2])
patch_lr = torch.from_numpy(self.lr[idx, x-2:x+3, y-2:y+3, z-2:z+3])
return patch_lr.permute(3, 0, 1, 2), patch_hr.permute(3, 0, 1, 2)
def __len__(self):
return self.mask.shape[0]
class TestDataset(DataLoader):
def __init__(self, lr, mask):
self.lr = lr
self.mask = mask
def __getitem__(self, index):
x, y, z = self.mask[index]
patch_lr = torch.from_numpy(self.lr[x-2:x+3, y-2:y+3, z-2:z+3])
return patch_lr.permute(3, 0, 1, 2), self.mask[index]
def __len__(self):
return self.mask.shape[0]