Skip to content

Commit

Permalink
added xunit test framework and the first test
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmusbergpalm committed Sep 23, 2012
1 parent c296831 commit 2de71dd
Show file tree
Hide file tree
Showing 40 changed files with 2,550 additions and 19 deletions.
38 changes: 19 additions & 19 deletions NN/nnchecknumgrad.m
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
function nnchecknumgrad(net, x, y)
function nnchecknumgrad(nn, x, y)
epsilon = 1e-4;
er = 1e-9;
n = net.n;
n = nn.n;
for l = 1 : (n - 1)
for i = 1 : size(net.W{l}, 1)
for j = 1 : size(net.W{l}, 2)
net_m = net; net_p = net;
net_m.W{l}(i, j) = net.W{l}(i, j) - epsilon;
net_p.W{l}(i, j) = net.W{l}(i, j) + epsilon;
net_m = nnff(net_m, x, y);
net_p = nnff(net_p, x, y);
dW = (net_p.L - net_m.L) / (2 * epsilon);
e = abs(dW - net.dW{l}(i, j));
for i = 1 : size(nn.W{l}, 1)
for j = 1 : size(nn.W{l}, 2)
nn_m = nn; nn_p = nn;
nn_m.W{l}(i, j) = nn.W{l}(i, j) - epsilon;
nn_p.W{l}(i, j) = nn.W{l}(i, j) + epsilon;
nn_m = nnff(nn_m, x, y);
nn_p = nnff(nn_p, x, y);
dW = (nn_p.L - nn_m.L) / (2 * epsilon);
e = abs(dW - nn.dW{l}(i, j));
if e > er
error('numerical gradient checking failed');
end
end
end

for i = 1 : size(net.b{l}, 1)
net_m = net; net_p = net;
net_m.b{l}(i) = net.b{l}(i) - epsilon;
net_p.b{l}(i) = net.b{l}(i) + epsilon;
net_m = nnff(net_m, x, y);
net_p = nnff(net_p, x, y);
db = (net_p.L - net_m.L) / (2 * epsilon);
e = abs(db - net.db{l}(i));
for i = 1 : size(nn.b{l}, 1)
nn_m = nn; nn_p = nn;
nn_m.b{l}(i) = nn.b{l}(i) - epsilon;
nn_p.b{l}(i) = nn.b{l}(i) + epsilon;
nn_m = nnff(nn_m, x, y);
nn_p = nnff(nn_p, x, y);
db = (nn_p.L - nn_m.L) / (2 * epsilon);
e = abs(db - nn.db{l}(i));
if e > er
error('numerical gradient checking failed');
end
Expand Down
3 changes: 3 additions & 0 deletions tests/runalltests.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
clear all; close all; clc;
addpath(genpath('../.'))
runtests
7 changes: 7 additions & 0 deletions tests/test_nn_gradients_are_numerically_correct.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function test_nn_gradients_are_numerically_correct
nn = nnsetup([5 2]);
batch_x = rand(20, 5);
batch_y = rand(20, 2);
nn = nnff(nn, batch_x, batch_y);
nn = nnbp(nn);
nnchecknumgrad(nn, batch_x, batch_y);
24 changes: 24 additions & 0 deletions util/xunit/+xunit/+utils/Contents.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
% UTILS Utility package for MATLAB xUnit Test Framework
%
% Array Comparison
% compareFloats - Compare floating-point arrays using tolerance
%
% Test Case Discovery Functions
% isTestCaseSubclass - True for name of TestCase subclass
%
% String Functions
% arrayToString - Convert array to string for display
% comparisonMessage - Assertion message string for comparing two arrays
% containsRegexp - True if string contains regular expression
% isSetUpString - True for string that looks like a setup function
% isTearDownString - True for string that looks like teardown function
% isTestString - True for string that looks like a test function
% stringToCellArray - Convert string to cell array of strings
%
% Miscellaneous Functions
% generateDoc - Publish test scripts in mtest/doc
% parseFloatAssertInputs - Common input-parsing logic for several functions

% Steven L. Eddins
% Copyright 2008-2009 The MathWorks, Inc.

96 changes: 96 additions & 0 deletions util/xunit/+xunit/+utils/arrayToString.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
function s = arrayToString(A)
%arrayToString Convert array to string for display.
% S = arrayToString(A) converts the array A into a string suitable for
% including in assertion messages. Small arrays are converted using disp(A).
% Large arrays are displayed similar to the way structure field values display
% using disp.

