-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathflax_qdense.py
144 lines (118 loc) · 3.97 KB
/
flax_qdense.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# IMSL Lab - University of Notre Dame
# Author: Clemens JS Schaefer
# Copied code from
# https://github.com/google/flax/blob/master/flax/linen/linear.py and
# modified to accomodate noise and quantization
from typing import (
Any,
Callable,
Sequence,
)
from flax.linen.module import Module, compact
from flax.linen.initializers import lecun_normal, zeros
import ml_collections
import jax
import jax.numpy as jnp
default_kernel_init = lecun_normal()
Array = Any
Dtype = Any
PRNGKey = Any
Shape = Sequence[int]
class QuantDense(Module):
"""A linear transformation applied over the last dimension of the input.
Attributes:
features: the number of output features.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
config: bit widths and other configurations
"""
features: int
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
config: dict = ml_collections.FrozenConfigDict({})
bits: int = 8
quant_act_sign: bool = True
g_scale: float = 0.
@compact
def __call__(self, inputs: Array, rng: Any = None) -> Array:
"""Applies a linear transformation to the inputs along the last
dimension.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
"kernel", self.kernel_init, (inputs.shape[-1], self.features)
)
kernel = jnp.asarray(kernel, self.dtype)
# Quantization.
if "weight" in self.config:
if self.bits != None:
kernel_fwd = self.config.weight(
bits=self.bits, g_scale=self.g_scale)(kernel)
else:
kernel_fwd = self.config.weight(
g_scale=self.g_scale)(kernel)
else:
kernel_fwd = kernel
if "act" in self.config:
if self.bits != None:
inpt_fwd = self.config.act(bits=self.bits, g_scale=self.g_scale)(
inputs, sign=self.quant_act_sign)
else:
inpt_fwd = self.config.act(g_scale=self.g_scale)(
inputs, sign=self.quant_act_sign)
else:
inpt_fwd = inputs
@jax.custom_vjp
def dot_general(inpt_fwd: Array, kernel_fwd: Array, inpt_bwd: Array,
kernel_bwd: Array, rng: PRNGKey) -> Array:
return jnp.dot(inpt_fwd, kernel_fwd)
def dot_general_fwd(
inpt_fwd: Array, kernel_fwd: Array, inpt_bwd: Array, kernel_bwd: Array,
rng: PRNGKey
) -> Array:
if rng is not None:
rng, prng = jax.random.split(rng, 2)
else:
prng = None
return dot_general(inpt_fwd, kernel_fwd, inpt_bwd, kernel_bwd, rng), (
inpt_bwd,
kernel_bwd,
prng,
)
def dot_general_bwd(res: tuple, g: Array) -> tuple:
(
inpt,
kernel,
rng,
) = res
g_inpt = g_weight = g
g_inpt_fwd = jnp.dot(g_inpt, jnp.transpose(kernel))
g_kernel_fwd = jnp.dot(jnp.transpose(inpt), g_weight)
return (g_inpt_fwd, g_kernel_fwd, None, None, None)
dot_general.defvjp(dot_general_fwd, dot_general_bwd)
y = dot_general(inpt_fwd, kernel_fwd, inputs, kernel, rng)
if self.use_bias:
bias = self.param("bias", self.bias_init, (self.features,))
bias = jnp.asarray(bias, self.dtype)
if "bias" in self.config:
if self.bits != None:
bias = self.config.bias(
bits=self.bits, g_scale=self.g_scale,
maxabs_w=jnp.max(jnp.abs(kernel)))(bias)
else:
bias = self.config.bias(
g_scale=self.g_scale,
maxabs_w=jnp.max(jnp.abs(kernel)))(bias)
y = y + bias
return y