-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathparamoperation.py
36 lines (26 loc) · 1010 Bytes
/
paramoperation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from operation import *
class ParamOperation(Operation):
def __init__(self, param):
super().__init__()
self.param = param
def backward(self, output_grad):
# call the self._input_grad
# call the self._param_grad
print(self.output)
print(output_grad)
assert_same_shape(self.output, output_grad)
self.input_grad = self._input_grad(output_grad)
self.param_grad = self._param_grad(output_grad)
assert_same_shape(self.input_, self.input_grad)
# print("=== this is self.param")
# print(self.param)
# print("=== this is self.param_grad")
# print(type(self))
# print(self.param_grad)
assert_same_shape(self.param, self.param_grad)
# print("== = self input _grand")
# print(self.input_grad)
return self.input_grad
def _param_grad(self, output_grad):
# every subclass of ParamOperation should implement
raise NotImplementedError()