% Steven L. Eddins
% Copyright 2009 The MathWorks, Inc.

if isTooBigToDisp(A)
s = dispAsStructField(A);
else
s = dispAsArray(A);
end

%===============================================================================
function tf = isTooBigToDisp(A)
% Use a heuristic to determine if the array is to convert to a string using
% disp. The heuristic is based on the size of the array in bytes, as reported
% by the whos function.

whos_output = whos('A');
byte_threshold = 1000;
tf = whos_output.bytes > byte_threshold;

%===============================================================================
function s = dispAsArray(A)
% Convert A to a string using disp. Remove leading and trailing blank lines.

s = evalc('disp(A)');
if isempty(s)
% disp displays nothing for some kinds of empty arrays.
s = dispAsStructField(A);
else
s = postprocessDisp(s);
end

%===============================================================================
function s = dispAsStructField(A)
% Convert A to a string using structure field display.

b.A = A;
s = evalc('disp(b)');
s = postprocessStructDisp(s);

%===============================================================================
function out = postprocessDisp(in)
% Remove leading and trailing blank lines from input string. Don't include a
% newline at the end.

lines = xunit.utils.stringToCellArray(in);

% Remove leading blank lines.
lines = removeLeadingBlankLines(lines);

% Remove trailing blank lines.
while ~isempty(lines) && isBlankLine(lines{end})
lines(end) = [];
end

% Convert cell of strings to single string with newlines. Don't put a newline
% at the end.
out = sprintf('%s\n', lines{1:end-1});
out = [out, lines{end}];

%===============================================================================
function out = postprocessStructDisp(in)
% Return the portion of the display string to the right of the colon in the
% output of the first structure field. Input is a string.

lines = xunit.utils.stringToCellArray(in);

% Remove leading blank lines
lines = removeLeadingBlankLines(lines);

line = lines{1};
idx = find(line == ':');
out = line((idx+2):end); % struct fields display with blank space following colon

%===============================================================================
function out = removeLeadingBlankLines(in)
% Input and output are cell arrays of strings.

out = in;
while ~isempty(out) && isBlankLine(out{1})
out(1) = [];
end

%===============================================================================
function tf = isBlankLine(line)
% Input is a string.

tf = all(isspace(line));


