From 82198f5afed9ac66918546e76a4b68e434d10356 Mon Sep 17 00:00:00 2001
From: Yjn024 <jiening.yu@outlook.com>
Date: Tue, 23 Jan 2024 18:49:56 +0800
Subject: [PATCH] Safe Event implementation (#26)

* rewrite Event

* fix errors

* send + sync in tests
---
 core/src/component.rs   | 102 ++++++++++--------
 core/src/item/event.rs  |  21 ++--
 util/event/src/lib.rs   | 228 ++++++++++++++++++++++++----------------
 util/event/src/tests.rs |  57 ++++++----
 4 files changed, 240 insertions(+), 168 deletions(-)

diff --git a/core/src/component.rs b/core/src/component.rs
index fbace03a..9ab69c27 100644
--- a/core/src/component.rs
+++ b/core/src/component.rs
@@ -2,11 +2,12 @@ use std::{
     any::{type_name, TypeId},
     collections::HashMap,
     ops::{Deref, DerefMut},
+    sync::Arc,
 };
 
 use bytes::Bytes;
 use rimecraft_edcode::Encode;
-use rimecraft_event::Event;
+use rimecraft_event::{DefaultEvent, Event};
 use rimecraft_primitives::{id, Id, SerDeUpdate};
 use tracing::{trace_span, warn};
 
@@ -196,12 +197,12 @@ impl From<ComponentsBuilder> for Components {
     }
 }
 
