diff --git a/examples/ra25/ra25.go b/examples/ra25/ra25.go index b872d7660..6a1688126 100644 --- a/examples/ra25/ra25.go +++ b/examples/ra25/ra25.go @@ -13,6 +13,7 @@ package main //go:generate core generate -add-types import ( + "fmt" "log" "os" @@ -20,10 +21,14 @@ import ( "cogentcore.org/core/base/mpi" "cogentcore.org/core/base/randx" "cogentcore.org/core/core" + "cogentcore.org/core/enums" "cogentcore.org/core/icons" "cogentcore.org/core/math32" "cogentcore.org/core/math32/vecint" + "cogentcore.org/core/plot" "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/datafs" + "cogentcore.org/core/tensor/stats/stats" "cogentcore.org/core/tensor/table" "cogentcore.org/core/tree" "github.com/emer/axon/v2/axon" @@ -31,7 +36,6 @@ import ( "github.com/emer/emergent/v2/egui" "github.com/emer/emergent/v2/emer" "github.com/emer/emergent/v2/env" - "github.com/emer/emergent/v2/estats" "github.com/emer/emergent/v2/etime" "github.com/emer/emergent/v2/looper" "github.com/emer/emergent/v2/netview" @@ -49,15 +53,6 @@ func main() { } } -// Times are the looping time levels for running and statistics. -type Times int32 //enums:enum - -const ( - Trial Times = iota - Epoch - Run -) - // LoopPhase is the phase of loop processing for given time. type LoopPhase int32 //enums:enum @@ -207,12 +202,6 @@ type Sim struct { // contains looper control loops for running sim Loops *looper.Stacks `new-window:"+" display:"no-inline"` - // contains computed statistic values - Stats estats.Stats `new-window:"+"` - - // Contains all the logs and information about the logs.' - // Logs elog.Logs `new-window:"+"` - // the training patterns to use Pats *table.Table `display:"no-inline"` @@ -222,6 +211,18 @@ type Sim struct { // netview update parameters ViewUpdate netview.ViewUpdate `display:"add-fields"` + // Root is the root data dir. + Root *datafs.Data `display:"-"` + + // StatFuncs are statistics functions, per stat, handles everything. + StatFuncs []func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) `display:"-"` + + // Stats has the stats dir. + Stats *datafs.Data `display:"-"` + + // Current has the current stats values. + Current *datafs.Data `display:"-"` + // manages all the gui elements GUI egui.GUI `display:"-"` @@ -232,9 +233,9 @@ type Sim struct { // New creates new blank elements and initializes defaults func (ss *Sim) New() { econfig.Config(&ss.Config, "config.toml") + ss.Root, _ = datafs.NewDir("Root") ss.Net = axon.NewNetwork("RA25") ss.Params.Config(ParamSets, ss.Config.Params.Sheet, ss.Config.Params.Tag, ss.Net) - ss.Stats.Init() ss.Pats = table.New() ss.RandSeeds.Init(100) // max 100 runs ss.InitRandSeed(0) @@ -249,8 +250,8 @@ func (ss *Sim) ConfigAll() { ss.OpenPats() ss.ConfigEnv() ss.ConfigNet(ss.Net) - ss.ConfigLogs() ss.ConfigLoops() + ss.ConfigStats() if ss.Config.Params.SaveAll { ss.Config.Params.SaveAll = false ss.Net.SaveParamsSnapshot(&ss.Params.Params, &ss.Config, ss.Config.Params.Good) @@ -339,9 +340,9 @@ func (ss *Sim) ApplyParams() { // Init restarts the run, and initializes everything, including network weights // and resets the epoch log table func (ss *Sim) Init() { - if ss.Config.GUI { - ss.Stats.SetString("RunName", ss.Params.RunName(0)) // in case user interactively changes tag - } + // if ss.Config.GUI { + // ss.Stats.SetString("RunName", ss.Params.RunName(0)) // in case user interactively changes tag + // } ss.Loops.ResetCounters() ss.InitRandSeed(0) // ss.ConfigEnv() // re-config env just in case a different set of patterns was @@ -412,8 +413,10 @@ func (ss *Sim) ConfigLoops() { } }) - ///////////////////////////////////////////// - // Logging + //////// Logging + + ls.AddOnStartToAll("StatsStart", ss.StatsStart) + ls.AddOnEndToAll("StatsStep", ss.StatsStep) // ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { // axon.LogTestErrors(&ss.Logs) @@ -426,8 +429,6 @@ func (ss *Sim) ConfigLoops() { // } // }) // - // ls.AddOnEndToAll("Log", ss.Log) - // axon.LooperResetLogBelow(ls, &ss.Logs) // ls.Loop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() { // trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur @@ -475,7 +476,7 @@ func (ss *Sim) ApplyInputs() { for di := uint32(0); di < ctx.NData; di++ { ev.Step() // note: must save env state for logging / stats due to data parallel re-use of same env - ss.Stats.SetStringDi("TrialName", int(di), ev.TrialName.Cur) + // ss.Stats.SetStringDi("TrialName", int(di), ev.TrialName.Cur) // todo: for _, lnm := range lays { ly := ss.Net.LayerByName(lnm) pats := ev.State(ly.Name) @@ -537,16 +538,109 @@ func (ss *Sim) OpenPats() { } } -//////////////////////////////////////////////////////////////////////////////////////////// -// Stats +//////// Stats + +func (ss *Sim) AddStat(f func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase)) { + ss.StatFuncs = append(ss.StatFuncs, f) +} + +func (ss *Sim) StatsStart(lmd, ltm enums.Enum) { + lmode := lmd.(etime.Modes) + ltime := ltm.(etime.Times) + ss.RunStats(lmode, ltime, Start) +} + +func (ss *Sim) StatsStep(lmd, ltm enums.Enum) { + lmode := lmd.(etime.Modes) + ltime := ltm.(etime.Times) + if ltime >= etime.Epoch { + fmt.Println("running stats:", lmode, ltime) + } + ss.RunStats(lmode, ltime, Step) +} + +func (ss *Sim) RunStats(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) { + for _, sf := range ss.StatFuncs { + sf(lmode, ltime, lphase) + } +} -// InitStats initializes all the statistics. -// called at start of new run func (ss *Sim) InitStats() { - // ss.Stats.SetFloat("UnitErr", 0.0) - // ss.Stats.SetFloat("PhaseDiff", 0.0) - // ss.Stats.SetString("TrialName", "") - // ss.Logs.InitErrStats() // inits TrlErr, FirstZero, LastZero, NZero + for mode, st := range ss.Loops.Stacks { + for _, tm := range st.Order { + ctm := tm.(etime.Times) + ss.RunStats(mode.(etime.Modes), ctm, Start) + } + } +} + +func (ss *Sim) ConfigStats() { + ss.Stats, _ = ss.Root.Mkdir("Stats") + ss.Current, _ = ss.Stats.Mkdir("Current") + for mode, st := range ss.Loops.Stacks { + for _, tm := range st.Order { + ctm := tm.(etime.Times) + ss.AddStat(func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) { + if ltime > ctm { // don't record counter for time above it + return + } + name := tm.String() // name of stat = time + modeDir := ss.Stats.RecycleDir(lmode.String()) + timeDir := modeDir.RecycleDir(ltime.String()) + tsr := datafs.Value[int](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0) + }) + plot.SetStylersTo(tsr, ps) + } + return + } + ctv := ss.Loops.Stacks[mode].Loops[tm].Counter.Cur + datafs.Scalar[int](ss.Current, name).SetInt1D(ctv, 0) + tsr.AppendRowInt(ctv) + }) + } + } + // note: it is essential to only have 1 per func + // so generic names can be used for everything. + ss.AddStat(func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) { + name := "UnitErr" + modeDir := ss.Stats.RecycleDir(lmode.String()) + timeDir := modeDir.RecycleDir(ltime.String()) + tsr := datafs.Value[float64](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0).SetMax(1) + s.On = true + }) + plot.SetStylersTo(tsr, ps) + } + return + } + switch ltime { + case etime.Trial: + out := ss.Net.LayerByName("Output") + stat := out.PctUnitErr(ss.Net.Context()) + for di := 0; di < ss.Config.Run.NData; di++ { + // todo: current needs di tensor + datafs.Scalar[float64](ss.Current, name).SetFloat(stat[di], 0) + tsr.AppendRowFloat(stat[di]) + } + case etime.Epoch: + subd := modeDir.RecycleDir(etime.Trial.String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + case etime.Run: + subd := modeDir.RecycleDir(etime.Epoch.String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + } + }) } // StatCounters saves current counters to Stats, so they are available for logging etc @@ -574,7 +668,7 @@ func (ss *Sim) NetViewCounters(tm etime.Times) { ss.TrialStats(di) // get trial stats for current di } ss.StatCounters(di) - ss.ViewUpdate.Text = ss.Stats.Print([]string{"Run", "Epoch", "Trial", "Di", "TrialName", "Cycle", "UnitErr", "TrlErr", "PhaseDiff"}) + // ss.ViewUpdate.Text = ss.Stats.Print([]string{"Run", "Epoch", "Trial", "Di", "TrialName", "Cycle", "UnitErr", "TrlErr", "PhaseDiff"}) } // TrialStats computes the trial-level statistics. @@ -672,6 +766,8 @@ func (ss *Sim) Log(mode etime.Modes, time etime.Times) { func (ss *Sim) ConfigGUI() { title := "Axon Random Associator" ss.GUI.MakeBody(ss, "ra25", title, `This demonstrates a basic Axon model. See emergent on GitHub.

