-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.go
97 lines (86 loc) · 2.14 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/*
Package vae implements Auto-Encoding Variational Bayes Algorithm
https://arxiv.org/pdf/1312.6114.pdf
*/
package vae
import (
"go-ml.dev/pkg/base/model"
"go-ml.dev/pkg/base/model/hyperopt"
"go-ml.dev/pkg/nn"
"go-ml.dev/pkg/nn/mx"
"reflect"
)
// Latent is the default name of feature for the decoder
const LatentCol = "Latent"
// default batch size for auto-encoders training
const DefaultBatchSize = 32
/*
Model of the Variational Auto-Encoder
*/
type Model struct {
// size of hidden layer, half of input by default
Hidden int
// size of latent (encoder Output/decoder input) layer
Latent int
// latent layer tensor as output of encoder and input for decoder
// vae.LatentCol by default
Feature string
// generative output for decoder
// model.PredictedCol by default
Predicted string
// Mxnet Context
// mx.CPU by default
Context mx.Context
// batch size
// vae.DefaultBatchSize by default
BatchSize int
// random generator seed
// random by default
Seed int
// optimizer config
// nn.Adam{Lr:0.001} by default
Optimizer nn.OptimizerConf
// input width
// normally it's calculated from features
Width int
// B-VAE hyper-patameter
Beta float64
}
/*
Feed model with dataset
*/
func (e Model) Feed(ds model.Dataset) model.FatModel {
return func(workout model.Workout) (*model.Report, error) {
return train(e, ds, workout)
}
}
/*
EncoderCollection is the name of collection containing encoder model
*/
const EncoderCollection = "encoder"
/*
DecoderCollection is the name of collection containing decoder model
*/
const DecoderCollection = "decoder"
/*
RecoderCollection is the name of collection containing recoder model
*/
const RecoderCollection = "recoder"
/*
ModelFunc updates model with parameters for hyper-optimization
*/
func (e Model) ModelFunc(params hyperopt.Params) model.HungryModel {
return e.Apply(params)
}
/*
Apply parameters to define model specific
*/
func (e Model) Apply(params hyperopt.Params) Model {
hyperopt.Apply(params, map[string]reflect.Value{
"Hidden": reflect.ValueOf(&e.Hidden),
"Latent": reflect.ValueOf(&e.Latent),
"Beta": reflect.ValueOf(&e.Beta),
"Seed": reflect.ValueOf(&e.Seed),
})
return e
}