Skip to content

Commit

Permalink
feat: LoRA adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 4, 2024
1 parent 26f46ce commit d877fb3
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/adapter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::{
path::Path,
ptr::{self, NonNull},
sync::Arc
};

use crate::{Allocator, Result, ortsys, util};

#[derive(Debug)]
pub(crate) struct AdapterInner {
pub(crate) ptr: NonNull<ort_sys::OrtLoraAdapter>
}

impl Drop for AdapterInner {
fn drop(&mut self) {
ortsys![unsafe ReleaseLoraAdapter(self.ptr.as_ptr())];
}
}

#[derive(Debug, Clone)]
pub struct Adapter {
pub(crate) inner: Arc<AdapterInner>
}

impl Adapter {
pub fn from_file(path: impl AsRef<Path>, allocator: Option<&Allocator>) -> Result<Self> {
let path = util::path_to_os_char(path);
let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut);
let mut ptr = ptr::null_mut();
ortsys![unsafe CreateLoraAdapter(path.as_ptr(), allocator_ptr, &mut ptr)?];
Ok(Adapter {
inner: Arc::new(AdapterInner {
ptr: unsafe { NonNull::new_unchecked(ptr) }
})
})
}

pub fn from_memory(bytes: &[u8], allocator: Option<&Allocator>) -> Result<Self> {
let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut);
let mut ptr = ptr::null_mut();
ortsys![unsafe CreateLoraAdapterFromArray(bytes.as_ptr().cast(), bytes.len(), allocator_ptr, &mut ptr)?];
Ok(Adapter {
inner: Arc::new(AdapterInner {
ptr: unsafe { NonNull::new_unchecked(ptr) }
})
})
}

pub fn ptr(&self) -> *mut ort_sys::OrtLoraAdapter {
self.inner.ptr.as_ptr()
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#[cfg(all(test, not(feature = "fetch-models")))]
compile_error!("`cargo test --features fetch-models`!!1!");

pub(crate) mod adapter;
pub(crate) mod environment;
pub(crate) mod error;
pub(crate) mod execution_providers;
Expand Down Expand Up @@ -48,6 +49,7 @@ pub use self::tensor::ArrayExtensions;
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
pub use self::training::*;
pub use self::{
adapter::Adapter,
environment::{Environment, EnvironmentBuilder, EnvironmentGlobalThreadPoolOptions, get_environment, init},
error::{Error, ErrorCode, Result},
execution_providers::*,
Expand Down
9 changes: 9 additions & 0 deletions src/session/run_options.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc};

use crate::{
adapter::{Adapter, AdapterInner},
error::Result,
ortsys,
session::Output,
Expand Down Expand Up @@ -157,6 +158,7 @@ impl SelectedOutputMarker for HasSelectedOutputs {}
pub struct RunOptions<O: SelectedOutputMarker = NoSelectedOutputs> {
pub(crate) run_options_ptr: NonNull<ort_sys::OrtRunOptions>,
pub(crate) outputs: OutputSelector,
adapters: Vec<Arc<AdapterInner>>,
_marker: PhantomData<O>
}

Expand All @@ -175,6 +177,7 @@ impl RunOptions {
Ok(RunOptions {
run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) },
outputs: OutputSelector::default(),
adapters: Vec::new(),
_marker: PhantomData
})
}
Expand Down Expand Up @@ -303,6 +306,12 @@ impl<O: SelectedOutputMarker> RunOptions<O> {
ortsys![unsafe AddRunConfigEntry(self.run_options_ptr.as_ptr(), key.as_ptr(), value.as_ptr())?];
Ok(())
}

pub fn add_adapter(&mut self, adapter: &Adapter) -> Result<()> {
ortsys![unsafe RunOptionsAddActiveLoraAdapter(self.run_options_ptr.as_ptr(), adapter.ptr())?];
self.adapters.push(Arc::clone(&adapter.inner));
Ok(())
}
}

impl<O: SelectedOutputMarker> Drop for RunOptions<O> {
Expand Down

0 comments on commit d877fb3

Please sign in to comment.