Skip to content

Commit

Permalink
reduce tensor sizes for 'slow' tests
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Dec 5, 2023
1 parent c6e0642 commit 9deca1b
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 187 deletions.
59 changes: 31 additions & 28 deletions dfdx/src/nn/layers/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,47 +174,50 @@ mod tests {
fn test_grouped_forward_sizes() {
let dev: TestDevice = Default::default();

let x = dev.ones::<Rank2<16, 10>>();
let x = dev.ones::<Rank2<2, 4>>();

let m = dev.build_module::<TestDtype>(<Conv1DConstConfig<16, 32, 3, 1, 0, 1>>::default());
let _: Tensor<Rank3<32, 16, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
let m = dev.build_module::<TestDtype>(<Conv1DConstConfig<2, 4, 3, 1, 0, 1>>::default());
let _: Tensor<Rank3<4, 2, 3>, _, _> = m.weight;
let _: Tensor<Rank2<4, 2>, _, _> = m.forward(x.clone());

let m =
dev.build_module::<TestDtype>(<Conv1DConstConfig<16, 32, 3, 1, 0, 1, 2>>::default());
let _: Tensor<Rank3<32, 8, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
let m = dev.build_module::<TestDtype>(<Conv1DConstConfig<2, 4, 3, 1, 0, 1, 2>>::default());
let _: Tensor<Rank3<4, 1, 3>, _, _> = m.weight;
let _: Tensor<Rank2<4, 2>, _, _> = m.forward(x.clone());

let m =
dev.build_module::<TestDtype>(<Conv1DConstConfig<16, 32, 3, 1, 0, 1, 4>>::default());
let _: Tensor<Rank3<32, 4, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
let x = dev.ones::<Rank2<4, 4>>();

let m =
dev.build_module::<TestDtype>(<Conv1DConstConfig<16, 32, 3, 1, 0, 1, 8>>::default());
let _: Tensor<Rank3<32, 2, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
let m = dev.build_module::<TestDtype>(<Conv1DConstConfig<4, 8, 3, 1, 0, 1, 4>>::default());
let _: Tensor<Rank3<8, 1, 3>, _, _> = m.weight;
let _: Tensor<Rank2<8, 2>, _, _> = m.forward(x.clone());

let x = dev.ones::<Rank2<8, 3>>();

let m = dev.build_module::<TestDtype>(<Conv1DConstConfig<8, 16, 3, 1, 0, 1, 8>>::default());
let _: Tensor<Rank3<16, 1, 3>, _, _> = m.weight;
let _: Tensor<Rank2<16, 1>, _, _> = m.forward(x.clone());

let x = dev.ones::<Rank2<16, 3>>();

let m =
dev.build_module::<TestDtype>(<Conv1DConstConfig<16, 32, 3, 1, 0, 1, 16>>::default());
let _: Tensor<Rank3<32, 1, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x);
let _: Tensor<Rank2<32, 1>, _, _> = m.forward(x);
}

#[rustfmt::skip]
#[test]
fn test_forward_4d_sizes() {
let dev: TestDevice = Default::default();
let x = dev.zeros::<Rank3<5, 3, 10>>();
let _: Tensor<Rank3<5, 2, 8>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 4, 8>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 4, 3>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 4, 9>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 4, 2>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 4, 7>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 4, 4>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 2, 4>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 2>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 2, 3>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 3>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 2, 10>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 1, 1>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 2, 12>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 1, 2>>::default()).forward(x.clone());
let _: Tensor<Rank3<5, 2, 6>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 2, 2>>::default()).forward(x.clone());
let x = dev.zeros::<Rank3<2, 3, 4>>();
let _: Tensor<Rank3<2, 2, 2>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 4, 2>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 4, 3>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 4, 3>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 4, 2>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 4, 1>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 4, 4>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 2, 1>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 2>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 2, 1>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 3>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 2, 4>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 1, 1>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 2, 6>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 1, 2>>::default()).forward(x.clone());
let _: Tensor<Rank3<2, 2, 3>, _, _, _> = dev.build_module::<TestDtype>(<Conv1DConstConfig<3, 2, 3, 2, 2>>::default()).forward(x.clone());
}

