Skip to content

Commit

Permalink
feat(training): simple trainer callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Aug 3, 2024
1 parent 227bc85 commit 733b7fa
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 242 deletions.
25 changes: 24 additions & 1 deletion examples/training/examples/train-clm-simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,36 @@ use std::{
path::Path
};

use kdam::BarExt;
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments};
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainerCallbacks, TrainingArguments};
use rand::RngCore;
use tokenizers::Tokenizer;

const BATCH_SIZE: usize = 16;
const SEQUENCE_LENGTH: usize = 256;

struct LoggerCallback {
progress_bar: kdam::Bar
}

impl LoggerCallback {
pub fn new() -> Self {
Self {
progress_bar: kdam::Bar::builder().leave(true).build().unwrap()
}
}
}

impl TrainerCallbacks for LoggerCallback {
fn train_step(&mut self, train_loss: f32, state: &ort::TrainerState, _: &mut ort::TrainerControl<'_>) -> ort::Result<()> {
self.progress_bar.total = state.max_steps;
self.progress_bar.set_postfix(format!("loss={train_loss:.3}"));
let _ = self.progress_bar.update_to(state.iter_step);
Ok(())
}
}

fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();

Expand Down Expand Up @@ -78,6 +100,7 @@ fn main() -> ort::Result<()> {
.with_lr(7e-5)
.with_max_steps(5000)
.with_ckpt_strategy(CheckpointStrategy::Steps(500))
.with_callbacks(LoggerCallback::new())
)?;

trainer.export("trained-clm.onnx", ["probs"])?;
Expand Down
11 changes: 11 additions & 0 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ unsafe impl Sync for EnvironmentSingleton {}

static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) };

/// An `Environment` is a process-global structure, under which [`Session`](crate::Session)s are created.
///
/// Environments can be used to [configure global thread pools](EnvironmentBuilder::with_global_thread_pool), in
/// which all sessions share threads from the environment's pool, and configuring [default execution
/// providers](EnvironmentBuilder::with_execution_providers) for all sessions. In the context of `ort` specifically,
/// environments are also used to configure ONNX Runtime to send log messages through the [`tracing`] crate in Rust.
///
/// For ease of use, and since sessions require an environment to be created, `ort` will automatically create an
/// environment if one is not configured via [`init`] (or [`init_from`]). [`init`] can be called at any point in the
/// program (even after an environment has been automatically created), though every session created before the
/// re-configuration would need to be re-created in order to use the config from the new environment.
#[derive(Debug)]
pub struct Environment {
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
Expand Down
5 changes: 4 additions & 1 deletion src/training/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ mod simple;
mod trainer;

pub use self::{
simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments},
simple::{
iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainerCallbacks, TrainerControl, TrainerState,
TrainingArguments
},
trainer::Trainer
};

Expand Down
240 changes: 0 additions & 240 deletions src/training/simple.rs

This file was deleted.

Loading

0 comments on commit 733b7fa

Please sign in to comment.