-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.m
100 lines (80 loc) · 2.61 KB
/
training.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
%% Set-Up
close all;
clear all;
clc;
addpath("Helpers\", "Data\", "Reset\", "Step\", "Reward\")
obsInfo = rlNumericSpec([13 24 8], Name='emgFeatures');
actInfo = rlFiniteSetSpec([1 2 3 4 5 6], Name='gestures');
obsInfos = {obsInfo, obsInfo};
actInfos = {actInfo, actInfo};
env = rlMultiAgentFunctionEnv(obsInfos, actInfos, @stepFnc, @resetFnc);
NNOptions = rlRepresentationOptions(...
LearnRate=1e-6,...
GradientThreshold=1, ...
Optimizer="adam",...
GradientThresholdMethod="l2norm",...
UseDevice="gpu");
cnnDQN = targetCNN();
criticDQN = rlQValueRepresentation( ...
cnnDQN, ...
obsInfo, ...
actInfo,...
'Observation','state', ...
NNOptions);
cnnDDQN = targetCNN();
criticDDQN = rlQValueRepresentation( ...
cnnDDQN, ...
obsInfo, ...
actInfo,...
'Observation','state', ...
NNOptions);
%% Params
targetSmoothFactor = 1e-3;
miniBatchSize = 32;
numStepsLookAhead = 1;
discountFactor = 0.98;
experienceBufferLength = 50;
epsilonDecay = 1e-5;
%% DQN
agentOptionsDQN = rlDQNAgentOptions( ...
UseDoubleDQN=false, ...
TargetSmoothFactor=targetSmoothFactor, ...
MiniBatchSize=miniBatchSize, ...
NumStepsToLookAhead=numStepsLookAhead, ...
DiscountFactor=discountFactor, ...
SaveExperienceBufferWithAgent=true, ...
ExperienceBufferLength=experienceBufferLength);
agentOptionsDQN.EpsilonGreedyExploration.EpsilonDecay = epsilonDecay;
agentDQN = rlDQNAgent(criticDQN, agentOptionsDQN);
%% DDQN
agentOptionsDDQN = rlDQNAgentOptions( ...
UseDoubleDQN=true, ...
TargetSmoothFactor=targetSmoothFactor, ...
MiniBatchSize=miniBatchSize, ...
NumStepsToLookAhead=numStepsLookAhead, ...
DiscountFactor=discountFactor, ...
SaveExperienceBufferWithAgent=true, ...
ExperienceBufferLength=experienceBufferLength);
agentOptionsDDQN.EpsilonGreedyExploration.EpsilonDecay = epsilonDecay;
agentDDQN = rlDQNAgent(criticDDQN, agentOptionsDDQN);
%% Training
maxEpisodes = prod([150 1 100]); % maxSamples, maxUsers, maxIterations
trainOpts = rlMultiAgentTrainingOptions(...
AgentGroups={[1 2]},...
LearningStrategy=("decentralized"),...
MaxEpisodes=maxEpisodes,...
ScoreAveragingWindowLength=150,...
StopTrainingCriteria="EpisodeCount",...
StopTrainingValue=maxEpisodes,...
SaveAgentCriteria="EpisodeFrequency",...
SaveAgentValue=1000,...
Verbose=true,...
SaveAgentDirectory=pwd + "\Snapshots",...
Plots="training-progress");
clear resetFnc;
clear stepFnc;
results = train([agentDQN, agentDDQN], env, trainOpts);
%% Save agents
save("results.mat", "results")
save("agentDQN.mat", "agentDQN");
save("agentDDQN.mat", "agentDDQN");