#[test]
Expand Down Expand Up @@ -248,7 +251,7 @@ mod tests {
let weight_init = m.weight.clone();

let mut opt = crate::nn::optim::Sgd::new(&m, Default::default());
let out = m.forward(dev.sample_normal::<Rank3<8, 2, 28>>().leaky_trace());
let out = m.forward(dev.sample_normal::<Rank3<4, 2, 4>>().leaky_trace());
let g = out.square().mean().backward();

assert_ne!(g.get(&m.weight).array(), [[[TestDtype::zero(); 3]; 2]; 4]);
Expand Down
67 changes: 36 additions & 31 deletions dfdx/src/nn/layers/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,48 +197,53 @@ mod tests {
fn test_grouped_forward_sizes() {
let dev: TestDevice = Default::default();

let x = dev.zeros::<Rank3<16, 10, 10>>();
let x = dev.zeros::<Rank3<2, 4, 3>>();

let m =
dev.build_module::<TestDtype>(<Conv2DConstConfig<16, 32, 3, 1, 0, 1, 1>>::default());
let _: Tensor<Rank4<32, 16, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());
let m = dev.build_module::<TestDtype>(<Conv2DConstConfig<2, 4, 3, 1, 0, 1, 1>>::default());
let _: Tensor<Rank4<4, 2, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<4, 2, 1>, _, _> = m.forward(x.clone());

let m =
dev.build_module::<TestDtype>(<Conv2DConstConfig<16, 32, 3, 1, 0, 1, 2>>::default());
let _: Tensor<Rank4<32, 8, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());
let m = dev.build_module::<TestDtype>(<Conv2DConstConfig<2, 4, 3, 1, 0, 1, 2>>::default());
let _: Tensor<Rank4<4, 1, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<4, 2, 1>, _, _> = m.forward(x.clone());

let m =
dev.build_module::<TestDtype>(<Conv2DConstConfig<16, 32, 3, 1, 0, 1, 4>>::default());
let _: Tensor<Rank4<32, 4, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());
let x = dev.zeros::<Rank3<4, 4, 3>>();

let m =
dev.build_module::<TestDtype>(<Conv2DConstConfig<16, 32, 3, 1, 0, 1, 8>>::default());
let _: Tensor<Rank4<32, 2, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());
let m = dev.build_module::<TestDtype>(<Conv2DConstConfig<4, 8, 3, 1, 0, 1, 4>>::default());
let _: Tensor<Rank4<8, 1, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<8, 2, 1>, _, _> = m.forward(x.clone());

let x = dev.zeros::<Rank3<8, 4, 3>>();

let m = dev.build_module::<TestDtype>(<Conv2DConstConfig<8, 16, 3, 1, 0, 1, 8>>::default());
let _: Tensor<Rank4<16, 1, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<16, 2, 1>, _, _> = m.forward(x.clone());

let x = dev.zeros::<Rank3<16, 4, 3>>();

let m =
dev.build_module::<TestDtype>(<Conv2DConstConfig<16, 32, 3, 1, 0, 1, 16>>::default());
let _: Tensor<Rank4<32, 1, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x);
let _: Tensor<Rank3<32, 2, 1>, _, _> = m.forward(x);
}

