-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathblock.go
89 lines (73 loc) · 1.43 KB
/
block.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package nn
import (
"go4ml.xyz/nn/mx"
)
type Block interface {
Combine(*mx.Symbol) *mx.Symbol
}
func Combine(nn Block) *mx.Symbol {
symbolMu.Lock()
defer symbolMu.Unlock()
resetSymbolId(0)
return nn.Combine(mx.Input())
}
type BlockConnect struct {
blocks []Block
}
func (bc *BlockConnect) Combine(s *mx.Symbol) *mx.Symbol {
for _, b := range bc.blocks {
s = b.Combine(s)
}
return s
}
func Sequence(b ...Block) Block {
return &BlockConnect{b}
}
type BlockConcat struct {
blocks []Block
}
func (bc *BlockConcat) Combine(s *mx.Symbol) *mx.Symbol {
b := make([]*mx.Symbol, 0, len(bc.blocks))
for _, v := range bc.blocks {
if v != nil {
x := v.Combine(s)
b = append(b, x)
}
}
return mx.Concat(b...)
}
func Concat(b ...Block) Block {
return &BlockConcat{b}
}
type BlockStack struct {
blocks []Block
axis1 bool
}
func (bc *BlockStack) Combine(s *mx.Symbol) *mx.Symbol {
b := make([]*mx.Symbol, len(bc.blocks), len(bc.blocks))
for i, v := range bc.blocks {
b[i] = v.Combine(s)
}
if bc.axis1 {
return mx.Stack1(b...)
}
return mx.Stack(b...)
}
func TransStack(b ...Block) Block {
return &BlockStack{b, true}
}
func Stack(b ...Block) Block {
return &BlockStack{b, false}
}
type ResidualBlock struct {
blocks []Block
}
func Residual(a ...Block) Block {
return &ResidualBlock{a}
}
func (rcb *ResidualBlock) Combine(a *mx.Symbol) *mx.Symbol {
for _, n := range rcb.blocks {
a = mx.Add(a, n.Combine(a))
}
return a
}