Skip to content

Commit

Permalink
Remove unsafe from OnNewBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 27, 2024
1 parent 53b1cc3 commit 13ef3fa
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ macro_rules! impl_buffer_hook_traits {
Self: 'dev,
{
#[inline]
unsafe fn on_new_buffer(
fn on_new_buffer(
&'dev self,
device: &'dev D,
new_buf: &mut Buffer<'dev, T, D, S>,
) {
unsafe { self.modules.on_new_buffer(device, new_buf) }
self.modules.on_new_buffer(device, new_buf)
}
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,12 @@ macro_rules! pass_down_tape_actions {
{
#[inline]
unsafe fn tape(&self) -> Option<&$crate::Tape<'dev>> {
self.modules.tape()
unsafe { self.modules.tape() }
}

#[inline]
unsafe fn tape_mut(&self) -> Option<&mut $crate::Tape<'dev>> {
self.modules.tape_mut()
unsafe { self.modules.tape_mut() }
}
}
};
Expand Down
3 changes: 1 addition & 2 deletions src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use crate::{Shape, Unit, WrappedData};
use super::{Buffer, Device};

pub trait OnNewBuffer<'dev, T: Unit, D: Device, S: Shape = ()> {
#[track_caller]
unsafe fn on_new_buffer<'s>(
fn on_new_buffer<'s>(
&'dev self,
_device: &'dev D,
_new_buf: &'s mut Buffer<'dev, T, D, S>,
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ pub trait HostPtr<T>: PtrType {

#[inline]
unsafe fn as_slice(&self) -> &[T] {
core::slice::from_raw_parts(self.ptr(), self.size())
unsafe { core::slice::from_raw_parts(self.ptr(), self.size()) }
}

#[inline]
unsafe fn as_mut_slice(&mut self) -> &mut [T] {
core::slice::from_raw_parts_mut(self.ptr_mut(), self.size())
unsafe { core::slice::from_raw_parts_mut(self.ptr_mut(), self.size()) }
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
Mods: OnNewBuffer<'dev, T, D, S> + CachedBuffers,
{
#[inline]
unsafe fn on_new_buffer(&'dev self, device: &'dev D, new_buf: &mut Buffer<'dev, T, D, S>) {
fn on_new_buffer(&'dev self, device: &'dev D, new_buf: &mut Buffer<'dev, T, D, S>) {
// let mut no_grads = self.no_grads_pool.borrow_mut();
// let wrapped_data = unsafe { new_buf.data.shallow() };

Expand Down
2 changes: 1 addition & 1 deletion src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ where
S: Shape,
{
#[inline]
unsafe fn on_new_buffer(&'a self, device: &'a D, new_buf: &mut Buffer<'a, T, D, S>) {
fn on_new_buffer(&'a self, device: &'a D, new_buf: &mut Buffer<'a, T, D, S>) {
self.modules.on_new_buffer(device, new_buf)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/modules/fork.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<'a, Mods: OnNewBuffer<'a, T, D, S>, T: Unit, D: Device, S: Shape> OnNewBuff
for Fork<Mods>
{
#[inline]
unsafe fn on_new_buffer(&'a self, device: &'a D, new_buf: &mut crate::Buffer<'a, T, D, S>) {
fn on_new_buffer(&'a self, device: &'a D, new_buf: &mut crate::Buffer<'a, T, D, S>) {
self.modules.on_new_buffer(device, new_buf)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ where
S: Shape,
{
#[inline]
unsafe fn on_new_buffer<'s>(&'a self, device: &'a D, new_buf: &'s mut Buffer<'a, T, D, S>) {
fn on_new_buffer<'s>(&'a self, device: &'a D, new_buf: &'s mut Buffer<'a, T, D, S>) {
unsafe { register_buf_copyable(&mut self.buffers.borrow_mut(), new_buf) };
self.modules.on_new_buffer(device, new_buf)
}
Expand Down
2 changes: 1 addition & 1 deletion src/parents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ macro_rules! impl_parents {
unsafe fn replication<'a>(self) -> Self::Replicated<'a> {
#[allow(non_snake_case)]
let ($($to_impl,)+) = self;
($($to_impl.replicate(),)+)
unsafe { ($($to_impl.replicate(),)+) }
}
}
};
Expand Down
4 changes: 2 additions & 2 deletions tests/demo_impl/cuda/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use custos::{cuda::launch_kernel, prelude::Number, Buffer, CDatatype, OnDropBuffer, Shape, CUDA};
use custos::{cuda::launch_kernel, prelude::Number, Buffer, CDatatype, Shape, WrappedData, CUDA};

pub fn cu_element_wise<Mods: OnDropBuffer, T: Number, S: Shape>(
pub fn cu_element_wise<Mods: WrappedData, T: Number, S: Shape>(
device: &CUDA<Mods>,
lhs: &Buffer<T, CUDA<Mods>, S>,
rhs: &Buffer<T, CUDA<Mods>, S>,
Expand Down

0 comments on commit 13ef3fa

Please sign in to comment.