Skip to content

Commit

Permalink
Move softmax bench to candle-nn
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jan 20, 2025
1 parent eb6985e commit b094d09
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 83 deletions.
83 changes: 1 addition & 82 deletions candle-core/benches/benchmarks/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Storage, Tensor};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use half::{bf16, f16};
use std::ops::Deref;
use std::time::Instant;

fn run_sum(a: &Tensor) {
Expand All @@ -12,55 +11,10 @@ fn run_arg_min(a: &Tensor) {
a.argmin_keepdim(2).unwrap();
}

// NOTE: Should this be removed? Softmax impls live in candle-nn.
fn softmax(a: &Tensor) -> candle_core::Result<()> {
use candle_core::{backend::BackendStorage, DType};
let (storage, layout) = a.storage_and_layout();

let device = a.device();

if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match a.dtype() {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
};

let n = layout.stride().len();
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
}

let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
.unwrap();
}
Ok(())
}

fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
let (lo, up) = (-1000.0f32, 1000.0f32);
for device in handler.devices {
run_softmax(c, &device, (lo, up));
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));

run_reduce(c, &device, (lo, up), false);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
Expand All @@ -79,41 +33,6 @@ fn criterion_benchmark(c: &mut Criterion) {
}
}

fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
if !device.is_metal() {
return;
}

let b = 1;
let m = 1024;
let k = 1024;
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();

let flops = b * m * k * T::DTYPE.size_in_bytes();

let name = match T::DTYPE {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
_ => "softmax",
};
softmax(&a).unwrap();

let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
softmax(black_box(&a)).unwrap();
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

fn run_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
Expand Down
6 changes: 5 additions & 1 deletion candle-nn/benches/bench_main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
mod benchmarks;

use criterion::criterion_main;
criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches);
criterion_main!(
benchmarks::softmax::benches,
benchmarks::layer_norm::benches,
benchmarks::conv::benches
);
1 change: 1 addition & 0 deletions candle-nn/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub(crate) mod conv;
pub(crate) mod layer_norm;
pub(crate) mod softmax;

use candle::{Device, Result};

Expand Down
49 changes: 49 additions & 0 deletions candle-nn/benches/benchmarks/softmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle::{DType, Device, Tensor};
use candle_nn::ops::softmax_last_dim;
use criterion::Throughput;
use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;

fn run(input: &Tensor) {
let _ = softmax_last_dim(&input).unwrap();
}

const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;

fn run_softmax_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let elements = B * M * K;

let input = Tensor::rand(-1000.0f32, 1000.0f32, (B, M, K), &device)
.unwrap()
.to_dtype(dtype)
.unwrap();

let flops = elements * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&input));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_softmax_benchmark(c, &d, DType::F32, "softmax_f32");
run_softmax_benchmark(c, &d, DType::BF16, "softmax_bf16");
run_softmax_benchmark(c, &d, DType::F16, "softmax_f16");
}
}

criterion_group!(benches, criterion_benchmark);

0 comments on commit b094d09

Please sign in to comment.