#[rustfmt::skip]
#[test]
fn test_forward_4d_sizes() {
let dev: TestDevice = Default::default();
let x = dev.zeros::<Rank4<5, 3, 10, 10>>();
let _: Tensor<Rank4<5, 2, 8, 8>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 4, 8, 8>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 4, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 4, 9, 9>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 4, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 4, 7, 7>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 4, 4>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 4, 4>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 3, 3>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 10, 10>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 1, 1>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 12, 12>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 1, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 6, 6>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 2, 2>>::default()).forward(x.clone());
let x = dev.zeros::<Rank4<2, 3, 4, 4>>();
let _: Tensor<Rank4<2, 2, 2, 2>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 4, 2, 2>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 4, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 4, 3, 3>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 4, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 4, 1, 1>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 4, 4>>::default()).forward(x.clone());
let x = dev.zeros::<Rank4<2, 3, 8, 8>>();
let _: Tensor<Rank4<2, 2, 3, 3>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 2, 2, 2>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 3>>::default()).forward(x.clone());
let x = dev.zeros::<Rank4<2, 3, 4, 4>>();
let _: Tensor<Rank4<2, 2, 4, 4>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 1, 1>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 2, 6, 6>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 1, 2>>::default()).forward(x.clone());
let x = dev.zeros::<Rank4<2, 3, 8, 8>>();
let _: Tensor<Rank4<2, 2, 5, 5>, _, _, _> = dev.build_module::<TestDtype>(<Conv2DConstConfig<3, 2, 3, 2, 2>>::default()).forward(x.clone());
}

#[test]
Expand Down Expand Up @@ -267,17 +272,17 @@ mod tests {
fn test_conv_with_optimizer() {
let dev: TestDevice = Default::default();

let mut m = dev.build_module::<TestDtype>(Conv2DConstConfig::<2, 4, 3>::default());
let mut m = dev.build_module::<TestDtype>(Conv2DConstConfig::<2, 3, 2>::default());

let weight_init = m.weight.clone();

let mut opt = crate::nn::optim::Sgd::new(&m, Default::default());
let out = m.forward(dev.sample_normal::<Rank4<8, 2, 28, 28>>().leaky_trace());
let out = m.forward(dev.sample_normal::<Rank4<2, 2, 4, 4>>().leaky_trace());
let g = out.square().mean().backward();

assert_ne!(
g.get(&m.weight).array(),
[[[[TestDtype::zero(); 3]; 3]; 2]; 4]
[[[[TestDtype::zero(); 2]; 2]; 2]; 3]
);

opt.update(&mut m, &g).expect("unused params");
Expand Down
30 changes: 19 additions & 11 deletions dfdx/src/nn/layers/conv_trans2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,24 @@ mod tests {
#[test]
fn test_forward_4d_sizes() {
let dev: TestDevice = Default::default();
let x = dev.zeros::<Rank4<5, 3, 8, 8>>();
let _: Tensor<Rank4<5, 2, 10, 10>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 4, 10, 10>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 4, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 4, 9, 9>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 4, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 4, 11, 11>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 4, 4>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 17, 17>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 24, 24>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 8, 8>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 1, 1>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 6, 6>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 1, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<5, 2, 13, 13>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 2, 2>>::default()).forward(x.clone());

let x = dev.zeros::<Rank4<2, 3, 3, 3>>();

let _: Tensor<Rank4<2, 2, 5, 5>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 4, 5, 5>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 4, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 4, 4, 4>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 4, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 4, 6, 6>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 4, 4>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 2, 7, 7>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 2>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 2, 9, 9>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 3>>::default()).forward(x.clone());
let _: Tensor<Rank4<2, 2, 3, 3>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 1, 1>>::default()).forward(x.clone());

let x = dev.zeros::<Rank4<2, 3, 5, 5>>();

let _: Tensor<Rank4<2, 2, 3, 3>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 1, 2>>::default()).forward(x.clone());

let x = dev.zeros::<Rank4<2, 3, 3, 3>>();

let _: Tensor<Rank4<2, 2, 3, 3>, _, _, _> = dev.build_module::<TestDtype>(<ConvTrans2DConstConfig<3, 2, 3, 2, 2>>::default()).forward(x.clone());
}

#[test]
Expand Down Expand Up @@ -225,7 +233,7 @@ mod tests {
let weight_init = m.weight.clone();

let mut opt = crate::nn::optim::Sgd::new(&m, Default::default());
let out = m.forward(dev.sample_normal::<Rank4<8, 2, 28, 28>>().leaky_trace());
let out = m.forward(dev.sample_normal::<Rank4<2, 2, 4, 4>>().leaky_trace());
let g = out.square().mean().backward();

assert_ne!(
Expand Down
Loading

0 comments on commit 9deca1b

Please sign in to comment.