-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdifferentiate.py
141 lines (95 loc) · 3.23 KB
/
differentiate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from weakref import WeakValueDictionary
from functools import wraps
import numpy as np
def varize(function):
@wraps(function)
def op(self, other):
if isinstance(other, int):
other = Variable(str(other), other, True)
return function(self, other)
return op
class Variable:
nametovar = WeakValueDictionary()
def __init__(self, name, val=None, fixed=False):
self.__class__.nametovar[name] = self
self.fixed = fixed
self.name = name
self.val = val
def __call__(self, varvalues):
for k, v in varvalues.items():
self.__class__.nametovar[k].val = v
return self.forward()
def forward(self):
return self.val
def backward(self, name):
if name == self.name:
return Variable(name='1', val=1, fixed=True)
return Variable(name='0', val=0, fixed=True)
def __str__(self):
return self.name
@varize
def __add__(self, other):
return Add(self, other, name='({} + {})'.format(self.name, other.name))
@varize
def __sub__(self, other):
return Add(self,
other * Variable(name='-1', val=-1, fixed=True),
name='({} - {})'.format(self.name, other.name))
@varize
def __mul__(self, other):
return Mul(self, other, name='({} * {})'.format(self.name, other.name))
@varize
def __truediv__(self, other):
return Div(self, other, name='({} / {})'.format(self.name, other.name))
class Op(Variable):
def __init__(self, a, b, *args, **kwargs):
super().__init__(*args, **kwargs)
self.a, self.b = a, b
class MonoOp(Variable):
def __init__(self, a, *args, **kwargs):
super().__init__(*args, **kwargs)
self.a = a
def checkself(function):
@wraps(function)
def backward(self, name):
v = super().backward(name)
if v.val: return v
return function(self, name)
return backward
class Add(Op):
def forward(self):
return self.a.forward() + self.b.forward()
@checkself
def backward(self, name):
return self.a.backward(name) + self.b.backward(name)
class Mul(Op):
def forward(self):
return self.a.forward() * self.b.forward()
@checkself
def backward(self, name):
return self.b * self.a.backward(name) + self.a * self.b.backward(name)
class Div(Op):
def forward(self):
return self.a.forward() / self.b.forward()
@checkself
def backward(self, name):
num = self.b * self.a.backward(name) - self.a * self.b.backward(name)
return num / (self.b * self.b)
class Heaviside(MonoOp):
def forward(self):
return (self.a.forward() > 0) * 1
@checkself
def backward(self, name): # this is not differentiable in 0, but whatever
return Variable(name='0', val=0, fixed=True)
class ReLU(MonoOp):
def forward(self):
return np.maximum(self.a.forward(), 0)
@checkself
def backward(self, name):
return Heaviside(self.a) * self.a.backward(name)
class Log(MonoOp):
def forward(self):
return np.log(self.a.forward())
@checkself
def backward(self, name):
return (Variable(name='1', val=1, fixed=True) / self.a) * self.a.backward()