128 changes: 128 additions & 0 deletions util/xunit/+xunit/+utils/compareFloats.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
function result = compareFloats(varargin)
%compareFloats Compare floating-point arrays using tolerance.
% result = compareFloats(A, B, compare_type, tol_type, tol, floor_tol)
% compares the floating-point arrays A and B using a tolerance. compare_type
% is either 'elementwise' or 'vector'. tol_type is either 'relative' or
% 'absolute'. tol and floor_tol are the scalar tolerance values.
%
% There are four different tolerance tests used, depending on the comparison
% type and the tolerance type:
%
% 1. Comparison type: 'elementwise' Tolerance type: 'relative'
%
% all( abs(A(:) - B(:)) <= tol * max(abs(A(:)), abs(B(:))) + floor_tol )
%
% 2. Comparison type: 'elementwise' Tolerance type: 'absolute'
%
% all( abs(A(:) - B(:) <= tol )
%
% 3. Comparison type: 'vector' Tolerance type: 'relative'
%
% norm(A(:) - B(:) <= tol * max(norm(A(:)), norm(B(:))) + floor_tol
%
% 4. Comparison type: 'vector' Tolerance type: 'absolute'
%
% norm(A(:) - B(:)) <= tol
%
% Note that floor_tol is not used when the tolerance type is 'absolute'.
%
% compare_type, tol_type, tol, and floor_tol are all optional inputs. The
% default value for compare_type is 'elementwise'. The default value for
% tol_type is 'relative'. If both A and B are double, then the default value
% for tol is sqrt(eps), and the default value for floor_tol is eps. If either
% A or B is single, then the default value for tol is sqrt(eps('single')), and
% the default value for floor_tol is eps('single').
%
% If A or B is complex, then the tolerance test is applied independently to
% the real and imaginary parts.
%
% For elementwise comparisons, compareFloats returns true for two elements
% that are both NaN, or for two infinite elements that have the same sign.
% For vector comparisons, compareFloats returns false if any input elements
% are infinite or NaN.

% Steven L. Eddins
% Copyright 2008-2009 The MathWorks, Inc.

if nargin >= 3
% compare_type specified. Grab it and then use parseFloatAssertInputs to
% process the remaining input arguments.
compare_type = varargin{3};
varargin(3) = [];
if isempty(strcmp(compare_type, {'elementwise', 'vector'}))
error('compareFloats:unrecognizedCompareType', ...
'COMPARE_TYPE must be ''elementwise'' or ''vector''.');
end
else
compare_type = 'elementwise';
end

params = xunit.utils.parseFloatAssertInputs(varargin{:});

A = params.A(:);
B = params.B(:);

switch compare_type
case 'elementwise'
magFcn = @abs;

case 'vector'
magFcn = @norm;

otherwise
error('compareFloats:unrecognizedCompareType', ...
'COMPARE_TYPE must be ''elementwise'' or ''vector''.');
end

switch params.ToleranceType
case 'relative'
coreCompareFcn = @(A, B) magFcn(A - B) <= ...
params.Tolerance * max(magFcn(A), magFcn(B)) + ...
params.FloorTolerance;

case 'absolute'
coreCompareFcn = @(A, B) magFcn(A - B) <= params.Tolerance;

otherwise
error('compareFloats:unrecognizedToleranceType', ...
'TOL_TYPE must be ''relative'' or ''absolute''.');
end

if strcmp(compare_type, 'elementwise')
compareFcn = @(A, B) ( coreCompareFcn(A, B) | bothNaN(A, B) | sameSignInfs(A, B) ) & ...
~oppositeSignInfs(A, B) & ...
~finiteAndInfinite(A, B);
else
compareFcn = @(A, B) coreCompareFcn(A, B) & ...
isfinite(magFcn(A)) & ...
isfinite(magFcn(B));
end

if isreal(A) && isreal(B)
result = compareFcn(A, B);
else
result = compareFcn(real(A), real(B)) & compareFcn(imag(A), imag(B));
end

result = all(result);

%===============================================================================
function out = bothNaN(A, B)

out = isnan(A) & isnan(B);

%===============================================================================
function out = oppositeSignInfs(A, B)

out = isinf(A) & isinf(B) & (sign(A) ~= sign(B));

%===============================================================================
function out = sameSignInfs(A, B)

out = isinf(A) & isinf(B) & (sign(A) == sign(B));

%===============================================================================
function out = finiteAndInfinite(A, B)

out = xor(isinf(A), isinf(B));

33 changes: 33 additions & 0 deletions util/xunit/+xunit/+utils/comparisonMessage.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
function msg = comparisonMessage(user_message, assertion_message, A, B)
%comparisonMessage Generate assertion message when comparing two arrays.
% msg = comparisonMessage(user_message, assertion_message, A, B) returns a
% string appropriate to use in a call to throw inside an assertion function
% that compares two arrays A and B.
%
% The string returned has the following form:
%
% <user_message>
% <assertion_message>
%
% First input:
% <string representation of value of A>
%
% Second input:
% <string representation of value of B>
%
% user_message can be the empty string, '', in which case user_message is
% skipped.

% Steven L. Eddins
% Copyright 2009 The MathWorks, Inc.

msg = sprintf('%s\n\n%s\n%s\n\n%s\n%s', ...
assertion_message, ...
'First input:', ...
xunit.utils.arrayToString(A), ...
'Second input:', ...
xunit.utils.arrayToString(B));

if ~isempty(user_message)
msg = sprintf('%s\n%s', user_message, msg);
end
17 changes: 17 additions & 0 deletions util/xunit/+xunit/+utils/containsRegexp.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
function tf = containsRegexp(str, exp)
%containsRegexp True if string contains regular expression
% TF = containsRegexp(str, exp) returns true if the string str contains the
% regular expression exp. If str is a cell array of strings, then
% containsRegexp tests each string in the cell array, returning the results in
% a logical array with the same size as str.

% Steven L. Eddins
% Copyright 2008-2009 The MathWorks, Inc.

% Convert to canonical input form: A cell array of strings.
if ~iscell(str)
str = {str};
end

matches = regexp(str, exp);
tf = ~cellfun('isempty', matches);
14 changes: 14 additions & 0 deletions util/xunit/+xunit/+utils/generateDoc.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function generateDoc
%generateDoc Publish the example scripts in the doc directory

% Steven L. Eddins
% Copyright 2008-2009 The MathWorks, Inc.

doc_dir = fullfile(fileparts(which('runtests')), '..', 'doc');
addpath(doc_dir);
cd(doc_dir)
mfiles = dir('*.m');
for k = 1:numel(mfiles)
publish(mfiles(k).name);
cd(doc_dir)
end
Loading

0 comments on commit 2de71dd

Please sign in to comment.