From 0586e742e6b94e76af1a7415054591a6050a15e4 Mon Sep 17 00:00:00 2001 From: prsabahrami Date: Fri, 17 Jan 2025 15:12:31 -0500 Subject: [PATCH] initial commit for conditional dependencies support --- cpp/include/resolvo.h | 9 ++++ cpp/src/lib.rs | 20 +++++++++ src/conflict.rs | 89 ++++++++++++++++++++++++++++++++++--- src/requirement.rs | 30 ++++++++++--- src/snapshot.rs | 8 ++++ src/solver/cache.rs | 5 +++ src/solver/clause.rs | 100 +++++++++++++++++++++++++++++++++++++++--- src/solver/mod.rs | 2 +- 8 files changed, 244 insertions(+), 19 deletions(-) diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h index 97d00f5..c824df4 100644 --- a/cpp/include/resolvo.h +++ b/cpp/include/resolvo.h @@ -24,6 +24,15 @@ inline Requirement requirement_union(VersionSetUnionId id) { return cbindgen_private::resolvo_requirement_union(id); } +/** + * Specifies a conditional requirement, where the requirement is only active when the condition is met. + * @param condition The version set that must be satisfied for the requirement to be active. + * @param requirement The version set that must be satisfied when the condition is met. + */ +inline Requirement requirement_conditional(VersionSetId condition, VersionSetId requirement) { + return cbindgen_private::resolvo_requirement_conditional(condition, requirement); +} + /** * Called to solve a package problem. * diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index 781e365..8267d9e 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -48,6 +48,11 @@ pub enum Requirement { /// cbindgen:derive-eq /// cbindgen:derive-neq Union(VersionSetUnionId), + /// Specifies a conditional requirement, where the requirement is only active when the condition is met. + /// First VersionSetId is the condition, second is the requirement. + /// cbindgen:derive-eq + /// cbindgen:derive-neq + ConditionalRequires(VersionSetId, VersionSetId), } impl From for crate::Requirement { @@ -55,6 +60,9 @@ impl From for crate::Requirement { match value { resolvo::Requirement::Single(id) => Requirement::Single(id.into()), resolvo::Requirement::Union(id) => Requirement::Union(id.into()), + resolvo::Requirement::ConditionalRequires(condition, requirement) => { + Requirement::ConditionalRequires(condition.into(), requirement.into()) + } } } } @@ -64,6 +72,9 @@ impl From for resolvo::Requirement { match value { Requirement::Single(id) => resolvo::Requirement::Single(id.into()), Requirement::Union(id) => resolvo::Requirement::Union(id.into()), + Requirement::ConditionalRequires(condition, requirement) => { + resolvo::Requirement::ConditionalRequires(condition.into(), requirement.into()) + } } } } @@ -539,6 +550,15 @@ pub extern "C" fn resolvo_requirement_union( Requirement::Union(version_set_union_id) } +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_requirement_conditional( + condition: VersionSetId, + requirement: VersionSetId, +) -> Requirement { + Requirement::ConditionalRequires(condition, requirement) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/conflict.rs b/src/conflict.rs index 3d121b6..48bc77a 100644 --- a/src/conflict.rs +++ b/src/conflict.rs @@ -11,14 +11,17 @@ use petgraph::{ Direction, }; -use crate::solver::variable_map::VariableOrigin; use crate::{ internal::{ arena::ArenaId, - id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId}, + id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VariableId, VersionSetId}, }, runtime::AsyncRuntime, - solver::{clause::Clause, Solver}, + solver::{ + clause::Clause, + variable_map::{VariableMap, VariableOrigin}, + Solver, + }, DependencyProvider, Interner, Requirement, }; @@ -160,6 +163,49 @@ impl Conflict { ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)), ); } + &Clause::Conditional(package_id, condition, then_version_set_id) => { + let solvable = package_id + .as_solvable_or_root(&solver.variable_map) + .expect("only solvables can be excluded"); + let package_node = Self::add_node(&mut graph, &mut nodes, solvable); + + let candidates = solver.async_runtime.block_on(solver.cache.get_or_cache_sorted_candidates(then_version_set_id)).unwrap_or_else(|_| { + unreachable!("The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`") + }); + + if candidates.is_empty() { + tracing::trace!( + "{package_id:?} conditionally requires {then_version_set_id:?}, which has no candidates" + ); + graph.add_edge( + package_node, + unresolved_node, + ConflictEdge::ConditionalRequires(then_version_set_id, condition), + ); + } else { + for &candidate_id in candidates { + tracing::trace!("{package_id:?} conditionally requires {candidate_id:?}"); + + let candidate_node = + Self::add_node(&mut graph, &mut nodes, candidate_id.into()); + graph.add_edge( + package_node, + candidate_node, + ConflictEdge::ConditionalRequires(then_version_set_id, condition), + ); + } + } + + // Add an edge for the unsatisfied condition if it exists + if let Some(condition_solvable) = condition.as_solvable(&solver.variable_map) { + let condition_node = Self::add_node(&mut graph, &mut nodes, condition_solvable.into()); + graph.add_edge( + package_node, + condition_node, + ConflictEdge::Conflict(ConflictCause::UnsatisfiedCondition(condition)), + ); + } + } } } @@ -205,7 +251,7 @@ impl Conflict { solver: &'a Solver, ) -> DisplayUnsat<'a, D> { let graph = self.graph(solver); - DisplayUnsat::new(graph, solver.provider()) + DisplayUnsat::new(graph, solver.provider(), &solver.variable_map) } } @@ -239,13 +285,15 @@ impl ConflictNode { } /// An edge in the graph representation of a [`Conflict`] -#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)] pub(crate) enum ConflictEdge { /// The target node is a candidate for the dependency specified by the /// [`Requirement`] Requires(Requirement), /// The target node is involved in a conflict, caused by `ConflictCause` Conflict(ConflictCause), + /// The target node is a candidate for a conditional dependency + ConditionalRequires(Requirement, VariableId), } impl ConflictEdge { @@ -253,12 +301,14 @@ impl ConflictEdge { match self { ConflictEdge::Requires(match_spec_id) => Some(match_spec_id), ConflictEdge::Conflict(_) => None, + ConflictEdge::ConditionalRequires(match_spec_id, _) => Some(match_spec_id), } } fn requires(self) -> Requirement { match self { ConflictEdge::Requires(match_spec_id) => match_spec_id, + ConflictEdge::ConditionalRequires(match_spec_id, _) => match_spec_id, ConflictEdge::Conflict(_) => panic!("expected requires edge, found conflict"), } } @@ -275,6 +325,8 @@ pub(crate) enum ConflictCause { ForbidMultipleInstances, /// The node was excluded Excluded, + /// The condition for a conditional dependency was not satisfied + UnsatisfiedCondition(VariableId), } /// Represents a node that has been merged with others @@ -307,6 +359,7 @@ impl ConflictGraph { &self, f: &mut impl std::io::Write, interner: &impl Interner, + variable_map: &VariableMap, simplify: bool, ) -> Result<(), std::io::Error> { let graph = &self.graph; @@ -356,6 +409,16 @@ impl ConflictGraph { "already installed".to_string() } ConflictEdge::Conflict(ConflictCause::Excluded) => "excluded".to_string(), + ConflictEdge::Conflict(ConflictCause::UnsatisfiedCondition(condition)) => { + let condition_solvable = condition.as_solvable(variable_map) + .expect("condition must be a solvable"); + format!("unsatisfied condition: {}", condition_solvable.display(interner)) + } + ConflictEdge::ConditionalRequires(requirement, condition) => { + let condition_solvable = condition.as_solvable(variable_map) + .expect("condition must be a solvable"); + format!("if {} then {}", condition_solvable.display(interner), requirement.display(interner)) + } }; let target = match target { @@ -494,6 +557,7 @@ impl ConflictGraph { .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::ConditionalRequires(version_set_id, _) => (version_set_id, e.target()), ConflictEdge::Conflict(_) => unreachable!(), }) .chunk_by(|(&version_set_id, _)| version_set_id); @@ -540,6 +604,7 @@ impl ConflictGraph { .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::ConditionalRequires(version_set_id, _) => (version_set_id, e.target()), ConflictEdge::Conflict(_) => unreachable!(), }) .chunk_by(|(&version_set_id, _)| version_set_id); @@ -673,10 +738,11 @@ pub struct DisplayUnsat<'i, I: Interner> { installable_set: HashSet, missing_set: HashSet, interner: &'i I, + variable_map: &'i VariableMap, } impl<'i, I: Interner> DisplayUnsat<'i, I> { - pub(crate) fn new(graph: ConflictGraph, interner: &'i I) -> Self { + pub(crate) fn new(graph: ConflictGraph, interner: &'i I, variable_map: &'i VariableMap) -> Self { let merged_candidates = graph.simplify(interner); let installable_set = graph.get_installable_set(); let missing_set = graph.get_missing_set(); @@ -687,6 +753,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { installable_set, missing_set, interner, + variable_map, } } @@ -1020,6 +1087,7 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { let conflict = match e.weight() { ConflictEdge::Requires(_) => continue, ConflictEdge::Conflict(conflict) => conflict, + ConflictEdge::ConditionalRequires(_, _) => continue, }; // The only possible conflict at the root level is a Locked conflict @@ -1045,6 +1113,15 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { )?; } ConflictCause::Excluded => continue, + &ConflictCause::UnsatisfiedCondition(condition) => { + let condition_solvable = condition.as_solvable(self.variable_map) + .expect("condition must be a solvable"); + writeln!( + f, + "{indent}condition {} is not satisfied", + condition_solvable.display(self.interner), + )?; + } }; } } diff --git a/src/requirement.rs b/src/requirement.rs index 244ec48..c8e876f 100644 --- a/src/requirement.rs +++ b/src/requirement.rs @@ -13,6 +13,9 @@ pub enum Requirement { /// This variant is typically used for requirements that can be satisfied by two or more /// version sets belonging to _different_ packages. Union(VersionSetUnionId), + /// Specifies a conditional requirement, where the requirement is only active when the condition is met. + /// First VersionSetId is the condition, second is the requirement. + ConditionalRequires(VersionSetId, VersionSetId), } impl Default for Requirement { @@ -46,12 +49,15 @@ impl Requirement { &'i self, interner: &'i impl Interner, ) -> impl Iterator + 'i { - match *self { + match self { Requirement::Single(version_set) => { - itertools::Either::Left(std::iter::once(version_set)) + itertools::Either::Left(itertools::Either::Left(std::iter::once(*version_set))) } Requirement::Union(version_set_union) => { - itertools::Either::Right(interner.version_sets_in_union(version_set_union)) + itertools::Either::Left(itertools::Either::Right(interner.version_sets_in_union(*version_set_union))) + } + Requirement::ConditionalRequires(condition, requirement) => { + itertools::Either::Right(std::iter::once(*condition).chain(std::iter::once(*requirement))) } } } @@ -64,18 +70,18 @@ pub(crate) struct DisplayRequirement<'i, I: Interner> { impl<'i, I: Interner> Display for DisplayRequirement<'i, I> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match *self.requirement { + match self.requirement { Requirement::Single(version_set) => write!( f, "{} {}", self.interner - .display_name(self.interner.version_set_name(version_set)), - self.interner.display_version_set(version_set) + .display_name(self.interner.version_set_name(*version_set)), + self.interner.display_version_set(*version_set) ), Requirement::Union(version_set_union) => { let formatted_version_sets = self .interner - .version_sets_in_union(version_set_union) + .version_sets_in_union(*version_set_union) .format_with(" | ", |version_set, f| { f(&format_args!( "{} {}", @@ -87,6 +93,16 @@ impl<'i, I: Interner> Display for DisplayRequirement<'i, I> { write!(f, "{}", formatted_version_sets) } + Requirement::ConditionalRequires(condition, requirement) => { + write!( + f, + "if {} then {} {}", + self.interner.display_version_set(*condition), + self.interner + .display_name(self.interner.version_set_name(*requirement)), + self.interner.display_version_set(*requirement) + ) + } } } } diff --git a/src/snapshot.rs b/src/snapshot.rs index 0b8b6d2..0c4a986 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -243,6 +243,14 @@ impl DependencySnapshot { .version_set_unions .insert(version_set_union_id, version_sets); } + Requirement::ConditionalRequires(condition, requirement) => { + if seen.insert(Element::VersionSet(condition)) { + queue.push_back(Element::VersionSet(condition)); + } + if seen.insert(Element::VersionSet(requirement)) { + queue.push_back(Element::VersionSet(requirement)); + } + } } } } diff --git a/src/solver/cache.rs b/src/solver/cache.rs index cf6c6cc..00c8c92 100644 --- a/src/solver/cache.rs +++ b/src/solver/cache.rs @@ -280,6 +280,11 @@ impl SolverCache { } } } + Requirement::ConditionalRequires(condition, requirement) => { + let candidates = self.get_or_cache_sorted_candidates_for_version_set(condition).await?; + let sorted_candidates = self.get_or_cache_sorted_candidates_for_version_set(requirement).await?; + Ok(sorted_candidates) + } } } diff --git a/src/solver/clause.rs b/src/solver/clause.rs index f034130..4e41c0e 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -11,12 +11,10 @@ use crate::{ internal::{ arena::{Arena, ArenaId}, id::{ClauseId, LearntClauseId, StringId, VersionSetId}, - }, - solver::{ + }, solver::{ decision_map::DecisionMap, decision_tracker::DecisionTracker, variable_map::VariableMap, VariableId, - }, - Interner, NameId, Requirement, + }, DependencyProvider, Interner, NameId, Requirement }; /// Represents a single clause in the SAT problem @@ -46,7 +44,7 @@ use crate::{ /// limited set of clauses. There are thousands of clauses for a particular /// dependency resolution problem, and we try to keep the [`Clause`] enum small. /// A naive implementation would store a `Vec`. -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Copy, Debug)] pub(crate) enum Clause { /// An assertion that the root solvable must be installed /// @@ -77,6 +75,10 @@ pub(crate) enum Clause { /// /// In SAT terms: (¬A ∨ ¬B) Constrains(VariableId, VariableId, VersionSetId), + /// In SAT terms: (¬A ∨ ¬C ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// C is the condition, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + Conditional(VariableId, VariableId, Requirement), /// Forbids the package on the right-hand side /// /// Note that the package on the left-hand side is not part of the clause, @@ -230,6 +232,45 @@ impl Clause { ) } + fn conditional_impl( + package_id: VariableId, + condition: VariableId, + then: Requirement, + candidates: impl IntoIterator, + decision_tracker: &DecisionTracker, + ) -> (Self, Option<[Literal; 2]>, bool) { + // It only makes sense to introduce a conditional clause when the package is undecided or going to be installed + assert_ne!(decision_tracker.assigned_value(package_id), Some(false)); + assert_ne!(decision_tracker.assigned_value(condition), Some(false)); + + let kind = Clause::Conditional(package_id, condition, then); + let mut candidates = candidates.into_iter().peekable(); + let first_candidate = candidates.peek().copied(); + if let Some(first_candidate) = first_candidate { + match candidates.find(|&c| decision_tracker.assigned_value(c) != Some(false)) { + // Watch any candidate that is not assigned to false + Some(watched_candidate) => ( + kind, + Some([package_id.negative(), watched_candidate.positive()]), + false, + ), + + // All candidates are assigned to false! Therefore, the clause conflicts with the + // current decisions. There are no valid watches for it at the moment, but we will + // assign default ones nevertheless, because they will become valid after the solver + // restarts. + None => ( + kind, + Some([package_id.negative(), first_candidate.positive()]), + true, + ), + } + } else { + // If there are no candidates there is no need to watch anything. + (kind, None, false) + } + } + /// Tries to fold over all the literals in the clause. /// /// This function is useful to iterate, find, or filter the literals in a @@ -272,6 +313,17 @@ impl Clause { Clause::Lock(_, s) => [s.negative(), VariableId::root().negative()] .into_iter() .try_fold(init, visit), + Clause::Conditional(package_id, condition, then) => { + [package_id.negative(), condition.negative()] + .into_iter() + .chain( + requirements_to_sorted_candidates[&then] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit) + } } } @@ -419,6 +471,33 @@ impl WatchedLiterals { (Self::from_kind_and_initial_watches(watched_literals), kind) } + /// Shorthand method to construct a [Clause::Conditional] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + pub fn conditional( + package_id: VariableId, + condition: VariableId, + then: Requirement, + candidates: impl IntoIterator, + decision_tracker: &DecisionTracker, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::conditional_impl( + package_id, + condition, + then, + candidates, + decision_tracker, + ); + + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Option { let watched_literals = watched_literals?; debug_assert!(watched_literals[0] != watched_literals[1]); @@ -611,6 +690,17 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { other, ) } + Clause::Conditional(package_id, condition, then) => { + write!( + f, + "Conditional({}({:?}), {}({:?}), {})", + package_id.display(self.variable_map, self.interner), + package_id, + condition.display(self.variable_map, self.interner), + condition, + then.display(self.interner) + ) + } } } } diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 8c0e026..cf3adf0 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -695,7 +695,7 @@ impl Solver { fn resolve_dependencies(&mut self, mut level: u32) -> Result { loop { // Make a decision. If no decision could be made it means the problem is - // satisfyable. + // satisfiable. let Some((candidate, required_by, clause_id)) = self.decide() else { break; };