forked from rasmusbergpalm/DeepLearnToolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
renamed SPAE (stacked prediction autoencoder to CAE (convoultional au…
…toencoder) which is actually what it is
- Loading branch information
1 parent
2de71dd
commit aecfdaa
Showing
22 changed files
with
350 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
function cae = caeapplygrads(cae) | ||
cae.sv = 0; | ||
for j = 1 : numel(cae.a) | ||
for i = 1 : numel(cae.i) | ||
% cae.vik{i}{j} = cae.momentum * cae.vik{i}{j} + cae.alpha ./ (cae.sigma + cae.ddik{i}{j}) .* cae.dik{i}{j}; | ||
% cae.vok{i}{j} = cae.momentum * cae.vok{i}{j} + cae.alpha ./ (cae.sigma + cae.ddok{i}{j}) .* cae.dok{i}{j}; | ||
cae.vik{i}{j} = cae.alpha * cae.dik{i}{j}; | ||
cae.vok{i}{j} = cae.alpha * cae.dok{i}{j}; | ||
cae.sv = cae.sv + sum(cae.vik{i}{j}(:) .^ 2); | ||
cae.sv = cae.sv + sum(cae.vok{i}{j}(:) .^ 2); | ||
|
||
cae.ik{i}{j} = cae.ik{i}{j} - cae.vik{i}{j}; | ||
cae.ok{i}{j} = cae.ok{i}{j} - cae.vok{i}{j}; | ||
end | ||
% cae.vb{j} = cae.momentum * cae.vb{j} + cae.alpha / (cae.sigma + cae.ddb{j}) * cae.db{j}; | ||
cae.vb{j} = cae.alpha * cae.db{j}; | ||
cae.sv = cae.sv + sum(cae.vb{j} .^ 2); | ||
|
||
cae.b{j} = cae.b{j} - cae.vb{j}; | ||
end | ||
|
||
for i = 1 : numel(cae.o) | ||
% cae.vc{i} = cae.momentum * cae.vc{i} + cae.alpha / (cae.sigma + cae.ddc{i}) * cae.dc{i}; | ||
cae.vc{i} = cae.alpha * cae.dc{i}; | ||
cae.sv = cae.sv + sum(cae.vc{i} .^ 2); | ||
|
||
cae.c{i} = cae.c{i} - cae.vc{i}; | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
function cae = caebbp(cae) | ||
|
||
%% backprop deltas | ||
for i = 1 : numel(cae.o) | ||
% output delta delta | ||
cae.odd{i} = (cae.o{i} .* (1 - cae.o{i}) .* cae.edgemask) .^ 2; | ||
% delta delta c | ||
cae.ddc{i} = sum(cae.odd{i}(:)) / size(cae.odd{i}, 1); | ||
end | ||
|
||
for j = 1 : numel(cae.a) % calc activation delta deltas | ||
z = 0; | ||
for i = 1 : numel(cae.o) | ||
z = z + convn(cae.odd{i}, flipall(cae.ok{i}{j} .^ 2), 'full'); | ||
end | ||
cae.add{j} = (cae.a{j} .* (1 - cae.a{j})) .^ 2 .* z; | ||
end | ||
|
||
%% calc params delta deltas | ||
ns = size(cae.odd{1}, 1); | ||
for j = 1 : numel(cae.a) | ||
cae.ddb{j} = sum(cae.add{j}(:)) / ns; | ||
for i = 1 : numel(cae.o) | ||
cae.ddok{i}{j} = convn(flipall(cae.a{j} .^ 2), cae.odd{i}, 'valid') / ns; | ||
cae.ddik{i}{j} = convn(cae.add{j}, flipall(cae.i{i} .^ 2), 'valid') / ns; | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
function cae = caebp(cae, y) | ||
|
||
%% backprop deltas | ||
cae.L = 0; | ||
for i = 1 : numel(cae.o) | ||
% error | ||
cae.e{i} = (cae.o{i} - y{i}) .* cae.edgemask; | ||
% loss function | ||
cae.L = cae.L + 1/2 * sum(cae.e{i}(:) .^2 ) / size(cae.e{i}, 1); | ||
% output delta | ||
cae.od{i} = cae.e{i} .* (cae.o{i} .* (1 - cae.o{i})); | ||
|
||
cae.dc{i} = sum(cae.od{i}(:)) / size(cae.e{i}, 1); | ||
end | ||
|
||
for j = 1 : numel(cae.a) % calc activation deltas | ||
z = 0; | ||
for i = 1 : numel(cae.o) | ||
z = z + convn(cae.od{i}, flipall(cae.ok{i}{j}), 'full'); | ||
end | ||
cae.ad{j} = cae.a{j} .* (1 - cae.a{j}) .* z; | ||
end | ||
|
||
%% calc gradients | ||
ns = size(cae.e{1}, 1); | ||
for j = 1 : numel(cae.a) | ||
cae.db{j} = sum(cae.ad{j}(:)) / ns; | ||
for i = 1 : numel(cae.o) | ||
cae.dok{i}{j} = convn(flipall(cae.a{j}), cae.od{i}, 'valid') / ns; | ||
cae.dik{i}{j} = convn(cae.ad{j}, flipall(cae.i{i}), 'valid') / ns; | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function cae = caedown(cae) | ||
pa = cae.a; | ||
pok = cae.ok; | ||
|
||
for i = 1 : numel(cae.o) | ||
z = 0; | ||
for j = 1 : numel(cae.a) | ||
z = z + convn(pa{j}, pok{i}{j}, 'valid'); | ||
end | ||
cae.o{i} = sigm(z + cae.c{i}); | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
function cae = caenumgradcheck(cae, x, y) | ||
epsilon = 1e-4; | ||
er = 1e-6; | ||
disp('performing numerical gradient checking...') | ||
for i = 1 : numel(cae.o) | ||
p_cae = cae; p_cae.c{i} = p_cae.c{i} + epsilon; | ||
m_cae = cae; m_cae.c{i} = m_cae.c{i} - epsilon; | ||
|
||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
|
||
e = abs(d - cae.dc{i}); | ||
if e > er | ||
disp('OUTPUT BIAS numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.dc{i}); | ||
keyboard | ||
end | ||
end | ||
|
||
for a = 1 : numel(cae.a) | ||
|
||
p_cae = cae; p_cae.b{a} = p_cae.b{a} + epsilon; | ||
m_cae = cae; m_cae.b{a} = m_cae.b{a} - epsilon; | ||
|
||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% cae.dok{i}{a}(u) = d; | ||
e = abs(d - cae.db{a}); | ||
if e > er | ||
disp('BIAS numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.db{a}); | ||
keyboard | ||
end | ||
|
||
for i = 1 : numel(cae.o) | ||
for u = 1 : numel(cae.ok{i}{a}) | ||
p_cae = cae; p_cae.ok{i}{a}(u) = p_cae.ok{i}{a}(u) + epsilon; | ||
m_cae = cae; m_cae.ok{i}{a}(u) = m_cae.ok{i}{a}(u) - epsilon; | ||
|
||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% cae.dok{i}{a}(u) = d; | ||
e = abs(d - cae.dok{i}{a}(u)); | ||
if e > er | ||
disp('OUTPUT KERNEL numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.dok{i}{a}(u)); | ||
% keyboard | ||
end | ||
end | ||
end | ||
|
||
for i = 1 : numel(cae.i) | ||
for u = 1 : numel(cae.ik{i}{a}) | ||
p_cae = cae; | ||
m_cae = cae; | ||
p_cae.ik{i}{a}(u) = p_cae.ik{i}{a}(u) + epsilon; | ||
m_cae.ik{i}{a}(u) = m_cae.ik{i}{a}(u) - epsilon; | ||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% cae.dik{i}{a}(u) = d; | ||
e = abs(d - cae.dik{i}{a}(u)); | ||
if e > er | ||
disp('INPUT KERNEL numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.dik{i}{a}(u)); | ||
end | ||
end | ||
end | ||
end | ||
|
||
disp('done') | ||
|
||
end | ||
|
||
function [m_cae, p_cae] = caerun(m_cae, p_cae, x, y) | ||
m_cae = caeup(m_cae, x); m_cae = caedown(m_cae); m_cae = caebp(m_cae, y); | ||
p_cae = caeup(p_cae, x); p_cae = caedown(p_cae); p_cae = caebp(p_cae, y); | ||
end | ||
|
||
%function checknumgrad(cae,what,x,y) | ||
% epsilon = 1e-4; | ||
% er = 1e-9; | ||
% | ||
% for i = 1 : numel(eval(what)) | ||
% if iscell(eval(['cae.' what])) | ||
% checknumgrad(cae,[what '{' num2str(i) '}'], x, y) | ||
% else | ||
% p_cae = cae; | ||
% m_cae = cae; | ||
% eval(['p_cae.' what '(' num2str(i) ')']) = eval([what '(' num2str(i) ')']) + epsilon; | ||
% eval(['m_cae.' what '(' num2str(i) ')']) = eval([what '(' num2str(i) ')']) - epsilon; | ||
% | ||
% m_cae = caeff(m_cae, x); m_cae = caedown(m_cae); m_cae = caebp(m_cae, y); | ||
% p_cae = caeff(p_cae, x); p_cae = caedown(p_cae); p_cae = caebp(p_cae, y); | ||
% | ||
% d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% e = abs(d - eval(['cae.d' what '(' num2str(i) ')'])); | ||
% if e > er | ||
% error('numerical gradient checking failed'); | ||
% end | ||
% end | ||
% end | ||
% | ||
% end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
function cae = caesdlm(cae, opts, m) | ||
%stochastic diagonal levenberg-marquardt | ||
|
||
%first round | ||
if isfield(cae,'ddok') == 0 | ||
cae = caebbp(cae); | ||
end | ||
|
||
%recalculate double grads every opts.ddinterval | ||
if mod(m, opts.ddinterval) == 0 | ||
cae_n = caebbp(cae); | ||
|
||
for ii = 1 : numel(cae.o) | ||
cae.ddc{ii} = opts.ddhist * cae.ddc{ii} + (1 - opts.ddhist) * cae_n.ddc{ii}; | ||
end | ||
|
||
for jj = 1 : numel(cae.a) | ||
cae.ddb{jj} = opts.ddhist * cae.ddb{jj} + (1 - opts.ddhist) * cae_n.ddb{jj}; | ||
for ii = 1 : numel(cae.o) | ||
cae.ddok{ii}{jj} = opts.ddhist * cae.ddok{ii}{jj} + (1 - opts.ddhist) * cae_n.ddok{ii}{jj}; | ||
cae.ddik{ii}{jj} = opts.ddhist * cae.ddik{ii}{jj} + (1 - opts.ddhist) * cae_n.ddik{ii}{jj}; | ||
end | ||
end | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
function cae = caeup(cae, x) | ||
cae.i = x; | ||
|
||
%init temp vars for parrallel processing | ||
pa = cell(size(cae.a)); | ||
pi = cae.i; | ||
pik = cae.ik; | ||
pb = cae.b; | ||
|
||
for j = 1 : numel(cae.a) | ||
z = 0; | ||
for i = 1 : numel(pi) | ||
z = z + convn(pi{i}, pik{i}{j}, 'full'); | ||
end | ||
pa{j} = sigm(z + pb{j}); | ||
|
||
% Max pool. | ||
if ~isequal(cae.scale, [1 1 1]) | ||
pa{j} = max3d(pa{j}, cae.M); | ||
end | ||
|
||
end | ||
cae.a = pa; | ||
|
||
end |
File renamed without changes.
Oops, something went wrong.