Skip to content

Commit

Permalink
refactor: reduce watchmap memory size (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra authored Jan 2, 2025
1 parent bebb6b4 commit f89c0b2
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 70 deletions.
39 changes: 24 additions & 15 deletions src/internal/id.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::fmt::{Display, Formatter};
use std::{
fmt::{Display, Formatter},
num::NonZeroU32,
};

use crate::{internal::arena::ArenaId, Interner};

Expand Down Expand Up @@ -165,32 +168,24 @@ impl From<SolvableId> for u32 {

#[repr(transparent)]
#[derive(Copy, Clone, PartialOrd, Ord, Eq, PartialEq, Debug, Hash)]
pub(crate) struct ClauseId(u32);
pub(crate) struct ClauseId(NonZeroU32);

impl ClauseId {
/// There is a guarentee that ClauseId(0) will always be
/// There is a guarentee that ClauseId(1) will always be
/// "Clause::InstallRoot". This assumption is verified by the solver.
pub(crate) fn install_root() -> Self {
Self(0)
}

pub(crate) fn is_null(self) -> bool {
self.0 == u32::MAX
}

pub(crate) fn null() -> ClauseId {
ClauseId(u32::MAX)
Self(unsafe { NonZeroU32::new_unchecked(1) })
}
}

impl ArenaId for ClauseId {
fn from_usize(x: usize) -> Self {
assert!(x < u32::MAX as usize, "clause id too big");
Self(x as u32)
// SAFETY: Safe because we always add 1 to the index
Self(unsafe { NonZeroU32::new_unchecked((x + 1).try_into().expect("clause id too big")) })
}

fn to_usize(self) -> usize {
self.0 as usize
(self.0.get() - 1) as usize
}
}

Expand Down Expand Up @@ -236,3 +231,17 @@ impl ArenaId for DependenciesId {
self.0 as usize
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_clause_id_size() {
// Verify that the size of a ClauseId is the same as an Option<ClauseId>.
assert_eq!(
std::mem::size_of::<ClauseId>(),
std::mem::size_of::<Option<ClauseId>>()
);
}
}
15 changes: 15 additions & 0 deletions src/internal/mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ impl<TId: ArenaId, TValue> Mapping<TId, TValue> {
previous_value
}

/// Unset a specific value in the mapping, returns the previous value.
pub fn unset(&mut self, id: TId) -> Option<TValue> {
let idx = id.to_usize();
let (chunk, offset) = Self::chunk_and_offset(idx);
if chunk >= self.chunks.len() {
return None;
}

let previous_value = self.chunks[chunk][offset].take();
if previous_value.is_some() {
self.len -= 1;
}
previous_value
}

/// Get a specific value in the mapping with bound checks
pub fn get(&self, id: TId) -> Option<&TValue> {
let (chunk, offset) = Self::chunk_and_offset(id.to_usize());
Expand Down
62 changes: 31 additions & 31 deletions src/solver/clause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ pub(crate) struct ClauseState {
// The ids of the solvables this clause is watching
pub watched_literals: [Literal; 2],
// The ids of the next clause in each linked list that this clause is part of
pub(crate) next_watches: [ClauseId; 2],
pub(crate) next_watches: [Option<ClauseId>; 2],
}

impl ClauseState {
Expand Down Expand Up @@ -417,15 +417,15 @@ impl ClauseState {

let clause = Self {
watched_literals,
next_watches: [ClauseId::null(), ClauseId::null()],
next_watches: [None, None],
};

debug_assert!(!clause.has_watches() || watched_literals[0] != watched_literals[1]);

clause
}

pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: ClauseId) {
pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: Option<ClauseId>) {
self.next_watches[watch_index] = linked_clause;
}

Expand All @@ -444,7 +444,7 @@ impl ClauseState {
}

#[inline]
pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> ClauseId {
pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> Option<ClauseId> {
if solvable_id == self.watched_literals[0].solvable_id() {
self.next_watches[0]
} else {
Expand Down Expand Up @@ -650,7 +650,7 @@ mod test {
use super::*;
use crate::{internal::arena::ArenaId, solver::decision::Decision};

fn clause(next_clauses: [ClauseId; 2], watch_literals: [Literal; 2]) -> ClauseState {
fn clause(next_clauses: [Option<ClauseId>; 2], watch_literals: [Literal; 2]) -> ClauseState {
ClauseState {
watched_literals: watch_literals,
next_watches: next_clauses,
Expand Down Expand Up @@ -691,21 +691,24 @@ mod test {
#[test]
fn test_unlink_clause_different() {
let clause1 = clause(
[ClauseId::from_usize(2), ClauseId::from_usize(3)],
[
ClauseId::from_usize(2).into(),
ClauseId::from_usize(3).into(),
],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1211).negative(),
],
);
let clause2 = clause(
[ClauseId::null(), ClauseId::from_usize(3)],
[None, ClauseId::from_usize(3).into()],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1208).negative(),
],
);
let clause3 = clause(
[ClauseId::null(), ClauseId::null()],
[None, None],
[
InternalSolvableId::from_usize(1211).negative(),
InternalSolvableId::from_usize(42).negative(),
Expand All @@ -723,10 +726,7 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::null(), ClauseId::from_usize(3)]
)
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(3).into()])
}

// Unlink 1
Expand All @@ -740,24 +740,24 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::from_usize(2), ClauseId::null()]
)
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
}
}

#[test]
fn test_unlink_clause_same() {
let clause1 = clause(
[ClauseId::from_usize(2), ClauseId::from_usize(2)],
[
ClauseId::from_usize(2).into(),
ClauseId::from_usize(2).into(),
],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1211).negative(),
],
);
let clause2 = clause(
[ClauseId::null(), ClauseId::null()],
[None, None],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1211).negative(),
Expand All @@ -775,10 +775,7 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::null(), ClauseId::from_usize(2)]
)
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(2).into()])
}

