-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathautodiff.py
executable file
·203 lines (144 loc) · 6.6 KB
/
autodiff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import numpy
class Tensor:
# 让Tensor又操作支持numpy数组,而不是单个数字,参照https://stackoverflow.com/questions/58408999/how-to-call-rsub-properly-with-a-numpy-array
__array_ufunc__ = None
def __init__(self, data, from_tensors=None, op=None, grad=None):
self.data = data # 数据
self.from_tensors = from_tensors # 是从什么Tensor得到的,保存计算图的历史
self.op = op # 操作符运算
# 梯度
if grad:
self.grad = grad
else:
self.grad = numpy.zeros(self.data.shape) if isinstance(self.data, numpy.ndarray) else 0
def __add__(self, other):
# 先判断other是否是常数,然后再调用
return add.forward([self, other]) if isinstance(other, Tensor) else add_with_const.forward([self, other])
def __sub__(self, other):
# other如果是常数,直接调用加法的常数计算
return sub.forward([self, other]) if isinstance(other, Tensor) else add_with_const.forward([self, -other])
def __rsub__(self, other):
# 常数-tensor ,则调用 rsub_with_const
return rsub_with_const.forward([self, other])
def __mul__(self, other):
# 先判断other是否是常数,然后再调用
return mul.forward([self, other]) if isinstance(other, Tensor) else mul_with_const.forward([self, other])
def __truediv__(self, other):
# tensor/常数 则直接使用乘法
return div.forward([self, other]) if isinstance(other, Tensor) else mul_with_const.forward([self, 1 / other])
def __rtruediv__(self, other):
# 常数/tensor,则调用 rdiv_with_const
return rdiv_with_const.forward([self, other])
def __neg__(self):
# 求负直接使用 0-tensor ,即__rsub__
return self.__rsub__(0)
def matmul(self, other):
# 不支持shape为1 的numpy,因为在转置时[1,2]的转置结果仍然为[1,2],此时需要换成[[1,2]]
return mul_with_matrix.forward([self, other])
def sum(self):
return sum.forward([self])
def mean(self):
# 平均使用 求和/数据的量
return sum.forward([self]) / self.data.size
def log(self):
return log.forward([self])
def exp(self):
return exp.forward([self])
def backward(self, grad=None):
# 判断y的梯度是否存在,如果不存在初始化和y.data一样类型的1的数据
if grad is None:
self.grad = grad = numpy.ones(self.data.shape) if isinstance(self.data, numpy.ndarray) else 1
# 如果op不存在,则说明该Tensor为根节点,其from_tensors也必然不存在,否则计算梯度
if self.op:
grad = self.op.backward(self.from_tensors, grad)
if self.from_tensors:
for i in range(len(grad)):
tensor = self.from_tensors[i]
# 把梯度加给对应的子Tensor,因为该Tensor可能参与多个运算
tensor.grad += grad[i]
# 子Tensor进行后向过程
tensor.backward(grad[i])
# 清空梯度,训练的时候,每个batch应该清空梯度
def zero_gard(self):
self.grad = numpy.zeros(self.data.shape) if isinstance(self.data, numpy.ndarray) else 0
__radd__ = __add__
__rmul__ = __mul__
class OP:
def forward(self, from_tensors):
pass
def backward(self, from_tensors, grad):
pass
class Add(OP):
def forward(self, from_tensors):
return Tensor(from_tensors[0].data + from_tensors[1].data, from_tensors, self)
def backward(self, from_tensors, grad):
return [grad, grad]
class AddWithConst(OP):
def forward(self, from_tensors):
return Tensor(from_tensors[0].data + from_tensors[1], from_tensors, self)
def backward(self, from_tensors, grad):
return [grad]
class Sub(OP):
def forward(self, from_tensors):
return Tensor(from_tensors[0].data - from_tensors[1].data, from_tensors, self)
def backward(self, from_tensors, grad):
return [grad, -grad]
class RSubWithConst(OP):
def forward(self, from_tensors):
return Tensor(-(from_tensors[0].data - from_tensors[1]), from_tensors, self)
def backward(self, from_tensors, grad):
return [-grad]
class Mul(OP):
def forward(self, from_tensors):
return Tensor(from_tensors[0].data * from_tensors[1].data, from_tensors, self)
def backward(self, from_tensors, grad):
return [from_tensors[1].data * grad, from_tensors[0].data * grad]
class MulWithConst(OP):
def forward(self, from_tensors):
return Tensor(from_tensors[0].data * from_tensors[1], from_tensors, self)
def backward(self, from_tensors, grad):
return [from_tensors[1] * grad]
class MulWithMatrix(OP):
def forward(self, from_tensors):
return Tensor(numpy.matmul(from_tensors[0].data, from_tensors[1].data), from_tensors, self)
def backward(self, from_tensors, grad):
# Useful formula: if Y=AB, then dA=dY B^T, dB=A^T dY
return [numpy.matmul(grad, from_tensors[1].data.T), numpy.matmul(from_tensors[0].data.T, grad)]
class Div(OP):
def forward(self, from_tensors):
return Tensor(from_tensors[0].data / from_tensors[1].data, from_tensors, self)
def backward(self, from_tensors, grad):
return [grad / from_tensors[1].data,
-grad * from_tensors[0].data / (from_tensors[1].data * from_tensors[1].data)]
class RDivWithConst(OP):
def forward(self, from_tensors):
return Tensor(from_tensors[1] / from_tensors[0].data, from_tensors, self)
def backward(self, from_tensors, grad):
return [-grad * from_tensors[1] / (from_tensors[0].data * from_tensors[0].data)]
class Sum(OP):
def forward(self, from_tensors):
return Tensor(numpy.sum(from_tensors[0].data), from_tensors, self)
def backward(self, from_tensors, grad):
return [grad * numpy.ones(from_tensors[0].data.shape)]
class Exp(OP):
def forward(self, from_tensors):
return Tensor(numpy.exp(from_tensors[0].data), from_tensors, self)
def backward(self, from_tensors, grad):
return [grad * numpy.exp(from_tensors[0].data)]
class Log(OP):
def forward(self, from_tensors):
return Tensor(numpy.log(from_tensors[0].data), from_tensors, self)
def backward(self, from_tensors, grad):
return [grad / from_tensors[0].data]
add = Add()
add_with_const = AddWithConst()
sub = Sub()
rsub_with_const = RSubWithConst()
mul = Mul()
mul_with_const = MulWithConst()
mul_with_matrix = MulWithMatrix()
div = Div()
rdiv_with_const = RDivWithConst()
sum = Sum()
exp = Exp()
log = Log()