Skip to content

Commit

Permalink
feat: tests for adapter API
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 4, 2024
1 parent d877fb3 commit 5b2983e
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,58 @@ impl Adapter {
self.inner.ptr.as_ptr()
}
}

#[cfg(test)]
mod tests {
use std::fs;

use super::Adapter;
use crate::{RunOptions, Session, Tensor};

#[test]
fn test_lora() -> crate::Result<()> {
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
let lora = Adapter::from_file("tests/data/adapter.orl", None)?;

let mut run_options = RunOptions::new()?;
run_options.add_adapter(&lora)?;

let output: Tensor<f32> = model
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?
.remove("output")
.expect("")
.downcast()?;
let (_, output) = output.extract_raw_tensor();
assert_eq!(output[0], 154.0);
assert_eq!(output[1], 176.0);
assert_eq!(output[2], 198.0);
assert_eq!(output[3], 220.0);

Ok(())
}

#[test]
fn test_lora_from_memory() -> crate::Result<()> {
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;

let lora_bytes = fs::read("tests/data/adapter.orl").expect("");
let lora = Adapter::from_memory(&lora_bytes, None)?;
drop(lora_bytes);

let mut run_options = RunOptions::new()?;
run_options.add_adapter(&lora)?;

let output: Tensor<f32> = model
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?
.remove("output")
.expect("")
.downcast()?;
let (_, output) = output.extract_raw_tensor();
assert_eq!(output[0], 154.0);
assert_eq!(output[1], 176.0);
assert_eq!(output[2], 198.0);
assert_eq!(output[3], 220.0);

Ok(())
}
}
Binary file added tests/data/adapter.orl
Binary file not shown.
Binary file added tests/data/lora_model.onnx
Binary file not shown.
84 changes: 84 additions & 0 deletions tools/test-data/generate-models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import inspect

import numpy as np
import onnx
import onnx.helper as G
import onnxruntime as ort

def make_tensor_from_np(name: str, arr: np.ndarray) -> onnx.TensorProto:
match arr.dtype:
case np.float32:
dtype = onnx.TensorProto.FLOAT
shape = list(arr.shape)
return G.make_tensor(name, dtype, shape, arr)

factories = []
def model_factory(func):
model_name = f'{func.__name__}.onnx'
def wrapper():
try:
model = func()
except Exception as e:
print(f'Failed to create `{model_name}`: {e}')
return
if isinstance(model, onnx.GraphProto):
model = G.make_model(model, opset_imports=[onnx.OperatorSetIdProto(domain=None, version=21)])
try:
onnx.checker.check_model(model)
except Exception as e:
print(f'`{model_name}` is invalid: {e}')
onnx.save_model(model, f'tests/data/{model_name}')
factories.append(wrapper)
return wrapper
def misc_factory(func):
def wrapper():
try:
func()
except Exception as e:
pass
factories.append(wrapper)
return wrapper

class Models:
@model_factory
def lora_model():
input_x = G.make_tensor_value_info('input', onnx.TensorProto.FLOAT, [4, 4])

lora_param_a_input = G.make_tensor_value_info('lora_param_a', onnx.TensorProto.FLOAT, [4, 'dim'])
lora_param_b_input = G.make_tensor_value_info('lora_param_b', onnx.TensorProto.FLOAT, ['dim', 4])

output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [4, 4])

weight_x = make_tensor_from_np('weight_x', np.array(range(1, 17)).reshape(4, 4).astype(np.float32))

lora_param_a = make_tensor_from_np('lora_param_a', np.zeros([4, 0], dtype=np.float32))
lora_param_b = make_tensor_from_np('lora_param_b', np.zeros([0, 4], dtype=np.float32))

matmul_x = G.make_node('MatMul', ['input', 'weight_x'], ['mm_output_x'])
matmul_a = G.make_node('MatMul', ['input', 'lora_param_a'], ['mm_output_a'])
matmul_b = G.make_node('MatMul', ['mm_output_a', 'lora_param_b'], ['mm_output_b'])
add_node = G.make_node('Add', ['mm_output_x', 'mm_output_b'], ['output'])

return G.make_graph(
nodes=[matmul_x, matmul_a, matmul_b, add_node],
inputs=[input_x, lora_param_a_input, lora_param_b_input],
outputs=[output],
initializer=[weight_x, lora_param_a, lora_param_b],
name='lora_test'
)

@misc_factory
def lora_adapter():
param_a = ort.OrtValue.ortvalue_from_numpy(np.array([[3], [4], [5], [6]], dtype=np.float32))
param_b = ort.OrtValue.ortvalue_from_numpy(np.array([[7, 8, 9, 10]], dtype=np.float32))

adapter = ort.AdapterFormat()
adapter.set_parameters({
'lora_param_a': param_a,
'lora_param_b': param_b
})
adapter.export_adapter('tests/data/adapter.orl')

if __name__ == '__main__':
for model_factory in factories:
model_factory()

0 comments on commit 5b2983e

Please sign in to comment.