// Unlink 1
Expand All @@ -792,10 +789,7 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::from_usize(2), ClauseId::null()]
)
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
}
}

Expand All @@ -820,7 +814,10 @@ mod test {

// No conflict, still one candidate available
decisions
.try_add_decision(Decision::new(candidate1.into(), false, ClauseId::null()), 1)
.try_add_decision(
Decision::new(candidate1.into(), false, ClauseId::from_usize(0)),
1,
)
.unwrap();
let (clause, conflict, _kind) = ClauseState::requires(
parent,
Expand All @@ -834,7 +831,10 @@ mod test {

// Conflict, no candidates available
decisions
.try_add_decision(Decision::new(candidate2.into(), false, ClauseId::null()), 1)
.try_add_decision(
Decision::new(candidate2.into(), false, ClauseId::install_root()),
1,
)
.unwrap();
let (clause, conflict, _kind) = ClauseState::requires(
parent,
Expand All @@ -848,7 +848,7 @@ mod test {

// Panic
decisions
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
.unwrap();
let panicked = std::panic::catch_unwind(|| {
ClauseState::requires(
Expand Down Expand Up @@ -878,7 +878,7 @@ mod test {

// Conflict, forbidden package installed
decisions
.try_add_decision(Decision::new(forbidden, true, ClauseId::null()), 1)
.try_add_decision(Decision::new(forbidden, true, ClauseId::install_root()), 1)
.unwrap();
let (clause, conflict, _kind) =
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions);
Expand All @@ -888,7 +888,7 @@ mod test {

// Panic
decisions
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
.unwrap();
let panicked = std::panic::catch_unwind(|| {
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions)
Expand Down
20 changes: 8 additions & 12 deletions src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1435,11 +1435,8 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
// solvable
let mut old_predecessor_clause_id: Option<ClauseId>;
let mut predecessor_clause_id: Option<ClauseId> = None;
let mut clause_id = self
.watches
.first_clause_watching_literal(watched_literal)
.unwrap_or(ClauseId::null());
while !clause_id.is_null() {
let mut next_clause_id = self.watches.first_clause_watching_literal(watched_literal);
while let Some(clause_id) = next_clause_id {
debug_assert!(
predecessor_clause_id != Some(clause_id),
"Linked list is circular!"
Expand All @@ -1466,8 +1463,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
predecessor_clause_id = Some(clause_id);

// Configure the next clause to visit
let this_clause_id = clause_id;
clause_id = clause_state.next_watched_clause(watched_literal.solvable_id());
next_clause_id = clause_state.next_watched_clause(watched_literal.solvable_id());

// Determine which watch turned false.
let (watch_index, other_watch_index) = if clause_state.watched_literals[0]
Expand All @@ -1492,7 +1488,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
// If the other watch is already true, we can simply skip
// this clause.
} else if let Some(variable) = clause_state.next_unwatched_literal(
&clauses[this_clause_id.to_usize()],
&clauses[clause_id.to_usize()],
&self.learnt_clauses,
&self.cache.requirement_to_sorted_candidates,
self.decision_tracker.map(),
Expand All @@ -1501,7 +1497,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
self.watches.update_watched(
predecessor_clause_state,
clause_state,
this_clause_id,
clause_id,
watch_index,
watched_literal,
variable,
Expand All @@ -1527,20 +1523,20 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
Decision::new(
remaining_watch.solvable_id(),
remaining_watch.satisfying_value(),
this_clause_id,
clause_id,
),
level,
)
.map_err(|_| {
PropagationError::Conflict(
remaining_watch.solvable_id(),
true,
this_clause_id,
clause_id,
)
})?;

if decided {
let clause = &clauses[this_clause_id.to_usize()];
let clause = &clauses[clause_id.to_usize()];
match clause {
// Skip logging for ForbidMultipleInstances, which is so noisy
Clause::ForbidMultipleInstances(..) => {}
Expand Down
19 changes: 7 additions & 12 deletions src/solver/watch_map.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::solver::clause::Literal;
use crate::{
internal::{id::ClauseId, mapping::Mapping},
solver::clause::ClauseState,
solver::clause::{ClauseState, Literal},
};

/// A map from solvables to the clauses that are watching them
Expand All @@ -20,9 +19,7 @@ impl WatchMap {

pub(crate) fn start_watching(&mut self, clause: &mut ClauseState, clause_id: ClauseId) {
for (watch_index, watched_literal) in clause.watched_literals.into_iter().enumerate() {
let already_watching = self
.first_clause_watching_literal(watched_literal)
.unwrap_or(ClauseId::null());
let already_watching = self.first_clause_watching_literal(watched_literal);
clause.link_to_clause(watch_index, already_watching);
self.watch_literal(watched_literal, clause_id);
}
Expand All @@ -42,18 +39,16 @@ impl WatchMap {
if let Some(predecessor_clause) = predecessor_clause {
// Unlink the clause
predecessor_clause.unlink_clause(clause, previous_watch.solvable_id(), watch_index);
} else {
} else if let Some(next_watch) = clause.next_watches[watch_index] {
// This was the first clause in the chain
self.map
.insert(previous_watch, clause.next_watches[watch_index]);
self.map.insert(previous_watch, next_watch);
} else {
self.map.unset(previous_watch);
}

// Set the new watch
clause.watched_literals[watch_index] = new_watch;
let previous_clause_id = self
.map
.insert(new_watch, clause_id)
.unwrap_or(ClauseId::null());
let previous_clause_id = self.map.insert(new_watch, clause_id);
clause.next_watches[watch_index] = previous_clause_id;
}

Expand Down

0 comments on commit f89c0b2

Please sign in to comment.