From 0ccaa877248075e0e72aeff4d998f00d11d9ad3a Mon Sep 17 00:00:00 2001 From: "Randall C. O'Reilly" Date: Fri, 22 Mar 2024 12:22:01 -0700 Subject: [PATCH] vspatch: was missing the non-reward DA that down-trains NR responses -- this is critical and really impairs learning. now obvious that the main issue is need for opponent D1 / D2 pathways so that D2 learns when not to respond and D1 learns when to respond. --- examples/vspatch/vspatch.go | 53 ++++++++++++++++++++------------- examples/vspatch/vspatch_env.go | 4 +-- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/examples/vspatch/vspatch.go b/examples/vspatch/vspatch.go index 1565c43ca..cda4c32c3 100644 --- a/examples/vspatch/vspatch.go +++ b/examples/vspatch/vspatch.go @@ -332,27 +332,24 @@ func (ss *Sim) ApplyPVLV(ev *VSPatchEnv, trial int, di uint32) { if trial == ev.NTrials-1 { axon.SetGlbV(ctx, di, axon.GvACh, 1) - ss.ApplyRew(ev, di) + ss.ApplyRew(di, ev.Rew) } else { - axon.GlobalSetRew(ctx, di, 0, false) // no rew + ss.ApplyRew(di, 0) axon.SetGlbV(ctx, di, axon.GvACh, 0) - axon.SetGlbV(ctx, di, axon.GvDA, 0) } } // ApplyRew applies reward -func (ss *Sim) ApplyRew(ev *VSPatchEnv, di uint32) { - // note: not using RPE here at this point - rew := ev.Rew - ss.SetRew(rew, di) -} - -func (ss *Sim) SetRew(rew float32, di uint32) { +func (ss *Sim) ApplyRew(di uint32, rew float32) { ctx := &ss.Context pv := &ss.Net.PVLV + if rew != 0 { + axon.GlobalSetRew(ctx, di, rew, true) + } else { + axon.GlobalSetRew(ctx, di, rew, false) + } vsp := axon.GlbV(ctx, uint32(di), axon.GvRewPred) dap := rew - vsp - axon.GlobalSetRew(ctx, di, rew, true) if rew > 0 { pv.SetUS(ctx, di, axon.Positive, 0, 1) } else if rew < 0 { @@ -387,7 +384,9 @@ func (ss *Sim) NewRun() { func (ss *Sim) InitStats() { ss.Stats.SetFloat("Rew", 0) ss.Stats.SetFloat("RewPred", 0) - ss.Stats.SetFloat("RPE", 0) + ss.Stats.SetFloat("RewPred_NR", 0) + ss.Stats.SetFloat("DA", 0) + ss.Stats.SetFloat("DA_NR", 0) } // StatCounters saves current counters to Stats, so they are available for logging etc @@ -414,19 +413,26 @@ func (ss *Sim) NetViewCounters(tm etime.Times) { ss.TrialStats(di) // get trial stats for current di } ss.StatCounters(di) - ss.ViewUpdt.Text = ss.Stats.Print([]string{"Run", "Epoch", "Sequence", "Trial", "Di", "TrialName", "Cycle", "Rew", "RewPred", "RPE"}) + ss.ViewUpdt.Text = ss.Stats.Print([]string{"Run", "Epoch", "Sequence", "Trial", "Di", "TrialName", "Cycle", "Rew", "RewPred", "RewPred_NR", "DA", "DA_NR"}) } // TrialStats records the trial-level statistics func (ss *Sim) TrialStats(di int) { ctx := &ss.Context + diu := uint32(di) ev := ss.Envs.ByModeDi(ctx.Mode, di).(*VSPatchEnv) ss.Stats.SetInt("Cond", ev.Sequence.Cur) + hasRew := (axon.GlbV(ctx, diu, axon.GvHasRew) > 0) ss.Stats.SetFloat32("Rew", ev.Rew) - ev.RewPred = axon.GlbV(ctx, uint32(di), axon.GvRewPred) - ev.RPE = axon.GlbV(ctx, uint32(di), axon.GvDA) - ss.Stats.SetFloat32("RewPred", ev.RewPred) - ss.Stats.SetFloat32("RPE", ev.RPE) + ev.RewPred = axon.GlbV(ctx, diu, axon.GvRewPred) + ev.DA = axon.GlbV(ctx, diu, axon.GvDA) + if hasRew { + ss.Stats.SetFloat32("RewPred", ev.RewPred) + ss.Stats.SetFloat32("DA", ev.DA) + } else { + ss.Stats.SetFloat32("RewPred_NR", ev.RewPred) + ss.Stats.SetFloat32("DA_NR", ev.DA) + } } ////////////////////////////////////////////////////////////////////////////// @@ -446,7 +452,13 @@ func (ss *Sim) ConfigLogs() { li.Range.Max = 1.2 li = ss.Logs.AddStatAggItem("RewPred", etime.Run, etime.Epoch, etime.Sequence) li.Range.Max = 1.2 - li = ss.Logs.AddStatAggItem("RPE", etime.Run, etime.Epoch, etime.Sequence) + ss.Logs.AddStatAggItem("RewPred_NR", etime.Run, etime.Epoch, etime.Sequence) + li = ss.Logs.AddStatAggItem("DA", etime.Run, etime.Epoch, etime.Sequence) + li.Range.Min = -0.5 + li.Range.Max = 1 + li.FixMin = true + li.FixMax = true + li = ss.Logs.AddStatAggItem("DA_NR", etime.Run, etime.Epoch, etime.Sequence) li.Range.Min = -0.5 li.Range.Max = 1 li.FixMin = true @@ -455,14 +467,15 @@ func (ss *Sim) ConfigLogs() { // axon.LogAddDiagnosticItems(&ss.Logs, ss.Net, etime.Epoch, etime.Trial) - ss.Logs.PlotItems("Rew", "RewPred", "RPE") + ss.Logs.PlotItems("Rew", "RewPred", "RewPred_NR", "DA", "DA_NR") ss.Logs.CreateTables() ss.Logs.SetContext(&ss.Stats, ss.Net) // don't plot certain combinations we don't use - // ss.Logs.NoPlot(etime.Train, etime.Cycle) + ss.Logs.NoPlot(etime.Train, etime.Cycle) ss.Logs.NoPlot(etime.Train, etime.Trial) + ss.Logs.NoPlot(etime.Test, etime.Cycle) ss.Logs.NoPlot(etime.Test, etime.Trial) ss.Logs.NoPlot(etime.Test, etime.Run) // note: Analyze not plotted by default diff --git a/examples/vspatch/vspatch_env.go b/examples/vspatch/vspatch_env.go index 6a6221350..366de0367 100644 --- a/examples/vspatch/vspatch_env.go +++ b/examples/vspatch/vspatch_env.go @@ -69,8 +69,8 @@ type VSPatchEnv struct { // reward prediction from model RewPred float32 `edit:"-"` - // reward prediction error: Rew - RewPred - RPE float32 `edit:"-"` + // DA = reward prediction error: Rew - RewPred + DA float32 `edit:"-"` } func (ev *VSPatchEnv) Name() string {