-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpipeline_diffusion_diff_grids_SNO.py
453 lines (362 loc) · 17.8 KB
/
pipeline_diffusion_diff_grids_SNO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
import warnings
warnings.filterwarnings('ignore')
import os
import jax
import sys
import optax
import cloudpickle
import itertools
import equinox as eqx
import jax.numpy as jnp
import scipy.stats as st
import numpy as np
import matplotlib.pyplot as plt
import argparse
from transforms.interpolation import linear_interpolation
from transforms import library_of_transforms as lft
from solvers import FD_2D
from tqdm import tqdm
from jax import device_put, clear_caches
from functools import partial
from jax.scipy.sparse.linalg import cg
from transforms import utilities, cheb
from collections import namedtuple
from jax.lax import scan, dot_general
from architectures import DilResNet, fSNO
from jax import config, random, grad, vmap, tree_map, tree_leaves, jit
from jax.experimental import sparse as jsparse
from scipy.interpolate import interpn
config.update("jax_enable_x64", True)
# %matplotlib inline
# %config InlineBackend.figure_format='retina'
@jit
def spsolve_scan(carry, n):
A, r = carry
A_bcsr = jsparse.BCSR.from_bcoo(A)
Ar = jsparse.linalg.spsolve(A_bcsr.data[n], A_bcsr.indices[n], A_bcsr.indptr[n], r[n])
return [A, r], Ar
def res_func(A, B, res):
_, Ar = scan(spsolve_scan, [A, res], jnp.arange(A.shape[0]))
Ar = jnp.array(Ar)
B_Ar = jsparse.bcoo_dot_general(A, B - Ar, dimension_numbers=((2, 1), (0, 0)))
numerator = jnp.sqrt(jnp.einsum('bi, bi -> b', B - Ar, B_Ar))
denominator = jnp.sqrt(jnp.einsum('bi, bi -> b', Ar, res))
value = numerator / denominator
return value
def random_polynomial_2D(x, y, coeff):
res = 0
for i, j in itertools.product(range(coeff.shape[0]), repeat=2):
res += coeff[i, j]*jnp.exp(2*jnp.pi*x*i*1j)*jnp.exp(2*jnp.pi*y*j*1j)/(1+i+j)**2
res = jnp.real(res)
return res
def get_rhs(key):
coeff = random.normal(key, (5, 5), dtype=jnp.complex128)
rhs = lambda x, y, c=coeff: random_polynomial_2D(x, y, c)
return rhs
def get_k(key):
coeff = random.normal(key, (5, 5), dtype=jnp.complex128)
k = lambda x, y, c=coeff: random_polynomial_2D(x, y, c) + 10
return k
def dataset(grid, N_samples, key):
keys = random.split(key, (N_samples, 2))
A, rhs = [], []
for key in keys:
rhs_sample, A_sample = FD_2D(grid, [get_k(key[1]), get_rhs(key[0])])
A.append(A_sample.reshape(1, grid**2, -1))
rhs.append(rhs_sample)
A = jsparse.bcoo_concatenate(A, dimension=0)
return A, jnp.array(rhs)
def get_exact_solution(A, rhs, grid, N_samples):
A_bcsr = jsparse.BCSR.from_bcoo(A)
u_exact = jnp.stack([jsparse.linalg.spsolve(A_bcsr.data[n], A_bcsr.indices[n], A_bcsr.indptr[n], rhs[n].reshape(-1,)) for n in range(N_samples)]).reshape(N_samples, grid, grid)
return u_exact
def get_SNO(key, features_train, grid):
input = features_train
D = len(features_train.shape[1:])
# Parameters of model
polynomials = ["Real_Fourier", ] * D
parameters = [[0.1, 0.1],] * D
M_keep = [20, ] * D
N_points = features_train.shape[1:]
grids = N_points
data = {
"polynomials": polynomials,
"parameters": parameters,
"grids": grids,
"M_keep": M_keep,
"N_points": N_points
}
synthesis = utilities.get_operators("synthesis", **data)
analysis = utilities.get_operators("analysis", **data)
kernel_size = 3
n_conv_layers = 3
cell = lambda features, key: DilResNet.DilatedConvBlock([features,]*(n_conv_layers + 1), [[1,]*D, ]*n_conv_layers, [[kernel_size,]*D, ]*n_conv_layers, key, activation=lambda x: x)
input_shape = input.shape
N_features_out = 1
if grid <= 64:
N_features = max(48, grid)
else:
N_features = min(80, grid // 1.5)
N_layers = 4
model = fSNO.fSNO(input_shape, N_features_out, N_layers, N_features, cell, key)
# Parameters of training
N_epoch = 200
if grid <= 64:
batch_size = 16
else:
batch_size = 512 // grid
learning_rate = 5e-4
model_data = {
"model_name": "fSNO",
"model": model
}
optimization_specification = {
"learning_rate": learning_rate,
"compute_loss": lambda carry, indices: compute_loss_scan(carry, indices, analysis, synthesis),
"make_step": lambda carry, indices, A, r, error, optim: make_step_scan(carry, indices, A, r, error, optim, analysis, synthesis),
"N_epochs": N_epoch,
"batch_size": batch_size,
"res_func": lambda A, model, input: res_func(A, model, input),
"analysis": analysis,
"synthesis": synthesis
}
return model_data, optimization_specification
def FCG(A, features, model, N_iter, m_max, optimization_specification, analysis=None, synthesis=None, eps=1e-30, count_values=False, j=0):
poly_type = "Real_Fourier"
samples = features.shape[0]
n = features.shape[-1]
X = jnp.zeros((samples, n, N_iter+1))
R = jnp.zeros((samples, n, N_iter+1))
P_list, S_list = [], []
X = X.at[:, :, 0].set(random.normal(random.PRNGKey(j), (samples, n)))
R = R.at[:, :, 0].set(features - jsparse.bcoo_dot_general(A, X[:, :, 0], dimension_numbers=((2, 1), (0, 0))))
grid = int(n**0.5)
h = 1. / grid
values = []
for idx in range(N_iter):
if type(model) != type(lambda x: x):
train_data = R[:, :, idx].reshape(-1, grid, grid)
train_data = uniform_to_polinomial(train_data, grid, poly_type)
norm = jnp.linalg.norm(train_data, axis=(1,2))
train_data = jnp.einsum('bij,b->bij', train_data, 1./norm)[:, None]
output = vmap(model, in_axes=(0, None, None))(train_data, analysis, synthesis)[:, 0]
output = polinomial_to_uniform(jnp.einsum('bij, b->bij', output, norm), grid, poly_type).reshape(-1, grid**2)
U = output
else:
history_train = []
U = vmap(model)(R[:, :, idx])
if count_values:
value = optimization_specification['res_func'](A, U, R[:, :, idx])
values.append(value)
P = U
for k in range(len(P_list)):
alpha = - jnp.einsum('bj, bj->b', S_list[k], U) / (jnp.einsum('bj, bj->b', S_list[k], P_list[k]) + eps)
P += jnp.einsum('b, bj->bj', alpha, P_list[k])
S = jsparse.bcoo_dot_general(A, P, dimension_numbers=((2, 1),(0, 0)))
beta = jnp.einsum('bj, bj -> b', P, R[:, :, idx]) / (jnp.einsum('bj, bj -> b', S, P) + eps)
X = X.at[:, :, idx+1].set(X[:, :, idx] + jnp.einsum('b, bj->bj', beta, P))
R = R.at[:, :, idx+1].set(R[:, :, idx] - jnp.einsum('b, bj->bj', beta, S))
if (idx % (m_max+1) == m_max) | (idx % (m_max + 1) == 0):
P_list = []
S_list = []
P_list.append(P)
S_list.append(S)
return P, R, X, values
def anal_synth(grid):
polynomials = ['Real_Fourier', ] * 2
parameters = [[0.1, 0.1],] * 2
M_keep = [20, ] * 2
N_points = [grid] * 2
grids = N_points
data = {
"polynomials": polynomials,
"parameters": parameters,
"grids": grids,
"M_keep": M_keep,
"N_points": N_points
}
synthesis = utilities.get_operators("synthesis", **data)
analysis = utilities.get_operators("analysis", **data)
return analysis, synthesis
@jit
def compute_loss_scan(carry, indices, analysis, synthesis):
model, A, x, error, N_repeats = carry
A, x, error = A[indices // N_repeats], x[indices], error[indices]
B = vmap(lambda z: model(z, analysis, synthesis), in_axes=(0,))(x[:, None, :])[:, 0].reshape(x.shape[0], -1)
B_e = jsparse.bcoo_dot_general(A, B - error, dimension_numbers=((2, 1), (0, 0)))
A_e = jsparse.bcoo_dot_general(A, error, dimension_numbers=((2, 1), (0, 0)))
return carry, jnp.mean(jnp.sqrt(jnp.einsum('bi, bi -> b', B - error, B_e) / jnp.einsum('bi, bi -> b', error, A_e)))
def compute_loss(model, A, x, error, analysis, synthesis):
B = vmap(lambda z: model(z, analysis, synthesis), in_axes=(0,))(x[:, None, :])[:, 0].reshape(x.shape[0], -1)
B_e = jsparse.bcoo_dot_general(A, B - error, dimension_numbers=((2, 1), (0, 0)))
A_e = jsparse.bcoo_dot_general(A, error, dimension_numbers=((2, 1), (0, 0)))
return jnp.mean(jnp.sqrt(jnp.einsum('bi, bi -> b', B - error, B_e) / jnp.einsum('bi, bi -> b', error, A_e)))
compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)
@eqx.filter_jit
def make_step_scan(carry, indices, A, r, error, optim, analysis, synthesis):
model, opt_state, N_repeats = carry
loss, grads = compute_loss_and_grads(model, A[indices // N_repeats], r[indices], error[indices], analysis, synthesis)
updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
return [model, opt_state, N_repeats], loss
def train_on_epoch(key, batch_size, A, model, x, error, opt_state, make_step, N_repeats):
N_samples = len(x)
list_of_indices = jnp.linspace(0, N_samples-1, N_samples, dtype=jnp.int64)
n_batches = N_samples // batch_size
make_step_ = lambda carry, indices: make_step(carry, indices, A, x, error)
carry = [model, opt_state, N_repeats]
n = random.choice(key, list_of_indices, shape = (n_batches, batch_size))
data, epoch_loss = scan(make_step_, carry, n)
model = data[0]
opt_state = data[1]
return epoch_loss, model, opt_state
def test_on_epoch(key, batch_size, A, model, x, error, compute_loss, N_repeats):
N_samples = len(x)
list_of_indices = jnp.linspace(0, N_samples-1, N_samples, dtype=jnp.int64)
n_batches = N_samples // batch_size
n = random.choice(key, list_of_indices, shape = (n_batches, batch_size))
carry = [model, A, x, error, N_repeats]
data, epoch_loss = scan(compute_loss, carry, n)
return epoch_loss
def train_model(model, A, x, error, optimization_specification, N_repeats):
model = model
history = []
history_test = []
c = x.shape[0] // optimization_specification['batch_size']
keys = [value * c for value in np.arange(50, 1000, 50)]
values = [0.5, ] * len(keys)
dict_lr = dict(zip(keys, values))
sc = optax.piecewise_constant_schedule(optimization_specification['learning_rate'], dict_lr)
optimizer = optax.adamw(sc, weight_decay=1e-2)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
make_step = lambda a, b, A, r, error: optimization_specification['make_step'](a, b, A, r, error, optimizer)
for it in tqdm(range(optimization_specification['N_epochs'])):
key = random.PRNGKey(it)
loss, model, opt_state = train_on_epoch(key, optimization_specification['batch_size'], A, model, x, error, opt_state, make_step, N_repeats)
history.append(loss)
return model, history
def get_dataset_rhs(state, key, grid):
rhs = state[0]
rhs = get_rhs(key)
x = jnp.linspace(0, 1, num=grid+2)[1:-1]
y = jnp.linspace(0, 1, num=grid+2)[1:-1]
xx, yy = jnp.meshgrid(x, y)
state = [rhs(xx, yy)]
return state, state
def generate_res_errors(grid, N_samples, N_repeats, A):
residuals = []
error = []
for i in range(N_repeats):
keys = random.split(random.PRNGKey(i), 2)
X_random = random.normal(keys[0], (N_samples, grid**2))
initial = [jnp.zeros(shape=(grid, grid))]
get_dataset_rhs_ = lambda x, k: get_dataset_rhs(x, k, grid)
rhs_random = scan(get_dataset_rhs_, initial, random.split(keys[1], N_samples))[1][0].reshape(-1, grid**2)
residuals_random = rhs_random - jsparse.bcoo_dot_general(A, X_random, dimension_numbers=((2, 1),(0, 0)))
norm = jnp.linalg.norm(residuals_random, axis=1)
residuals_random = jnp.einsum('bij, b -> bij', residuals_random.reshape(-1, grid, grid), 1./norm)
u_exact = get_exact_solution(A, rhs_random, grid=grid, N_samples=N_samples)
error_random = u_exact - X_random.reshape((N_samples, -1, grid))
error_random = jnp.einsum('bij, b -> bij', error_random, 1./norm).reshape(N_samples, -1)
residuals.append(residuals_random)
error.append(error_random)
residuals = jnp.concatenate(residuals, axis=0)
error = jnp.concatenate(error, axis=0)
return residuals, error
@jit
def interpolate_2D(carry, n):
data, grids_in, grids_out = carry
data_intrp = vmap(jnp.interp, in_axes=(None, None, 1), out_axes=0)(grids_out[0], grids_in[0], data[n])
data_intrp = vmap(jnp.interp, in_axes=(None, None, 1), out_axes=0)(grids_out[1], grids_in[1], data_intrp)
return [data, grids_in, grids_out], data_intrp
def uniform_to_polinomial(data, grid, poly_type):
grids_in = [jnp.linspace(0, 1, grid+2)[1:-1], jnp.linspace(0, 1, grid+2)[1:-1]]
grids_out = [lft.poly_data[poly_type]["nodes"](grid+1, [0, 0])[1:]]*2
_, data_intrp = scan(interpolate_2D, [data, grids_in, grids_out], jnp.arange(data.shape[0]))
return jnp.array(data_intrp)
def polinomial_to_uniform(data, grid, poly_type):
grids_out = [jnp.linspace(0, 1, grid+2)[1:-1], jnp.linspace(0, 1, grid+2)[1:-1]]
grids_in = [lft.poly_data[poly_type]["nodes"](grid+1, [0, 0])[1:]]*2
_, data_intrp = scan(interpolate_2D, [data, grids_in, grids_out], jnp.arange(data.shape[0]))
return jnp.array(data_intrp)
def save_data(model, values, R, history, path, experiment_name):
data = {
"loss_model": values,
"R_model": R,
"history": history,
}
with open(path + f"{experiment_name}_data.npz", "wb") as f:
jnp.savez(f, **data)
if type(model) != type(lambda x: x):
with open(path + f"{experiment_name}_model", "wb") as f:
cloudpickle.dump(model, f)
def main(model_type, train_generation, grid, N_samples, N_repeats, m_max, path):
h = 1. / grid
if model_type == 'SNO':
key = random.PRNGKey(2)
N_samples = grid
A_train, rhs_train = dataset(grid=grid, N_samples=N_samples, key=key)
if train_generation == 'FCG':
model_ = lambda x: x
optimization_specification = {"res_func": lambda A, B, input: res_func(A, B, input)}
R_, X_ = [], []
for j in tqdm(range(N_samples // 2)):
_, R, X, _ = FCG(A_train[j*2:(j+1)*2], rhs_train[j*2:(j+1)*2], model=model_, N_iter=N_repeats-1, m_max=m_max, optimization_specification=optimization_specification, count_values=False, j=j)
R_.append(device_put(R, device=jax.devices("cpu")[0]))
X_.append(device_put(X, device=jax.devices("cpu")[0]))
del R, X
clear_caches()
R_ = jnp.concatenate(R_, axis=0)
X_ = jnp.concatenate(X_, axis=0)
residuals = R_.transpose(0, 2, 1).reshape(-1, grid, grid)
norm = jnp.linalg.norm(residuals, axis=(1, 2))
residuals = jnp.einsum('bij, b -> bij', residuals, 1./norm)
u_exact = device_put(get_exact_solution(A_train, rhs_train, grid=grid, N_samples=N_samples), device=jax.devices("cpu")[0])
del rhs_train
clear_caches()
error = (jnp.repeat(u_exact.reshape(-1, grid**2)[:, :, None], N_repeats, axis=2) - X_).transpose(0, 2, 1).reshape((N_samples * N_repeats, -1, grid))
error = jnp.einsum('bij, b -> bij', error, 1./norm).reshape(N_samples * N_repeats, -1)
del R_, X_, u_exact
clear_caches()
else:
residuals, error = generate_res_errors(grid, N_samples, N_repeats, A_train)
model_data, optimization_specification = get_SNO(random.PRNGKey(40), residuals[0][None, ...], grid)
model, history = train_model(model_data['model'], A_train, device_put(residuals, device=jax.devices("gpu")[0]),
device_put(error, device=jax.devices("gpu")[0]), optimization_specification, N_repeats)
del residuals, error, A_train
clear_caches()
key = random.PRNGKey(12)
N_samples = 20
grid_fine = grid * 4
A_test, rhs_test = dataset(grid=grid_fine, N_samples=N_samples, key=key)
N_iter = min(int(grid_fine * 3.5), 300)
analysis, synthesis = anal_synth(grid_fine)
A_bcsr = jsparse.BCSR.from_bcoo(A_test)
u_exact = jnp.stack([jsparse.linalg.spsolve(A_bcsr.data[n], A_bcsr.indices[n], A_bcsr.indptr[n], rhs_test[n].reshape(-1,)) for n in range(N_samples)])
del A_bcsr
clear_caches()
R, values = [], []
for j in tqdm(range(N_samples//4)):
_, R_, _, values_ = FCG(A_test[j*4:(j+1)*4], rhs_test[j*4:(j+1)*4], model=model, N_iter=N_iter, m_max=m_max, optimization_specification=optimization_specification, analysis=analysis, synthesis=synthesis, count_values=True, j=j)
R.append(R_)
values.append(jnp.array(values_))
del R_, values_
clear_caches()
R = jnp.concatenate(R, axis=0)
values = jnp.concatenate(values, axis=1).mean(axis=1)
save_data(model, values, R, history, path, f'{model_type}_{train_generation}_{grid}-{grid_fine}_{N_samples}_{N_repeats}_{m_max}')
print(f'Done for {model_type}_{train_generation}_{grid}-{grid_fine}_{N_samples}_{N_repeats}_{m_max}')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='pipeline')
parser.add_argument('--model', type=str, help='model type') # SNO
parser.add_argument('--gentype', type=str, help='generation type') # FCG, random
parser.add_argument('--nsamp', type=int, help='n samples') # N_samples
parser.add_argument('--nrep', type=int, help='number of repeats') # N_iter for FCG(I)
parser.add_argument('--cuda', type=int, help='device cuda') # cuda
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
grids = [32, 64]
path = f'./Elliptic/SNO/diff_grid_'
m_max = 20
for grid in grids:
main(args.model, args.gentype, grid, args.nsamp, args.nrep, m_max, path)