-type CompPEvent = Event<dyn Fn(TypeId, &mut Components)>;
+type CompPEvent = DefaultEvent<dyn Fn(TypeId, &mut Components) + Send + Sync>;
 
 static ATTACH_EVENTS: parking_lot::RwLock<CompPEvent> =
     parking_lot::RwLock::new(Event::new(|listeners| {
-        Box::new(move |type_id, components| {
-            for listener in listeners {
+        Arc::new(move |type_id, components| {
+            for listener in &listeners {
                 listener(type_id, components)
             }
         })
@@ -362,34 +363,37 @@ where
 
         let _ = span.enter();
 
-        match components
-            .get_mut::<Component<Event<dyn Fn(&mut HashMap<Id, Bytes>) -> anyhow::Result<()>>>>(
-                &NET_SEND_ID,
-            ) {
-            Ok(Component(event)) => event.register(Box::new(move |map| {
-                let this = unsafe { &*ptr };
+        match components.get_mut::<Component<BytesPEvent>>(&NET_SEND_ID) {
+            Ok(Component(event)) => rimecraft_event::register!(
+                event,
+                Arc::new(move |map| {
+                    let this = unsafe { &*ptr };
 
-                map.insert(this.1.clone(), {
-                    let mut bytes_mut = bytes::BytesMut::new();
-                    this.0.encode(&mut bytes_mut)?;
+                    map.insert(this.1.clone(), {
+                        let mut bytes_mut = bytes::BytesMut::new();
+                        this.0.encode(&mut bytes_mut)?;
 
-                    bytes_mut.into()
-                });
+                        bytes_mut.into()
+                    });
 
-                Ok(())
-            })),
+                    Ok(())
+                })
+            ),
             Err(err) => {
                 warn!("network sending event not found: {err}");
             }
         }
 
         match components.get_mut::<Component<BytesPEvent>>(&NET_RECV_ID) {
-            Ok(Component(event)) => event.register(Box::new(move |map| {
-                let this = unsafe { &mut *ptr };
-                let mut bytes = map.remove(&this.1).unwrap();
-
-                this.0.update(&mut bytes).map_err(From::from)
-            })),
+            Ok(Component(event)) => rimecraft_event::register!(
+                event,
+                Arc::new(move |map| {
+                    let this = unsafe { &mut *ptr };
+                    let mut bytes = map.remove(&this.1).unwrap();
+
+                    this.0.update(&mut bytes).map_err(From::from)
+                })
+            ),
             Err(err) => {
                 warn!("network receiving event not found: {err}");
             }
@@ -478,13 +482,14 @@ where
     }
 }
 
-type BytesPEvent = Event<dyn Fn(&mut HashMap<Id, Bytes>) -> anyhow::Result<()>>;
+type BytesPEvent =
+    DefaultEvent<dyn Fn(&mut HashMap<Id, Bytes>) -> anyhow::Result<()> + Send + Sync>;
 
 #[inline]
 fn net_event_comp() -> Component<BytesPEvent> {
     Component(Event::new(|listeners| {
-        Box::new(move |map| {
-            for listener in listeners {
+        Arc::new(move |map| {
+            for listener in &listeners {
                 listener(map)?;
             }
 
@@ -530,29 +535,32 @@ where
 
         let _ = span.enter();
 
-        match components.get_mut::<Component<
-            Event<dyn Fn(&mut HashMap<Id, fastnbt::Value>) -> fastnbt::error::Result<()>>,
-        >>(&NBT_SAVE_ID)
-        {
-            Ok(Component(event)) => event.register(Box::new(move |map| {
-                let this = unsafe { &*ptr };
-                map.insert(
-                    this.1.clone(),
-                    this.0.serialize(&mut fastnbt::value::Serializer)?,
-                );
-
-                Ok(())
-            })),
+        match components.get_mut::<Component<ValuePEvent>>(&NBT_SAVE_ID) {
+            Ok(Component(event)) => rimecraft_event::register!(
+                event,
+                Arc::new(move |map| {
+                    let this = unsafe { &*ptr };
+                    map.insert(
+                        this.1.clone(),
+                        this.0.serialize(&mut fastnbt::value::Serializer)?,
+                    );
+
+                    Ok(())
+                })
+            ),
             Err(err) => {
                 warn!("nbt saving event not found: {err}");
             }
         }
 
         match components.get_mut::<Component<ValuePEvent>>(&NBT_READ_ID) {
-            Ok(Component(event)) => event.register(Box::new(move |map| {
-                let this = unsafe { &mut *ptr };
-                this.0.update(&map.remove(&this.1).unwrap())
-            })),
+            Ok(Component(event)) => rimecraft_event::register!(
+                event,
+                Arc::new(move |map| {
+                    let this = unsafe { &mut *ptr };
+                    this.0.update(&map.remove(&this.1).unwrap())
+                })
+            ),
             Err(err) => {
                 warn!("nbt reading event not found: {err}");
             }
@@ -641,12 +649,14 @@ where
     }
 }
 
-type ValuePEvent = Event<dyn Fn(&mut HashMap<Id, fastnbt::Value>) -> fastnbt::error::Result<()>>;
+type ValuePEvent = DefaultEvent<
+    dyn Fn(&mut HashMap<Id, fastnbt::Value>) -> fastnbt::error::Result<()> + Send + Sync,
+>;
 
 fn nbt_event_comp() -> Component<ValuePEvent> {
     Component(Event::new(|listeners| {
-        Box::new(move |map| {
-            for listener in listeners {
+        Arc::new(move |map| {
+            for listener in &listeners {
                 listener(map)?;
             }
 
diff --git a/core/src/item/event.rs b/core/src/item/event.rs
index 1e58ca55..cb3d83af 100644
--- a/core/src/item/event.rs
+++ b/core/src/item/event.rs
@@ -1,12 +1,15 @@
-use rimecraft_event::Event;
+use std::sync::Arc;
+
+use rimecraft_event::{DefaultEvent, Event};
 
 use super::Item;
 
-pub static POST_PROCESS_NBT: Event<dyn Fn(Item, &mut rimecraft_nbt_ext::Compound)> =
-    Event::new(|listeners| {
-        Box::new(move |item, nbt| {
-            for listener in listeners {
-                listener(item, nbt)
-            }
-        })
-    });
+pub static POST_PROCESS_NBT: DefaultEvent<
+    dyn Fn(Item, &mut rimecraft_nbt_ext::Compound) + Send + Sync,
+> = Event::new(|listeners| {
+    Arc::new(move |item, nbt| {
+        for listener in &listeners {
+            listener(item, nbt)
+        }
+    })
+});
diff --git a/util/event/src/lib.rs b/util/event/src/lib.rs
index 50a45403..7e9234d4 100644
--- a/util/event/src/lib.rs
+++ b/util/event/src/lib.rs
@@ -1,127 +1,169 @@
-#[cfg(test)]
-mod tests;
-
-use std::sync::atomic::{AtomicBool, Ordering};
-
-use parking_lot::RwLock;
+use std::{ops::Deref, sync::Arc};
 
-/// Listeners and cache.
-type LisCac<T, Phase> = (Vec<(Phase, *const T)>, Option<Box<T>>, Vec<&'static T>);
+use parking_lot::{RwLock, RwLockReadGuard};
 
 /// A type containing listeners of this event,
 /// which can be invoked by an invoker.
 ///
-/// The listeners are sorted by phases ([`i8`] by default)
+/// The listeners are sorted by phases,
 /// that can be called in order.
-pub struct Event<T, Phase = i8>
-where
-    T: ?Sized + 'static,
-{
-    /// Whether listeners has been modified before requesting the invoker.
-    dirty: AtomicBool,
+pub struct Event<T, F, P> {
+    listeners: Vec<(T, P)>,
+    factory: F,
+    invoker: RwLock<Option<T>>,
+}
 
-    invoker_factory: fn(&'static [&'static T]) -> Box<T>,
+/// A sequence of listeners.
+#[derive(Debug)]
+pub struct Listeners<T> {
+    inner: Arc<Vec<T>>,
+}
 
-    /// 0: raw listeners with phases\
-    /// 1: cached invoker\
-    /// 2: cached listener references
-    lis_cac: RwLock<LisCac<T, Phase>>,
+impl<T> Listeners<T> {
+    /// Whether there are no listeners
+    /// in this sequence.
+    #[inline]
+    pub fn is_empty(&self) -> bool {
+        self.inner.is_empty()
+    }
 }
 
-impl<T, Phase> Event<T, Phase>
-where
-    T: ?Sized,
-    Phase: Ord,
-{
-    /// Create a new event with provided event factory.
-    ///
-    /// To avoid lifetime problems in the factory, listeners
-    /// provided are all in static references so that they're
-    /// able to be copied and moved.
-    /// So you should add a `move` keyword before the closure
-    /// to return in the factory.
-    pub const fn new(factory: fn(&'static [&'static T]) -> Box<T>) -> Self {
-        Self {
-            lis_cac: RwLock::new((Vec::new(), None, Vec::new())),
-            invoker_factory: factory,
-            dirty: AtomicBool::new(false),
+impl<'a, T: Deref> IntoIterator for &'a Listeners<T> {
+    type Item = &'a <T as Deref>::Target;
+
+    type IntoIter = ListenersIter<'a, T>;
+
+    #[inline]
+    fn into_iter(self) -> Self::IntoIter {
+        ListenersIter {
+            inner: self.inner.iter(),
         }
     }
+}
 
-    /// Get the invoker of this event.
-    ///
-    /// Once the invoker is created, it will be cached until
-    /// the next modification of listeners, and will be re-created
-    /// by the factory.
-    pub fn invoker(&self) -> &T {
-        if self.dirty.load(Ordering::Acquire) {
-            let mut write_guard = self.lis_cac.write();
-            write_guard.0.sort_by(|e0, e1| Phase::cmp(&e0.0, &e1.0));
-            self.dirty.store(false, Ordering::Release);
-
-            write_guard.2 = write_guard.0.iter().map(|e| unsafe { &*e.1 }).collect();
-            write_guard.1 = Some((self.invoker_factory)(unsafe {
-                &*(&write_guard.2 as *const Vec<&'static T>)
-            }));
-        } else if self.lis_cac.read().1.is_none() {
-            let mut write_guard = self.lis_cac.write();
-            write_guard.1 = Some((self.invoker_factory)(unsafe {
-                &*(&write_guard.2 as *const Vec<&'static T>)
-            }));
+impl<T> Clone for Listeners<T> {
+    #[inline]
+    fn clone(&self) -> Self {
+        Self {
+            inner: self.inner.clone(),
         }
-
-        unsafe { &*(&**self.lis_cac.read().1.as_ref().unwrap() as *const T) }
     }
+}
 
-    /// Register a listener to this event for the specified phase.
-    pub fn register_with_phase(&mut self, listener: Box<T>, phase: Phase) {
-        self.lis_cac
-            .get_mut()
-            .0
-            .push((phase, Box::into_raw(listener)));
+#[derive(Debug)]
+pub struct ListenersIter<'a, T> {
+    inner: std::slice::Iter<'a, T>,
+}
 
-        if !self.dirty.load(Ordering::Acquire) {
-            self.dirty.store(true, Ordering::Release);
-        }
+impl<'a, T: Deref> Iterator for ListenersIter<'a, T> {
+    type Item = &'a <T as Deref>::Target;
+
+    #[inline]
+    fn next(&mut self) -> Option<Self::Item> {
+        self.inner.next().map(Deref::deref)
     }
 }
 
-impl<T, Phase> Event<T, Phase>
-where
-    T: ?Sized,
-    Phase: Ord + Default,
-{
-    /// Register a listener to this event for the default phase.
+/// A cell that can be dereferenced
+/// to the listener.
+#[derive(Debug)]
+pub struct Listener<T> {
+    inner: T,
+}
+
+impl<T: Deref> Deref for Listener<T> {
+    type Target = <T as Deref>::Target;
+
     #[inline]
-    pub fn register(&mut self, listener: Box<T>) {
-        self.register_with_phase(listener, Default::default())
+    fn deref(&self) -> &Self::Target {
+        &self.inner
     }
 }
 
-impl<T, Phase> Drop for Event<T, Phase>
-where
-    T: ?Sized,
-{
-    fn drop(&mut self) {
-        let mut vec = Vec::new();
-        std::mem::swap(&mut self.lis_cac.get_mut().0, &mut vec);
+#[derive(Debug)]
+pub struct Invoker<'a, T> {
+    inner: RwLockReadGuard<'a, Option<T>>,
+}
+
+impl<'a, T: Deref> Deref for Invoker<'a, T> {
+    type Target = <T as Deref>::Target;
+
+    #[inline]
+    fn deref(&self) -> &Self::Target {
+        self.inner.as_ref().unwrap()
+    }
+}
 
-        for value in vec {
-            let _ = unsafe { Box::from_raw(value.1 as *mut T) };
+impl<T, F, P> Event<T, F, P> {
+    /// Creates a new event with provided invoker factory.
+    pub const fn new(factory: F) -> Self {
+        Self {
+            listeners: Vec::new(),
+            factory,
+            invoker: RwLock::new(None),
         }
     }
+
+    #[inline]
+    fn make_dirty(&mut self) {
+        self.invoker.get_mut().take();
+    }
+
+    /// Registers a listener with given phase into
+    /// this event.
+    #[inline]
+    pub fn register(&mut self, listener: T, phase: P) {
+        self.listeners.push((listener, phase));
+        self.make_dirty()
+    }
 }
 
-unsafe impl<T, Phase> Send for Event<T, Phase>
+impl<T, F, P> Event<T, F, P>
 where
-    T: ?Sized,
-    Phase: Ord + Send,
+    F: Fn(Listeners<T>) -> T,
+    P: Ord,
+    T: Clone,
 {
+    /// Obtains the invoker of this event.
+    pub fn invoker(&self) -> Invoker<'_, T> {
+        {
+            let rg = self.invoker.read();
+            if rg.as_ref().is_some() {
+                return Invoker { inner: rg };
+            }
+        }
+
+        let mut listeners = self
+            .listeners
+            .iter()
+            .map(|(l, p)| (l.clone(), p))
+            .collect::<Vec<_>>();
+        listeners.sort_by_key(|(_, p)| *p);
+        let listeners = Listeners {
+            inner: Arc::new(listeners.into_iter().map(|(l, _)| l).collect()),
+        };
+
+        *self.invoker.write() = Some((self.factory)(listeners));
+        Invoker {
+            inner: self.invoker.read(),
+        }
+    }
 }
 
-unsafe impl<T, Phase> Sync for Event<T, Phase>
-where
-    T: ?Sized,
-    Phase: Ord + Sync,
-{
+/// Registers a listener into the event.
+#[macro_export]
+macro_rules! register {
+    ($e:expr, $l:expr, $p:expr$(,)?) => {
+        $e.register($l, $p)
+    };
+    ($e:expr, $l:expr$(,)?) => {
+        $crate::register!($e, $l, ::core::default::Default::default())
+    };
 }
+
+pub type InvokerFactory<T> = fn(Listeners<T>) -> T;
+
+pub type DefaultEvent<T, P = i8> = Event<Arc<T>, InvokerFactory<Arc<T>>, P>;
+
+#[cfg(test)]
+mod tests;
diff --git a/util/event/src/tests.rs b/util/event/src/tests.rs
index 13f28d94..bef57267 100644
--- a/util/event/src/tests.rs
+++ b/util/event/src/tests.rs
@@ -1,10 +1,14 @@
+use std::sync::Arc;
+
+use crate::DefaultEvent;
+
 use super::Event;
 
 #[test]
 fn registering_invoking() {
-    let mut event: Event<dyn Fn(&str) -> bool> = Event::new(|listeners| {
-        Box::new(move |string| {
-            for listener in listeners {
+    let mut event: DefaultEvent<dyn Fn(&str) -> bool + Send + Sync> = Event::new(|listeners| {
+        Arc::new(move |string| {
+            for listener in &listeners {
                 if !listener(string) {
                     return false;
                 }
@@ -17,13 +21,18 @@ fn registering_invoking() {
         "minecraft by mojang is a propritary software."
     ));
 
-    event.register(Box::new(|string| {
-        !string.to_lowercase().contains("propritary software")
-    }));
-    event.register(Box::new(|string| !string.to_lowercase().contains("mojang")));
-    event.register(Box::new(|string| {
-        !string.to_lowercase().contains("minecraft")
-    }));
+    register!(
+        event,
+        Arc::new(|string: &str| { !string.to_lowercase().contains("propritary software") })
+    );
+    register!(
+        event,
+        Arc::new(|string: &str| !string.to_lowercase().contains("mojang"))
+    );
+    register!(
+        event,
+        Arc::new(|string| { !string.to_lowercase().contains("minecraft") })
+    );
 
     assert!(!event.invoker()(
         "minecraft by mojang is a propritary software."
@@ -31,25 +40,32 @@ fn registering_invoking() {
 
     assert!(event.invoker()("i love krlite."));
 
-    event.register(Box::new(|string| !string.to_lowercase().contains("krlite")));
+    register!(
+        event,
+        Arc::new(|string| !string.to_lowercase().contains("krlite"))
+    );
 
     assert!(!event.invoker()("i love krlite."));
 }
 
 #[test]
 fn phases() {
-    let mut event: Event<dyn Fn(&mut String)> = Event::new(|listeners| {
-        Box::new(move |string| {
-            for listener in listeners {
+    let mut event: DefaultEvent<dyn Fn(&mut String) + Send + Sync> = Event::new(|listeners| {
+        Arc::new(move |string| {
+            for listener in &listeners {
                 listener(string);
             }
         })
     });
 
-    event.register(Box::new(|string| string.push_str("genshin impact ")));
-    event.register_with_phase(Box::new(|string| string.push_str("you're right, ")), -3);
-    event.register_with_phase(Box::new(|string| string.push_str("but ")), -2);
-    event.register_with_phase(Box::new(|string| string.push_str("is a...")), 10);
+    register!(event, Arc::new(|string| string.push_str("genshin impact ")));
+    register!(
+        event,
+        Arc::new(|string| string.push_str("you're right, ")),
+        -3,
+    );
+    register!(event, Arc::new(|string| string.push_str("but ")), -2);
+    register!(event, Arc::new(|string| string.push_str("is a...")), 10);
 
     {
         let mut string = String::new();
@@ -57,8 +73,9 @@ fn phases() {
         assert_eq!(string, "you're right, but genshin impact is a...");
     }
 
-    event.register_with_phase(
-        Box::new(|string| string.push_str("genshin impact, bootstrap! ")),
+    register!(
+        event,
+        Arc::new(|string| string.push_str("genshin impact, bootstrap! ")),
         -100,
     );