-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizers.py
174 lines (146 loc) · 6.3 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import Tuple
import flax
import jax
import jax.numpy as jnp # JAX NumPy
import chex
from flax import linen as nn # Linen API
import training
def compute_fans(shape: Tuple[int,...]):
"""Computes the number of input and output units for a weight shape.
Args:
shape: Integer shape tuple or TF tensor shape.
Returns:
A tuple of integer scalars (fan_in, fan_out).
"""
if len(shape) < 1: # Just to avoid errors for constants.
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
fan_in = shape[0]
fan_out = shape[1]
else:
# Assuming convolution kernels (2D, 3D, or more).
# kernel shape: (..., input_depth, depth)
receptive_field_size = 1
for dim in shape[:-2]:
receptive_field_size *= dim
fan_in = shape[-2] * receptive_field_size
fan_out = shape[-1] * receptive_field_size
return int(fan_in), int(fan_out)
def sample_crossentropy_hessian(predictions, samples):
y = nn.activation.softmax(predictions)
z = jnp.sqrt(y)
return z * samples - y * jnp.sum(z * samples, axis=-1, keepdims=True)
def kalman_blockwise_trace_transformation(fading: float, lr: float) -> training.NaturalGradientTransformation:
@flax.struct.dataclass
class State:
fim_trace: chex.ArrayTree
def init_fim_trace(param: jax.Array):
variance = jnp.var(param)
if (variance > 0):
return 1.0 / variance
fan_in, fan_out = compute_fans(param.shape)
return fan_in
def init(params: chex.ArrayTree):
fim_trace = jax.tree_util.tree_map(
init_fim_trace, params)
return State(fim_trace=fim_trace)
def transform_update(updates, state: State, params=None):
updates = jax.tree_util.tree_map(
lambda u, information: -u / information * lr, updates, state.fim_trace)
fim_trace = jax.tree_util.tree_map(
lambda i: i * fading, state.fim_trace)
return updates, state.replace(fim_trace=fim_trace)
def consume_sample(information_samples, state: State, params=None):
def consume_sample_block(sample, fim_trace, param):
size = float(jnp.size(param))
return fim_trace + lr * jnp.sum(jnp.multiply(sample, sample)) / size
fim_trace = jax.tree_util.tree_map(
consume_sample_block, information_samples, state.fim_trace, params)
return state.replace(fim_trace=fim_trace)
return training.NaturalGradientTransformation(init, transform_update, consume_sample)
def kalman_blockwise_spectral_transformation(fading: float, lr: float, kernel_rank: int, buffer_rank: int, rng) -> training.NaturalGradientTransformation:
@flax.struct.dataclass
class InformationState:
kernel_trace: float
samples_trace: float
basis: jax.Array
samples: jax.Array
rank: int
kernel: jax.Array
@flax.struct.dataclass
class State:
fim: chex.ArrayTree
rng_key: jax.random.PRNGKey
rank: int
def init_fim_trace(param: jax.Array):
variance = jnp.var(param)
if (variance > 0):
return float(1.0 / variance)
fan_in, fan_out = compute_fans(param.shape)
return float(fan_in)
def init(params: chex.ArrayTree):
def init_block(param):
kernel_trace = init_fim_trace(param)
basis = jnp.zeros((kernel_rank,) + param.shape, dtype='float32')
samples = jnp.zeros((buffer_rank,) + param.shape, dtype='float32')
kernel = jnp.zeros([kernel_rank], dtype='float32')
return InformationState(kernel_trace=kernel_trace, samples_trace=0.0, basis=basis, samples=samples, kernel=kernel, rank=0)
return State(fim=jax.tree_util.tree_map(init_block, params), rng_key=rng, rank=0)
def augment_samples_block(sample: jax.Array, fim: InformationState) -> InformationState:
samples = fim.samples.at[fim.rank,...].set(sample)
trace = jnp.tensordot(sample, sample, sample.ndim)
return fim.replace(samples=samples, samples_trace=fim.samples_trace+trace, rank=fim.rank+1)
@jax.jit
def compress_samples_block(param: jax.Array, fim: InformationState) -> InformationState:
rank = kernel_rank
sum_dims = list(range(1, fim.basis.ndim))
# maybe jnp.tensordot(transform.T, fim.basis, axes=[[0],[0]])
basis = jnp.concatenate([fim.basis, fim.samples], 0)
kernel = jnp.tensordot(basis, basis, [sum_dims,sum_dims])
s, v = jnp.linalg.eigh(kernel)
s = jnp.maximum(s, 0.0)
size = float(jnp.size(param))
kernel_spill = jnp.sum(s[:-rank]) / size
#transform = jax.random.orthogonal(rng_key, rank)
#u = jnp.dot(v[:,-rank:], transform.T)
basis = jnp.tensordot(v[:,-rank:], basis, axes=[[0],[0]])
kernel = s[-rank:]
return fim.replace(rank=0,
kernel=kernel,
basis=basis,
kernel_trace=fim.kernel_trace+kernel_spill,
samples_trace=0.0)
@jax.jit
def transform_update_block(update, fim: InformationState):
X = jnp.tensordot(fim.basis, update, update.ndim)
size = float(jnp.size(update))
trace = fim.kernel_trace + fim.samples_trace / size
X = X / (fim.kernel + trace)
#X = jax.scipy.linalg.solve(fim.kernel + fim.trace * jnp.eye(fim.kernel.shape[0]), X, assume_a='pos')
return (jnp.tensordot(fim.basis, X, [[0], [0]]) - update) * lr / trace
def transform_update(updates, state: State, params=None):
updates = jax.tree_util.tree_map(
transform_update_block, updates, state.fim)
state = jax.lax.cond(state.rank == buffer_rank,
lambda s, p: compress_samples(s, p), lambda s, p: s,
state, params)
return updates, state
@jax.jit
def augment_samples(information_samples, state: State):
fim = jax.tree_util.tree_map(
augment_samples_block, information_samples, state.fim)
return state.replace(fim=fim, rank=state.rank+1)
@jax.jit
def compress_samples(state: State, params):
#rng_key, subkey = jax.random.split(state.rng_key)
#treedef = jax.tree_structure(params)
#subkeys = jax.random.split(subkey, treedef.num_leaves)
#subkeys = jax.tree_unflatten(treedef, subkeys)
fim = jax.tree_util.tree_map(
compress_samples_block, params, state.fim)
return state.replace(fim=fim, rank=0)
def consume_sample(information_samples, state: State, params=None):
return augment_samples(information_samples, state)
return training.NaturalGradientTransformation(init, transform_update, consume_sample)