`) + ss.GUI.FS = ss.Root + ss.GUI.DataRoot = "Root" ss.GUI.CycleUpdateInterval = 10 nv := ss.GUI.AddNetView("Network") @@ -683,6 +779,7 @@ func (ss *Sim) ConfigGUI() { nv.SceneXYZ().Camera.Pose.Pos.Set(0, 1, 2.75) // more "head on" than default which is more "top down" nv.SceneXYZ().Camera.LookAt(math32.Vec3(0, 0, 0), math32.Vec3(0, 1, 0)) + ss.GUI.UpdateFiles() // ss.GUI.AddPlots(title, &ss.Logs) ss.GUI.FinalizeGUI(false) // if ss.Config.Run.GPU { diff --git a/examples/ra25/ra25.goal b/examples/ra25/ra25.goal index 7bc7faf49..31e80edb0 100644 --- a/examples/ra25/ra25.goal +++ b/examples/ra25/ra25.goal @@ -11,6 +11,7 @@ package main //go:generate core generate -add-types import ( + "fmt" "log" "os" @@ -18,10 +19,14 @@ import ( "cogentcore.org/core/base/mpi" "cogentcore.org/core/base/randx" "cogentcore.org/core/core" + "cogentcore.org/core/enums" "cogentcore.org/core/icons" "cogentcore.org/core/math32" "cogentcore.org/core/math32/vecint" + "cogentcore.org/core/plot" "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/datafs" + "cogentcore.org/core/tensor/stats/stats" "cogentcore.org/core/tensor/table" "cogentcore.org/core/tree" "github.com/emer/axon/v2/axon" @@ -29,7 +34,6 @@ import ( "github.com/emer/emergent/v2/egui" "github.com/emer/emergent/v2/emer" "github.com/emer/emergent/v2/env" - "github.com/emer/emergent/v2/estats" "github.com/emer/emergent/v2/etime" "github.com/emer/emergent/v2/looper" "github.com/emer/emergent/v2/netview" @@ -47,15 +51,6 @@ func main() { } } -// Times are the looping time levels for running and statistics. -type Times int32 //enums:enum - -const ( - Trial Times = iota - Epoch - Run -) - // LoopPhase is the phase of loop processing for given time. type LoopPhase int32 //enums:enum @@ -205,12 +200,6 @@ type Sim struct { // contains looper control loops for running sim Loops *looper.Stacks `new-window:"+" display:"no-inline"` - // contains computed statistic values - Stats estats.Stats `new-window:"+"` - - // Contains all the logs and information about the logs.' - // Logs elog.Logs `new-window:"+"` - // the training patterns to use Pats *table.Table `display:"no-inline"` @@ -220,6 +209,18 @@ type Sim struct { // netview update parameters ViewUpdate netview.ViewUpdate `display:"add-fields"` + // Root is the root data dir. + Root *datafs.Data `display:"-"` + + // StatFuncs are statistics functions, per stat, handles everything. + StatFuncs []func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) `display:"-"` + + // Stats has the stats dir. + Stats *datafs.Data `display:"-"` + + // Current has the current stats values. + Current *datafs.Data `display:"-"` + // manages all the gui elements GUI egui.GUI `display:"-"` @@ -230,9 +231,9 @@ type Sim struct { // New creates new blank elements and initializes defaults func (ss *Sim) New() { econfig.Config(&ss.Config, "config.toml") + ss.Root, _ = datafs.NewDir("Root") ss.Net = axon.NewNetwork("RA25") ss.Params.Config(ParamSets, ss.Config.Params.Sheet, ss.Config.Params.Tag, ss.Net) - ss.Stats.Init() ss.Pats = table.New() ss.RandSeeds.Init(100) // max 100 runs ss.InitRandSeed(0) @@ -247,8 +248,8 @@ func (ss *Sim) ConfigAll() { ss.OpenPats() ss.ConfigEnv() ss.ConfigNet(ss.Net) - ss.ConfigLogs() ss.ConfigLoops() + ss.ConfigStats() if ss.Config.Params.SaveAll { ss.Config.Params.SaveAll = false ss.Net.SaveParamsSnapshot(&ss.Params.Params, &ss.Config, ss.Config.Params.Good) @@ -337,9 +338,9 @@ func (ss *Sim) ApplyParams() { // Init restarts the run, and initializes everything, including network weights // and resets the epoch log table func (ss *Sim) Init() { - if ss.Config.GUI { - ss.Stats.SetString("RunName", ss.Params.RunName(0)) // in case user interactively changes tag - } + // if ss.Config.GUI { + // ss.Stats.SetString("RunName", ss.Params.RunName(0)) // in case user interactively changes tag + // } ss.Loops.ResetCounters() ss.InitRandSeed(0) // ss.ConfigEnv() // re-config env just in case a different set of patterns was @@ -410,8 +411,10 @@ func (ss *Sim) ConfigLoops() { } }) - ///////////////////////////////////////////// - // Logging + //////// Logging + + ls.AddOnStartToAll("StatsStart", ss.StatsStart) + ls.AddOnEndToAll("StatsStep", ss.StatsStep) // ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { // axon.LogTestErrors(&ss.Logs) @@ -424,8 +427,6 @@ func (ss *Sim) ConfigLoops() { // } // }) // - // ls.AddOnEndToAll("Log", ss.Log) - // axon.LooperResetLogBelow(ls, &ss.Logs) // ls.Loop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() { // trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur @@ -473,7 +474,7 @@ func (ss *Sim) ApplyInputs() { for di := uint32(0); di < ctx.NData; di++ { ev.Step() // note: must save env state for logging / stats due to data parallel re-use of same env - ss.Stats.SetStringDi("TrialName", int(di), ev.TrialName.Cur) + // ss.Stats.SetStringDi("TrialName", int(di), ev.TrialName.Cur) // todo: for _, lnm := range lays { ly := ss.Net.LayerByName(lnm) pats := ev.State(ly.Name) @@ -535,16 +536,109 @@ func (ss *Sim) OpenPats() { } } -//////////////////////////////////////////////////////////////////////////////////////////// -// Stats +//////// Stats + +func (ss *Sim) AddStat(f func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase)) { + ss.StatFuncs = append(ss.StatFuncs, f) +} + +func (ss *Sim) StatsStart(lmd, ltm enums.Enum) { + lmode := lmd.(etime.Modes) + ltime := ltm.(etime.Times) + ss.RunStats(lmode, ltime, Start) +} + +func (ss *Sim) StatsStep(lmd, ltm enums.Enum) { + lmode := lmd.(etime.Modes) + ltime := ltm.(etime.Times) + if ltime >= etime.Epoch { + fmt.Println("running stats:", lmode, ltime) + } + ss.RunStats(lmode, ltime, Step) +} + +func (ss *Sim) RunStats(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) { + for _, sf := range ss.StatFuncs { + sf(lmode, ltime, lphase) + } +} -// InitStats initializes all the statistics. -// called at start of new run func (ss *Sim) InitStats() { - // ss.Stats.SetFloat("UnitErr", 0.0) - // ss.Stats.SetFloat("PhaseDiff", 0.0) - // ss.Stats.SetString("TrialName", "") - // ss.Logs.InitErrStats() // inits TrlErr, FirstZero, LastZero, NZero + for mode, st := range ss.Loops.Stacks { + for _, tm := range st.Order { + ctm := tm.(etime.Times) + ss.RunStats(mode.(etime.Modes), ctm, Start) + } + } +} + +func (ss *Sim) ConfigStats() { + ss.Stats, _ = ss.Root.Mkdir("Stats") + ss.Current, _ = ss.Stats.Mkdir("Current") + for mode, st := range ss.Loops.Stacks { + for _, tm := range st.Order { + ctm := tm.(etime.Times) + ss.AddStat(func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) { + if ltime > ctm { // don't record counter for time above it + return + } + name := tm.String() // name of stat = time + modeDir := ss.Stats.RecycleDir(lmode.String()) + timeDir := modeDir.RecycleDir(ltime.String()) + tsr := datafs.Value[int](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0) + }) + plot.SetStylersTo(tsr, ps) + } + return + } + ctv := ss.Loops.Stacks[mode].Loops[tm].Counter.Cur + datafs.Scalar[int](ss.Current, name).SetInt1D(ctv, 0) + tsr.AppendRowInt(ctv) + }) + } + } + // note: it is essential to only have 1 per func + // so generic names can be used for everything. + ss.AddStat(func(lmode etime.Modes, ltime etime.Times, lphase LoopPhase) { + name := "UnitErr" + modeDir := ss.Stats.RecycleDir(lmode.String()) + timeDir := modeDir.RecycleDir(ltime.String()) + tsr := datafs.Value[float64](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0).SetMax(1) + s.On = true + }) + plot.SetStylersTo(tsr, ps) + } + return + } + switch ltime { + case etime.Trial: + out := ss.Net.LayerByName("Output") + stat := out.PctUnitErr(ss.Net.Context()) + for di := 0; di < ss.Config.Run.NData; di++ { + // todo: current needs di tensor + datafs.Scalar[float64](ss.Current, name).SetFloat(stat[di], 0) + tsr.AppendRowFloat(stat[di]) + } + case etime.Epoch: + subd := modeDir.RecycleDir(etime.Trial.String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + case etime.Run: + subd := modeDir.RecycleDir(etime.Epoch.String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + } + }) } // StatCounters saves current counters to Stats, so they are available for logging etc @@ -572,7 +666,7 @@ func (ss *Sim) NetViewCounters(tm etime.Times) { ss.TrialStats(di) // get trial stats for current di } ss.StatCounters(di) - ss.ViewUpdate.Text = ss.Stats.Print([]string{"Run", "Epoch", "Trial", "Di", "TrialName", "Cycle", "UnitErr", "TrlErr", "PhaseDiff"}) + // ss.ViewUpdate.Text = ss.Stats.Print([]string{"Run", "Epoch", "Trial", "Di", "TrialName", "Cycle", "UnitErr", "TrlErr", "PhaseDiff"}) } // TrialStats computes the trial-level statistics. @@ -662,6 +756,8 @@ func (ss *Sim) Log(mode etime.Modes, time etime.Times) { func (ss *Sim) ConfigGUI() { title := "Axon Random Associator" ss.GUI.MakeBody(ss, "ra25", title, `This demonstrates a basic Axon model. See emergent on GitHub.

`) + ss.GUI.FS = ss.Root + ss.GUI.DataRoot = "Root" ss.GUI.CycleUpdateInterval = 10 nv := ss.GUI.AddNetView("Network") @@ -673,6 +769,7 @@ func (ss *Sim) ConfigGUI() { nv.SceneXYZ().Camera.Pose.Pos.Set(0, 1, 2.75) // more "head on" than default which is more "top down" nv.SceneXYZ().Camera.LookAt(math32.Vec3(0, 0, 0), math32.Vec3(0, 1, 0)) + ss.GUI.UpdateFiles() // ss.GUI.AddPlots(title, &ss.Logs) ss.GUI.FinalizeGUI(false) // if ss.Config.Run.GPU {