-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathadv_test.py
77 lines (62 loc) · 3.01 KB
/
adv_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import argparse
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
torch.autograd.set_detect_anomaly(True)
from models import resnet_cifar
cifar10_mean = (0.4914, 0.4822, 0.4465) ## for CIFAR-10
cifar10_std = (0.247, 0.243, 0.261)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--norm', type=str, default='Linf')
parser.add_argument('--epsilon', type=float, default=8./255.)
parser.add_argument('--model', type=str, default='./model_test.pt')
parser.add_argument('--n_ex', type=int, default=1000)
parser.add_argument('--individual', action='store_true')
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--batch_size', type=int, default=500)
parser.add_argument('--log_path', type=str, default='./log_file.txt')
parser.add_argument('--version', type=str, default='standard')
args = parser.parse_args()
# load model
model = resnet_cifar.resnet18(num_classes=10,pretrained=False)
ckpt = torch.load(args.model)['state_dict']
model.load_state_dict(ckpt)
model.cuda()
model.eval()
# load data
transform_list = [transforms.ToTensor(), transforms.Normalize(cifar10_mean, cifar10_std)]
transform_chain = transforms.Compose(transform_list)
item = datasets.CIFAR10(root='/home/vamshi/datasets/CIFAR_10_data/', train=False, transform=transform_chain, download=False)
test_loader = data.DataLoader(item, batch_size=1000, shuffle=False, num_workers=0)
# create save dir
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# load attack
from autoattack import AutoAttack
adversary = AutoAttack(model, norm=args.norm, eps=args.epsilon, log_path=args.log_path,
version=args.version)
l = [x for (x, y) in test_loader]
x_test = torch.cat(l, 0)
l = [y for (x, y) in test_loader]
y_test = torch.cat(l, 0)
# example of custom version
if args.version == 'custom':
adversary.attacks_to_run = ['apgd-ce']
adversary.apgd.n_restarts = 1
# run attack and save images
with torch.no_grad():
if not args.individual:
adv_complete = adversary.run_standard_evaluation(x_test[:args.n_ex], y_test[:args.n_ex],
bs=args.batch_size)
torch.save({'adv_complete': adv_complete}, '{}/{}_{}_1_{}_eps_{:.5f}.pth'.format(
args.save_dir, 'aa', args.version, adv_complete.shape[0], args.epsilon))
else:
# individual version, each attack is run on all test points
adv_complete = adversary.run_standard_evaluation_individual(x_test[:args.n_ex],
y_test[:args.n_ex], bs=args.batch_size)
torch.save(adv_complete, '{}/{}_{}_individual_1_{}_eps_{:.5f}_plus_{}_cheap_{}.pth'.format(
args.save_dir, 'aa', args.version, args.n_ex, args.epsilon))