Skip to content

Commit

Permalink
Fixed dbnexample. fixes rasmusbergpalm#5. Thanks.
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmusbergpalm committed Oct 16, 2012
1 parent 5f8e1af commit b829694
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 11 deletions.
5 changes: 2 additions & 3 deletions DBN/dbnexamples.m
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

nn = dbnunfoldtonn(dbn);
nn = dbnunfoldtonn(dbn, 10);

nn.alpha = 1;
nn.lambda = 1e-4;
Expand All @@ -36,6 +36,5 @@
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

%disp([num2str(er * 100) '% error']);
printf('%5.2f% error', 100 * er)
disp([num2str(er * 100) '% error']);
figure; visualize(nn.W{1}', 1);
8 changes: 4 additions & 4 deletions DBN/dbnunfoldtonn.m
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function nn = dbnunfoldtonn(dbn)
function nn = dbnunfoldtonn(dbn, outputsize)
%DBNUNFOLDTONN Unfolds a DBN to a NN
% Takes a DBN structure, traverses all layers and assigns upwards weights
% and biases to an equally sized NN structure which it returns
% dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final
% layer of size outputsize added.

nn = nnsetup(dbn.sizes);
nn = nnsetup([dbn.sizes outputsize]);
for i = 1 : numel(dbn.rbm)
nn.W{i} = dbn.rbm{i}.W;
nn.b{i} = dbn.rbm{i}.c;
Expand Down
2 changes: 1 addition & 1 deletion DBN/rbmtrain.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
m = size(x, 1);
numbatches = m / opts.batchsize;

assert(rem(numbatches, 1) ~= 0, 'numbatches not integer');
assert(rem(numbatches, 1) == 0, 'numbatches not integer');


for i = 1 : opts.numepochs
Expand Down
4 changes: 1 addition & 3 deletions NN/nntrain.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

numbatches = m / batchsize;

if rem(numbatches, 1) ~= 0
error('numbatches not integer');
end
assert(rem(numbatches, 1) == 0, 'numbatches not integer');

nn.rL = [];
n = 1;
Expand Down

0 comments on commit b829694

Please sign in to comment.