-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathextract_vq_features.py
142 lines (112 loc) · 5.15 KB
/
extract_vq_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Extract feature from the VQGAN
import os
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from Network.vq_model import VQ_models
from Dataset.dataloader import get_data
torch.set_float32_matmul_precision('high')
def ddp_setup():
""" Initialization of the multi_gpus training"""
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def launch_multi_main(args):
""" Launch multi training"""
ddp_setup()
args.device = int(os.environ["LOCAL_RANK"])
args.is_master = args.device == 0
main(args)
destroy_process_group()
def tensor2pil(image):
""" Transform a tensor Image into """
image = ((image + 1) / 2) * 255
image = image.permute(1, 2, 0).clip(0, 255).cpu().numpy().astype(np.uint8)
return Image.fromarray(image)
class Extractor:
def __init__(self, args):
self.args = args
self.ae = self.get_network("vqgan-llama") # Load VQGAN
self.patch_size = self.args.img_size // self.args.f_factor
self.train_data, self.test_data = get_data(
args.data, img_size=args.img_size, data_folder=args.data_folder,
bsize=args.bsize, num_workers=args.num_workers, is_multi_gpus=False, seed=-1
)
def get_network(self, archi):
if archi == "vqgan-llama":
model = VQ_models[f"VQ-{self.args.f_factor}"](codebook_size=16384, codebook_embed_dim=8)
checkpoint = torch.load(self.args.vqgan_folder, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model = model.eval()
model = model.to(self.args.device)
if self.args.compile:
model = torch.compile(model)
if self.args.is_multi_gpus: # put model on multi GPUs if available
model = DDP(model, device_ids=[self.args.device])
model = model.module
else:
model = None
if self.args.is_master:
print(f"Size of model {archi}: "
f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M")
return model
@torch.no_grad()
def extract_and_save(self, split):
if split == "Train":
bar = tqdm(self.train_data, leave=False) if self.args.is_master else self.train_data
elif split == "Eval":
bar = tqdm(self.test_data, leave=False) if self.args.is_master else self.test_data
else:
bar = None
# create the folder is it does not exist
root = os.path.join(self.args.dest_folder, split)
os.makedirs(root, exist_ok=True)
cpt = 0
for idx, (img, y) in enumerate(bar):
bsize = img.size(0)
img = img.to(self.args.device)
# VQGAN encoding img to tokens
_, _, [_, _, code] = self.ae.encode(img)
code = code.reshape(bsize, self.patch_size, self.patch_size)
code = code.detach().cpu().numpy().astype(np.uint16)
# save each code
for i in range(bsize):
output_dict = {
"code": code[i],
"y": y[i].item()
}
name = os.path.join(root, f"{cpt:07d}.pth")
with open(name, 'wb') as f:
pickle.dump(output_dict, f)
cpt += 1
def main(args):
extractor = Extractor(args)
extractor.extract_and_save(split="Train")
extractor.extract_and_save(split="Eval")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="imagenet", help="")
parser.add_argument("--data-folder", type=str, default="", help="data source")
parser.add_argument("--dest-folder", type=str, default="", help="data destination")
parser.add_argument("--vqgan-folder", type=str, default="", help="vqgan folder")
parser.add_argument("--bsize", type=int, default=128, help="batch size")
parser.add_argument("--img-size", type=int, default=256, help="image size")
parser.add_argument("--f-factor", type=int, default=8, help="downsize factor for tokenizer")
parser.add_argument("--num-workers", type=int, default=8, help="number of workers for loading")
parser.add_argument("--compile", action='store_true', help="compile the network, pytorch 2.0")
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
world_size = torch.cuda.device_count()
if world_size > 1:
print(f"{world_size} GPU(s) found, launch multi-gpus training")
args.is_multi_gpus = True
launch_multi_main(args)
else:
print(f"{world_size} GPU found")
args.is_master = True
args.is_multi_gpus = False
main(args)