-
Notifications
You must be signed in to change notification settings - Fork 74
/
Copy pathmodel_lls.go
62 lines (50 loc) · 1.26 KB
/
model_lls.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package optimus
import (
"fmt"
"io"
"io/ioutil"
"github.com/cdipaolo/goml/base"
"github.com/cdipaolo/goml/linear"
"go.uber.org/zap"
)
type llsModelConfig struct {
Alpha float64 `yaml:"alpha" default:"1e-3"`
Regularization float64 `yaml:"regularization" default:"6.0"`
MaxIterations int `yaml:"max_iterations" default:"1000"`
}
func (m *llsModelConfig) Config() interface{} {
return m
}
func (m *llsModelConfig) Create(log *zap.SugaredLogger) Model {
return &llsModel{
cfg: *m,
output: ioutil.Discard,
}
}
type llsModel struct {
cfg llsModelConfig
output io.Writer
}
func (m *llsModel) Train(trainingSet [][]float64, expectation []float64) (TrainedModel, error) {
model := linear.NewLeastSquares(base.BatchGA, m.cfg.Alpha, m.cfg.Regularization, m.cfg.MaxIterations, trainingSet, expectation)
//model.Output = m.output
if err := model.Learn(); err != nil {
return nil, err
}
return &trainedLLSModel{
model: model,
}, nil
}
type trainedLLSModel struct {
model *linear.LeastSquares
}
func (m *trainedLLSModel) Predict(vec []float64) (float64, error) {
prediction, err := m.model.Predict(vec)
if err != nil {
return 0.0, err
}
if len(prediction) == 0 {
return 0.0, fmt.Errorf("no prediction made")
}
return prediction[0], nil
}