Skip to content

Commit

Permalink
feature: compare local pids (#611)
Browse files Browse the repository at this point in the history
* rustler_sys: add 'nif_compare_pids'
* rustler: add (Partial)Eq/Ord for LocalPid
* rustler_tests: add tests for LocalPid cmp/eq
* sys: define 'enif_compare_pids' to behave like macro in C code
* tests: add unit test with equality check for local pids
  • Loading branch information
hengermax authored May 29, 2024
1 parent b882d51 commit 127e255
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 0 deletions.
22 changes: 22 additions & 0 deletions rustler/src/types/local_pid.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::wrapper::{pid, ErlNifPid};
use crate::{Decoder, Encoder, Env, Error, NifResult, Term};
use std::cmp::Ordering;
use std::mem::MaybeUninit;

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -36,6 +37,27 @@ impl Encoder for LocalPid {
}
}

impl PartialEq for LocalPid {
fn eq(&self, other: &Self) -> bool {
unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) == 0 }
}
}

impl Eq for LocalPid {}

impl PartialOrd for LocalPid {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for LocalPid {
fn cmp(&self, other: &Self) -> Ordering {
let cmp = unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) };
cmp.cmp(&0)
}
}

impl<'a> Env<'a> {
/// Return the calling process's pid.
///
Expand Down
6 changes: 6 additions & 0 deletions rustler_sys/src/rustler_sys_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ pub unsafe fn enif_make_pid(_env: *mut ErlNifEnv, pid: ErlNifPid) -> ERL_NIF_TER
pid.pid
}

/// See [enif_compare_pids](http://erlang.org/doc/man/erl_nif.html#enif_compare_pids) in the Erlang docs
pub unsafe fn enif_compare_pids(pid1: *const ErlNifPid, pid2: *const ErlNifPid) -> c_int {
// Mimics the implementation of the enif_compare_pids macro
enif_compare((*pid1).pid, (*pid2).pid)
}

/// See [ErlNifSysInfo](http://www.erlang.org/doc/man/erl_nif.html#ErlNifSysInfo) in the Erlang docs.
#[allow(missing_copy_implementations)]
#[repr(C)]
Expand Down
3 changes: 3 additions & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ defmodule RustlerTest do
def sum_list(_), do: err()
def make_list(), do: err()

def compare_local_pids(_, _), do: err()
def are_equal_local_pids(_, _), do: err()

def term_debug(_), do: err()

def term_debug_and_reparse(term) do
Expand Down
3 changes: 3 additions & 0 deletions rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod test_dirty;
mod test_env;
mod test_error;
mod test_list;
mod test_local_pid;
mod test_map;
mod test_nif_attrs;
mod test_path;
Expand All @@ -28,6 +29,8 @@ rustler::init!(
test_primitives::echo_i128,
test_list::sum_list,
test_list::make_list,
test_local_pid::compare_local_pids,
test_local_pid::are_equal_local_pids,
test_term::term_debug,
test_term::term_eq,
test_term::term_cmp,
Expand Down
17 changes: 17 additions & 0 deletions rustler_tests/native/rustler_test/src/test_local_pid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use std::cmp::Ordering;

use rustler::LocalPid;

#[rustler::nif]
pub fn compare_local_pids(lhs: LocalPid, rhs: LocalPid) -> i32 {
match lhs.cmp(&rhs) {
Ordering::Less => -1,
Ordering::Equal => 0,
Ordering::Greater => 1,
}
}

#[rustler::nif]
pub fn are_equal_local_pids(lhs: LocalPid, rhs: LocalPid) -> bool {
lhs == rhs
}
35 changes: 35 additions & 0 deletions rustler_tests/test/local_pid_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
defmodule RustlerTest.LocalPidTest do
use ExUnit.Case, async: true

def make_pid() do
{:ok, pid} = Task.start(fn -> :ok end)
pid
end

def compare(lhs, rhs) do
cond do
lhs < rhs -> -1
lhs == rhs -> 0
lhs > rhs -> 1
end
end

test "local pid comparison" do
# We make sure that the code we have in rust code matches the comparisons
# that are performed in the BEAM code.
pids = for _ <- 1..3, do: make_pid()

for lhs <- pids, rhs <- pids do
assert RustlerTest.compare_local_pids(lhs, rhs) == compare(lhs, rhs)
end
end

test "local pid equality" do
pids = for _ <- 1..3, do: make_pid()

for lhs <- pids, rhs <- pids do
expected = lhs == rhs
assert RustlerTest.are_equal_local_pids(lhs, rhs) == expected
end
end
end

0 comments on commit 127e255

Please sign in to comment.