Skip to content

Commit

Permalink
vspatch: was missing the non-reward DA that down-trains NR responses …
Browse files Browse the repository at this point in the history
…-- this is critical and really impairs learning. now obvious that the main issue is need for opponent D1 / D2 pathways so that D2 learns when not to respond and D1 learns when to respond.
  • Loading branch information
rcoreilly committed Mar 22, 2024
1 parent 7e53ae4 commit 0ccaa87
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 22 deletions.
53 changes: 33 additions & 20 deletions examples/vspatch/vspatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,27 +332,24 @@ func (ss *Sim) ApplyPVLV(ev *VSPatchEnv, trial int, di uint32) {

if trial == ev.NTrials-1 {
axon.SetGlbV(ctx, di, axon.GvACh, 1)
ss.ApplyRew(ev, di)
ss.ApplyRew(di, ev.Rew)
} else {
axon.GlobalSetRew(ctx, di, 0, false) // no rew
ss.ApplyRew(di, 0)
axon.SetGlbV(ctx, di, axon.GvACh, 0)
axon.SetGlbV(ctx, di, axon.GvDA, 0)
}
}

// ApplyRew applies reward
func (ss *Sim) ApplyRew(ev *VSPatchEnv, di uint32) {
// note: not using RPE here at this point
rew := ev.Rew
ss.SetRew(rew, di)
}

func (ss *Sim) SetRew(rew float32, di uint32) {
func (ss *Sim) ApplyRew(di uint32, rew float32) {
ctx := &ss.Context
pv := &ss.Net.PVLV
if rew != 0 {
axon.GlobalSetRew(ctx, di, rew, true)
} else {
axon.GlobalSetRew(ctx, di, rew, false)
}
vsp := axon.GlbV(ctx, uint32(di), axon.GvRewPred)
dap := rew - vsp
axon.GlobalSetRew(ctx, di, rew, true)
if rew > 0 {
pv.SetUS(ctx, di, axon.Positive, 0, 1)
} else if rew < 0 {
Expand Down Expand Up @@ -387,7 +384,9 @@ func (ss *Sim) NewRun() {
func (ss *Sim) InitStats() {
ss.Stats.SetFloat("Rew", 0)
ss.Stats.SetFloat("RewPred", 0)
ss.Stats.SetFloat("RPE", 0)
ss.Stats.SetFloat("RewPred_NR", 0)
ss.Stats.SetFloat("DA", 0)
ss.Stats.SetFloat("DA_NR", 0)
}

// StatCounters saves current counters to Stats, so they are available for logging etc
Expand All @@ -414,19 +413,26 @@ func (ss *Sim) NetViewCounters(tm etime.Times) {
ss.TrialStats(di) // get trial stats for current di
}
ss.StatCounters(di)
ss.ViewUpdt.Text = ss.Stats.Print([]string{"Run", "Epoch", "Sequence", "Trial", "Di", "TrialName", "Cycle", "Rew", "RewPred", "RPE"})
ss.ViewUpdt.Text = ss.Stats.Print([]string{"Run", "Epoch", "Sequence", "Trial", "Di", "TrialName", "Cycle", "Rew", "RewPred", "RewPred_NR", "DA", "DA_NR"})
}

// TrialStats records the trial-level statistics
func (ss *Sim) TrialStats(di int) {
ctx := &ss.Context
diu := uint32(di)
ev := ss.Envs.ByModeDi(ctx.Mode, di).(*VSPatchEnv)
ss.Stats.SetInt("Cond", ev.Sequence.Cur)
hasRew := (axon.GlbV(ctx, diu, axon.GvHasRew) > 0)
ss.Stats.SetFloat32("Rew", ev.Rew)
ev.RewPred = axon.GlbV(ctx, uint32(di), axon.GvRewPred)
ev.RPE = axon.GlbV(ctx, uint32(di), axon.GvDA)
ss.Stats.SetFloat32("RewPred", ev.RewPred)
ss.Stats.SetFloat32("RPE", ev.RPE)
ev.RewPred = axon.GlbV(ctx, diu, axon.GvRewPred)
ev.DA = axon.GlbV(ctx, diu, axon.GvDA)
if hasRew {
ss.Stats.SetFloat32("RewPred", ev.RewPred)
ss.Stats.SetFloat32("DA", ev.DA)
} else {
ss.Stats.SetFloat32("RewPred_NR", ev.RewPred)
ss.Stats.SetFloat32("DA_NR", ev.DA)
}
}

//////////////////////////////////////////////////////////////////////////////
Expand All @@ -446,7 +452,13 @@ func (ss *Sim) ConfigLogs() {
li.Range.Max = 1.2
li = ss.Logs.AddStatAggItem("RewPred", etime.Run, etime.Epoch, etime.Sequence)
li.Range.Max = 1.2
li = ss.Logs.AddStatAggItem("RPE", etime.Run, etime.Epoch, etime.Sequence)
ss.Logs.AddStatAggItem("RewPred_NR", etime.Run, etime.Epoch, etime.Sequence)
li = ss.Logs.AddStatAggItem("DA", etime.Run, etime.Epoch, etime.Sequence)
li.Range.Min = -0.5
li.Range.Max = 1
li.FixMin = true
li.FixMax = true
li = ss.Logs.AddStatAggItem("DA_NR", etime.Run, etime.Epoch, etime.Sequence)
li.Range.Min = -0.5
li.Range.Max = 1
li.FixMin = true
Expand All @@ -455,14 +467,15 @@ func (ss *Sim) ConfigLogs() {

// axon.LogAddDiagnosticItems(&ss.Logs, ss.Net, etime.Epoch, etime.Trial)

ss.Logs.PlotItems("Rew", "RewPred", "RPE")
ss.Logs.PlotItems("Rew", "RewPred", "RewPred_NR", "DA", "DA_NR")

ss.Logs.CreateTables()

ss.Logs.SetContext(&ss.Stats, ss.Net)
// don't plot certain combinations we don't use
// ss.Logs.NoPlot(etime.Train, etime.Cycle)
ss.Logs.NoPlot(etime.Train, etime.Cycle)
ss.Logs.NoPlot(etime.Train, etime.Trial)
ss.Logs.NoPlot(etime.Test, etime.Cycle)
ss.Logs.NoPlot(etime.Test, etime.Trial)
ss.Logs.NoPlot(etime.Test, etime.Run)
// note: Analyze not plotted by default
Expand Down
4 changes: 2 additions & 2 deletions examples/vspatch/vspatch_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ type VSPatchEnv struct {
// reward prediction from model
RewPred float32 `edit:"-"`

// reward prediction error: Rew - RewPred
RPE float32 `edit:"-"`
// DA = reward prediction error: Rew - RewPred
DA float32 `edit:"-"`
}

func (ev *VSPatchEnv) Name() string {
Expand Down

0 comments on commit 0ccaa87

Please sign in to comment.