From 461ba3e4d1c4a82eb786d685b8e4c1481c9f5964 Mon Sep 17 00:00:00 2001
From: Kai O'Reilly
Date: Sat, 9 Mar 2024 14:38:39 -0800
Subject: [PATCH] doc and python formatting updates
---
examples/hip/hip.py | 1290 +++++++++++++++++++++++++++++------------
examples/ra25/ra25.py | 709 ++++++++++++++++------
2 files changed, 1457 insertions(+), 542 deletions(-)
diff --git a/examples/hip/hip.py b/examples/hip/hip.py
index 64590e7c7..a7b1e28d9 100755
--- a/examples/hip/hip.py
+++ b/examples/hip/hip.py
@@ -6,19 +6,48 @@
# hip project
-from leabra import go, leabra, emer, relpos, eplot, env, agg, patgen, prjn, etable, efile, split, etensor, params, netview, rand, erand, gi, giv, pygiv, pyparams, mat32, hip, evec, simat, metric
-
-import importlib as il #il.reload(ra25) -- doesn't seem to work for reasons unknown
+from leabra import (
+ go,
+ leabra,
+ emer,
+ relpos,
+ eplot,
+ env,
+ agg,
+ patgen,
+ prjn,
+ etable,
+ efile,
+ split,
+ etensor,
+ params,
+ netview,
+ rand,
+ erand,
+ gi,
+ giv,
+ pygiv,
+ pyparams,
+ mat32,
+ hip,
+ evec,
+ simat,
+ metric,
+)
+
+import importlib as il # il.reload(ra25) -- doesn't seem to work for reasons unknown
import io, sys, getopt
from datetime import datetime, timezone
# OuterLoopParams are the parameters to run for outer crossed factor testing
# var OuterLoopParams = []string{"SmallHip", "MedHip"} //, "BigHip"}
-OuterLoopParams = go.Slice_string(["MedHip"]) #, "BigHip"}
+OuterLoopParams = go.Slice_string(["MedHip"]) # , "BigHip"}
# InnerLoopParams are the parameters to run for inner crossed factor testing
# var InnerLoopParams = []string{"List020", "List040", "List050", "List060", "List070", "List080"} // , "List100"}
-InnerLoopParams = go.Slice_string(["List040", "List080", "List120", "List160", "List200"]) # , "List100"}
+InnerLoopParams = go.Slice_string(
+ ["List040", "List080", "List120", "List160", "List200"]
+) # , "List100"}
# import numpy as np
# import matplotlib
@@ -26,7 +55,7 @@
# import matplotlib.pyplot as plt
# plt.rcParams['svg.fonttype'] = 'none' # essential for not rendering fonts as paths
-# this will become Sim later..
+# this will become Sim later..
TheSim = 1
# LogPrec is precision for saving float values in logs
@@ -35,20 +64,24 @@
# note: we cannot use methods for callbacks from Go -- must be separate functions
# so below are all the callbacks from the GUI toolbar actions
+
def InitCB(recv, send, sig, data):
TheSim.Init()
TheSim.UpdateClassView()
TheSim.vp.SetNeedsFullRender()
+
def TrainCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
TheSim.ToolBar.UpdateActions()
TheSim.Train()
+
def StopCB(recv, send, sig, data):
TheSim.Stop()
+
def StepTrialCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
@@ -57,18 +90,21 @@ def StepTrialCB(recv, send, sig, data):
TheSim.UpdateClassView()
TheSim.vp.SetNeedsFullRender()
+
def StepEpochCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
TheSim.ToolBar.UpdateActions()
TheSim.TrainEpoch()
+
def StepRunCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
TheSim.ToolBar.UpdateActions()
TheSim.TrainRun()
+
def TestTrialCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
@@ -77,6 +113,7 @@ def TestTrialCB(recv, send, sig, data):
TheSim.UpdateClassView()
TheSim.vp.SetNeedsFullRender()
+
def TestItemCB2(recv, send, sig, data):
win = gi.Window(handle=recv)
vp = win.WinViewport2D()
@@ -84,9 +121,20 @@ def TestItemCB2(recv, send, sig, data):
if sig != gi.DialogAccepted:
return
val = gi.StringPromptDialogValue(dlg)
- idxs = TheSim.TestEnv.Table.RowsByString("Name", val, True, True) # contains, ignoreCase
+ idxs = TheSim.TestEnv.Table.RowsByString(
+ "Name", val, True, True
+ ) # contains, ignoreCase
if len(idxs) == 0:
- gi.PromptDialog(vp, gi.DlgOpts(Title="Name Not Found", Prompt="No patterns found containing: " + val), True, False, go.nil, go.nil)
+ gi.PromptDialog(
+ vp,
+ gi.DlgOpts(
+ Title="Name Not Found", Prompt="No patterns found containing: " + val
+ ),
+ True,
+ False,
+ go.nil,
+ go.nil,
+ )
else:
if not TheSim.IsRunning:
TheSim.IsRunning = True
@@ -95,10 +143,21 @@ def TestItemCB2(recv, send, sig, data):
TheSim.IsRunning = False
vp.SetNeedsFullRender()
+
def TestItemCB(recv, send, sig, data):
win = gi.Window(handle=recv)
- gi.StringPromptDialog(win.WinViewport2D(), "", "Test Item",
- gi.DlgOpts(Title="Test Item", Prompt="Enter the Name of a given input pattern to test (case insensitive, contains given string."), win, TestItemCB2)
+ gi.StringPromptDialog(
+ win.WinViewport2D(),
+ "",
+ "Test Item",
+ gi.DlgOpts(
+ Title="Test Item",
+ Prompt="Enter the Name of a given input pattern to test (case insensitive, contains given string.",
+ ),
+ win,
+ TestItemCB2,
+ )
+
def TestAllCB(recv, send, sig, data):
if not TheSim.IsRunning:
@@ -106,35 +165,46 @@ def TestAllCB(recv, send, sig, data):
TheSim.ToolBar.UpdateActions()
TheSim.RunTestAll()
+
def ResetRunLogCB(recv, send, sig, data):
TheSim.RunLog.SetNumRows(0)
TheSim.RunPlot.Update()
+
def NewRndSeedCB(recv, send, sig, data):
TheSim.NewRndSeed()
+
def ReadmeCB(recv, send, sig, data):
- gi.TheApp.OpenURL("https://github.com/emer/leabra/blob/master/examples/hip/README.md")
+ gi.TheApp.OpenURL(
+ "https://github.com/emer/leabra/blob/master/examples/hip/README.md"
+ )
+
def FilterSSE(et, row):
- return etable.Table(handle=et).CellFloat("SSE", row) > 0 # include error trials
+ return etable.Table(handle=et).CellFloat("SSE", row) > 0 # include error trials
+
def AggIfGt0(idx, val):
return val > 0
+
def AggIfEq0(idx, val):
return val == 0
+
def UpdtFuncNotRunning(act):
act.SetActiveStateUpdt(not TheSim.IsRunning)
-
+
+
def UpdtFuncRunning(act):
act.SetActiveStateUpdt(TheSim.IsRunning)
-
-#####################################################
+
+#####################################################
# Sim
+
class HipParams(pygiv.ClassViewObj):
"""
see def_params.go for the default params, and params.go for user-saved versions
@@ -144,7 +214,9 @@ class HipParams(pygiv.ClassViewObj):
def __init__(self):
super(HipParams, self).__init__()
self.ECSize = evec.Vec2i()
- self.SetTags("ECSize", 'desc:"size of EC in terms of overall pools (outer dimension)"')
+ self.SetTags(
+ "ECSize", 'desc:"size of EC in terms of overall pools (outer dimension)"'
+ )
self.ECPool = evec.Vec2i()
self.SetTags("ECPool", 'desc:"size of one EC pool"')
self.CA1Pool = evec.Vec2i()
@@ -164,9 +236,15 @@ def __init__(self):
self.ECPctAct = float()
self.SetTags("ECPctAct", 'desc:"percent activation in EC pool"')
self.MossyDel = float()
- self.SetTags("MossyDel", 'desc:"delta in mossy effective strength between minus and plus phase"')
+ self.SetTags(
+ "MossyDel",
+ 'desc:"delta in mossy effective strength between minus and plus phase"',
+ )
self.MossyDelTest = float()
- self.SetTags("MossyDelTest", 'desc:"delta in mossy strength for testing (relative to base param)"')
+ self.SetTags(
+ "MossyDelTest",
+ 'desc:"delta in mossy strength for testing (relative to base param)"',
+ )
def Update(hp):
hp.DGSize.X = int(float(hp.CA3Size.X) * hp.DGRatio)
@@ -180,14 +258,15 @@ def Defaults(hp):
hp.DGRatio = 1.5
# ratio
- hp.DGPCon = 0.25 # .35 is sig worse, .2 learns faster but AB recall is worse
+ hp.DGPCon = 0.25 # .35 is sig worse, .2 learns faster but AB recall is worse
hp.CA3PCon = 0.25
- hp.MossyPCon = 0.02 # .02 > .05 > .01 (for small net)
+ hp.MossyPCon = 0.02 # .02 > .05 > .01 (for small net)
hp.ECPctAct = 0.2
- hp.MossyDel = 4 # 4 > 2 -- best is 4 del on 4 rel baseline
- hp.MossyDelTest = 3 # for rel = 4: 3 > 2 > 0 > 4 -- 4 is very bad -- need a small amount..
-
+ hp.MossyDel = 4 # 4 > 2 -- best is 4 del on 4 rel baseline
+ hp.MossyDelTest = (
+ 3 # for rel = 4: 3 > 2 > 0 > 4 -- 4 is very bad -- need a small amount..
+ )
class PatParams(pygiv.ClassViewObj):
@@ -200,20 +279,31 @@ def __init__(self):
self.ListSize = int()
self.SetTags("ListSize", 'desc:"number of A-B, A-C patterns each"')
self.MinDiffPct = float()
- self.SetTags("MinDiffPct", 'desc:"minimum difference between item random patterns, as a proportion (0-1) of total active"')
+ self.SetTags(
+ "MinDiffPct",
+ 'desc:"minimum difference between item random patterns, as a proportion (0-1) of total active"',
+ )
self.DriftCtxt = bool()
- self.SetTags("DriftCtxt", 'desc:"use drifting context representations -- otherwise does bit flips from prototype"')
+ self.SetTags(
+ "DriftCtxt",
+ 'desc:"use drifting context representations -- otherwise does bit flips from prototype"',
+ )
self.CtxtFlipPct = float()
- self.SetTags("CtxtFlipPct", 'desc:"proportion (0-1) of active bits to flip for each context pattern, relative to a prototype, for non-drifting"')
+ self.SetTags(
+ "CtxtFlipPct",
+ 'desc:"proportion (0-1) of active bits to flip for each context pattern, relative to a prototype, for non-drifting"',
+ )
self.DriftPct = float()
- self.SetTags("DriftPct", 'desc:"percentage of active bits that drift, per step, for drifting context"')
+ self.SetTags(
+ "DriftPct",
+ 'desc:"percentage of active bits that drift, per step, for drifting context"',
+ )
def Defaults(pp):
- pp.ListSize = 20 # 10 is too small to see issues..
+ pp.ListSize = 20 # 10 is too small to see issues..
pp.MinDiffPct = 0.5
- pp.CtxtFlipPct = .25
- pp.DriftPct = .2
-
+ pp.CtxtFlipPct = 0.25
+ pp.DriftPct = 0.2
class Sim(pygiv.ClassViewObj):
@@ -246,17 +336,29 @@ def __init__(self):
self.TestLure = etable.Table()
self.SetTags("TestLure", 'view:"no-inline" desc:"Lure testing patterns to use"')
self.TrainAll = etable.Table()
- self.SetTags("TrainAll", 'view:"no-inline" desc:"all training patterns -- for pretrain"')
+ self.SetTags(
+ "TrainAll", 'view:"no-inline" desc:"all training patterns -- for pretrain"'
+ )
self.TrnTrlLog = etable.Table()
- self.SetTags("TrnTrlLog", 'view:"no-inline" desc:"training trial-level log data"')
+ self.SetTags(
+ "TrnTrlLog", 'view:"no-inline" desc:"training trial-level log data"'
+ )
self.TrnEpcLog = etable.Table()
- self.SetTags("TrnEpcLog", 'view:"no-inline" desc:"training epoch-level log data"')
+ self.SetTags(
+ "TrnEpcLog", 'view:"no-inline" desc:"training epoch-level log data"'
+ )
self.TstEpcLog = etable.Table()
- self.SetTags("TstEpcLog", 'view:"no-inline" desc:"testing epoch-level log data"')
+ self.SetTags(
+ "TstEpcLog", 'view:"no-inline" desc:"testing epoch-level log data"'
+ )
self.TstTrlLog = etable.Table()
- self.SetTags("TstTrlLog", 'view:"no-inline" desc:"testing trial-level log data"')
+ self.SetTags(
+ "TstTrlLog", 'view:"no-inline" desc:"testing trial-level log data"'
+ )
self.TstCycLog = etable.Table()
- self.SetTags("TstCycLog", 'view:"no-inline" desc:"testing cycle-level log data"')
+ self.SetTags(
+ "TstCycLog", 'view:"no-inline" desc:"testing cycle-level log data"'
+ )
self.RunLog = etable.Table()
self.SetTags("RunLog", 'view:"no-inline" desc:"summary log of each run"')
self.RunStats = etable.Table()
@@ -264,13 +366,21 @@ def __init__(self):
self.TstStats = etable.Table()
self.SetTags("TstStats", 'view:"no-inline" desc:"testing stats"')
self.SimMats = {}
- self.SetTags("SimMats", 'view:"no-inline" desc:"similarity matrix results for layers"')
+ self.SetTags(
+ "SimMats", 'view:"no-inline" desc:"similarity matrix results for layers"'
+ )
self.Params = params.Sets()
self.SetTags("Params", 'view:"no-inline" desc:"full collection of param sets"')
self.ParamSet = str()
- self.SetTags("ParamSet", 'desc:"which set of *additional* parameters to use -- always applies Base and optionaly this next if set"')
+ self.SetTags(
+ "ParamSet",
+ 'desc:"which set of *additional* parameters to use -- always applies Base and optionaly this next if set"',
+ )
self.Tag = str()
- self.SetTags("Tag", 'desc:"extra tag string to add to any file names output from sim (e.g., weights files, log files, params)"')
+ self.SetTags(
+ "Tag",
+ 'desc:"extra tag string to add to any file names output from sim (e.g., weights files, log files, params)"',
+ )
self.MaxRuns = int(10)
self.SetTags("MaxRuns", 'desc:"maximum number of model runs to perform"')
self.MaxEpcs = int(30)
@@ -278,68 +388,143 @@ def __init__(self):
self.PreTrainEpcs = int(5)
self.SetTags("PreTrainEpcs", 'desc:"number of epochs to run for pretraining"')
self.NZeroStop = int(1)
- self.SetTags("NZeroStop", 'desc:"if a positive number, training will stop after this many epochs with zero mem errors"')
+ self.SetTags(
+ "NZeroStop",
+ 'desc:"if a positive number, training will stop after this many epochs with zero mem errors"',
+ )
self.TrainEnv = env.FixedTable()
- self.SetTags("TrainEnv", 'desc:"Training environment -- contains everything about iterating over input / output patterns over training"')
+ self.SetTags(
+ "TrainEnv",
+ 'desc:"Training environment -- contains everything about iterating over input / output patterns over training"',
+ )
self.TestEnv = env.FixedTable()
- self.SetTags("TestEnv", 'desc:"Testing environment -- manages iterating over testing"')
+ self.SetTags(
+ "TestEnv", 'desc:"Testing environment -- manages iterating over testing"'
+ )
self.Time = leabra.Time()
self.SetTags("Time", 'desc:"leabra timing parameters and state"')
self.ViewOn = True
- self.SetTags("ViewOn", 'desc:"whether to update the network view while running"')
+ self.SetTags(
+ "ViewOn", 'desc:"whether to update the network view while running"'
+ )
self.TrainUpdt = leabra.TimeScales.AlphaCycle
- self.SetTags("TrainUpdt", 'desc:"at what time scale to update the display during training? Anything longer than Epoch updates at Epoch in this model"')
+ self.SetTags(
+ "TrainUpdt",
+ 'desc:"at what time scale to update the display during training? Anything longer than Epoch updates at Epoch in this model"',
+ )
self.TestUpdt = leabra.TimeScales.AlphaCycle
- self.SetTags("TestUpdt", 'desc:"at what time scale to update the display during testing? Anything longer than Epoch updates at Epoch in this model"')
+ self.SetTags(
+ "TestUpdt",
+ 'desc:"at what time scale to update the display during testing? Anything longer than Epoch updates at Epoch in this model"',
+ )
self.TestInterval = int(1)
- self.SetTags("TestInterval", 'desc:"how often to run through all the test patterns, in terms of training epochs -- can use 0 or -1 for no testing"')
+ self.SetTags(
+ "TestInterval",
+ 'desc:"how often to run through all the test patterns, in terms of training epochs -- can use 0 or -1 for no testing"',
+ )
self.MemThr = float(0.34)
- self.SetTags("MemThr", 'desc:"threshold to use for memory test -- if error proportion is below this number, it is scored as a correct trial"')
+ self.SetTags(
+ "MemThr",
+ 'desc:"threshold to use for memory test -- if error proportion is below this number, it is scored as a correct trial"',
+ )
# statistics: note use float64 as that is best for etable.Table
self.TestNm = str()
- self.SetTags("TestNm", 'inactive:"+" desc:"what set of patterns are we currently testing"')
+ self.SetTags(
+ "TestNm",
+ 'inactive:"+" desc:"what set of patterns are we currently testing"',
+ )
self.Mem = float()
- self.SetTags("Mem", 'inactive:"+" desc:"whether current trial\'s ECout met memory criterion"')
+ self.SetTags(
+ "Mem",
+ 'inactive:"+" desc:"whether current trial\'s ECout met memory criterion"',
+ )
self.TrgOnWasOffAll = float()
- self.SetTags("TrgOnWasOffAll", 'inactive:"+" desc:"current trial\'s proportion of bits where target = on but ECout was off ( < 0.5), for all bits"')
+ self.SetTags(
+ "TrgOnWasOffAll",
+ 'inactive:"+" desc:"current trial\'s proportion of bits where target = on but ECout was off ( < 0.5), for all bits"',
+ )
self.TrgOnWasOffCmp = float()
- self.SetTags("TrgOnWasOffCmp", 'inactive:"+" desc:"current trial\'s proportion of bits where target = on but ECout was off ( < 0.5), for only completion bits that were not active in ECin"')
+ self.SetTags(
+ "TrgOnWasOffCmp",
+ 'inactive:"+" desc:"current trial\'s proportion of bits where target = on but ECout was off ( < 0.5), for only completion bits that were not active in ECin"',
+ )
self.TrgOffWasOn = float()
- self.SetTags("TrgOffWasOn", 'inactive:"+" desc:"current trial\'s proportion of bits where target = off but ECout was on ( > 0.5)"')
+ self.SetTags(
+ "TrgOffWasOn",
+ 'inactive:"+" desc:"current trial\'s proportion of bits where target = off but ECout was on ( > 0.5)"',
+ )
self.TrlSSE = float()
self.SetTags("TrlSSE", 'inactive:"+" desc:"current trial\'s sum squared error"')
self.TrlAvgSSE = float()
- self.SetTags("TrlAvgSSE", 'inactive:"+" desc:"current trial\'s average sum squared error"')
+ self.SetTags(
+ "TrlAvgSSE",
+ 'inactive:"+" desc:"current trial\'s average sum squared error"',
+ )
self.TrlCosDiff = float()
- self.SetTags("TrlCosDiff", 'inactive:"+" desc:"current trial\'s cosine difference"')
+ self.SetTags(
+ "TrlCosDiff", 'inactive:"+" desc:"current trial\'s cosine difference"'
+ )
self.EpcSSE = float()
- self.SetTags("EpcSSE", 'inactive:"+" desc:"last epoch\'s total sum squared error"')
+ self.SetTags(
+ "EpcSSE", 'inactive:"+" desc:"last epoch\'s total sum squared error"'
+ )
self.EpcAvgSSE = float()
- self.SetTags("EpcAvgSSE", 'inactive:"+" desc:"last epoch\'s average sum squared error (average over trials, and over units within layer)"')
+ self.SetTags(
+ "EpcAvgSSE",
+ 'inactive:"+" desc:"last epoch\'s average sum squared error (average over trials, and over units within layer)"',
+ )
self.EpcPctErr = float()
- self.SetTags("EpcPctErr", 'inactive:"+" desc:"last epoch\'s percent of trials that had SSE > 0 (subject to .5 unit-wise tolerance)"')
+ self.SetTags(
+ "EpcPctErr",
+ 'inactive:"+" desc:"last epoch\'s percent of trials that had SSE > 0 (subject to .5 unit-wise tolerance)"',
+ )
self.EpcPctCor = float()
- self.SetTags("EpcPctCor", 'inactive:"+" desc:"last epoch\'s percent of trials that had SSE == 0 (subject to .5 unit-wise tolerance)"')
+ self.SetTags(
+ "EpcPctCor",
+ 'inactive:"+" desc:"last epoch\'s percent of trials that had SSE == 0 (subject to .5 unit-wise tolerance)"',
+ )
self.EpcCosDiff = float()
- self.SetTags("EpcCosDiff", 'inactive:"+" desc:"last epoch\'s average cosine difference for output layer (a normalized error measure, maximum of 1 when the minus phase exactly matches the plus)"')
+ self.SetTags(
+ "EpcCosDiff",
+ 'inactive:"+" desc:"last epoch\'s average cosine difference for output layer (a normalized error measure, maximum of 1 when the minus phase exactly matches the plus)"',
+ )
self.EpcPerTrlMSec = float()
- self.SetTags("EpcPerTrlMSec", 'inactive:"+" desc:"how long did the epoch take per trial in wall-clock milliseconds"')
+ self.SetTags(
+ "EpcPerTrlMSec",
+ 'inactive:"+" desc:"how long did the epoch take per trial in wall-clock milliseconds"',
+ )
self.FirstZero = int()
- self.SetTags("FirstZero", 'inactive:"+" desc:"epoch at when Mem err first went to zero"')
+ self.SetTags(
+ "FirstZero", 'inactive:"+" desc:"epoch at when Mem err first went to zero"'
+ )
self.NZero = int()
- self.SetTags("NZero", 'inactive:"+" desc:"number of epochs in a row with zero Mem err"')
+ self.SetTags(
+ "NZero", 'inactive:"+" desc:"number of epochs in a row with zero Mem err"'
+ )
# internal state - view:"-"
self.SumSSE = float()
- self.SetTags("SumSSE", 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"')
+ self.SetTags(
+ "SumSSE",
+ 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"',
+ )
self.SumAvgSSE = float()
- self.SetTags("SumAvgSSE", 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"')
+ self.SetTags(
+ "SumAvgSSE",
+ 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"',
+ )
self.SumCosDiff = float()
- self.SetTags("SumCosDiff", 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"')
+ self.SetTags(
+ "SumCosDiff",
+ 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"',
+ )
self.CntErr = int()
- self.SetTags("CntErr", 'view:"-" inactive:"+" desc:"sum of errs to increment as we go through epoch"')
+ self.SetTags(
+ "CntErr",
+ 'view:"-" inactive:"+" desc:"sum of errs to increment as we go through epoch"',
+ )
self.Win = 0
self.SetTags("Win", 'view:"-" desc:"main GUI window"')
self.NetView = 0
@@ -371,9 +556,15 @@ def __init__(self):
self.RunFile = 0
self.SetTags("RunFile", 'view:"-" desc:"log file"')
self.TmpVals = go.Slice_float32()
- self.SetTags("TmpVals", 'view:"-" desc:"temp slice for holding values -- prevent mem allocs"')
+ self.SetTags(
+ "TmpVals",
+ 'view:"-" desc:"temp slice for holding values -- prevent mem allocs"',
+ )
self.LayStatNms = go.Slice_string(["ECin", "ECout", "DG", "CA3", "CA1"])
- self.SetTags("LayStatNms", 'view:"-" desc:"names of layers to collect more detailed stats on (avg act, etc)"')
+ self.SetTags(
+ "LayStatNms",
+ 'view:"-" desc:"names of layers to collect more detailed stats on (avg act, etc)"',
+ )
self.TstNms = go.Slice_string(["AB", "AC", "Lure"])
self.SetTags("TstNms", 'view:"-" desc:"names of test tables"')
self.SimMatStats = go.Slice_string(["Within", "Between"])
@@ -383,24 +574,33 @@ def __init__(self):
self.ValsTsrs = {}
self.SetTags("ValsTsrs", 'view:"-" desc:"for holding layer values"')
self.SaveWts = False
- self.SetTags("SaveWts", 'view:"-" desc:"for command-line run only, auto-save final weights after each run"')
+ self.SetTags(
+ "SaveWts",
+ 'view:"-" desc:"for command-line run only, auto-save final weights after each run"',
+ )
self.PreTrainWts = ""
self.SetTags("PreTrainWts", 'view:"-" desc:"name of pretrained wts file"')
self.NoGui = False
self.SetTags("NoGui", 'view:"-" desc:"if true, runing in no GUI mode"')
self.LogSetParams = False
- self.SetTags("LogSetParams", 'view:"-" desc:"if true, print message for all params that are set"')
+ self.SetTags(
+ "LogSetParams",
+ 'view:"-" desc:"if true, print message for all params that are set"',
+ )
self.IsRunning = False
self.SetTags("IsRunning", 'view:"-" desc:"true if sim is running"')
self.StopNow = False
self.SetTags("StopNow", 'view:"-" desc:"flag to stop running"')
self.NeedsNewRun = False
- self.SetTags("NeedsNewRun", 'view:"-" desc:"flag to initialize NewRun if last one finished"')
+ self.SetTags(
+ "NeedsNewRun",
+ 'view:"-" desc:"flag to initialize NewRun if last one finished"',
+ )
self.RndSeed = int(2)
self.SetTags("RndSeed", 'view:"-" desc:"the current random seed"')
self.LastEpcTime = 0
self.SetTags("LastEpcTime", 'view:"-" desc:"timer for last epoch"')
- self.vp = 0
+ self.vp = 0
self.SetTags("vp", 'view:"-" desc:"viewport"')
def InitParams(ss):
@@ -414,7 +614,7 @@ def InitParams(ss):
def Defaults(ss):
ss.Hip.Defaults()
ss.Pat.Defaults()
- ss.Time.CycPerQtr = 25 # note: key param - 25 seems like it is actually fine?
+ ss.Time.CycPerQtr = 25 # note: key param - 25 seems like it is actually fine?
ss.Update()
def Update(ss):
@@ -436,18 +636,20 @@ def Config(ss):
ss.ConfigRunLog(ss.RunLog)
def ConfigEnv(ss):
- if ss.MaxRuns == 0: # allow user override
+ if ss.MaxRuns == 0: # allow user override
ss.MaxRuns = 10
- if ss.MaxEpcs == 0: # allow user override
+ if ss.MaxEpcs == 0: # allow user override
ss.MaxEpcs = 30
ss.NZeroStop = 1
- ss.PreTrainEpcs = 5 # seems sufficient?
+ ss.PreTrainEpcs = 5 # seems sufficient?
ss.TrainEnv.Nm = "TrainEnv"
ss.TrainEnv.Dsc = "training params and state"
ss.TrainEnv.Table = etable.NewIdxView(ss.TrainAB)
ss.TrainEnv.Validate()
- ss.TrainEnv.Run.Max = ss.MaxRuns # note: we are not setting epoch max -- do that manually
+ ss.TrainEnv.Run.Max = (
+ ss.MaxRuns
+ ) # note: we are not setting epoch max -- do that manually
ss.TestEnv.Nm = "TestEnv"
ss.TestEnv.Dsc = "testing params and state"
@@ -471,21 +673,51 @@ def SetEnv(ss, trainAC):
def ConfigNet(ss, net):
net.InitName(net, "Hip_bench")
hp = ss.Hip
- inl = net.AddLayer4D("Input", hp.ECSize.Y, hp.ECSize.X, hp.ECPool.Y, hp.ECPool.X, emer.Input)
- ecin = net.AddLayer4D("ECin", hp.ECSize.Y, hp.ECSize.X, hp.ECPool.Y, hp.ECPool.X, emer.Hidden)
- ecout = net.AddLayer4D("ECout", hp.ECSize.Y, hp.ECSize.X, hp.ECPool.Y, hp.ECPool.X, emer.Target)
- ca1 = net.AddLayer4D("CA1", hp.ECSize.Y, hp.ECSize.X, hp.CA1Pool.Y, hp.CA1Pool.X, emer.Hidden)
+ inl = net.AddLayer4D(
+ "Input", hp.ECSize.Y, hp.ECSize.X, hp.ECPool.Y, hp.ECPool.X, emer.Input
+ )
+ ecin = net.AddLayer4D(
+ "ECin", hp.ECSize.Y, hp.ECSize.X, hp.ECPool.Y, hp.ECPool.X, emer.Hidden
+ )
+ ecout = net.AddLayer4D(
+ "ECout", hp.ECSize.Y, hp.ECSize.X, hp.ECPool.Y, hp.ECPool.X, emer.Target
+ )
+ ca1 = net.AddLayer4D(
+ "CA1", hp.ECSize.Y, hp.ECSize.X, hp.CA1Pool.Y, hp.CA1Pool.X, emer.Hidden
+ )
dg = net.AddLayer2D("DG", hp.DGSize.Y, hp.DGSize.X, emer.Hidden)
ca3 = net.AddLayer2D("CA3", hp.CA3Size.Y, hp.CA3Size.X, emer.Hidden)
ecin.SetClass("EC")
ecout.SetClass("EC")
- ecin.SetRelPos(relpos.Rel(Rel= relpos.RightOf, Other= "Input", YAlign= relpos.Front, Space= 2))
- ecout.SetRelPos(relpos.Rel(Rel= relpos.RightOf, Other= "ECin", YAlign= relpos.Front, Space= 2))
- dg.SetRelPos(relpos.Rel(Rel= relpos.Above, Other= "Input", YAlign= relpos.Front, XAlign= relpos.Left, Space= 0))
- ca3.SetRelPos(relpos.Rel(Rel= relpos.Above, Other= "DG", YAlign= relpos.Front, XAlign= relpos.Left, Space= 0))
- ca1.SetRelPos(relpos.Rel(Rel= relpos.RightOf, Other= "CA3", YAlign= relpos.Front, Space= 2))
+ ecin.SetRelPos(
+ relpos.Rel(Rel=relpos.RightOf, Other="Input", YAlign=relpos.Front, Space=2)
+ )
+ ecout.SetRelPos(
+ relpos.Rel(Rel=relpos.RightOf, Other="ECin", YAlign=relpos.Front, Space=2)
+ )
+ dg.SetRelPos(
+ relpos.Rel(
+ Rel=relpos.Above,
+ Other="Input",
+ YAlign=relpos.Front,
+ XAlign=relpos.Left,
+ Space=0,
+ )
+ )
+ ca3.SetRelPos(
+ relpos.Rel(
+ Rel=relpos.Above,
+ Other="DG",
+ YAlign=relpos.Front,
+ XAlign=relpos.Left,
+ Space=0,
+ )
+ )
+ ca1.SetRelPos(
+ relpos.Rel(Rel=relpos.RightOf, Other="CA3", YAlign=relpos.Front, Space=2)
+ )
onetoone = prjn.NewOneToOne()
pool1to1 = prjn.NewPoolOneToOne()
@@ -511,8 +743,10 @@ def ConfigNet(ss, net):
pj = net.ConnectLayersPrjn(ecin, dg, ppathDG, emer.Forward, hip.CHLPrjn())
pj.SetClass("HippoCHL")
- if True: # toggle for bcm vs. ppath
- pj = net.ConnectLayersPrjn(ecin, ca3, ppathCA3, emer.Forward, hip.EcCa1Prjn())
+ if True: # toggle for bcm vs. ppath
+ pj = net.ConnectLayersPrjn(
+ ecin, ca3, ppathCA3, emer.Forward, hip.EcCa1Prjn()
+ )
pj.SetClass("PPath")
pj = net.ConnectLayersPrjn(ca3, ca3, full, emer.Lateral, hip.EcCa1Prjn())
pj.SetClass("PPath")
@@ -529,19 +763,21 @@ def ConfigNet(ss, net):
pj.SetClass("HippoCHL")
else:
# note: this requires lrate = 1.0 or maybe 1.2, doesn't work *nearly* as well
- pj = net.ConnectLayers(ca3, ca1, full, emer.Forward) # default con
+ pj = net.ConnectLayers(ca3, ca1, full, emer.Forward) # default con
# pj.SetClass("HippoCHL")
# Mossy fibers
mossy = prjn.NewUnifRnd()
mossy.PCon = hp.MossyPCon
- pj = net.ConnectLayersPrjn(dg, ca3, mossy, emer.Forward, hip.CHLPrjn()) # no learning
+ pj = net.ConnectLayersPrjn(
+ dg, ca3, mossy, emer.Forward, hip.CHLPrjn()
+ ) # no learning
pj.SetClass("HippoCHL")
# using 4 threads total (rest on 0)
dg.SetThread(1)
ca3.SetThread(2)
- ca1.SetThread(3) # this has the most
+ ca1.SetThread(3) # this has the most
# note: if you wanted to change a layer type from e.g., Target to Compare, do this:
# outLay.SetType(emer.Compare)
@@ -549,18 +785,17 @@ def ConfigNet(ss, net):
# and thus removes error-driven learning -- but stats are still computed.
net.Defaults()
- ss.SetParams("Network", ss.LogSetParams) # only set Network params
+ ss.SetParams("Network", ss.LogSetParams) # only set Network params
net.Build()
net.InitWts()
def ReConfigNet(ss):
ss.ConfigPats()
- ss.Net = leabra.Network() # start over with new network
+ ss.Net = leabra.Network() # start over with new network
ss.ConfigNet(ss.Net)
if ss.NetView != 0:
ss.NetView.SetNet(ss.Net)
- ss.NetView.Update() # issue #41 closed
-
+ ss.NetView.Update() # issue #41 closed
def Init(ss):
"""
@@ -590,9 +825,21 @@ def Counters(ss, train):
and add a few tabs at the end to allow for expansion..
"""
if train:
- return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TrainEnv.Trial.Cur, ss.Time.Cycle, ss.TrainEnv.TrialName.Cur)
+ return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (
+ ss.TrainEnv.Run.Cur,
+ ss.TrainEnv.Epoch.Cur,
+ ss.TrainEnv.Trial.Cur,
+ ss.Time.Cycle,
+ ss.TrainEnv.TrialName.Cur,
+ )
else:
- return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TestEnv.Trial.Cur, ss.Time.Cycle, ss.TestEnv.TrialName.Cur)
+ return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (
+ ss.TrainEnv.Run.Cur,
+ ss.TrainEnv.Epoch.Cur,
+ ss.TestEnv.Trial.Cur,
+ ss.Time.Cycle,
+ ss.TestEnv.TrialName.Cur,
+ )
def UpdateView(ss, train):
if ss.NetView != 0 and ss.NetView.IsVisible():
@@ -610,7 +857,7 @@ def AlphaCyc(ss, train):
"""
if ss.Win != 0:
- ss.Win.PollEvents() # this is essential for GUI responsiveness while running
+ ss.Win.PollEvents() # this is essential for GUI responsiveness while running
viewUpdt = ss.TrainUpdt.value
if not train:
viewUpdt = ss.TestUpdt.value
@@ -635,11 +882,11 @@ def AlphaCyc(ss, train):
ca3FmDg.WtScale.Rel = dgwtscale - ss.Hip.MossyDel
if train:
- ecout.SetType(emer.Target) # clamp a plus phase during testing
+ ecout.SetType(emer.Target) # clamp a plus phase during testing
else:
- ecout.SetType(emer.Compare) # don't clamp
+ ecout.SetType(emer.Compare) # don't clamp
- ecout.UpdateExtFlags() # call this after updating type
+ ecout.UpdateExtFlags() # call this after updating type
ss.Net.AlphaCycInit()
ss.Time.AlphaCycStart()
@@ -651,32 +898,34 @@ def AlphaCyc(ss, train):
ss.Time.CycleInc()
if ss.ViewOn:
if viewUpdt == leabra.Cycle:
- if cyc != ss.Time.CycPerQtr-1: # will be updated by quarter
+ if cyc != ss.Time.CycPerQtr - 1: # will be updated by quarter
ss.UpdateView(train)
if viewUpdt == leabra.FastSpike:
- if (cyc+1)%10 == 0:
+ if (cyc + 1) % 10 == 0:
ss.UpdateView(train)
- if qtr == 1: # Second, Third Quarters: CA1 is driven by CA3 recall
+ if qtr == 1: # Second, Third Quarters: CA1 is driven by CA3 recall
ca1FmECin.WtScale.Abs = 0
ca1FmCa3.WtScale.Abs = 1
if train:
ca3FmDg.WtScale.Rel = dgwtscale
else:
- ca3FmDg.WtScale.Rel = dgwtscale - ss.Hip.MossyDelTest # testing
+ ca3FmDg.WtScale.Rel = dgwtscale - ss.Hip.MossyDelTest # testing
- ss.Net.GScaleFmAvgAct() # update computed scaling factors
- ss.Net.InitGInc() # scaling params change, so need to recompute all netins
- if qtr == 3: # Fourth Quarter: CA1 back to ECin drive only
+ ss.Net.GScaleFmAvgAct() # update computed scaling factors
+ ss.Net.InitGInc() # scaling params change, so need to recompute all netins
+ if qtr == 3: # Fourth Quarter: CA1 back to ECin drive only
ca1FmECin.WtScale.Abs = 1
ca1FmCa3.WtScale.Abs = 0
- ss.Net.GScaleFmAvgAct() # update computed scaling factors
- ss.Net.InitGInc() # scaling params change, so need to recompute all netins
- if train: # clamp ECout from ECin
- ecin.UnitVals(ss.TmpVals, "Act") # note: could use input instead -- not much diff
+ ss.Net.GScaleFmAvgAct() # update computed scaling factors
+ ss.Net.InitGInc() # scaling params change, so need to recompute all netins
+ if train: # clamp ECout from ECin
+ ecin.UnitVals(
+ ss.TmpVals, "Act"
+ ) # note: could use input instead -- not much diff
ecout.ApplyExt1D32(ss.TmpVals)
ss.Net.QuarterFinal(ss.Time)
- if qtr+1 == 3:
- ss.MemStats(train) # must come after QuarterFinal
+ if qtr + 1 == 3:
+ ss.MemStats(train) # must come after QuarterFinal
ss.Time.QuarterInc()
if ss.ViewOn:
@@ -686,7 +935,7 @@ def AlphaCyc(ss, train):
if qtr >= 2:
ss.UpdateView(train)
- ca3FmDg.WtScale.Rel = dgwtscale # restore
+ ca3FmDg.WtScale.Rel = dgwtscale # restore
ca1FmCa3.WtScale.Abs = 1
if train:
@@ -695,8 +944,7 @@ def AlphaCyc(ss, train):
ss.UpdateView(train)
if not train:
if ss.TstCycPlot != 0:
- ss.TstCycPlot.GoUpdate() # make sure up-to-date at end
-
+ ss.TstCycPlot.GoUpdate() # make sure up-to-date at end
def ApplyInputs(ss, en):
"""
@@ -708,7 +956,7 @@ def ApplyInputs(ss, en):
ss.Net.InitExt()
lays = go.Slice_string(["Input", "ECout"])
- for lnm in lays :
+ for lnm in lays:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
pats = en.State(ly.Nm)
if pats != 0:
@@ -731,15 +979,19 @@ def TrainTrial(ss):
ss.LogTrnEpc(ss.TrnEpcLog)
if ss.ViewOn and ss.TrainUpdt.value > leabra.AlphaCycle:
ss.UpdateView(True)
- if ss.TestInterval > 0 and epc%ss.TestInterval == 0: # note: epc is *next* so won't trigger first time
+ if (
+ ss.TestInterval > 0 and epc % ss.TestInterval == 0
+ ): # note: epc is *next* so won't trigger first time
ss.TestAll()
- learned = (ss.NZeroStop > 0 and ss.NZero >= ss.NZeroStop)
- if ss.TrainEnv.Table.Table.MetaData["name"] == "TrainAB" and (learned or epc == ss.MaxEpcs/2):
+ learned = ss.NZeroStop > 0 and ss.NZero >= ss.NZeroStop
+ if ss.TrainEnv.Table.Table.MetaData["name"] == "TrainAB" and (
+ learned or epc == ss.MaxEpcs / 2
+ ):
ss.TrainEnv.Table = etable.NewIdxView(ss.TrainAC)
learned = False
- if learned or epc >= ss.MaxEpcs: # done with training..
+ if learned or epc >= ss.MaxEpcs: # done with training..
ss.RunEnd()
- if ss.TrainEnv.Run.Incr(): # we are done!
+ if ss.TrainEnv.Run.Incr(): # we are done!
ss.StopNow = True
return
else:
@@ -747,8 +999,8 @@ def TrainTrial(ss):
return
ss.ApplyInputs(ss.TrainEnv)
- ss.AlphaCyc(True) # train
- ss.TrialStats(True) # accumulate
+ ss.AlphaCyc(True) # train
+ ss.TrialStats(True) # accumulate
ss.LogTrnTrl(ss.TrnTrlLog)
def PreTrainTrial(ss):
@@ -768,13 +1020,13 @@ def PreTrainTrial(ss):
ss.LogTrnEpc(ss.TrnEpcLog)
if ss.ViewOn and ss.TrainUpdt.value > leabra.AlphaCycle:
ss.UpdateView(True)
- if epc >= ss.PreTrainEpcs: # done with training..
+ if epc >= ss.PreTrainEpcs: # done with training..
ss.StopNow = True
return
ss.ApplyInputs(ss.TrainEnv)
- ss.AlphaCyc(True) # train
- ss.TrialStats(True) # accumulate
+ ss.AlphaCyc(True) # train
+ ss.TrialStats(True) # accumulate
ss.LogTrnTrl(ss.TrnTrlLog)
def RunEnd(ss):
@@ -785,7 +1037,7 @@ def RunEnd(ss):
if ss.SaveWts:
fnm = ss.WeightsFileName()
print("Saving Weights to: %s\n" % fnm)
- ss.Net.SaveWtsJSON(gi.FileName(fnm))
+ ss.Net.SaveWtsJSON(gi.Filename(fnm))
def NewRun(ss):
"""
@@ -847,8 +1099,8 @@ def MemStats(ss, train):
nn = ecout.Shape().Len()
trgOnWasOffAll = 0.0
trgOnWasOffCmp = 0.0
- trgOffWasOn = 0.0 # should have been off
- cmpN = 0.0 # completion target
+ trgOffWasOn = 0.0 # should have been off
+ cmpN = 0.0 # completion target
trgOnN = 0.0
trgOffN = 0.0
actMi = ecout.UnitVarIdx("ActM")
@@ -856,15 +1108,15 @@ def MemStats(ss, train):
actQ1i = ecout.UnitVarIdx("ActQ1")
for ni in range(nn):
actm = ecout.UnitVal1D(actMi, ni)
- trg = ecout.UnitVal1D(targi, ni) # full pattern target
+ trg = ecout.UnitVal1D(targi, ni) # full pattern target
inact = ecin.UnitVal1D(actQ1i, ni)
- if trg < 0.5: # trgOff
+ if trg < 0.5: # trgOff
trgOffN += 1
if actm > 0.5:
trgOffWasOn += 1
- else: # trgOn
+ else: # trgOn
trgOnN += 1
- if inact < 0.5: # missing in ECin -- completion target
+ if inact < 0.5: # missing in ECin -- completion target
cmpN += 1
if actm < 0.5:
trgOnWasOffAll += 1
@@ -874,13 +1126,13 @@ def MemStats(ss, train):
trgOnWasOffAll += 1
trgOnWasOffAll /= trgOnN
trgOffWasOn /= trgOffN
- if train: # no cmp
+ if train: # no cmp
if trgOnWasOffAll < ss.MemThr and trgOffWasOn < ss.MemThr:
ss.Mem = 1
else:
ss.Mem = 0
- else: # test
- if cmpN > 0: # should be
+ else: # test
+ if cmpN > 0: # should be
trgOnWasOffCmp /= cmpN
if trgOnWasOffCmp < ss.MemThr and trgOffWasOn < ss.MemThr:
ss.Mem = 1
@@ -900,7 +1152,7 @@ def TrialStats(ss, accum):
"""
outLay = leabra.Layer(ss.Net.LayerByName("ECout"))
ss.TrlCosDiff = float(outLay.CosDiff.Cos)
- ss.TrlSSE = outLay.SSE(0.5) # 0.5 = per-unit tolerance -- right side of .5
+ ss.TrlSSE = outLay.SSE(0.5) # 0.5 = per-unit tolerance -- right side of .5
ss.TrlAvgSSE = ss.TrlSSE / len(outLay.Neurons)
if accum:
ss.SumSSE += ss.TrlSSE
@@ -1092,7 +1344,7 @@ def SetParams(ss, sheet, setMsg):
sps = ss.ParamSet.split()
for ps in sps:
ss.SetParamsSet(ps, sheet, setMsg)
-
+
def SetParamsSet(ss, setNm, sheet, setMsg):
"""
SetParamsSet sets the params for given params.Set name.
@@ -1105,10 +1357,10 @@ def SetParamsSet(ss, setNm, sheet, setMsg):
if "Network" in pset.Sheets:
netp = pset.SheetByNameTry("Network")
ss.Net.ApplyParams(netp, setMsg)
-
+
if sheet == "" or sheet == "Sim":
if "Sim" in pset.Sheets:
- simp= pset.SheetByNameTry("Sim")
+ simp = pset.SheetByNameTry("Sim")
pyparams.ApplyParams(ss, simp, setMsg)
if sheet == "" or sheet == "Hip":
@@ -1122,7 +1374,7 @@ def SetParamsSet(ss, setNm, sheet, setMsg):
pyparams.ApplyParams(ss.Pat, simp, setMsg)
def OpenPat(ss, dt, fname, name, desc):
- err = dt.OpenCSV(gi.FileName(fname), etable.Tab)
+ err = dt.OpenCSV(gi.Filename(fname), etable.Tab)
if err != 0:
log.Println(err)
return
@@ -1136,19 +1388,31 @@ def ConfigPats(ss):
npats = ss.Pat.ListSize
pctAct = hp.ECPctAct
minDiff = ss.Pat.MinDiffPct
- nOn = patgen.NFmPct(pctAct, plY*plX)
+ nOn = patgen.NFmPct(pctAct, plY * plX)
ctxtflip = patgen.NFmPct(ss.Pat.CtxtFlipPct, nOn)
patgen.AddVocabEmpty(ss.PoolVocab, "empty", npats, plY, plX)
- patgen.AddVocabPermutedBinary(ss.PoolVocab, "A", npats, plY, plX, pctAct, minDiff)
- patgen.AddVocabPermutedBinary(ss.PoolVocab, "B", npats, plY, plX, pctAct, minDiff)
- patgen.AddVocabPermutedBinary(ss.PoolVocab, "C", npats, plY, plX, pctAct, minDiff)
- patgen.AddVocabPermutedBinary(ss.PoolVocab, "lA", npats, plY, plX, pctAct, minDiff)
- patgen.AddVocabPermutedBinary(ss.PoolVocab, "lB", npats, plY, plX, pctAct, minDiff)
- patgen.AddVocabPermutedBinary(ss.PoolVocab, "ctxt", 3, plY, plX, pctAct, minDiff)
+ patgen.AddVocabPermutedBinary(
+ ss.PoolVocab, "A", npats, plY, plX, pctAct, minDiff
+ )
+ patgen.AddVocabPermutedBinary(
+ ss.PoolVocab, "B", npats, plY, plX, pctAct, minDiff
+ )
+ patgen.AddVocabPermutedBinary(
+ ss.PoolVocab, "C", npats, plY, plX, pctAct, minDiff
+ )
+ patgen.AddVocabPermutedBinary(
+ ss.PoolVocab, "lA", npats, plY, plX, pctAct, minDiff
+ )
+ patgen.AddVocabPermutedBinary(
+ ss.PoolVocab, "lB", npats, plY, plX, pctAct, minDiff
+ )
+ patgen.AddVocabPermutedBinary(
+ ss.PoolVocab, "ctxt", 3, plY, plX, pctAct, minDiff
+ )
for i in range(12):
lst = int(i / 4)
- ctxtNm = "ctxt%d" % (i+1)
+ ctxtNm = "ctxt%d" % (i + 1)
tsr = patgen.AddVocabRepeat(ss.PoolVocab, ctxtNm, npats, "ctxt", lst)
patgen.FlipBitsRows(tsr, ctxtflip, ctxtflip, 1, 0)
# todo: also support drifting
@@ -1158,25 +1422,130 @@ def ConfigPats(ss):
ecY = hp.ECSize.Y
ecX = hp.ECSize.X
- patgen.InitPats(ss.TrainAB, "TrainAB", "TrainAB Pats", "Input", "ECout", npats, ecY, ecX, plY, plX)
- patgen.MixPats(ss.TrainAB, ss.PoolVocab, "Input", go.Slice_string(["A", "B", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]))
- patgen.MixPats(ss.TrainAB, ss.PoolVocab, "ECout", go.Slice_string(["A", "B", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]))
+ patgen.InitPats(
+ ss.TrainAB,
+ "TrainAB",
+ "TrainAB Pats",
+ "Input",
+ "ECout",
+ npats,
+ ecY,
+ ecX,
+ plY,
+ plX,
+ )
+ patgen.MixPats(
+ ss.TrainAB,
+ ss.PoolVocab,
+ "Input",
+ go.Slice_string(["A", "B", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]),
+ )
+ patgen.MixPats(
+ ss.TrainAB,
+ ss.PoolVocab,
+ "ECout",
+ go.Slice_string(["A", "B", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]),
+ )
- patgen.InitPats(ss.TestAB, "TestAB", "TestAB Pats", "Input", "ECout", npats, ecY, ecX, plY, plX)
- patgen.MixPats(ss.TestAB, ss.PoolVocab, "Input", go.Slice_string(["A", "empty", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]))
- patgen.MixPats(ss.TestAB, ss.PoolVocab, "ECout", go.Slice_string(["A", "B", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]))
+ patgen.InitPats(
+ ss.TestAB,
+ "TestAB",
+ "TestAB Pats",
+ "Input",
+ "ECout",
+ npats,
+ ecY,
+ ecX,
+ plY,
+ plX,
+ )
+ patgen.MixPats(
+ ss.TestAB,
+ ss.PoolVocab,
+ "Input",
+ go.Slice_string(["A", "empty", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]),
+ )
+ patgen.MixPats(
+ ss.TestAB,
+ ss.PoolVocab,
+ "ECout",
+ go.Slice_string(["A", "B", "ctxt1", "ctxt2", "ctxt3", "ctxt4"]),
+ )
- patgen.InitPats(ss.TrainAC, "TrainAC", "TrainAC Pats", "Input", "ECout", npats, ecY, ecX, plY, plX)
- patgen.MixPats(ss.TrainAC, ss.PoolVocab, "Input", go.Slice_string(["A", "C", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]))
- patgen.MixPats(ss.TrainAC, ss.PoolVocab, "ECout", go.Slice_string(["A", "C", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]))
+ patgen.InitPats(
+ ss.TrainAC,
+ "TrainAC",
+ "TrainAC Pats",
+ "Input",
+ "ECout",
+ npats,
+ ecY,
+ ecX,
+ plY,
+ plX,
+ )
+ patgen.MixPats(
+ ss.TrainAC,
+ ss.PoolVocab,
+ "Input",
+ go.Slice_string(["A", "C", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]),
+ )
+ patgen.MixPats(
+ ss.TrainAC,
+ ss.PoolVocab,
+ "ECout",
+ go.Slice_string(["A", "C", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]),
+ )
- patgen.InitPats(ss.TestAC, "TestAC", "TestAC Pats", "Input", "ECout", npats, ecY, ecX, plY, plX)
- patgen.MixPats(ss.TestAC, ss.PoolVocab, "Input", go.Slice_string(["A", "empty", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]))
- patgen.MixPats(ss.TestAC, ss.PoolVocab, "ECout", go.Slice_string(["A", "C", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]))
+ patgen.InitPats(
+ ss.TestAC,
+ "TestAC",
+ "TestAC Pats",
+ "Input",
+ "ECout",
+ npats,
+ ecY,
+ ecX,
+ plY,
+ plX,
+ )
+ patgen.MixPats(
+ ss.TestAC,
+ ss.PoolVocab,
+ "Input",
+ go.Slice_string(["A", "empty", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]),
+ )
+ patgen.MixPats(
+ ss.TestAC,
+ ss.PoolVocab,
+ "ECout",
+ go.Slice_string(["A", "C", "ctxt5", "ctxt6", "ctxt7", "ctxt8"]),
+ )
- patgen.InitPats(ss.TestLure, "TestLure", "TestLure Pats", "Input", "ECout", npats, ecY, ecX, plY, plX)
- patgen.MixPats(ss.TestLure, ss.PoolVocab, "Input", go.Slice_string(["lA", "empty", "ctxt9", "ctxt10", "ctxt11", "ctxt12"])) # arbitrary ctxt here
- patgen.MixPats(ss.TestLure, ss.PoolVocab, "ECout", go.Slice_string(["lA", "lB", "ctxt9", "ctxt10", "ctxt11", "ctxt12"])) # arbitrary ctxt here
+ patgen.InitPats(
+ ss.TestLure,
+ "TestLure",
+ "TestLure Pats",
+ "Input",
+ "ECout",
+ npats,
+ ecY,
+ ecX,
+ plY,
+ plX,
+ )
+ patgen.MixPats(
+ ss.TestLure,
+ ss.PoolVocab,
+ "Input",
+ go.Slice_string(["lA", "empty", "ctxt9", "ctxt10", "ctxt11", "ctxt12"]),
+ ) # arbitrary ctxt here
+ patgen.MixPats(
+ ss.TestLure,
+ ss.PoolVocab,
+ "ECout",
+ go.Slice_string(["lA", "lB", "ctxt9", "ctxt10", "ctxt11", "ctxt12"]),
+ ) # arbitrary ctxt here
ss.TrainAll = ss.TrainAB.Clone()
ss.TrainAll.AppendRows(ss.TrainAC)
@@ -1217,7 +1586,14 @@ def WeightsFileName(ss):
"""
WeightsFileName returns default current weights file name
"""
- return ss.Net.Nm + "_" + ss.RunName() + "_" + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur) + ".wts"
+ return (
+ ss.Net.Nm
+ + "_"
+ + ss.RunName()
+ + "_"
+ + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur)
+ + ".wts"
+ )
def LogFileName(ss, lognm):
"""
@@ -1262,16 +1638,18 @@ def ConfigTrnTrlLog(ss, dt):
nt = ss.TestEnv.Table.Len()
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
- etable.Column("Trial", etensor.INT64, go.nil, go.nil),
- etable.Column("TrialName", etensor.STRING, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("Mem", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("TrgOnWasOff", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("TrgOffWasOn", etensor.FLOAT64, go.nil, go.nil)]
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
+ etable.Column("Trial", etensor.INT64, go.nil, go.nil),
+ etable.Column("TrialName", etensor.STRING, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("Mem", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("TrgOnWasOff", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("TrgOffWasOn", etensor.FLOAT64, go.nil, go.nil),
+ ]
)
dt.SetFromSchema(sch, nt)
@@ -1296,16 +1674,16 @@ def ConfigTrnTrlPlot(ss, plt, dt):
def LogTrnEpc(ss, dt):
"""
- LogTrnEpc adds data from current epoch to the TrnEpcLog table.
- computes epoch averages prior to logging.
+ LogTrnEpc adds data from current epoch to the TrnEpcLog table.
+ computes epoch averages prior to logging.
- # this is triggered by increment so use previous value
+ # this is triggered by increment so use previous value
"""
row = dt.Rows
dt.SetNumRows(row + 1)
epc = ss.TrainEnv.Epoch.Prv
- nt = float(ss.TrainEnv.Table.Len()) # number of trials in view
+ nt = float(ss.TrainEnv.Table.Len()) # number of trials in view
ss.EpcSSE = ss.SumSSE / nt
ss.SumSSE = 0
@@ -1333,9 +1711,11 @@ def LogTrnEpc(ss, dt):
dt.SetCellFloat("TrgOnWasOff", row, agg.Mean(tix, "TrgOnWasOff")[0])
dt.SetCellFloat("TrgOffWasOn", row, agg.Mean(tix, "TrgOffWasOn")[0])
- for lnm in ss.LayStatNms :
+ for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
- dt.SetCellFloat(ly.Nm+" ActAvg", row, float(ly.Pools[0].ActAvg.ActPAvgEff))
+ dt.SetCellFloat(
+ ly.Nm + " ActAvg", row, float(ly.Pools[0].ActAvg.ActPAvgEff)
+ )
# note: essential to use Go version of update when called from another goroutine
if ss.TrnEpcPlot != 0:
@@ -1353,19 +1733,21 @@ def ConfigTrnEpcLog(ss, dt):
dt.SetMetaData("precision", str(LogPrec))
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("Mem", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("TrgOnWasOff", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("TrgOffWasOn", etensor.FLOAT64, go.nil, go.nil)]
- )
- for lnm in ss.LayStatNms :
- sch.append( etable.Column(lnm + " ActAvg", etensor.FLOAT64, go.nil, go.nil))
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("Mem", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("TrgOnWasOff", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("TrgOffWasOn", etensor.FLOAT64, go.nil, go.nil),
+ ]
+ )
+ for lnm in ss.LayStatNms:
+ sch.append(etable.Column(lnm + " ActAvg", etensor.FLOAT64, go.nil, go.nil))
dt.SetFromSchema(sch, 0)
def ConfigTrnEpcPlot(ss, plt, dt):
@@ -1381,25 +1763,33 @@ def ConfigTrnEpcPlot(ss, plt, dt):
plt.SetColParams("PctCor", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
- plt.SetColParams("Mem", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) # default plot
- plt.SetColParams("TrgOnWasOff", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) # default plot
- plt.SetColParams("TrgOffWasOn", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) # default plot
+ plt.SetColParams(
+ "Mem", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1
+ ) # default plot
+ plt.SetColParams(
+ "TrgOnWasOff", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1
+ ) # default plot
+ plt.SetColParams(
+ "TrgOffWasOn", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1
+ ) # default plot
- for lnm in ss.LayStatNms :
- plt.SetColParams(lnm+" ActAvg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 0.5)
+ for lnm in ss.LayStatNms:
+ plt.SetColParams(
+ lnm + " ActAvg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 0.5
+ )
return plt
def LogTstTrl(ss, dt):
"""
- LogTstTrl adds data from current trial to the TstTrlLog table.
- # this is triggered by increment so use previous value
- log always contains number of testing items
+ LogTstTrl adds data from current trial to the TstTrlLog table.
+ # this is triggered by increment so use previous value
+ log always contains number of testing items
"""
epc = ss.TrainEnv.Epoch.Prv
trl = ss.TestEnv.Trial.Cur
row = dt.Rows
- if ss.TestNm == "AB" and trl == 0: # reset at start
+ if ss.TestNm == "AB" and trl == 0: # reset at start
row = 0
dt.SetNumRows(row + 1)
@@ -1416,15 +1806,15 @@ def LogTstTrl(ss, dt):
dt.SetCellFloat("TrgOnWasOff", row, ss.TrgOnWasOffCmp)
dt.SetCellFloat("TrgOffWasOn", row, ss.TrgOffWasOn)
- for lnm in ss.LayStatNms :
+ for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
- dt.SetCellFloat(ly.Nm+" ActM.Avg", row, float(ly.Pools[0].ActM.Avg))
+ dt.SetCellFloat(ly.Nm + " ActM.Avg", row, float(ly.Pools[0].ActM.Avg))
- for lnm in ss.LayStatNms :
+ for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
tsr = ss.ValsTsr(lnm)
ly.UnitValsTensor(tsr, "Act")
- dt.SetCellTensor(lnm+"Act", row, tsr)
+ dt.SetCellTensor(lnm + "Act", row, tsr)
# note: essential to use Go version of update when called from another goroutine
if ss.TstTrlPlot != 0:
@@ -1439,25 +1829,29 @@ def ConfigTstTrlLog(ss, dt):
dt.SetMetaData("read-only", "true")
dt.SetMetaData("precision", str(LogPrec))
- nt = ss.TestEnv.Table.Len() # number in view
+ nt = ss.TestEnv.Table.Len() # number in view
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
- etable.Column("TestNm", etensor.STRING, go.nil, go.nil),
- etable.Column("Trial", etensor.INT64, go.nil, go.nil),
- etable.Column("TrialName", etensor.STRING, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("Mem", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("TrgOnWasOff", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("TrgOffWasOn", etensor.FLOAT64, go.nil, go.nil)]
- )
- for lnm in ss.LayStatNms :
- sch.append( etable.Column(lnm + " ActM.Avg", etensor.FLOAT64, go.nil, go.nil))
- for lnm in ss.LayStatNms :
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
+ etable.Column("TestNm", etensor.STRING, go.nil, go.nil),
+ etable.Column("Trial", etensor.INT64, go.nil, go.nil),
+ etable.Column("TrialName", etensor.STRING, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("Mem", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("TrgOnWasOff", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("TrgOffWasOn", etensor.FLOAT64, go.nil, go.nil),
+ ]
+ )
+ for lnm in ss.LayStatNms:
+ sch.append(
+ etable.Column(lnm + " ActM.Avg", etensor.FLOAT64, go.nil, go.nil)
+ )
+ for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
- sch.append( etable.Column(lnm + "Act", etensor.FLOAT64, ly.Shp.Shp, go.nil))
+ sch.append(etable.Column(lnm + "Act", etensor.FLOAT64, ly.Shp.Shp, go.nil))
dt.SetFromSchema(sch, nt)
@@ -1465,7 +1859,7 @@ def ConfigTstTrlPlot(ss, plt, dt):
plt.Params.Title = "Hippocampus Test Trial Plot"
plt.Params.XAxisCol = "TrialName"
plt.Params.Type = eplot.Bar
- plt.SetTable(dt) # this sets defaults so set params after
+ plt.SetTable(dt) # this sets defaults so set params after
plt.Params.XAxisRot = 45
# order of params: on, fixMin, min, fixMax, max
plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
@@ -1481,10 +1875,12 @@ def ConfigTstTrlPlot(ss, plt, dt):
plt.SetColParams("TrgOnWasOff", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1)
plt.SetColParams("TrgOffWasOn", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1)
- for lnm in ss.LayStatNms :
- plt.SetColParams(lnm+" ActM.Avg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 0.5)
- for lnm in ss.LayStatNms :
- plt.SetColParams(lnm+" Act", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
+ for lnm in ss.LayStatNms:
+ plt.SetColParams(
+ lnm + " ActM.Avg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 0.5
+ )
+ for lnm in ss.LayStatNms:
+ plt.SetColParams(lnm + " Act", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
return plt
@@ -1500,7 +1896,7 @@ def RepsAnalysis(ss):
ss.SimMats[lnm] = sm
else:
sm = ss.SimMats[lnm]
- sm.TableColStd(acts, lnm+"Act", "TrialName", True, metric.Correlation)
+ sm.TableColStd(acts, lnm + "Act", "TrialName", True, metric.Correlation)
def SimMatStat(ss, lnm):
"""
@@ -1561,27 +1957,27 @@ def LogTstEpc(ss, dt):
trix = etable.NewIdxView(trl)
spl = split.GroupBy(trix, go.Slice_string(["TestNm"]))
- for ts in ss.TstStatNms :
+ for ts in ss.TstStatNms:
split.Agg(spl, ts, agg.AggMean)
ss.TstStats = spl.AggsToTable(etable.ColNameOnly)
for ri in range(ss.TstStats.Rows):
tst = ss.TstStats.CellString("TestNm", ri)
- for ts in ss.TstStatNms :
- dt.SetCellFloat(tst+" "+ts, row, ss.TstStats.CellFloat(ts, ri))
+ for ts in ss.TstStatNms:
+ dt.SetCellFloat(tst + " " + ts, row, ss.TstStats.CellFloat(ts, ri))
- for lnm in ss.LayStatNms :
+ for lnm in ss.LayStatNms:
# win, btn = ss.SimMatStat(lnm)
win = 0
btn = 0
- for ts in ss.SimMatStats :
+ for ts in ss.SimMatStats:
if ts == "Within":
- dt.SetCellFloat(lnm+" "+ts, row, win)
+ dt.SetCellFloat(lnm + " " + ts, row, win)
else:
- dt.SetCellFloat(lnm+" "+ts, row, btn)
+ dt.SetCellFloat(lnm + " " + ts, row, btn)
# base zero on testing performance!
- curAB = (ss.TrainEnv.Table.Table.MetaData["name"] == "TrainAB")
+ curAB = ss.TrainEnv.Table.Table.MetaData["name"] == "TrainAB"
mem = float()
if curAB:
mem = dt.CellFloat("AB Mem", row)
@@ -1610,27 +2006,33 @@ def ConfigTstEpcLog(ss, dt):
dt.SetMetaData("precision", str(LogPrec))
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
- etable.Column("PerTrlMSec", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil)]
- )
- for tn in ss.TstNms :
- for ts in ss.TstStatNms :
- sch.append( etable.Column(tn + " " + ts, etensor.FLOAT64, go.nil, go.nil))
- for lnm in ss.LayStatNms :
- for ts in ss.SimMatStats :
- sch.append( etable.Column(lnm + " " + ts, etensor.FLOAT64, go.nil, go.nil))
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
+ etable.Column("PerTrlMSec", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ ]
+ )
+ for tn in ss.TstNms:
+ for ts in ss.TstStatNms:
+ sch.append(
+ etable.Column(tn + " " + ts, etensor.FLOAT64, go.nil, go.nil)
+ )
+ for lnm in ss.LayStatNms:
+ for ts in ss.SimMatStats:
+ sch.append(
+ etable.Column(lnm + " " + ts, etensor.FLOAT64, go.nil, go.nil)
+ )
dt.SetFromSchema(sch, 0)
def ConfigTstEpcPlot(ss, plt, dt):
plt.Params.Title = "Hippocampus Testing Epoch Plot"
plt.Params.XAxisCol = "Epoch"
- plt.SetTable(dt) # this sets defaults so set params after
+ plt.SetTable(dt) # this sets defaults so set params after
# order of params: on, fixMin, min, fixMax, max
plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
@@ -1641,15 +2043,21 @@ def ConfigTstEpcPlot(ss, plt, dt):
plt.SetColParams("PctCor", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
- for tn in ss.TstNms :
- for ts in ss.TstStatNms :
+ for tn in ss.TstNms:
+ for ts in ss.TstStatNms:
if ts == "Mem":
- plt.SetColParams(tn+" "+ts, eplot.On, eplot.FixMin, 0, eplot.FixMax, 1)
+ plt.SetColParams(
+ tn + " " + ts, eplot.On, eplot.FixMin, 0, eplot.FixMax, 1
+ )
else:
- plt.SetColParams(tn+" "+ts, eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
- for lnm in ss.LayStatNms :
- for ts in ss.SimMatStats :
- plt.SetColParams(lnm+" "+ts, eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 1)
+ plt.SetColParams(
+ tn + " " + ts, eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1
+ )
+ for lnm in ss.LayStatNms:
+ for ts in ss.SimMatStats:
+ plt.SetColParams(
+ lnm + " " + ts, eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 1
+ )
return plt
def LogTstCyc(ss, dt, cyc):
@@ -1661,12 +2069,12 @@ def LogTstCyc(ss, dt, cyc):
dt.SetNumRows(cyc + 1)
dt.SetCellFloat("Cycle", cyc, float(cyc))
- for lnm in ss.LayStatNms :
+ for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
- dt.SetCellFloat(ly.Nm+" Ge.Avg", cyc, float(ly.Pools[0].Inhib.Ge.Avg))
- dt.SetCellFloat(ly.Nm+" Act.Avg", cyc, float(ly.Pools[0].Inhib.Act.Avg))
+ dt.SetCellFloat(ly.Nm + " Ge.Avg", cyc, float(ly.Pools[0].Inhib.Ge.Avg))
+ dt.SetCellFloat(ly.Nm + " Act.Avg", cyc, float(ly.Pools[0].Inhib.Act.Avg))
- if cyc%10 == 0: # too slow to do every cyc
+ if cyc % 10 == 0: # too slow to do every cyc
# note: essential to use Go version of update when called from another goroutine
if ss.TstCycPlot != 0:
ss.TstCycPlot.GoUpdate()
@@ -1677,13 +2085,11 @@ def ConfigTstCycLog(ss, dt):
dt.SetMetaData("read-only", "true")
dt.SetMetaData("precision", str(LogPrec))
- np = 100 # max cycles
- sch = etable.Schema(
- [etable.Column("Cycle", etensor.INT64, go.nil, go.nil)]
- )
- for lnm in ss.LayStatNms :
- sch.append( etable.Column(lnm + " Ge.Avg", etensor.FLOAT64, go.nil, go.nil))
- sch.append( etable.Column(lnm + " Act.Avg", etensor.FLOAT64, go.nil, go.nil))
+ np = 100 # max cycles
+ sch = etable.Schema([etable.Column("Cycle", etensor.INT64, go.nil, go.nil)])
+ for lnm in ss.LayStatNms:
+ sch.append(etable.Column(lnm + " Ge.Avg", etensor.FLOAT64, go.nil, go.nil))
+ sch.append(etable.Column(lnm + " Act.Avg", etensor.FLOAT64, go.nil, go.nil))
dt.SetFromSchema(sch, np)
def ConfigTstCycPlot(ss, plt, dt):
@@ -1692,9 +2098,13 @@ def ConfigTstCycPlot(ss, plt, dt):
plt.SetTable(dt)
# order of params: on, fixMin, min, fixMax, max
plt.SetColParams("Cycle", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
- for lnm in ss.LayStatNms :
- plt.SetColParams(lnm+" Ge.Avg", eplot.On, eplot.FixMin, 0, eplot.FixMax, .5)
- plt.SetColParams(lnm+" Act.Avg", eplot.On, eplot.FixMin, 0, eplot.FixMax, .5)
+ for lnm in ss.LayStatNms:
+ plt.SetColParams(
+ lnm + " Ge.Avg", eplot.On, eplot.FixMin, 0, eplot.FixMax, 0.5
+ )
+ plt.SetColParams(
+ lnm + " Act.Avg", eplot.On, eplot.FixMin, 0, eplot.FixMax, 0.5
+ )
return plt
def LogRun(ss, dt):
@@ -1706,17 +2116,17 @@ def LogRun(ss, dt):
if epcix.Len() == 0:
return
- run = ss.TrainEnv.Run.Cur # this is NOT triggered by increment yet -- use Cur
+ run = ss.TrainEnv.Run.Cur # this is NOT triggered by increment yet -- use Cur
row = dt.Rows
dt.SetNumRows(row + 1)
# compute mean over last N epochs for run level
nlast = 1
- if nlast > epcix.Len()-1:
+ if nlast > epcix.Len() - 1:
nlast = epcix.Len() - 1
- epcix.Idxs = epcix.Idxs[epcix.Len()-nlast:]
+ epcix.Idxs = epcix.Idxs[epcix.Len() - nlast :]
- params = ss.RunName() # includes tag
+ params = ss.RunName() # includes tag
fzero = ss.FirstZero
if fzero < 0:
@@ -1732,12 +2142,12 @@ def LogRun(ss, dt):
dt.SetCellFloat("PctCor", row, agg.Mean(epcix, "PctCor")[0])
dt.SetCellFloat("CosDiff", row, agg.Mean(epcix, "CosDiff")[0])
- for tn in ss.TstNms :
- for ts in ss.TstStatNms :
+ for tn in ss.TstNms:
+ for ts in ss.TstStatNms:
nm = tn + " " + ts
dt.SetCellFloat(nm, row, agg.Mean(epcix, nm)[0])
- for lnm in ss.LayStatNms :
- for ts in ss.SimMatStats :
+ for lnm in ss.LayStatNms:
+ for ts in ss.SimMatStats:
nm = lnm + " " + ts
dt.SetCellFloat(nm, row, agg.Mean(epcix, nm)[0])
@@ -1758,22 +2168,28 @@ def ConfigRunLog(ss, dt):
dt.SetMetaData("precision", str(LogPrec))
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Params", etensor.STRING, go.nil, go.nil),
- etable.Column("NEpochs", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("FirstZero", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil)]
- )
- for tn in ss.TstNms :
- for ts in ss.TstStatNms :
- sch.append( etable.Column(tn + " " + ts, etensor.FLOAT64, go.nil, go.nil))
- for lnm in ss.LayStatNms :
- for ts in ss.SimMatStats :
- sch.append( etable.Column(lnm + " " + ts, etensor.FLOAT64, go.nil, go.nil))
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Params", etensor.STRING, go.nil, go.nil),
+ etable.Column("NEpochs", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("FirstZero", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ ]
+ )
+ for tn in ss.TstNms:
+ for ts in ss.TstStatNms:
+ sch.append(
+ etable.Column(tn + " " + ts, etensor.FLOAT64, go.nil, go.nil)
+ )
+ for lnm in ss.LayStatNms:
+ for ts in ss.SimMatStats:
+ sch.append(
+ etable.Column(lnm + " " + ts, etensor.FLOAT64, go.nil, go.nil)
+ )
dt.SetFromSchema(sch, 0)
def ConfigRunPlot(ss, plt, dt):
@@ -1790,15 +2206,21 @@ def ConfigRunPlot(ss, plt, dt):
plt.SetColParams("PctCor", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
- for tn in ss.TstNms :
- for ts in ss.TstStatNms :
+ for tn in ss.TstNms:
+ for ts in ss.TstStatNms:
if ts == "Mem":
- plt.SetColParams(tn+" "+ts, eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) # default plot
+ plt.SetColParams(
+ tn + " " + ts, eplot.On, eplot.FixMin, 0, eplot.FixMax, 1
+ ) # default plot
else:
- plt.SetColParams(tn+" "+ts, eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
- for lnm in ss.LayStatNms :
- for ts in ss.SimMatStats :
- plt.SetColParams(lnm+" "+ts, eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 1)
+ plt.SetColParams(
+ tn + " " + ts, eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1
+ )
+ for lnm in ss.LayStatNms:
+ for ts in ss.SimMatStats:
+ plt.SetColParams(
+ lnm + " " + ts, eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 1
+ )
return plt
def LogRunStats(ss):
@@ -1808,14 +2230,14 @@ def LogRunStats(ss):
dt = ss.RunLog
runix = etable.NewIdxView(dt)
spl = split.GroupBy(runix, go.Slice_string(["Params"]))
- for tn in ss.TstNms :
+ for tn in ss.TstNms:
nm = tn + " " + "Mem"
split.Desc(spl, nm)
split.Desc(spl, "FirstZero")
split.Desc(spl, "NEpochs")
- for lnm in ss.LayStatNms :
- for ts in ss.SimMatStats :
- split.Desc(spl, lnm+" "+ts)
+ for lnm in ss.LayStatNms:
+ for ts in ss.SimMatStats:
+ split.Desc(spl, lnm + " " + ts)
ss.RunStats = spl.AggsToTable(etable.AddAggName)
if ss.RunStatsPlot != 0:
ss.ConfigRunStatsPlot(ss.RunStatsPlot, ss.RunStats)
@@ -1832,21 +2254,27 @@ def ConfigRunStatsPlot(ss, plt, dt):
cp.ErrCol = "AB Mem:Sem"
cp = plt.SetColParams("AC Mem:Mean", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1)
cp.ErrCol = "AC Mem:Sem"
- cp = plt.SetColParams("FirstZero:Mean", eplot.On, eplot.FixMin, 0, eplot.FixMax, 30)
+ cp = plt.SetColParams(
+ "FirstZero:Mean", eplot.On, eplot.FixMin, 0, eplot.FixMax, 30
+ )
cp.ErrCol = "FirstZero:Sem"
- cp = plt.SetColParams("NEpochs:Mean", eplot.On, eplot.FixMin, 0, eplot.FixMax, 30)
+ cp = plt.SetColParams(
+ "NEpochs:Mean", eplot.On, eplot.FixMin, 0, eplot.FixMax, 30
+ )
cp.ErrCol = "NEpochs:Sem"
return plt
- def ConfigGui(ss):
+ def ConfigGUI(ss):
"""
- ConfigGui configures the GoGi gui interface for this simulation,
+ ConfigGUI configures the GoGi gui interface for this simulation,
"""
width = 1600
height = 1200
gi.SetAppName("hip_bench")
- gi.SetAppAbout('This demonstrates a basic Hippocampus model in Leabra. See emergent on GitHub.
')
+ gi.SetAppAbout(
+ 'This demonstrates a basic Hippocampus model in Leabra. See emergent on GitHub.'
+ )
win = gi.NewMainWindow("hip_bench", "Hippocampus AB-AC", width, height)
ss.Win = win
@@ -1902,42 +2330,147 @@ def ConfigGui(ss):
tv.AddTab(plt, "RunPlot")
ss.RunPlot = ss.ConfigRunPlot(plt, ss.RunLog)
- split.SetSplitsList(go.Slice_float32([.2, .8]))
+ split.SetSplitsList(go.Slice_float32([0.2, 0.8]))
recv = win.This()
- tbar.AddAction(gi.ActOpts(Label="Init", Icon="update", Tooltip="Initialize everything including network weights, and start over. Also applies current params.", UpdateFunc=UpdtFuncNotRunning), recv, InitCB)
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Init",
+ Icon="update",
+ Tooltip="Initialize everything including network weights, and start over. Also applies current params.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ InitCB,
+ )
- tbar.AddAction(gi.ActOpts(Label="Train", Icon="run", Tooltip="Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.", UpdateFunc=UpdtFuncNotRunning), recv, TrainCB)
-
- tbar.AddAction(gi.ActOpts(Label="Stop", Icon="stop", Tooltip="Interrupts running. Hitting Train again will pick back up where it left off.", UpdateFunc=UpdtFuncRunning), recv, StopCB)
-
- tbar.AddAction(gi.ActOpts(Label="Step Trial", Icon="step-fwd", Tooltip="Advances one training trial at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepTrialCB)
-
- tbar.AddAction(gi.ActOpts(Label="Step Epoch", Icon="fast-fwd", Tooltip="Advances one epoch (complete set of training patterns) at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepEpochCB)
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Train",
+ Icon="run",
+ Tooltip="Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TrainCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Stop",
+ Icon="stop",
+ Tooltip="Interrupts running. Hitting Train again will pick back up where it left off.",
+ UpdateFunc=UpdtFuncRunning,
+ ),
+ recv,
+ StopCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Step Trial",
+ Icon="step-fwd",
+ Tooltip="Advances one training trial at a time.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ StepTrialCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Step Epoch",
+ Icon="fast-fwd",
+ Tooltip="Advances one epoch (complete set of training patterns) at a time.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ StepEpochCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Step Run",
+ Icon="fast-fwd",
+ Tooltip="Advances one full training Run at a time.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ StepRunCB,
+ )
- tbar.AddAction(gi.ActOpts(Label="Step Run", Icon="fast-fwd", Tooltip="Advances one full training Run at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepRunCB)
-
tbar.AddSeparator("test")
-
- tbar.AddAction(gi.ActOpts(Label="Test Trial", Icon="step-fwd", Tooltip="Runs the next testing trial.", UpdateFunc=UpdtFuncNotRunning), recv, TestTrialCB)
-
- tbar.AddAction(gi.ActOpts(Label="Test Item", Icon="step-fwd", Tooltip="Prompts for a specific input pattern name to run, and runs it in testing mode.", UpdateFunc=UpdtFuncNotRunning), recv, TestItemCB)
-
- tbar.AddAction(gi.ActOpts(Label="Test All", Icon="fast-fwd", Tooltip="Tests all of the testing trials.", UpdateFunc=UpdtFuncNotRunning), recv, TestAllCB)
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Test Trial",
+ Icon="step-fwd",
+ Tooltip="Runs the next testing trial.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TestTrialCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Test Item",
+ Icon="step-fwd",
+ Tooltip="Prompts for a specific input pattern name to run, and runs it in testing mode.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TestItemCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Test All",
+ Icon="fast-fwd",
+ Tooltip="Tests all of the testing trials.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TestAllCB,
+ )
tbar.AddSeparator("log")
-
+
# tbar.AddAction(gi.ActOpts(Label= "Env", Icon= "gear", Tooltip= "select training input patterns: AB or AC."), win.This(),
# funcrecv, send, sig, data:
# giv.CallMethod(ss, "SetEnv", vp))
- tbar.AddAction(gi.ActOpts(Label="Reset RunLog", Icon="reset", Tooltip="Resets the accumulated log of all Runs, which are tagged with the ParamSet used"), recv, ResetRunLogCB)
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Reset RunLog",
+ Icon="reset",
+ Tooltip="Resets the accumulated log of all Runs, which are tagged with the ParamSet used",
+ ),
+ recv,
+ ResetRunLogCB,
+ )
tbar.AddSeparator("misc")
-
- tbar.AddAction(gi.ActOpts(Label="New Seed", Icon="new", Tooltip="Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time."), recv, NewRndSeedCB)
- tbar.AddAction(gi.ActOpts(Label="README", Icon="file-markdown", Tooltip="Opens your browser on the README file that contains instructions for how to run this model."), recv, ReadmeCB)
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="New Seed",
+ Icon="new",
+ Tooltip="Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time.",
+ ),
+ recv,
+ NewRndSeedCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="README",
+ Icon="file-markdown",
+ Tooltip="Opens your browser on the README file that contains instructions for how to run this model.",
+ ),
+ recv,
+ ReadmeCB,
+ )
# main menu
appnm = gi.AppName()
@@ -1979,10 +2512,10 @@ def TwoFactorRun(ss):
for inf in InnerLoopParams:
ss.Tag = usetag + otf + "_" + inf
print("running: " + ss.Tag)
- rand.Seed(ss.RndSeed) # each run starts at same seed..
+ rand.Seed(ss.RndSeed) # each run starts at same seed..
ss.SetParamsSet(otf, "", ss.LogSetParams)
ss.SetParamsSet(inf, "", ss.LogSetParams)
- ss.ReConfigNet() # note: this applies Base params to Network
+ ss.ReConfigNet() # note: this applies Base params to Network
ss.ConfigEnv()
ss.StopNow = False
ss.PreTrain()
@@ -1990,22 +2523,38 @@ def TwoFactorRun(ss):
ss.Train()
ss.Tag = tag
+
# TheSim is the overall state for this simulation
TheSim = Sim()
-
+
+
def usage():
- print(sys.argv[0] + " --params= --tag= --setparams --wts --epclog=0 --runlog=0 --nogui")
+ print(
+ sys.argv[0]
+ + " --params= --tag= --setparams --wts --epclog=0 --runlog=0 --nogui"
+ )
print("\t pyleabra -i %s to run in interactive, gui mode" % sys.argv[0])
- print("\t --params= additional params to apply on top of Base (name must be in loaded Params")
- print("\t --tag= tag is appended to file names to uniquely identify this run")
- print("\t --note= user note -- describe the run params etc")
+ print(
+ "\t --params= additional params to apply on top of Base (name must be in loaded Params"
+ )
+ print(
+ "\t --tag= tag is appended to file names to uniquely identify this run"
+ )
+ print("\t --note= user note -- describe the run params etc")
print("\t --runs= number of runs to do")
print("\t --epcs= number of epochs per run")
print("\t --setparams show the parameter values that are set")
print("\t --wts save final trained weights after every run")
- print("\t --epclog=0/False turn off save training epoch log data to file named by param set, tag")
- print("\t --runlog=0/False turn off save run log data to file named by param set, tag")
- print("\t --nogui if no other args needed, this prevents running under the gui")
+ print(
+ "\t --epclog=0/False turn off save training epoch log data to file named by param set, tag"
+ )
+ print(
+ "\t --runlog=0/False turn off save run log data to file named by param set, tag"
+ )
+ print(
+ "\t --nogui if no other args needed, this prevents running under the gui"
+ )
+
def main(argv):
TheSim.Config()
@@ -2014,15 +2563,30 @@ def main(argv):
TheSim.NoGui = len(argv) > 1
saveEpcLog = True
saveRunLog = True
-
+
try:
- opts, args = getopt.getopt(argv,"h:",["params=","tag=","note=","runs=","epcs=","setparams","wts","epclog=","runlog=","nogui"])
+ opts, args = getopt.getopt(
+ argv,
+ "h:",
+ [
+ "params=",
+ "tag=",
+ "note=",
+ "runs=",
+ "epcs=",
+ "setparams",
+ "wts",
+ "epclog=",
+ "runlog=",
+ "nogui",
+ ],
+ )
except getopt.GetoptError:
usage()
sys.exit(2)
for opt, arg in opts:
# print("opt: %s arg: %s" % (opt, arg))
- if opt == '-h':
+ if opt == "-h":
usage()
sys.exit()
elif opt == "--tag":
@@ -2048,15 +2612,15 @@ def main(argv):
TheSim.NoGui = True
TheSim.Init()
-
+
if TheSim.NoGui:
if saveEpcLog:
- fnm = TheSim.LogFileName("epc")
+ fnm = TheSim.LogFileName("epc")
print("Saving test epoch log to: %s" % fnm)
TheSim.TstEpcFile = efile.Create(fnm)
-
+
if saveRunLog:
- fnm = TheSim.LogFileName("run")
+ fnm = TheSim.LogFileName("run")
print("Saving run log to: %s" % fnm)
TheSim.RunFile = efile.Create(fnm)
@@ -2064,15 +2628,17 @@ def main(argv):
TheSim.TwoFactorRun()
fnm = TheSim.LogFileName("runs")
TheSim.RunStats.SaveCSV(fnm, etable.Tab, etable.Headers)
-
+
else:
- TheSim.ConfigGui()
- print("Note: run pyleabra -i hip_bench.py to run in interactive mode, or just pyleabra, then 'import ra25'")
+ TheSim.ConfigGUI()
+ print(
+ "Note: run pyleabra -i hip_bench.py to run in interactive mode, or just pyleabra, then 'import ra25'"
+ )
print("for non-gui background running, here are the args:")
usage()
import code
- code.interact(local=locals())
-main(sys.argv[1:])
+ code.interact(local=locals())
+main(sys.argv[1:])
diff --git a/examples/ra25/ra25.py b/examples/ra25/ra25.py
index 0aa86c2ec..59ef91d00 100755
--- a/examples/ra25/ra25.py
+++ b/examples/ra25/ra25.py
@@ -5,7 +5,7 @@
# license that can be found in the LICENSE file.
# use:
-# pyleabra -i ra25.py
+# pyleabra -i ra25.py
# to run in gui interactive mode from the command line (or pyleabra, import ra25)
# see main function at the end for startup args
@@ -13,7 +13,7 @@
# * install gopy, currently in fork at https://github.com/goki/gopy
# e.g., 'go get github.com/goki/gopy -u ./...' and then cd to that package
# and do 'go install'
-# * go to the python directory in this emergent repository, read README.md there, and
+# * go to the python directory in this emergent repository, read README.md there, and
# type 'make' -- if that works, then type make install (may need sudo)
# * cd back here, and run 'pyemergent' which was installed into /usr/local/bin
# * then type 'import ra25' and this should run
@@ -21,9 +21,32 @@
# labra25ra runs a simple random-associator 5x5 = 25 four-layer leabra network
-from leabra import go, leabra, emer, relpos, eplot, env, agg, patgen, prjn, etable, efile, split, etensor, params, netview, rand, erand, gi, giv, pygiv, pyparams, mat32
-
-import importlib as il #il.reload(ra25) -- doesn't seem to work for reasons unknown
+from leabra import (
+ go,
+ leabra,
+ emer,
+ relpos,
+ eplot,
+ env,
+ agg,
+ patgen,
+ prjn,
+ etable,
+ efile,
+ split,
+ etensor,
+ params,
+ netview,
+ rand,
+ erand,
+ gi,
+ giv,
+ pygiv,
+ pyparams,
+ mat32,
+)
+
+import importlib as il # il.reload(ra25) -- doesn't seem to work for reasons unknown
import io, sys, getopt
from datetime import datetime, timezone
@@ -39,7 +62,7 @@
# pandas. Support for easy migration between these is forthcoming.
# import pandas as pd
-# this will become Sim later..
+# this will become Sim later..
TheSim = 1
# LogPrec is precision for saving float values in logs
@@ -48,20 +71,24 @@
# note: we cannot use methods for callbacks from Go -- must be separate functions
# so below are all the callbacks from the GUI toolbar actions
+
def InitCB(recv, send, sig, data):
TheSim.Init()
TheSim.UpdateClassView()
TheSim.vp.SetNeedsFullRender()
+
def TrainCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
TheSim.ToolBar.UpdateActions()
TheSim.Train()
+
def StopCB(recv, send, sig, data):
TheSim.Stop()
+
def StepTrialCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
@@ -70,18 +97,21 @@ def StepTrialCB(recv, send, sig, data):
TheSim.UpdateClassView()
TheSim.vp.SetNeedsFullRender()
+
def StepEpochCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
TheSim.ToolBar.UpdateActions()
TheSim.TrainEpoch()
+
def StepRunCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
TheSim.ToolBar.UpdateActions()
TheSim.TrainRun()
+
def TestTrialCB(recv, send, sig, data):
if not TheSim.IsRunning:
TheSim.IsRunning = True
@@ -90,6 +120,7 @@ def TestTrialCB(recv, send, sig, data):
TheSim.UpdateClassView()
TheSim.vp.SetNeedsFullRender()
+
def TestItemCB2(recv, send, sig, data):
win = gi.Window(handle=recv)
vp = win.WinViewport2D()
@@ -97,9 +128,20 @@ def TestItemCB2(recv, send, sig, data):
if sig != gi.DialogAccepted:
return
val = gi.StringPromptDialogValue(dlg)
- idxs = TheSim.TestEnv.Table.RowsByString("Name", val, True, True) # contains, ignoreCase
+ idxs = TheSim.TestEnv.Table.RowsByString(
+ "Name", val, True, True
+ ) # contains, ignoreCase
if len(idxs) == 0:
- gi.PromptDialog(vp, gi.DlgOpts(Title="Name Not Found", Prompt="No patterns found containing: " + val), True, False, go.nil, go.nil)
+ gi.PromptDialog(
+ vp,
+ gi.DlgOpts(
+ Title="Name Not Found", Prompt="No patterns found containing: " + val
+ ),
+ True,
+ False,
+ go.nil,
+ go.nil,
+ )
else:
if not TheSim.IsRunning:
TheSim.IsRunning = True
@@ -108,10 +150,21 @@ def TestItemCB2(recv, send, sig, data):
TheSim.IsRunning = False
vp.SetNeedsFullRender()
+
def TestItemCB(recv, send, sig, data):
win = gi.Window(handle=recv)
- gi.StringPromptDialog(win.WinViewport2D(), "", "Test Item",
- gi.DlgOpts(Title="Test Item", Prompt="Enter the Name of a given input pattern to test (case insensitive, contains given string."), win, TestItemCB2)
+ gi.StringPromptDialog(
+ win.WinViewport2D(),
+ "",
+ "Test Item",
+ gi.DlgOpts(
+ Title="Test Item",
+ Prompt="Enter the Name of a given input pattern to test (case insensitive, contains given string.",
+ ),
+ win,
+ TestItemCB2,
+ )
+
def TestAllCB(recv, send, sig, data):
if not TheSim.IsRunning:
@@ -119,29 +172,38 @@ def TestAllCB(recv, send, sig, data):
TheSim.ToolBar.UpdateActions()
TheSim.RunTestAll()
+
def ResetRunLogCB(recv, send, sig, data):
TheSim.RunLog.SetNumRows(0)
TheSim.RunPlot.Update()
+
def NewRndSeedCB(recv, send, sig, data):
TheSim.NewRndSeed()
+
def ReadmeCB(recv, send, sig, data):
- gi.TheApp.OpenURL("https://github.com/emer/leabra/blob/master/examples/ra25/README.md")
+ gi.TheApp.OpenURL(
+ "https://github.com/emer/leabra/blob/master/examples/ra25/README.md"
+ )
+
def FilterSSE(et, row):
- return etable.Table(handle=et).CellFloat("SSE", row) > 0 # include error trials
+ return etable.Table(handle=et).CellFloat("SSE", row) > 0 # include error trials
+
def UpdtFuncNotRunning(act):
act.SetActiveStateUpdt(not TheSim.IsRunning)
-
+
+
def UpdtFuncRunning(act):
act.SetActiveStateUpdt(TheSim.IsRunning)
-
-#####################################################
+
+#####################################################
# Sim
+
class Sim(pygiv.ClassViewObj):
"""
Sim encapsulates the entire simulation model, and we define all the
@@ -154,21 +216,38 @@ class Sim(pygiv.ClassViewObj):
def __init__(self):
super(Sim, self).__init__()
self.Net = leabra.Network()
- self.SetTags("Net", 'view:"no-inline" desc:"the network -- click to view / edit parameters for layers, prjns, etc"')
+ self.SetTags(
+ "Net",
+ 'view:"no-inline" desc:"the network -- click to view / edit parameters for layers, prjns, etc"',
+ )
self.Pats = etable.Table()
self.SetTags("Pats", 'view:"no-inline" desc:"the training patterns to use"')
self.TrnEpcLog = etable.Table()
- self.SetTags("TrnEpcLog", 'view:"no-inline" desc:"training epoch-level log data"')
+ self.SetTags(
+ "TrnEpcLog", 'view:"no-inline" desc:"training epoch-level log data"'
+ )
self.TstEpcLog = etable.Table()
- self.SetTags("TstEpcLog", 'view:"no-inline" desc:"testing epoch-level log data"')
+ self.SetTags(
+ "TstEpcLog", 'view:"no-inline" desc:"testing epoch-level log data"'
+ )
self.TstTrlLog = etable.Table()
- self.SetTags("TstTrlLog", 'view:"no-inline" desc:"testing trial-level log data"')
+ self.SetTags(
+ "TstTrlLog", 'view:"no-inline" desc:"testing trial-level log data"'
+ )
self.TstErrLog = etable.Table()
- self.SetTags("TstErrLog", 'view:"no-inline" desc:"log of all test trials where errors were made"')
+ self.SetTags(
+ "TstErrLog",
+ 'view:"no-inline" desc:"log of all test trials where errors were made"',
+ )
self.TstErrStats = etable.Table()
- self.SetTags("TstErrStats", 'view:"no-inline" desc:"stats on test trials where errors were made"')
+ self.SetTags(
+ "TstErrStats",
+ 'view:"no-inline" desc:"stats on test trials where errors were made"',
+ )
self.TstCycLog = etable.Table()
- self.SetTags("TstCycLog", 'view:"no-inline" desc:"testing cycle-level log data"')
+ self.SetTags(
+ "TstCycLog", 'view:"no-inline" desc:"testing cycle-level log data"'
+ )
self.RunLog = etable.Table()
self.SetTags("RunLog", 'view:"no-inline" desc:"summary log of each run"')
self.RunStats = etable.Table()
@@ -176,67 +255,132 @@ def __init__(self):
self.Params = params.Sets()
self.SetTags("Params", 'view:"no-inline" desc:"full collection of param sets"')
self.ParamSet = str()
- self.SetTags("ParamSet", 'desc:"which set of *additional* parameters to use -- always applies Base and optionaly this next if set -- can use multiple names separated by spaces (don\'t put spaces in ParamSet names!)"')
+ self.SetTags(
+ "ParamSet",
+ 'desc:"which set of *additional* parameters to use -- always applies Base and optionaly this next if set -- can use multiple names separated by spaces (don\'t put spaces in ParamSet names!)"',
+ )
self.Tag = str()
- self.SetTags("Tag", 'desc:"extra tag string to add to any file names output from sim (e.g., weights files, log files, params for run)"')
+ self.SetTags(
+ "Tag",
+ 'desc:"extra tag string to add to any file names output from sim (e.g., weights files, log files, params for run)"',
+ )
self.MaxRuns = int(10)
self.SetTags("MaxRuns", 'desc:"maximum number of model runs to perform"')
self.MaxEpcs = int(50)
self.SetTags("MaxEpcs", 'desc:"maximum number of epochs to run per model run"')
self.NZeroStop = int(5)
- self.SetTags("NZeroStop", 'desc:"if a positive number, training will stop after this many epochs with zero SSE"')
+ self.SetTags(
+ "NZeroStop",
+ 'desc:"if a positive number, training will stop after this many epochs with zero SSE"',
+ )
self.TrainEnv = env.FixedTable()
- self.SetTags("TrainEnv", 'desc:"Training environment -- contains everything about iterating over input / output patterns over training"')
+ self.SetTags(
+ "TrainEnv",
+ 'desc:"Training environment -- contains everything about iterating over input / output patterns over training"',
+ )
self.TestEnv = env.FixedTable()
- self.SetTags("TestEnv", 'desc:"Testing environment -- manages iterating over testing"')
+ self.SetTags(
+ "TestEnv", 'desc:"Testing environment -- manages iterating over testing"'
+ )
self.Time = leabra.Time()
self.SetTags("Time", 'desc:"leabra timing parameters and state"')
self.ViewOn = True
- self.SetTags("ViewOn", 'desc:"whether to update the network view while running"')
+ self.SetTags(
+ "ViewOn", 'desc:"whether to update the network view while running"'
+ )
self.TrainUpdt = leabra.TimeScales.AlphaCycle
- self.SetTags("TrainUpdt", 'desc:"at what time scale to update the display during training? Anything longer than Epoch updates at Epoch in this model"')
+ self.SetTags(
+ "TrainUpdt",
+ 'desc:"at what time scale to update the display during training? Anything longer than Epoch updates at Epoch in this model"',
+ )
self.TestUpdt = leabra.TimeScales.Cycle
- self.SetTags("TestUpdt", 'desc:"at what time scale to update the display during testing? Anything longer than Epoch updates at Epoch in this model"')
+ self.SetTags(
+ "TestUpdt",
+ 'desc:"at what time scale to update the display during testing? Anything longer than Epoch updates at Epoch in this model"',
+ )
self.TestInterval = int(5)
- self.SetTags("TestInterval", 'desc:"how often to run through all the test patterns, in terms of training epochs -- can use 0 or -1 for no testing"')
+ self.SetTags(
+ "TestInterval",
+ 'desc:"how often to run through all the test patterns, in terms of training epochs -- can use 0 or -1 for no testing"',
+ )
self.LayStatNms = go.Slice_string(["Hidden1", "Hidden2", "Output"])
- self.SetTags("LayStatNms", 'desc:"names of layers to collect more detailed stats on (avg act, etc)"')
+ self.SetTags(
+ "LayStatNms",
+ 'desc:"names of layers to collect more detailed stats on (avg act, etc)"',
+ )
# statistics: note use float64 as that is best for etable.Table
self.TrlErr = float()
- self.SetTags("TrlErr", 'inactive:"+" desc:"1 if trial was error, 0 if correct -- based on SSE = 0 (subject to .5 unit-wise tolerance)"')
+ self.SetTags(
+ "TrlErr",
+ 'inactive:"+" desc:"1 if trial was error, 0 if correct -- based on SSE = 0 (subject to .5 unit-wise tolerance)"',
+ )
self.TrlSSE = float()
self.SetTags("TrlSSE", 'inactive:"+" desc:"current trial\'s sum squared error"')
self.TrlAvgSSE = float()
- self.SetTags("TrlAvgSSE", 'inactive:"+" desc:"current trial\'s average sum squared error"')
+ self.SetTags(
+ "TrlAvgSSE",
+ 'inactive:"+" desc:"current trial\'s average sum squared error"',
+ )
self.TrlCosDiff = float()
- self.SetTags("TrlCosDiff", 'inactive:"+" desc:"current trial\'s cosine difference"')
+ self.SetTags(
+ "TrlCosDiff", 'inactive:"+" desc:"current trial\'s cosine difference"'
+ )
self.EpcSSE = float()
- self.SetTags("EpcSSE", 'inactive:"+" desc:"last epoch\'s total sum squared error"')
+ self.SetTags(
+ "EpcSSE", 'inactive:"+" desc:"last epoch\'s total sum squared error"'
+ )
self.EpcAvgSSE = float()
- self.SetTags("EpcAvgSSE", 'inactive:"+" desc:"last epoch\'s average sum squared error (average over trials, and over units within layer)"')
+ self.SetTags(
+ "EpcAvgSSE",
+ 'inactive:"+" desc:"last epoch\'s average sum squared error (average over trials, and over units within layer)"',
+ )
self.EpcPctErr = float()
self.SetTags("EpcPctErr", 'inactive:"+" desc:"last epoch\'s average TrlErr"')
self.EpcPctCor = float()
- self.SetTags("EpcPctCor", 'inactive:"+" desc:"1 - last epoch\'s average TrlErr"')
+ self.SetTags(
+ "EpcPctCor", 'inactive:"+" desc:"1 - last epoch\'s average TrlErr"'
+ )
self.EpcCosDiff = float()
- self.SetTags("EpcCosDiff", 'inactive:"+" desc:"last epoch\'s average cosine difference for output layer (a normalized error measure, maximum of 1 when the minus phase exactly matches the plus)"')
+ self.SetTags(
+ "EpcCosDiff",
+ 'inactive:"+" desc:"last epoch\'s average cosine difference for output layer (a normalized error measure, maximum of 1 when the minus phase exactly matches the plus)"',
+ )
self.EpcPerTrlMSec = float()
- self.SetTags("EpcPerTrlMSec", 'inactive:"+" desc:"how long did the epoch take per trial in wall-clock milliseconds"')
+ self.SetTags(
+ "EpcPerTrlMSec",
+ 'inactive:"+" desc:"how long did the epoch take per trial in wall-clock milliseconds"',
+ )
self.FirstZero = int()
- self.SetTags("FirstZero", 'inactive:"+" desc:"epoch at when SSE first went to zero"')
+ self.SetTags(
+ "FirstZero", 'inactive:"+" desc:"epoch at when SSE first went to zero"'
+ )
self.NZero = int()
- self.SetTags("NZero", 'inactive:"+" desc:"number of epochs in a row with zero SSE"')
+ self.SetTags(
+ "NZero", 'inactive:"+" desc:"number of epochs in a row with zero SSE"'
+ )
# internal state - view:"-"
self.SumErr = float()
- self.SetTags("SumErr", 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"')
+ self.SetTags(
+ "SumErr",
+ 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"',
+ )
self.SumSSE = float()
- self.SetTags("SumSSE", 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"')
+ self.SetTags(
+ "SumSSE",
+ 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"',
+ )
self.SumAvgSSE = float()
- self.SetTags("SumAvgSSE", 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"')
+ self.SetTags(
+ "SumAvgSSE",
+ 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"',
+ )
self.SumCosDiff = float()
- self.SetTags("SumCosDiff", 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"')
+ self.SetTags(
+ "SumCosDiff",
+ 'view:"-" inactive:"+" desc:"sum to increment as we go through epoch"',
+ )
self.Win = 0
self.SetTags("Win", 'view:"-" desc:"main GUI window"')
@@ -261,22 +405,31 @@ def __init__(self):
self.ValsTsrs = {}
self.SetTags("ValsTsrs", 'view:"-" desc:"for holding layer values"')
self.SaveWts = False
- self.SetTags("SaveWts", 'view:"-" desc:"for command-line run only, auto-save final weights after each run"')
+ self.SetTags(
+ "SaveWts",
+ 'view:"-" desc:"for command-line run only, auto-save final weights after each run"',
+ )
self.NoGui = False
self.SetTags("NoGui", 'view:"-" desc:"if true, runing in no GUI mode"')
self.LogSetParams = False
- self.SetTags("LogSetParams", 'view:"-" desc:"if true, print message for all params that are set"')
+ self.SetTags(
+ "LogSetParams",
+ 'view:"-" desc:"if true, print message for all params that are set"',
+ )
self.IsRunning = False
self.SetTags("IsRunning", 'view:"-" desc:"true if sim is running"')
self.StopNow = False
self.SetTags("StopNow", 'view:"-" desc:"flag to stop running"')
self.NeedsNewRun = False
- self.SetTags("NeedsNewRun", 'view:"-" desc:"flag to initialize NewRun if last one finished"')
+ self.SetTags(
+ "NeedsNewRun",
+ 'view:"-" desc:"flag to initialize NewRun if last one finished"',
+ )
self.RndSeed = int(1)
self.SetTags("RndSeed", 'view:"-" desc:"the current random seed"')
self.LastEpcTime = int()
self.SetTags("LastEpcTime", 'view:"-" desc:"timer for last epoch"')
- self.vp = 0
+ self.vp = 0
self.SetTags("vp", 'view:"-" desc:"viewport"')
def InitParams(ss):
@@ -301,9 +454,9 @@ def Config(ss):
ss.ConfigRunLog(ss.RunLog)
def ConfigEnv(ss):
- if ss.MaxRuns == 0: # allow user override
+ if ss.MaxRuns == 0: # allow user override
ss.MaxRuns = 10
- if ss.MaxEpcs == 0: # allow user override
+ if ss.MaxEpcs == 0: # allow user override
ss.MaxEpcs = 50
ss.NZeroStop = 5
@@ -311,7 +464,9 @@ def ConfigEnv(ss):
ss.TrainEnv.Dsc = "training params and state"
ss.TrainEnv.Table = etable.NewIdxView(ss.Pats)
ss.TrainEnv.Validate()
- ss.TrainEnv.Run.Max = ss.MaxRuns # note: we are not setting epoch max -- do that manually
+ ss.TrainEnv.Run.Max = (
+ ss.MaxRuns
+ ) # note: we are not setting epoch max -- do that manually
ss.TestEnv.Nm = "TestEnv"
ss.TestEnv.Dsc = "testing params and state"
@@ -337,7 +492,11 @@ def ConfigNet(ss, net):
# use this to position layers relative to each other
# default is Above, YAlign = Front, XAlign = Center
- hid2.SetRelPos(relpos.Rel(Rel= relpos.RightOf, Other= "Hidden1", YAlign= relpos.Front, Space= 2))
+ hid2.SetRelPos(
+ relpos.Rel(
+ Rel=relpos.RightOf, Other="Hidden1", YAlign=relpos.Front, Space=2
+ )
+ )
# note: see emergent/prjn module for all the options on how to connect
# NewFull returns a new prjn.Full connectivity pattern
@@ -360,7 +519,7 @@ def ConfigNet(ss, net):
# and thus removes error-driven learning -- but stats are still computed.
net.Defaults()
- ss.SetParams("Network", ss.LogSetParams) # only set Network params
+ ss.SetParams("Network", ss.LogSetParams) # only set Network params
net.Build()
net.InitWts()
@@ -373,7 +532,7 @@ def Init(ss):
ss.ConfigEnv()
ss.StopNow = False
- ss.SetParams("", ss.LogSetParams) # all sheets
+ ss.SetParams("", ss.LogSetParams) # all sheets
ss.NewRun()
ss.UpdateView(True)
@@ -391,9 +550,21 @@ def Counters(ss, train):
and add a few tabs at the end to allow for expansion..
"""
if train:
- return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TrainEnv.Trial.Cur, ss.Time.Cycle, ss.TrainEnv.TrialName.Cur)
+ return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (
+ ss.TrainEnv.Run.Cur,
+ ss.TrainEnv.Epoch.Cur,
+ ss.TrainEnv.Trial.Cur,
+ ss.Time.Cycle,
+ ss.TrainEnv.TrialName.Cur,
+ )
else:
- return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TestEnv.Trial.Cur, ss.Time.Cycle, ss.TestEnv.TrialName.Cur)
+ return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (
+ ss.TrainEnv.Run.Cur,
+ ss.TrainEnv.Epoch.Cur,
+ ss.TestEnv.Trial.Cur,
+ ss.Time.Cycle,
+ ss.TestEnv.TrialName.Cur,
+ )
def UpdateView(ss, train):
if ss.NetView != 0 and ss.NetView.IsVisible():
@@ -411,7 +582,7 @@ def AlphaCyc(ss, train):
"""
if ss.Win != 0:
- ss.Win.PollEvents() # this is essential for GUI responsiveness while running
+ ss.Win.PollEvents() # this is essential for GUI responsiveness while running
viewUpdt = ss.TrainUpdt.value
if not train:
viewUpdt = ss.TestUpdt.value
@@ -433,10 +604,10 @@ def AlphaCyc(ss, train):
ss.Time.CycleInc()
if ss.ViewOn:
if viewUpdt == leabra.Cycle:
- if cyc != ss.Time.CycPerQtr-1: # will be updated by quarter
+ if cyc != ss.Time.CycPerQtr - 1: # will be updated by quarter
ss.UpdateView(train)
if viewUpdt == leabra.FastSpike:
- if (cyc+1)%10 == 0:
+ if (cyc + 1) % 10 == 0:
ss.UpdateView(train)
ss.Net.QuarterFinal(ss.Time)
ss.Time.QuarterInc()
@@ -452,8 +623,7 @@ def AlphaCyc(ss, train):
if ss.ViewOn and viewUpdt == leabra.AlphaCycle:
ss.UpdateView(train)
if ss.TstCycPlot != 0 and not train:
- ss.TstCycPlot.GoUpdate() # make sure up-to-date at end
-
+ ss.TstCycPlot.GoUpdate() # make sure up-to-date at end
def ApplyInputs(ss, en):
"""
@@ -464,7 +634,7 @@ def ApplyInputs(ss, en):
"""
lays = ["Input", "Output"]
- for lnm in lays :
+ for lnm in lays:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
pats = en.State(ly.Nm)
if pats != 0:
@@ -488,12 +658,14 @@ def TrainTrial(ss):
ss.LogTrnEpc(ss.TrnEpcLog)
if ss.ViewOn and ss.TrainUpdt.value > leabra.AlphaCycle:
ss.UpdateView(True)
- if ss.TestInterval > 0 and epc%ss.TestInterval == 0: # note: epc is *next* so won't trigger first time
+ if (
+ ss.TestInterval > 0 and epc % ss.TestInterval == 0
+ ): # note: epc is *next* so won't trigger first time
ss.TestAll()
if epc >= ss.MaxEpcs or (ss.NZeroStop > 0 and ss.NZero >= ss.NZeroStop):
# done with training..
ss.RunEnd()
- if ss.TrainEnv.Run.Incr(): # we are done!
+ if ss.TrainEnv.Run.Incr(): # we are done!
ss.StopNow = True
return
else:
@@ -501,8 +673,8 @@ def TrainTrial(ss):
return
ss.ApplyInputs(ss.TrainEnv)
- ss.AlphaCyc(True) # train
- ss.TrialStats(True) # accumulate
+ ss.AlphaCyc(True) # train
+ ss.TrialStats(True) # accumulate
def RunEnd(ss):
"""
@@ -512,7 +684,7 @@ def RunEnd(ss):
if ss.SaveWts:
fnm = ss.WeightsFileName()
print("Saving Weights to: %s\n" % fnm)
- ss.Net.SaveWtsJSON(gi.FileName(fnm))
+ ss.Net.SaveWtsJSON(gi.Filename(fnm))
def NewRun(ss):
"""
@@ -560,7 +732,7 @@ def TrialStats(ss, accum):
"""
out = leabra.Layer(ss.Net.LayerByName("Output"))
ss.TrlCosDiff = float(out.CosDiff.Cos)
- ss.TrlSSE = out.SSE(0.5) # 0.5 = per-unit tolerance -- right side of .5
+ ss.TrlSSE = out.SSE(0.5) # 0.5 = per-unit tolerance -- right side of .5
ss.TrlAvgSSE = ss.TrlSSE / len(out.Neurons)
if ss.TrlSSE > 0:
ss.TrlErr = 1
@@ -720,7 +892,7 @@ def SetParamsSet(ss, setNm, sheet, setMsg):
if sheet == "" or sheet == "Sim":
if "Sim" in pset.Sheets:
- simp= pset.SheetByNameTry("Sim")
+ simp = pset.SheetByNameTry("Sim")
pyparams.ApplyParams(ss, simp, setMsg)
def ConfigPats(ss):
@@ -728,9 +900,21 @@ def ConfigPats(ss):
dt.SetMetaData("name", "TrainPats")
dt.SetMetaData("desc", "Training patterns")
sch = etable.Schema(
- [etable.Column("Name", etensor.STRING, go.nil, go.nil),
- etable.Column("Input", etensor.FLOAT32, go.Slice_int([5, 5]), go.Slice_string(["Y", "X"])),
- etable.Column("Output", etensor.FLOAT32, go.Slice_int([5, 5]), go.Slice_string(["Y", "X"]))]
+ [
+ etable.Column("Name", etensor.STRING, go.nil, go.nil),
+ etable.Column(
+ "Input",
+ etensor.FLOAT32,
+ go.Slice_int([5, 5]),
+ go.Slice_string(["Y", "X"]),
+ ),
+ etable.Column(
+ "Output",
+ etensor.FLOAT32,
+ go.Slice_int([5, 5]),
+ go.Slice_string(["Y", "X"]),
+ ),
+ ]
)
dt.SetFromSchema(sch, 25)
@@ -775,7 +959,14 @@ def WeightsFileName(ss):
"""
WeightsFileName returns default current weights file name
"""
- return ss.Net.Nm + "_" + ss.RunName() + "_" + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur) + ".wts"
+ return (
+ ss.Net.Nm
+ + "_"
+ + ss.RunName()
+ + "_"
+ + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur)
+ + ".wts"
+ )
def LogFileName(ss, lognm):
"""
@@ -828,7 +1019,7 @@ def LogTrnEpc(ss, dt):
for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
- dt.SetCellFloat(ly.Nm+"_ActAvg", row, float(ly.Pool(0).ActAvg.ActPAvgEff))
+ dt.SetCellFloat(ly.Nm + "_ActAvg", row, float(ly.Pool(0).ActAvg.ActPAvgEff))
if ss.TrnEpcPlot != 0:
ss.TrnEpcPlot.GoUpdate()
@@ -844,14 +1035,16 @@ def ConfigTrnEpcLog(ss, dt):
dt.SetMetaData("precision", str(LogPrec))
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PerTrlMSec", etensor.FLOAT64, go.nil, go.nil)]
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PerTrlMSec", etensor.FLOAT64, go.nil, go.nil),
+ ]
)
for lnm in ss.LayStatNms:
sch.append(etable.Column(lnm + "_ActAvg", etensor.FLOAT64, go.nil, go.nil))
@@ -872,7 +1065,9 @@ def ConfigTrnEpcPlot(ss, plt, dt):
plt.SetColParams("PerTrlMSec", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
for lnm in ss.LayStatNms:
- plt.SetColParams(lnm+"_ActAvg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, .5)
+ plt.SetColParams(
+ lnm + "_ActAvg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 0.5
+ )
return plt
def LogTstTrl(ss, dt):
@@ -901,7 +1096,7 @@ def LogTstTrl(ss, dt):
for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
- dt.SetCellFloat(ly.Nm+" ActM.Avg", row, float(ly.Pool(0).ActM.Avg))
+ dt.SetCellFloat(ly.Nm + " ActM.Avg", row, float(ly.Pool(0).ActM.Avg))
ivt = ss.ValsTsr("Input")
ovt = ss.ValsTsr("Output")
inp.UnitValsTensor(ivt, "Act")
@@ -923,20 +1118,24 @@ def ConfigTstTrlLog(ss, dt):
dt.SetMetaData("read-only", "true")
dt.SetMetaData("precision", str(LogPrec))
- nt = ss.TestEnv.Table.Len() # number in view
+ nt = ss.TestEnv.Table.Len() # number in view
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
- etable.Column("Trial", etensor.INT64, go.nil, go.nil),
- etable.Column("TrialName", etensor.STRING, go.nil, go.nil),
- etable.Column("Err", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil)]
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
+ etable.Column("Trial", etensor.INT64, go.nil, go.nil),
+ etable.Column("TrialName", etensor.STRING, go.nil, go.nil),
+ etable.Column("Err", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ ]
)
- for lnm in ss.LayStatNms :
- sch.append(etable.Column(lnm + " ActM.Avg", etensor.FLOAT64, go.nil, go.nil))
-
+ for lnm in ss.LayStatNms:
+ sch.append(
+ etable.Column(lnm + " ActM.Avg", etensor.FLOAT64, go.nil, go.nil)
+ )
+
sch.append(etable.Column("InAct", etensor.FLOAT64, inp.Shp.Shp, go.nil))
sch.append(etable.Column("OutActM", etensor.FLOAT64, out.Shp.Shp, go.nil))
sch.append(etable.Column("OutActP", etensor.FLOAT64, out.Shp.Shp, go.nil))
@@ -957,7 +1156,9 @@ def ConfigTstTrlPlot(ss, plt, dt):
plt.SetColParams("CosDiff", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1)
for lnm in ss.LayStatNms:
- plt.SetColParams(lnm+" ActM.Avg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, .5)
+ plt.SetColParams(
+ lnm + " ActM.Avg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 0.5
+ )
plt.SetColParams("InAct", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
plt.SetColParams("OutActM", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
@@ -970,7 +1171,7 @@ def LogTstEpc(ss, dt):
trl = ss.TstTrlLog
tix = etable.NewIdxView(trl)
- epc = ss.TrainEnv.Epoch.Prv # ?
+ epc = ss.TrainEnv.Epoch.Prv # ?
# note: this shows how to use agg methods to compute summary data from another
# data table, instead of incrementing on the Sim
@@ -979,11 +1180,11 @@ def LogTstEpc(ss, dt):
dt.SetCellFloat("SSE", row, agg.Sum(tix, "SSE")[0])
dt.SetCellFloat("AvgSSE", row, agg.Mean(tix, "AvgSSE")[0])
dt.SetCellFloat("PctErr", row, agg.Mean(tix, "Err")[0])
- dt.SetCellFloat("PctCor", row, 1-agg.Mean(tix, "Err")[0])
+ dt.SetCellFloat("PctCor", row, 1 - agg.Mean(tix, "Err")[0])
dt.SetCellFloat("CosDiff", row, agg.Mean(tix, "CosDiff")[0])
trlix = etable.NewIdxView(trl)
- trlix.Filter(FilterSSE) # requires separate function
+ trlix.Filter(FilterSSE) # requires separate function
ss.TstErrLog = trlix.NewTable()
@@ -1007,13 +1208,15 @@ def ConfigTstEpcLog(ss, dt):
dt.SetMetaData("precision", str(LogPrec))
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil)]
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Epoch", etensor.INT64, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ ]
)
dt.SetFromSchema(sch, 0)
@@ -1026,8 +1229,12 @@ def ConfigTstEpcPlot(ss, plt, dt):
plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
- plt.SetColParams("PctErr", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) # default plot
- plt.SetColParams("PctCor", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) # default plot
+ plt.SetColParams(
+ "PctErr", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1
+ ) # default plot
+ plt.SetColParams(
+ "PctCor", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1
+ ) # default plot
plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
return plt
@@ -1042,10 +1249,10 @@ def LogTstCyc(ss, dt, cyc):
dt.SetCellFloat("Cycle", cyc, float(cyc))
for lnm in ss.LayStatNms:
ly = leabra.Layer(ss.Net.LayerByName(lnm))
- dt.SetCellFloat(ly.Nm+" Ge.Avg", cyc, float(ly.Pool(0).Inhib.Ge.Avg))
- dt.SetCellFloat(ly.Nm+" Act.Avg", cyc, float(ly.Pool(0).Inhib.Act.Avg))
+ dt.SetCellFloat(ly.Nm + " Ge.Avg", cyc, float(ly.Pool(0).Inhib.Ge.Avg))
+ dt.SetCellFloat(ly.Nm + " Act.Avg", cyc, float(ly.Pool(0).Inhib.Act.Avg))
- if ss.TstCycPlot != 0 and cyc%10 == 0: # too slow to do every cyc
+ if ss.TstCycPlot != 0 and cyc % 10 == 0: # too slow to do every cyc
# note: essential to use Go version of update when called from another goroutine
ss.TstCycPlot.GoUpdate()
@@ -1055,10 +1262,8 @@ def ConfigTstCycLog(ss, dt):
dt.SetMetaData("read-only", "true")
dt.SetMetaData("precision", str(LogPrec))
- np = 100 # max cycles
- sch = etable.Schema(
- [etable.Column("Cycle", etensor.INT64, go.nil, go.nil)]
- )
+ np = 100 # max cycles
+ sch = etable.Schema([etable.Column("Cycle", etensor.INT64, go.nil, go.nil)])
for lnm in ss.LayStatNms:
sch.append(etable.Column(lnm + " Ge.Avg", etensor.FLOAT64, go.nil, go.nil))
sch.append(etable.Column(lnm + " Act.Avg", etensor.FLOAT64, go.nil, go.nil))
@@ -1071,8 +1276,8 @@ def ConfigTstCycPlot(ss, plt, dt):
# order of params: on, fixMin, min, fixMax, max
plt.SetColParams("Cycle", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
for lnm in ss.LayStatNms:
- plt.SetColParams(lnm+" Ge.Avg", True, True, 0, True, .5)
- plt.SetColParams(lnm+" Act.Avg", True, True, 0, True, .5)
+ plt.SetColParams(lnm + " Ge.Avg", True, True, 0, True, 0.5)
+ plt.SetColParams(lnm + " Act.Avg", True, True, 0, True, 0.5)
return plt
def LogRun(ss, dt):
@@ -1083,18 +1288,18 @@ def LogRun(ss, dt):
epcix = etable.NewIdxView(epclog)
if epcix.Len() == 0:
return
-
- run = ss.TrainEnv.Run.Cur # this is NOT triggered by increment yet -- use Cur
+
+ run = ss.TrainEnv.Run.Cur # this is NOT triggered by increment yet -- use Cur
row = dt.Rows
dt.SetNumRows(row + 1)
# compute mean over last N epochs for run level
nlast = 5
- if nlast > epcix.Len()-1:
+ if nlast > epcix.Len() - 1:
nlast = epcix.Len() - 1
- epcix.Idxs = epcix.Idxs[epcix.Len()-nlast:]
+ epcix.Idxs = epcix.Idxs[epcix.Len() - nlast :]
- params = ss.RunName() # includes tag
+ params = ss.RunName() # includes tag
dt.SetCellFloat("Run", row, float(run))
dt.SetCellString("Params", row, params)
@@ -1126,14 +1331,16 @@ def ConfigRunLog(ss, dt):
dt.SetMetaData("precision", str(LogPrec))
sch = etable.Schema(
- [etable.Column("Run", etensor.INT64, go.nil, go.nil),
- etable.Column("Params", etensor.STRING, go.nil, go.nil),
- etable.Column("FirstZero", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
- etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil)]
+ [
+ etable.Column("Run", etensor.INT64, go.nil, go.nil),
+ etable.Column("Params", etensor.STRING, go.nil, go.nil),
+ etable.Column("FirstZero", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("SSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("AvgSSE", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctErr", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("PctCor", etensor.FLOAT64, go.nil, go.nil),
+ etable.Column("CosDiff", etensor.FLOAT64, go.nil, go.nil),
+ ]
)
dt.SetFromSchema(sch, 0)
@@ -1143,7 +1350,9 @@ def ConfigRunPlot(ss, plt, dt):
plt.SetTable(dt)
# order of params: on, fixMin, min, fixMax, max
plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
- plt.SetColParams("FirstZero", eplot.On, eplot.FixMin, 0, eplot.FloatMax, 0) # default plot
+ plt.SetColParams(
+ "FirstZero", eplot.On, eplot.FixMin, 0, eplot.FloatMax, 0
+ ) # default plot
plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
plt.SetColParams("PctErr", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
@@ -1151,15 +1360,17 @@ def ConfigRunPlot(ss, plt, dt):
plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
return plt
- def ConfigGui(ss):
+ def ConfigGUI(ss):
"""
- ConfigGui configures the GoGi gui interface for this simulation,
+ ConfigGUI configures the GoGi gui interface for this simulation,
"""
width = 1600
height = 1200
gi.SetAppName("ra25")
- gi.SetAppAbout('This demonstrates a basic Leabra model. See emergent on GitHub.')
+ gi.SetAppAbout(
+ 'This demonstrates a basic Leabra model. See emergent on GitHub.'
+ )
win = gi.NewMainWindow("ra25", "Leabra Random Associator", width, height)
ss.Win = win
@@ -1193,7 +1404,9 @@ def ConfigGui(ss):
nv.SetNet(ss.Net)
ss.NetView = nv
- nv.Scene().Camera.Pose.Pos.Set(0, 1, 2.75) # more "head on" than default which is more "top down"
+ nv.Scene().Camera.Pose.Pos.Set(
+ 0, 1, 2.75
+ ) # more "head on" than default which is more "top down"
nv.Scene().Camera.LookAt(mat32.Vec3(0, 0, 0), mat32.Vec3(0, 1, 0))
plt = eplot.Plot2D()
@@ -1216,39 +1429,144 @@ def ConfigGui(ss):
tv.AddTab(plt, "RunPlot")
ss.RunPlot = ss.ConfigRunPlot(plt, ss.RunLog)
- split.SetSplitsList(go.Slice_float32([.2, .8]))
+ split.SetSplitsList(go.Slice_float32([0.2, 0.8]))
recv = win.This()
-
- tbar.AddAction(gi.ActOpts(Label="Init", Icon="update", Tooltip="Initialize everything including network weights, and start over. Also applies current params.", UpdateFunc=UpdtFuncNotRunning), recv, InitCB)
-
- tbar.AddAction(gi.ActOpts(Label="Train", Icon="run", Tooltip="Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.", UpdateFunc=UpdtFuncNotRunning), recv, TrainCB)
-
- tbar.AddAction(gi.ActOpts(Label="Stop", Icon="stop", Tooltip="Interrupts running. Hitting Train again will pick back up where it left off.", UpdateFunc=UpdtFuncRunning), recv, StopCB)
-
- tbar.AddAction(gi.ActOpts(Label="Step Trial", Icon="step-fwd", Tooltip="Advances one training trial at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepTrialCB)
-
- tbar.AddAction(gi.ActOpts(Label="Step Epoch", Icon="fast-fwd", Tooltip="Advances one epoch (complete set of training patterns) at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepEpochCB)
-
- tbar.AddAction(gi.ActOpts(Label="Step Run", Icon="fast-fwd", Tooltip="Advances one full training Run at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepRunCB)
-
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Init",
+ Icon="update",
+ Tooltip="Initialize everything including network weights, and start over. Also applies current params.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ InitCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Train",
+ Icon="run",
+ Tooltip="Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TrainCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Stop",
+ Icon="stop",
+ Tooltip="Interrupts running. Hitting Train again will pick back up where it left off.",
+ UpdateFunc=UpdtFuncRunning,
+ ),
+ recv,
+ StopCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Step Trial",
+ Icon="step-fwd",
+ Tooltip="Advances one training trial at a time.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ StepTrialCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Step Epoch",
+ Icon="fast-fwd",
+ Tooltip="Advances one epoch (complete set of training patterns) at a time.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ StepEpochCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Step Run",
+ Icon="fast-fwd",
+ Tooltip="Advances one full training Run at a time.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ StepRunCB,
+ )
+
tbar.AddSeparator("test")
-
- tbar.AddAction(gi.ActOpts(Label="Test Trial", Icon="step-fwd", Tooltip="Runs the next testing trial.", UpdateFunc=UpdtFuncNotRunning), recv, TestTrialCB)
-
- tbar.AddAction(gi.ActOpts(Label="Test Item", Icon="step-fwd", Tooltip="Prompts for a specific input pattern name to run, and runs it in testing mode.", UpdateFunc=UpdtFuncNotRunning), recv, TestItemCB)
-
- tbar.AddAction(gi.ActOpts(Label="Test All", Icon="fast-fwd", Tooltip="Tests all of the testing trials.", UpdateFunc=UpdtFuncNotRunning), recv, TestAllCB)
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Test Trial",
+ Icon="step-fwd",
+ Tooltip="Runs the next testing trial.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TestTrialCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Test Item",
+ Icon="step-fwd",
+ Tooltip="Prompts for a specific input pattern name to run, and runs it in testing mode.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TestItemCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Test All",
+ Icon="fast-fwd",
+ Tooltip="Tests all of the testing trials.",
+ UpdateFunc=UpdtFuncNotRunning,
+ ),
+ recv,
+ TestAllCB,
+ )
tbar.AddSeparator("log")
-
- tbar.AddAction(gi.ActOpts(Label="Reset RunLog", Icon="reset", Tooltip="Resets the accumulated log of all Runs, which are tagged with the ParamSet used"), recv, ResetRunLogCB)
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="Reset RunLog",
+ Icon="reset",
+ Tooltip="Resets the accumulated log of all Runs, which are tagged with the ParamSet used",
+ ),
+ recv,
+ ResetRunLogCB,
+ )
tbar.AddSeparator("misc")
-
- tbar.AddAction(gi.ActOpts(Label="New Seed", Icon="new", Tooltip="Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time."), recv, NewRndSeedCB)
- tbar.AddAction(gi.ActOpts(Label="README", Icon="file-markdown", Tooltip="Opens your browser on the README file that contains instructions for how to run this model."), recv, ReadmeCB)
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="New Seed",
+ Icon="new",
+ Tooltip="Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time.",
+ ),
+ recv,
+ NewRndSeedCB,
+ )
+
+ tbar.AddAction(
+ gi.ActOpts(
+ Label="README",
+ Icon="file-markdown",
+ Tooltip="Opens your browser on the README file that contains instructions for how to run this model.",
+ ),
+ recv,
+ ReadmeCB,
+ )
# main menu
appnm = gi.AppName()
@@ -1282,17 +1600,32 @@ def ConfigGui(ss):
# TheSim is the overall state for this simulation
TheSim = Sim()
+
def usage():
- print(sys.argv[0] + " --params= --tag= --setparams --wts --epclog=0 --runlog=0 --nogui")
+ print(
+ sys.argv[0]
+ + " --params= --tag= --setparams --wts --epclog=0 --runlog=0 --nogui"
+ )
print("\t pyleabra -i %s to run in interactive, gui mode" % sys.argv[0])
- print("\t --params= additional params to apply on top of Base (name must be in loaded Params")
- print("\t --tag= tag is appended to file names to uniquely identify this run")
+ print(
+ "\t --params= additional params to apply on top of Base (name must be in loaded Params"
+ )
+ print(
+ "\t --tag= tag is appended to file names to uniquely identify this run"
+ )
print("\t --runs= number of runs to do")
print("\t --setparams show the parameter values that are set")
print("\t --wts save final trained weights after every run")
- print("\t --epclog=0/False turn off save training epoch log data to file named by param set, tag")
- print("\t --runlog=0/False turn off save run log data to file named by param set, tag")
- print("\t --nogui if no other args needed, this prevents running under the gui")
+ print(
+ "\t --epclog=0/False turn off save training epoch log data to file named by param set, tag"
+ )
+ print(
+ "\t --runlog=0/False turn off save run log data to file named by param set, tag"
+ )
+ print(
+ "\t --nogui if no other args needed, this prevents running under the gui"
+ )
+
def main(argv):
TheSim.Config()
@@ -1301,15 +1634,28 @@ def main(argv):
TheSim.NoGui = len(argv) > 1
saveEpcLog = True
saveRunLog = True
-
+
try:
- opts, args = getopt.getopt(argv,"h:",["params=","tag=","runs=","setparams","wts","epclog=","runlog=","nogui"])
+ opts, args = getopt.getopt(
+ argv,
+ "h:",
+ [
+ "params=",
+ "tag=",
+ "runs=",
+ "setparams",
+ "wts",
+ "epclog=",
+ "runlog=",
+ "nogui",
+ ],
+ )
except getopt.GetoptError:
usage()
sys.exit(2)
for opt, arg in opts:
# print("opt: %s arg: %s" % (opt, arg))
- if opt == '-h':
+ if opt == "-h":
usage()
sys.exit()
elif opt == "--tag":
@@ -1332,27 +1678,30 @@ def main(argv):
TheSim.NoGui = True
TheSim.Init()
-
+
if TheSim.NoGui:
if saveEpcLog:
- fnm = TheSim.LogFileName("epc")
+ fnm = TheSim.LogFileName("epc")
print("Saving epoch log to: %s" % fnm)
TheSim.TrnEpcFile = efile.Create(fnm)
-
+
if saveRunLog:
- fnm = TheSim.LogFileName("run")
+ fnm = TheSim.LogFileName("run")
print("Saving run log to: %s" % fnm)
TheSim.RunFile = efile.Create(fnm)
-
+
TheSim.Train()
sys.exit(0)
else:
- TheSim.ConfigGui()
- print("Note: run pyleabra -i ra25.py to run in interactive mode, or just pyleabra, then 'import ra25'")
+ TheSim.ConfigGUI()
+ print(
+ "Note: run pyleabra -i ra25.py to run in interactive mode, or just pyleabra, then 'import ra25'"
+ )
print("for non-gui background running, here are the args:")
usage()
import code
+
code.interact(local=locals())
-main(sys.argv[1:])
+main(sys.argv[1:])