From 013537b27b25b35f69f110ac14f4e9dea7ce222e Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Wed, 31 Jan 2024 20:17:15 +0100 Subject: [PATCH] Get rid of thread handles in favor of simple id-based map --- examples/callbacks.rs | 4 +- examples/scheduler_ordering.rs | 4 +- lib/functions.rs | 29 ++++++-- lib/handle.rs | 130 --------------------------------- lib/lib.rs | 3 +- lib/queue.rs | 21 +----- lib/result_map.rs | 41 +++++++++++ lib/runtime.rs | 75 +++++++++++++++---- lib/traits.rs | 14 ++-- lib/util.rs | 11 +++ 10 files changed, 153 insertions(+), 179 deletions(-) delete mode 100644 lib/handle.rs create mode 100644 lib/result_map.rs diff --git a/examples/callbacks.rs b/examples/callbacks.rs index a88d794..ee4da89 100644 --- a/examples/callbacks.rs +++ b/examples/callbacks.rs @@ -26,13 +26,13 @@ pub fn main() -> LuaResult<()> { // Load the main script into the runtime, and keep track of the thread we spawn let main = lua.load(MAIN_SCRIPT); - let handle = rt.push_thread_front(main, ())?; + let id = rt.push_thread_front(main, ())?; // Run until completion block_on(rt.run()); // We should have gotten the error back from our script - assert!(handle.result(&lua).unwrap().is_err()); + assert!(rt.thread_result(id).unwrap().is_err()); Ok(()) } diff --git a/examples/scheduler_ordering.rs b/examples/scheduler_ordering.rs index d91b3dd..d774ec6 100644 --- a/examples/scheduler_ordering.rs +++ b/examples/scheduler_ordering.rs @@ -32,13 +32,13 @@ pub fn main() -> LuaResult<()> { // Load the main script into the runtime, and keep track of the thread we spawn let main = lua.load(MAIN_SCRIPT); - let handle = rt.push_thread_front(main, ())?; + let id = rt.push_thread_front(main, ())?; // Run until completion block_on(rt.run()); // We should have gotten proper values back from our script - let res = handle.result(&lua).unwrap().unwrap(); + let res = rt.thread_result(id).unwrap().unwrap(); let nums = Vec::::from_lua_multi(res, &lua)?; assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]); diff --git a/lib/functions.rs b/lib/functions.rs index 7f50d60..e52160f 100644 --- a/lib/functions.rs +++ b/lib/functions.rs @@ -6,8 +6,10 @@ use mlua::prelude::*; use crate::{ error_callback::ThreadErrorCallback, queue::{DeferredThreadQueue, SpawnedThreadQueue}, + result_map::ThreadResultMap, runtime::Runtime, - util::LuaThreadOrFunction, + thread_id::ThreadId, + util::{is_poll_pending, LuaThreadOrFunction, ThreadResult}, }; const ERR_METADATA_NOT_ATTACHED: &str = "\ @@ -63,24 +65,39 @@ impl<'lua> Functions<'lua> { .app_data_ref::() .expect(ERR_METADATA_NOT_ATTACHED) .clone(); + let result_map = lua + .app_data_ref::() + .expect(ERR_METADATA_NOT_ATTACHED) + .clone(); + let spawn_map = result_map.clone(); let spawn = lua.create_function( move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { let thread = tof.into_thread(lua)?; if thread.status() == LuaThreadStatus::Resumable { // NOTE: We need to resume the thread once instantly for correct behavior, // and only if we get the pending value back we can spawn to async executor - match thread.resume::<_, LuaValue>(args.clone()) { + match thread.resume::<_, LuaMultiValue>(args.clone()) { Ok(v) => { - if v.as_light_userdata() - .map(|l| l == Lua::poll_pending()) - .unwrap_or_default() - { + if v.get(0).map(is_poll_pending).unwrap_or_default() { spawn_queue.push_item(lua, &thread, args)?; + } else { + // Not pending, store the value + let id = ThreadId::from(&thread); + if spawn_map.is_tracked(id) { + let res = ThreadResult::new(Ok(v), lua); + spawn_map.insert(id, res); + } } } Err(e) => { error_callback.call(&e); + // Not pending, store the error + let id = ThreadId::from(&thread); + if spawn_map.is_tracked(id) { + let res = ThreadResult::new(Err(e), lua); + spawn_map.insert(id, res); + } } }; } diff --git a/lib/handle.rs b/lib/handle.rs deleted file mode 100644 index 56b3458..0000000 --- a/lib/handle.rs +++ /dev/null @@ -1,130 +0,0 @@ -#![allow(unused_imports)] -#![allow(clippy::missing_panics_doc)] -#![allow(clippy::module_name_repetitions)] - -use std::{ - cell::{Cell, RefCell}, - rc::Rc, -}; - -use event_listener::Event; -use mlua::prelude::*; - -use crate::{ - runtime::Runtime, - status::Status, - traits::IntoLuaThread, - util::{run_until_yield, ThreadResult, ThreadWithArgs}, -}; - -/** - A handle to a thread that has been spawned onto a [`Runtime`]. - - This handle contains a public method, [`Handle::result`], which may - be used to extract the result of the thread, once it finishes running. - - A result may be waited for using the [`Handle::listen`] method. -*/ -#[derive(Debug, Clone)] -pub struct Handle { - thread: Rc>>, - result: Rc>>, - status: Rc>, - event: Rc, -} - -impl Handle { - pub(crate) fn new<'lua>( - lua: &'lua Lua, - thread: impl IntoLuaThread<'lua>, - args: impl IntoLuaMulti<'lua>, - ) -> LuaResult { - let thread = thread.into_lua_thread(lua)?; - let args = args.into_lua_multi(lua)?; - - let packed = ThreadWithArgs::new(lua, thread, args)?; - - Ok(Self { - thread: Rc::new(RefCell::new(Some(packed))), - result: Rc::new(RefCell::new(None)), - status: Rc::new(Cell::new(false)), - event: Rc::new(Event::new()), - }) - } - - pub(crate) fn create_thread<'lua>(&self, lua: &'lua Lua) -> LuaResult> { - let env = lua.create_table()?; - env.set("handle", self.clone())?; - lua.load("return handle:resume()") - .set_name("__runtime_handle") - .set_environment(env) - .into_lua_thread(lua) - } - - fn take<'lua>(&self, lua: &'lua Lua) -> (LuaThread<'lua>, LuaMultiValue<'lua>) { - self.thread - .borrow_mut() - .take() - .expect("thread handle may only be taken once") - .into_inner(lua) - } - - fn set<'lua>(&self, lua: &'lua Lua, result: &LuaResult>, is_final: bool) { - self.result - .borrow_mut() - .replace(ThreadResult::new(result.clone(), lua)); - self.status.replace(is_final); - if is_final { - self.event.notify(usize::MAX); - } - } - - /** - Extracts the result for this thread handle. - - Depending on the current [`Runtime::status`], this method will return: - - - [`Status::NotStarted`]: returns `None`. - - [`Status::Running`]: may return `Some(Ok(v))` or `Some(Err(e))`, but it is not guaranteed. - - [`Status::Completed`]: returns `Some(Ok(v))` or `Some(Err(e))`. - - Note that this method also takes the value out of the handle, so it may only be called once. - - Any subsequent calls after this method returns `Some` will return `None`. - */ - #[must_use] - pub fn result<'lua>(&self, lua: &'lua Lua) -> Option>> { - let mut res = self.result.borrow_mut(); - res.take().map(|r| r.value(lua)) - } - - /** - Waits for this handle to have its final result available. - - Does not wait if the final result is already available. - */ - pub async fn listen(&self) { - if !self.status.get() { - self.event.listen().await; - } - } -} - -impl LuaUserData for Handle { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("resume", |lua, this, (): ()| async move { - /* - 1. Take the thread and args out of the handle - 2. Run the thread until it yields or completes - 3. Store the result of the thread in the lua registry - 4. Return the result of the thread back to lua as well, so that - it may be caught using the runtime and any error callback(s) - */ - let (thread, args) = this.take(lua); - let result = run_until_yield(thread.clone(), args).await; - let is_final = thread.status() != LuaThreadStatus::Resumable; - this.set(lua, &result, is_final); - result - }); - } -} diff --git a/lib/lib.rs b/lib/lib.rs index d17a15e..81f8745 100644 --- a/lib/lib.rs +++ b/lib/lib.rs @@ -1,7 +1,7 @@ mod error_callback; mod functions; -mod handle; mod queue; +mod result_map; mod runtime; mod status; mod thread_id; @@ -9,7 +9,6 @@ mod traits; mod util; pub use functions::Functions; -pub use handle::Handle; pub use runtime::Runtime; pub use status::Status; pub use thread_id::ThreadId; diff --git a/lib/queue.rs b/lib/queue.rs index 54837f4..1e08580 100644 --- a/lib/queue.rs +++ b/lib/queue.rs @@ -6,7 +6,7 @@ use event_listener::Event; use futures_lite::{Future, FutureExt}; use mlua::prelude::*; -use crate::{handle::Handle, traits::IntoLuaThread, util::ThreadWithArgs}; +use crate::{traits::IntoLuaThread, util::ThreadWithArgs, ThreadId}; /** Queue for storing [`LuaThread`]s with associated arguments. @@ -32,31 +32,18 @@ impl ThreadQueue { lua: &'lua Lua, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult<()> { + ) -> LuaResult { let thread = thread.into_lua_thread(lua)?; let args = args.into_lua_multi(lua)?; tracing::trace!("pushing item to queue with {} args", args.len()); + let id = ThreadId::from(&thread); let stored = ThreadWithArgs::new(lua, thread, args)?; self.queue.push(stored).into_lua_err()?; self.event.notify(usize::MAX); - Ok(()) - } - - pub fn push_item_with_handle<'lua>( - &self, - lua: &'lua Lua, - thread: impl IntoLuaThread<'lua>, - args: impl IntoLuaMulti<'lua>, - ) -> LuaResult { - let handle = Handle::new(lua, thread, args)?; - let handle_thread = handle.create_thread(lua)?; - - self.push_item(lua, handle_thread, ())?; - - Ok(handle) + Ok(id) } pub fn drain_items<'outer, 'lua>( diff --git a/lib/result_map.rs b/lib/result_map.rs new file mode 100644 index 0000000..bcdf6b8 --- /dev/null +++ b/lib/result_map.rs @@ -0,0 +1,41 @@ +use std::{ + cell::RefCell, + collections::{HashMap, HashSet}, + rc::Rc, +}; + +use crate::{thread_id::ThreadId, util::ThreadResult}; + +#[derive(Clone)] +pub(crate) struct ThreadResultMap { + tracked: Rc>>, + inner: Rc>>, +} + +impl ThreadResultMap { + pub fn new() -> Self { + Self { + tracked: Rc::new(RefCell::new(HashSet::new())), + inner: Rc::new(RefCell::new(HashMap::new())), + } + } + + pub fn track(&self, id: ThreadId) { + self.tracked.borrow_mut().insert(id); + } + + pub fn is_tracked(&self, id: ThreadId) -> bool { + self.tracked.borrow().contains(&id) + } + + pub fn insert(&self, id: ThreadId, result: ThreadResult) { + assert!(self.is_tracked(id), "Thread must be tracked"); + self.inner.borrow_mut().insert(id, result); + } + + pub fn remove(&self, id: ThreadId) -> Option { + let res = self.inner.borrow_mut().remove(&id)?; + self.tracked.borrow_mut().remove(&id); + Some(res) + } +} diff --git a/lib/runtime.rs b/lib/runtime.rs index b00f355..06315ff 100644 --- a/lib/runtime.rs +++ b/lib/runtime.rs @@ -14,11 +14,12 @@ use tracing::Instrument; use crate::{ error_callback::ThreadErrorCallback, - handle::Handle, queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue}, + result_map::ThreadResultMap, status::Status, + thread_id::ThreadId, traits::IntoLuaThread, - util::run_until_yield, + util::{run_until_yield, ThreadResult}, }; const ERR_METADATA_ALREADY_ATTACHED: &str = "\ @@ -45,6 +46,7 @@ pub struct Runtime<'lua> { queue_spawn: SpawnedThreadQueue, queue_defer: DeferredThreadQueue, error_callback: ThreadErrorCallback, + result_map: ThreadResultMap, status: Rc>, } @@ -63,7 +65,7 @@ impl<'lua> Runtime<'lua> { let queue_spawn = SpawnedThreadQueue::new(); let queue_defer = DeferredThreadQueue::new(); let error_callback = ThreadErrorCallback::default(); - let status = Rc::new(Cell::new(Status::NotStarted)); + let result_map = ThreadResultMap::new(); assert!( lua.app_data_ref::().is_none(), @@ -77,16 +79,24 @@ impl<'lua> Runtime<'lua> { lua.app_data_ref::().is_none(), "{ERR_METADATA_ALREADY_ATTACHED}" ); + assert!( + lua.app_data_ref::().is_none(), + "{ERR_METADATA_ALREADY_ATTACHED}" + ); lua.set_app_data(queue_spawn.clone()); lua.set_app_data(queue_defer.clone()); lua.set_app_data(error_callback.clone()); + lua.set_app_data(result_map.clone()); + + let status = Rc::new(Cell::new(Status::NotStarted)); Runtime { lua, queue_spawn, queue_defer, error_callback, + result_map, status, } } @@ -142,7 +152,7 @@ impl<'lua> Runtime<'lua> { # Returns - Returns a [`Handle`] that can be used to retrieve the result of the thread. + Returns a [`ThreadId`] that can be used to retrieve the result of the thread. Note that the result may not be available until [`Runtime::run`] completes. @@ -154,10 +164,11 @@ impl<'lua> Runtime<'lua> { &self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult { + ) -> LuaResult { tracing::debug!(deferred = false, "new runtime thread"); - self.queue_spawn - .push_item_with_handle(self.lua, thread, args) + let id = self.queue_spawn.push_item(self.lua, thread, args)?; + self.result_map.track(id); + Ok(id) } /** @@ -169,7 +180,7 @@ impl<'lua> Runtime<'lua> { # Returns - Returns a [`Handle`] that can be used to retrieve the result of the thread. + Returns a [`ThreadId`] that can be used to retrieve the result of the thread. Note that the result may not be available until [`Runtime::run`] completes. @@ -181,10 +192,30 @@ impl<'lua> Runtime<'lua> { &self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult { + ) -> LuaResult { tracing::debug!(deferred = true, "new runtime thread"); - self.queue_defer - .push_item_with_handle(self.lua, thread, args) + let id = self.queue_defer.push_item(self.lua, thread, args)?; + self.result_map.track(id); + Ok(id) + } + + /** + Gets the tracked result for the [`LuaThread`] with the given [`ThreadId`]. + + Depending on the current [`Runtime::status`], this method will return: + + - [`Status::NotStarted`]: returns `None`. + - [`Status::Running`]: may return `Some(Ok(v))` or `Some(Err(e))`, but it is not guaranteed. + - [`Status::Completed`]: returns `Some(Ok(v))` or `Some(Err(e))`. + + Note that this method also takes the value out of the runtime and + stops tracking the given thread, so it may only be called once. + + Any subsequent calls after this method returns `Some` will return `None`. + */ + #[must_use] + pub fn thread_result(&self, id: ThreadId) -> Option>> { + self.result_map.remove(id).map(|r| r.value(self.lua)) } /** @@ -245,14 +276,29 @@ impl<'lua> Runtime<'lua> { when there are new Lua threads to enqueue and potentially more work to be done. */ let fut = async { + let result_map = self.result_map.clone(); let process_thread = |thread: LuaThread<'lua>, args| { // NOTE: Thread may have been cancelled from Lua // before we got here, so we need to check it again if thread.status() == LuaThreadStatus::Resumable { + // Check if we should be tracking this thread + let id = ThreadId::from(&thread); + let id_tracked = result_map.is_tracked(id); + let result_map_inner = if id_tracked { + Some(result_map.clone()) + } else { + None + }; + // Spawn it on the executor and store the result when done local_exec .spawn(async move { - if let Err(e) = run_until_yield(thread, args).await { - self.error_callback.call(&e); + let res = run_until_yield(thread, args).await; + if let Err(e) = res.as_ref() { + self.error_callback.call(e); + } + if id_tracked { + let thread_res = ThreadResult::new(res, self.lua); + result_map_inner.unwrap().insert(id, thread_res); } }) .detach(); @@ -352,5 +398,8 @@ impl Drop for Runtime<'_> { self.lua .remove_app_data::() .expect(ERR_METADATA_REMOVED); + self.lua + .remove_app_data::() + .expect(ERR_METADATA_REMOVED); } } diff --git a/lib/traits.rs b/lib/traits.rs index e370aa0..363f935 100644 --- a/lib/traits.rs +++ b/lib/traits.rs @@ -8,9 +8,9 @@ use mlua::prelude::*; use async_executor::{Executor, Task}; use crate::{ - handle::Handle, queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue}, runtime::Runtime, + thread_id::ThreadId, }; /** @@ -76,7 +76,7 @@ pub trait LuaRuntimeExt<'lua> { &'lua self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult; + ) -> LuaResult; /** Pushes (defers) a lua thread to the **back** of the current runtime. @@ -91,7 +91,7 @@ pub trait LuaRuntimeExt<'lua> { &'lua self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult; + ) -> LuaResult; /** Spawns the given future on the current executor and returns its [`Task`]. @@ -180,22 +180,22 @@ impl<'lua> LuaRuntimeExt<'lua> for Lua { &'lua self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult { + ) -> LuaResult { let queue = self .app_data_ref::() .expect("lua threads can only be pushed within a runtime"); - queue.push_item_with_handle(self, thread, args) + queue.push_item(self, thread, args) } fn push_thread_back( &'lua self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult { + ) -> LuaResult { let queue = self .app_data_ref::() .expect("lua threads can only be pushed within a runtime"); - queue.push_item_with_handle(self, thread, args) + queue.push_item(self, thread, args) } fn spawn(&self, fut: impl Future + Send + 'static) -> Task { diff --git a/lib/util.rs b/lib/util.rs index 933886b..9d37829 100644 --- a/lib/util.rs +++ b/lib/util.rs @@ -21,6 +21,17 @@ pub(crate) async fn run_until_yield<'lua>( stream.next().await.unwrap() } +/** + Checks if the given [`LuaValue`] is the async `POLL_PENDING` constant. +*/ +#[inline] +pub(crate) fn is_poll_pending(value: &LuaValue) -> bool { + value + .as_light_userdata() + .map(|l| l == Lua::poll_pending()) + .unwrap_or_default() +} + /** Representation of a [`LuaResult`] with an associated [`LuaMultiValue`] currently stored in the Lua registry. */