Skip to content

Commit

Permalink
refactor/rcublas: deduplicate test code, add setup+teardown
Browse files Browse the repository at this point in the history
The setup and teardown methods should become a procmacro
which avoids all the chore changes.
  • Loading branch information
drahnr committed Apr 9, 2020
1 parent 360bf19 commit e7615c8
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 60 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rcublas/cublas/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ coaster = { path = "../../coaster", default-features = false, features = [
"native"
], version = "0.1" }
serial_test = "0.4"
env_logger = "0.7"

[features]
dev = []
Expand Down
13 changes: 13 additions & 0 deletions rcublas/cublas/src/api/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,35 @@ impl Context {
mod test {
use super::*;
use super::super::PointerMode;
use crate::chore::*;

#[test]
#[serial_test::serial]
fn create_context() {
test_setup();

Context::new().unwrap();

test_teardown();
}

#[test]
#[serial_test::serial]
fn default_pointer_mode_is_host() {
test_setup();

let ctx = Context::new().unwrap();
let mode = ctx.pointer_mode().unwrap();
assert_eq!(PointerMode::Host, mode);

test_teardown();
}

#[test]
#[serial_test::serial]
fn can_set_pointer_mode() {
test_setup();

let mut context = Context::new().unwrap();
// set to Device
context.set_pointer_mode(PointerMode::Device).unwrap();
Expand All @@ -196,5 +207,7 @@ mod test {
context.set_pointer_mode(PointerMode::Host).unwrap();
let mode2 = context.pointer_mode().unwrap();
assert_eq!(PointerMode::Host, mode2);

test_teardown();
}
}
61 changes: 29 additions & 32 deletions rcublas/cublas/src/api/level1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,38 +266,14 @@ mod test {
use crate::API;
use crate::api::context::Context;
use crate::api::enums::PointerMode;
use crate::co::backend::{Backend, IBackend};
use crate::co::framework::IFramework;
use crate::co::frameworks::{Cuda, Native};
use crate::co::frameworks::native::flatbox::FlatBox;
use crate::co::tensor::SharedTensor;

fn get_native_backend() -> Backend<Native> {
Backend::<Native>::default().unwrap()
}
fn get_cuda_backend() -> Backend<Cuda> {
Backend::<Cuda>::default().unwrap()
}

fn write_to_memory<T: Copy>(mem: &mut FlatBox, data: &[T]) {
let mem_buffer = mem.as_mut_slice::<T>();
for (index, datum) in data.iter().enumerate() {
mem_buffer[index] = *datum;
}
}

fn filled_tensor<B: IBackend, T: Copy>(_backend: &B, n: usize, val: T) -> SharedTensor<T> {
let mut x = SharedTensor::<T>::new(&vec![n]);
let values: &[T] = &::std::iter::repeat(val)
.take(x.capacity())
.collect::<Vec<T>>();
write_to_memory(x.write_only(get_native_backend().device()).unwrap(), values);
x
}
use crate::chore::*;

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_asum() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand All @@ -306,13 +282,9 @@ mod test {
let val = 2f32;
let x = filled_tensor(&native, n as usize, val);



// set up result
let mut result = SharedTensor::<f32>::new(&vec![1]);



{
let cuda_mem = x.read(cuda.device()).unwrap();
let cuda_mem_result = result.write_only(cuda.device()).unwrap();
Expand All @@ -327,11 +299,15 @@ mod test {

let native_res = result.read(native.device()).unwrap();
assert_eq!(&[40f32], native_res.as_slice::<f32>());

test_teardown();
}

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_axpy() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand Down Expand Up @@ -364,11 +340,15 @@ mod test {

let native_y = y.read(native.device()).unwrap();
assert_eq!(&[7f32, 7f32, 7f32, 7f32, 7f32], native_y.as_slice::<f32>());

test_teardown();
}

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_copy() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand Down Expand Up @@ -396,11 +376,15 @@ mod test {

let native_y = y.read(native.device()).unwrap();
assert_eq!(&[2f32, 2f32, 2f32, 2f32, 2f32], native_y.as_slice::<f32>());

test_teardown();
}

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_dot() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand Down Expand Up @@ -430,14 +414,17 @@ mod test {
}
}


let native_result = result.read(native.device()).unwrap();
assert_eq!(&[40f32], native_result.as_slice::<f32>());

test_teardown();
}

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_nrm2() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand Down Expand Up @@ -465,11 +452,15 @@ mod test {

let native_result = result.read(native.device()).unwrap();
assert_eq!(&[3f32], native_result.as_slice::<f32>());

test_teardown();
}

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_scal() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand All @@ -496,11 +487,15 @@ mod test {

let native_x = x.read(native.device()).unwrap();
assert_eq!(&[5f32, 5f32, 5f32], native_x.as_slice::<f32>());

test_teardown();
}

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_swap() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand Down Expand Up @@ -530,5 +525,7 @@ mod test {

let native_y = y.read(native.device()).unwrap();
assert_eq!(&[2f32, 2f32, 2f32, 2f32, 2f32], native_y.as_slice::<f32>());

test_teardown();
}
}
32 changes: 5 additions & 27 deletions rcublas/cublas/src/api/level3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,38 +98,14 @@ mod test {
use crate::API;
use crate::api::context::Context;
use crate::api::enums::PointerMode;
use crate::co::backend::{Backend, IBackend};
use crate::co::framework::IFramework;
use crate::co::frameworks::{Cuda, Native};
use crate::co::frameworks::native::flatbox::FlatBox;
use crate::co::tensor::SharedTensor;

fn get_native_backend() -> Backend<Native> {
Backend::<Native>::default().unwrap()
}
fn get_cuda_backend() -> Backend<Cuda> {
Backend::<Cuda>::default().unwrap()
}

fn write_to_memory<T: Copy>(mem: &mut FlatBox, data: &[T]) {
let mem_buffer = mem.as_mut_slice::<T>();
for (index, datum) in data.iter().enumerate() {
mem_buffer[index] = *datum;
}
}

fn filled_tensor<B: IBackend, T: Copy>(_backend: &B, n: usize, val: T) -> SharedTensor<T> {
let mut x = SharedTensor::<T>::new(&vec![n]);
let values: &[T] = &::std::iter::repeat(val)
.take(x.capacity())
.collect::<Vec<T>>();
write_to_memory(x.write_only(get_native_backend().device()).unwrap(), values);
x
}
use crate::chore::*;

#[test]
#[serial_test::serial]
fn use_cuda_memory_for_gemm() {
test_setup();

let native = get_native_backend();
let cuda = get_cuda_backend();

Expand Down Expand Up @@ -202,5 +178,7 @@ mod test {
&[28f32, 7f32, 7f32, 28f32, 7f32, 7f32, 28f32, 7f32, 7f32],
native_c.as_slice::<f32>()
);

test_teardown();
}
}
10 changes: 9 additions & 1 deletion rcublas/cublas/src/api/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{API, Error};
use super::Context;
use super::PointerMode;
use lazy_static::lazy_static;
use log::{debug,warn,info};
use log::debug;
use std::collections::HashSet;
use std::convert::AsRef;
use std::convert::TryFrom;
Expand Down Expand Up @@ -182,6 +182,8 @@ mod test {
#[test]
#[serial_test::serial]
fn manual_context_creation() {
crate::chore::test_setup();

unsafe {
let handle = API::ffi_create().unwrap();
API::ffi_destroy(handle).unwrap();
Expand All @@ -191,16 +193,21 @@ mod test {
#[test]
#[serial_test::serial]
fn default_pointer_mode_is_host() {
crate::chore::test_setup();

unsafe {
let context = Context::new().unwrap();
let mode = API::ffi_get_pointer_mode(*context.id_c()).unwrap();
assert_eq!(cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST, mode);
}
crate::chore::test_teardown();
}

#[test]
#[serial_test::serial]
fn can_set_pointer_mode() {
crate::chore::test_setup();

unsafe {
let context = Context::new().unwrap();
API::ffi_set_pointer_mode(
Expand All @@ -216,5 +223,6 @@ mod test {
let mode2 = API::ffi_get_pointer_mode(*context.id_c()).unwrap();
assert_eq!(cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST, mode2);
}
crate::chore::test_teardown();
}
}
38 changes: 38 additions & 0 deletions rcublas/cublas/src/chore.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use crate::co::backend::{Backend, IBackend};
use crate::co::frameworks::{Cuda, Native};
use crate::co::frameworks::native::flatbox::FlatBox;
use crate::co::tensor::SharedTensor;
use env_logger;

pub fn test_setup() {

let _ = env_logger::builder().is_test(true).try_init();
}

pub fn test_teardown() {

}

pub fn get_native_backend() -> Backend<Native> {
Backend::<Native>::default().unwrap()
}

pub fn get_cuda_backend() -> Backend<Cuda> {
Backend::<Cuda>::default().unwrap()
}

pub fn write_to_memory<T: Copy>(mem: &mut FlatBox, data: &[T]) {
let mem_buffer = mem.as_mut_slice::<T>();
for (index, datum) in data.iter().enumerate() {
mem_buffer[index] = *datum;
}
}

pub fn filled_tensor<B: IBackend, T: Copy>(_backend: &B, n: usize, val: T) -> SharedTensor<T> {
let mut x = SharedTensor::<T>::new(&vec![n]);
let values: &[T] = &::std::iter::repeat(val)
.take(x.capacity())
.collect::<Vec<T>>();
write_to_memory(x.write_only(get_native_backend().device()).unwrap(), values);
x
}
3 changes: 3 additions & 0 deletions rcublas/cublas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ pub struct API;

pub mod api;
pub mod error;

#[cfg(test)]
pub(crate) mod chore;

0 comments on commit e7615c8

Please sign in to comment.