-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptax_Test.py
46 lines (37 loc) · 1.5 KB
/
optax_Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
from typing import Tuple
import optax
import jax.numpy as jnp
import jax
from jax_main import GradDescent
from jax.experimental.host_callback import id_print # this is a way to print in Jax when things are preself.compiledd
NUM_TRAIN_STEPS = 1_000
holder = GradDescent(12, include_field=True, autorun=False, plot_direc="optax_plots")
holder.check_field(holder.truth_field, "truth_field", show=True, save=True)
holder.check_field(holder.data, "truth_field", show=True, save=True)
initial_params = holder.s_field
TRAINING_DATA = holder.data
def loss(params: optax.Params) -> jnp.ndarray:
loss_value = holder.chi_sq_jax(params)
# id_print(loss_value)
return loss_value
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state):
loss_value, grads = jax.value_and_grad(loss)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i in range(NUM_TRAIN_STEPS):
params, opt_state, loss_value = step(params, opt_state)
print(params)
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.rmsprop(learning_rate=10)
params = fit(initial_params, optimizer)
import matplotlib.pyplot as plt
plt.imshow(jnp.reshape(params, (256, 256)))
plt.colorbar()
plt.savefig("output.png")