From 0586e742e6b94e76af1a7415054591a6050a15e4 Mon Sep 17 00:00:00 2001
From: prsabahrami
Date: Fri, 17 Jan 2025 15:12:31 -0500
Subject: [PATCH] initial commit for conditional dependencies support
---
cpp/include/resolvo.h | 9 ++++
cpp/src/lib.rs | 20 +++++++++
src/conflict.rs | 89 ++++++++++++++++++++++++++++++++++---
src/requirement.rs | 30 ++++++++++---
src/snapshot.rs | 8 ++++
src/solver/cache.rs | 5 +++
src/solver/clause.rs | 100 +++++++++++++++++++++++++++++++++++++++---
src/solver/mod.rs | 2 +-
8 files changed, 244 insertions(+), 19 deletions(-)
diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h
index 97d00f5..c824df4 100644
--- a/cpp/include/resolvo.h
+++ b/cpp/include/resolvo.h
@@ -24,6 +24,15 @@ inline Requirement requirement_union(VersionSetUnionId id) {
return cbindgen_private::resolvo_requirement_union(id);
}
+/**
+ * Specifies a conditional requirement, where the requirement is only active when the condition is met.
+ * @param condition The version set that must be satisfied for the requirement to be active.
+ * @param requirement The version set that must be satisfied when the condition is met.
+ */
+inline Requirement requirement_conditional(VersionSetId condition, VersionSetId requirement) {
+ return cbindgen_private::resolvo_requirement_conditional(condition, requirement);
+}
+
/**
* Called to solve a package problem.
*
diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs
index 781e365..8267d9e 100644
--- a/cpp/src/lib.rs
+++ b/cpp/src/lib.rs
@@ -48,6 +48,11 @@ pub enum Requirement {
/// cbindgen:derive-eq
/// cbindgen:derive-neq
Union(VersionSetUnionId),
+ /// Specifies a conditional requirement, where the requirement is only active when the condition is met.
+ /// First VersionSetId is the condition, second is the requirement.
+ /// cbindgen:derive-eq
+ /// cbindgen:derive-neq
+ ConditionalRequires(VersionSetId, VersionSetId),
}
impl From for crate::Requirement {
@@ -55,6 +60,9 @@ impl From for crate::Requirement {
match value {
resolvo::Requirement::Single(id) => Requirement::Single(id.into()),
resolvo::Requirement::Union(id) => Requirement::Union(id.into()),
+ resolvo::Requirement::ConditionalRequires(condition, requirement) => {
+ Requirement::ConditionalRequires(condition.into(), requirement.into())
+ }
}
}
}
@@ -64,6 +72,9 @@ impl From for resolvo::Requirement {
match value {
Requirement::Single(id) => resolvo::Requirement::Single(id.into()),
Requirement::Union(id) => resolvo::Requirement::Union(id.into()),
+ Requirement::ConditionalRequires(condition, requirement) => {
+ resolvo::Requirement::ConditionalRequires(condition.into(), requirement.into())
+ }
}
}
}
@@ -539,6 +550,15 @@ pub extern "C" fn resolvo_requirement_union(
Requirement::Union(version_set_union_id)
}
+#[no_mangle]
+#[allow(unused)]
+pub extern "C" fn resolvo_requirement_conditional(
+ condition: VersionSetId,
+ requirement: VersionSetId,
+) -> Requirement {
+ Requirement::ConditionalRequires(condition, requirement)
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/src/conflict.rs b/src/conflict.rs
index 3d121b6..48bc77a 100644
--- a/src/conflict.rs
+++ b/src/conflict.rs
@@ -11,14 +11,17 @@ use petgraph::{
Direction,
};
-use crate::solver::variable_map::VariableOrigin;
use crate::{
internal::{
arena::ArenaId,
- id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId},
+ id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VariableId, VersionSetId},
},
runtime::AsyncRuntime,
- solver::{clause::Clause, Solver},
+ solver::{
+ clause::Clause,
+ variable_map::{VariableMap, VariableOrigin},
+ Solver,
+ },
DependencyProvider, Interner, Requirement,
};
@@ -160,6 +163,49 @@ impl Conflict {
ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)),
);
}
+ &Clause::Conditional(package_id, condition, then_version_set_id) => {
+ let solvable = package_id
+ .as_solvable_or_root(&solver.variable_map)
+ .expect("only solvables can be excluded");
+ let package_node = Self::add_node(&mut graph, &mut nodes, solvable);
+
+ let candidates = solver.async_runtime.block_on(solver.cache.get_or_cache_sorted_candidates(then_version_set_id)).unwrap_or_else(|_| {
+ unreachable!("The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`")
+ });
+
+ if candidates.is_empty() {
+ tracing::trace!(
+ "{package_id:?} conditionally requires {then_version_set_id:?}, which has no candidates"
+ );
+ graph.add_edge(
+ package_node,
+ unresolved_node,
+ ConflictEdge::ConditionalRequires(then_version_set_id, condition),
+ );
+ } else {
+ for &candidate_id in candidates {
+ tracing::trace!("{package_id:?} conditionally requires {candidate_id:?}");
+
+ let candidate_node =
+ Self::add_node(&mut graph, &mut nodes, candidate_id.into());
+ graph.add_edge(
+ package_node,
+ candidate_node,
+ ConflictEdge::ConditionalRequires(then_version_set_id, condition),
+ );
+ }
+ }
+
+ // Add an edge for the unsatisfied condition if it exists
+ if let Some(condition_solvable) = condition.as_solvable(&solver.variable_map) {
+ let condition_node = Self::add_node(&mut graph, &mut nodes, condition_solvable.into());
+ graph.add_edge(
+ package_node,
+ condition_node,
+ ConflictEdge::Conflict(ConflictCause::UnsatisfiedCondition(condition)),
+ );
+ }
+ }
}
}
@@ -205,7 +251,7 @@ impl Conflict {
solver: &'a Solver,
) -> DisplayUnsat<'a, D> {
let graph = self.graph(solver);
- DisplayUnsat::new(graph, solver.provider())
+ DisplayUnsat::new(graph, solver.provider(), &solver.variable_map)
}
}
@@ -239,13 +285,15 @@ impl ConflictNode {
}
/// An edge in the graph representation of a [`Conflict`]
-#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
+#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) enum ConflictEdge {
/// The target node is a candidate for the dependency specified by the
/// [`Requirement`]
Requires(Requirement),
/// The target node is involved in a conflict, caused by `ConflictCause`
Conflict(ConflictCause),
+ /// The target node is a candidate for a conditional dependency
+ ConditionalRequires(Requirement, VariableId),
}
impl ConflictEdge {
@@ -253,12 +301,14 @@ impl ConflictEdge {
match self {
ConflictEdge::Requires(match_spec_id) => Some(match_spec_id),
ConflictEdge::Conflict(_) => None,
+ ConflictEdge::ConditionalRequires(match_spec_id, _) => Some(match_spec_id),
}
}
fn requires(self) -> Requirement {
match self {
ConflictEdge::Requires(match_spec_id) => match_spec_id,
+ ConflictEdge::ConditionalRequires(match_spec_id, _) => match_spec_id,
ConflictEdge::Conflict(_) => panic!("expected requires edge, found conflict"),
}
}
@@ -275,6 +325,8 @@ pub(crate) enum ConflictCause {
ForbidMultipleInstances,
/// The node was excluded
Excluded,
+ /// The condition for a conditional dependency was not satisfied
+ UnsatisfiedCondition(VariableId),
}
/// Represents a node that has been merged with others
@@ -307,6 +359,7 @@ impl ConflictGraph {
&self,
f: &mut impl std::io::Write,
interner: &impl Interner,
+ variable_map: &VariableMap,
simplify: bool,
) -> Result<(), std::io::Error> {
let graph = &self.graph;
@@ -356,6 +409,16 @@ impl ConflictGraph {
"already installed".to_string()
}
ConflictEdge::Conflict(ConflictCause::Excluded) => "excluded".to_string(),
+ ConflictEdge::Conflict(ConflictCause::UnsatisfiedCondition(condition)) => {
+ let condition_solvable = condition.as_solvable(variable_map)
+ .expect("condition must be a solvable");
+ format!("unsatisfied condition: {}", condition_solvable.display(interner))
+ }
+ ConflictEdge::ConditionalRequires(requirement, condition) => {
+ let condition_solvable = condition.as_solvable(variable_map)
+ .expect("condition must be a solvable");
+ format!("if {} then {}", condition_solvable.display(interner), requirement.display(interner))
+ }
};
let target = match target {
@@ -494,6 +557,7 @@ impl ConflictGraph {
.edges_directed(nx, Direction::Outgoing)
.map(|e| match e.weight() {
ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()),
+ ConflictEdge::ConditionalRequires(version_set_id, _) => (version_set_id, e.target()),
ConflictEdge::Conflict(_) => unreachable!(),
})
.chunk_by(|(&version_set_id, _)| version_set_id);
@@ -540,6 +604,7 @@ impl ConflictGraph {
.edges_directed(nx, Direction::Outgoing)
.map(|e| match e.weight() {
ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()),
+ ConflictEdge::ConditionalRequires(version_set_id, _) => (version_set_id, e.target()),
ConflictEdge::Conflict(_) => unreachable!(),
})
.chunk_by(|(&version_set_id, _)| version_set_id);
@@ -673,10 +738,11 @@ pub struct DisplayUnsat<'i, I: Interner> {
installable_set: HashSet,
missing_set: HashSet,
interner: &'i I,
+ variable_map: &'i VariableMap,
}
impl<'i, I: Interner> DisplayUnsat<'i, I> {
- pub(crate) fn new(graph: ConflictGraph, interner: &'i I) -> Self {
+ pub(crate) fn new(graph: ConflictGraph, interner: &'i I, variable_map: &'i VariableMap) -> Self {
let merged_candidates = graph.simplify(interner);
let installable_set = graph.get_installable_set();
let missing_set = graph.get_missing_set();
@@ -687,6 +753,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> {
installable_set,
missing_set,
interner,
+ variable_map,
}
}
@@ -1020,6 +1087,7 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> {
let conflict = match e.weight() {
ConflictEdge::Requires(_) => continue,
ConflictEdge::Conflict(conflict) => conflict,
+ ConflictEdge::ConditionalRequires(_, _) => continue,
};
// The only possible conflict at the root level is a Locked conflict
@@ -1045,6 +1113,15 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> {
)?;
}
ConflictCause::Excluded => continue,
+ &ConflictCause::UnsatisfiedCondition(condition) => {
+ let condition_solvable = condition.as_solvable(self.variable_map)
+ .expect("condition must be a solvable");
+ writeln!(
+ f,
+ "{indent}condition {} is not satisfied",
+ condition_solvable.display(self.interner),
+ )?;
+ }
};
}
}
diff --git a/src/requirement.rs b/src/requirement.rs
index 244ec48..c8e876f 100644
--- a/src/requirement.rs
+++ b/src/requirement.rs
@@ -13,6 +13,9 @@ pub enum Requirement {
/// This variant is typically used for requirements that can be satisfied by two or more
/// version sets belonging to _different_ packages.
Union(VersionSetUnionId),
+ /// Specifies a conditional requirement, where the requirement is only active when the condition is met.
+ /// First VersionSetId is the condition, second is the requirement.
+ ConditionalRequires(VersionSetId, VersionSetId),
}
impl Default for Requirement {
@@ -46,12 +49,15 @@ impl Requirement {
&'i self,
interner: &'i impl Interner,
) -> impl Iterator- + 'i {
- match *self {
+ match self {
Requirement::Single(version_set) => {
- itertools::Either::Left(std::iter::once(version_set))
+ itertools::Either::Left(itertools::Either::Left(std::iter::once(*version_set)))
}
Requirement::Union(version_set_union) => {
- itertools::Either::Right(interner.version_sets_in_union(version_set_union))
+ itertools::Either::Left(itertools::Either::Right(interner.version_sets_in_union(*version_set_union)))
+ }
+ Requirement::ConditionalRequires(condition, requirement) => {
+ itertools::Either::Right(std::iter::once(*condition).chain(std::iter::once(*requirement)))
}
}
}
@@ -64,18 +70,18 @@ pub(crate) struct DisplayRequirement<'i, I: Interner> {
impl<'i, I: Interner> Display for DisplayRequirement<'i, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match *self.requirement {
+ match self.requirement {
Requirement::Single(version_set) => write!(
f,
"{} {}",
self.interner
- .display_name(self.interner.version_set_name(version_set)),
- self.interner.display_version_set(version_set)
+ .display_name(self.interner.version_set_name(*version_set)),
+ self.interner.display_version_set(*version_set)
),
Requirement::Union(version_set_union) => {
let formatted_version_sets = self
.interner
- .version_sets_in_union(version_set_union)
+ .version_sets_in_union(*version_set_union)
.format_with(" | ", |version_set, f| {
f(&format_args!(
"{} {}",
@@ -87,6 +93,16 @@ impl<'i, I: Interner> Display for DisplayRequirement<'i, I> {
write!(f, "{}", formatted_version_sets)
}
+ Requirement::ConditionalRequires(condition, requirement) => {
+ write!(
+ f,
+ "if {} then {} {}",
+ self.interner.display_version_set(*condition),
+ self.interner
+ .display_name(self.interner.version_set_name(*requirement)),
+ self.interner.display_version_set(*requirement)
+ )
+ }
}
}
}
diff --git a/src/snapshot.rs b/src/snapshot.rs
index 0b8b6d2..0c4a986 100644
--- a/src/snapshot.rs
+++ b/src/snapshot.rs
@@ -243,6 +243,14 @@ impl DependencySnapshot {
.version_set_unions
.insert(version_set_union_id, version_sets);
}
+ Requirement::ConditionalRequires(condition, requirement) => {
+ if seen.insert(Element::VersionSet(condition)) {
+ queue.push_back(Element::VersionSet(condition));
+ }
+ if seen.insert(Element::VersionSet(requirement)) {
+ queue.push_back(Element::VersionSet(requirement));
+ }
+ }
}
}
}
diff --git a/src/solver/cache.rs b/src/solver/cache.rs
index cf6c6cc..00c8c92 100644
--- a/src/solver/cache.rs
+++ b/src/solver/cache.rs
@@ -280,6 +280,11 @@ impl SolverCache {
}
}
}
+ Requirement::ConditionalRequires(condition, requirement) => {
+ let candidates = self.get_or_cache_sorted_candidates_for_version_set(condition).await?;
+ let sorted_candidates = self.get_or_cache_sorted_candidates_for_version_set(requirement).await?;
+ Ok(sorted_candidates)
+ }
}
}
diff --git a/src/solver/clause.rs b/src/solver/clause.rs
index f034130..4e41c0e 100644
--- a/src/solver/clause.rs
+++ b/src/solver/clause.rs
@@ -11,12 +11,10 @@ use crate::{
internal::{
arena::{Arena, ArenaId},
id::{ClauseId, LearntClauseId, StringId, VersionSetId},
- },
- solver::{
+ }, solver::{
decision_map::DecisionMap, decision_tracker::DecisionTracker, variable_map::VariableMap,
VariableId,
- },
- Interner, NameId, Requirement,
+ }, DependencyProvider, Interner, NameId, Requirement
};
/// Represents a single clause in the SAT problem
@@ -46,7 +44,7 @@ use crate::{
/// limited set of clauses. There are thousands of clauses for a particular
/// dependency resolution problem, and we try to keep the [`Clause`] enum small.
/// A naive implementation would store a `Vec`.
-#[derive(Copy, Clone, Debug)]
+#[derive(Clone, Copy, Debug)]
pub(crate) enum Clause {
/// An assertion that the root solvable must be installed
///
@@ -77,6 +75,10 @@ pub(crate) enum Clause {
///
/// In SAT terms: (¬A ∨ ¬B)
Constrains(VariableId, VariableId, VersionSetId),
+ /// In SAT terms: (¬A ∨ ¬C ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable,
+ /// C is the condition, and B1 to B99 represent the possible candidates for
+ /// the provided [`Requirement`].
+ Conditional(VariableId, VariableId, Requirement),
/// Forbids the package on the right-hand side
///
/// Note that the package on the left-hand side is not part of the clause,
@@ -230,6 +232,45 @@ impl Clause {
)
}
+ fn conditional_impl(
+ package_id: VariableId,
+ condition: VariableId,
+ then: Requirement,
+ candidates: impl IntoIterator
- ,
+ decision_tracker: &DecisionTracker,
+ ) -> (Self, Option<[Literal; 2]>, bool) {
+ // It only makes sense to introduce a conditional clause when the package is undecided or going to be installed
+ assert_ne!(decision_tracker.assigned_value(package_id), Some(false));
+ assert_ne!(decision_tracker.assigned_value(condition), Some(false));
+
+ let kind = Clause::Conditional(package_id, condition, then);
+ let mut candidates = candidates.into_iter().peekable();
+ let first_candidate = candidates.peek().copied();
+ if let Some(first_candidate) = first_candidate {
+ match candidates.find(|&c| decision_tracker.assigned_value(c) != Some(false)) {
+ // Watch any candidate that is not assigned to false
+ Some(watched_candidate) => (
+ kind,
+ Some([package_id.negative(), watched_candidate.positive()]),
+ false,
+ ),
+
+ // All candidates are assigned to false! Therefore, the clause conflicts with the
+ // current decisions. There are no valid watches for it at the moment, but we will
+ // assign default ones nevertheless, because they will become valid after the solver
+ // restarts.
+ None => (
+ kind,
+ Some([package_id.negative(), first_candidate.positive()]),
+ true,
+ ),
+ }
+ } else {
+ // If there are no candidates there is no need to watch anything.
+ (kind, None, false)
+ }
+ }
+
/// Tries to fold over all the literals in the clause.
///
/// This function is useful to iterate, find, or filter the literals in a
@@ -272,6 +313,17 @@ impl Clause {
Clause::Lock(_, s) => [s.negative(), VariableId::root().negative()]
.into_iter()
.try_fold(init, visit),
+ Clause::Conditional(package_id, condition, then) => {
+ [package_id.negative(), condition.negative()]
+ .into_iter()
+ .chain(
+ requirements_to_sorted_candidates[&then]
+ .iter()
+ .flatten()
+ .map(|&s| s.positive()),
+ )
+ .try_fold(init, visit)
+ }
}
}
@@ -419,6 +471,33 @@ impl WatchedLiterals {
(Self::from_kind_and_initial_watches(watched_literals), kind)
}
+ /// Shorthand method to construct a [Clause::Conditional] without requiring
+ /// complicated arguments.
+ ///
+ /// The returned boolean value is true when adding the clause resulted in a
+ /// conflict.
+ pub fn conditional(
+ package_id: VariableId,
+ condition: VariableId,
+ then: Requirement,
+ candidates: impl IntoIterator
- ,
+ decision_tracker: &DecisionTracker,
+ ) -> (Option, bool, Clause) {
+ let (kind, watched_literals, conflict) = Clause::conditional_impl(
+ package_id,
+ condition,
+ then,
+ candidates,
+ decision_tracker,
+ );
+
+ (
+ WatchedLiterals::from_kind_and_initial_watches(watched_literals),
+ conflict,
+ kind,
+ )
+ }
+
fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Option {
let watched_literals = watched_literals?;
debug_assert!(watched_literals[0] != watched_literals[1]);
@@ -611,6 +690,17 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> {
other,
)
}
+ Clause::Conditional(package_id, condition, then) => {
+ write!(
+ f,
+ "Conditional({}({:?}), {}({:?}), {})",
+ package_id.display(self.variable_map, self.interner),
+ package_id,
+ condition.display(self.variable_map, self.interner),
+ condition,
+ then.display(self.interner)
+ )
+ }
}
}
}
diff --git a/src/solver/mod.rs b/src/solver/mod.rs
index 8c0e026..cf3adf0 100644
--- a/src/solver/mod.rs
+++ b/src/solver/mod.rs
@@ -695,7 +695,7 @@ impl Solver {
fn resolve_dependencies(&mut self, mut level: u32) -> Result {
loop {
// Make a decision. If no decision could be made it means the problem is
- // satisfyable.
+ // satisfiable.
let Some((candidate, required_by, clause_id)) = self.decide() else {
break;
};