Skip to content

Commit

Permalink
Create consistency constraints member variable
Browse files Browse the repository at this point in the history
  • Loading branch information
duembgen committed Dec 9, 2024
1 parent a8db5cf commit 8048f56
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 49 deletions.
6 changes: 3 additions & 3 deletions _test/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def plot_admm():
A_dicts.append(A_dict)
plt.close("all")

eq_list = problem.get_consistency_constraints()
assert len(eq_list) == 2 * 6
problem.consistency_constraints()
assert len(problem.Es) == 2 * 6
counter = {(1, 0): 0, (2, 1): 0}
plots = {(1, 0): plt.subplots(2, 6)[1], (2, 1): plt.subplots(2, 6)[1]}
for k, l, Ak, Al in eq_list:
for k, l, Ak, Al in problem.Es:
# Ak.matshow(ax=plots[(k, l)][0, counter[(k, l)]])
# Al.matshow(ax=plots[(k, l)][1, counter[(k, l)]])
plots[(k, l)][0, counter[(k, l)]].matshow(Ak.toarray())
Expand Down
15 changes: 8 additions & 7 deletions _test/test_homqcqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import unittest

import numpy as np
from poly_matrix import PolyMatrix

from cert_tools import HomQCQP
from cert_tools.hom_qcqp import greedy_cover
from cert_tools.linalg_tools import svec
from cert_tools.sdp_solvers import solve_sdp_homqcqp
from cert_tools.sparse_solvers import solve_clarabel, solve_dsdp
from cert_tools.test_tools import get_chain_rot_prob, get_loop_rot_prob

from poly_matrix import PolyMatrix

# import pytest


Expand Down Expand Up @@ -151,12 +151,12 @@ def test_consistency_constraints(self):
nvars = 5
problem = get_chain_rot_prob(N=nvars)
problem.clique_decomposition() # Run clique decomposition
eq_list = problem.get_consistency_constraints()
problem.consistency_constraints(constrain_only_h_row=False)

# check the number of constraints generated
clq_dim = 10 # homogenizing var plus rotation
n_cons_per_sep = round(clq_dim * (clq_dim + 1) / 2)
assert len(eq_list) == (nvars - 2) * n_cons_per_sep, ValueError(
assert len(problem.Es) == (nvars - 2) * n_cons_per_sep, ValueError(
"Wrong number of equality constraints"
)

Expand Down Expand Up @@ -215,7 +215,7 @@ def test_decompose_matrix(self):
# problem.decompose_matrix(A)

# functionality
for method in ["split", "first", "greedy-cover"]:
for method in ["split", "greedy-cover"]:
C_d = problem.decompose_matrix(C, method=method)
assert len(C_d.keys()) == nvars - 1, ValueError(
f"{method} Method: Wrong number of cliques in decomposed matrix"
Expand Down Expand Up @@ -250,6 +250,7 @@ def test_solve_primal_dsdp(self, rank1=False):
problem = get_chain_rot_prob(N=nvars, locked_pose=locked_pose)
problem.clique_decomposition() # get cliques
# Solve decomposed problem (Interior Point Version)
problem.consistency_constraints(constrain_only_h_row=False)
c_list, info = solve_dsdp(problem, verbose=True, tol=1e-8) # check solutions

# Solve non-decomposed problem
Expand Down Expand Up @@ -402,7 +403,7 @@ def test_solve_dual_dsdp(self, rank1=False):
# test.test_consistency_constraints()
# test.test_greedy_cover()
# test.test_decompose_matrix()
# test.test_solve_primal_dsdp()
test.test_solve_primal_dsdp()
# test.test_solve_dual_dsdp()
test.test_standard_form()
# test.test_standard_form()
# test.test_clarabel()
5 changes: 2 additions & 3 deletions _test/test_sdp_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@

import mosek
import numpy as np

from cert_tools import (
solve_sdp_mosek,
solve_low_rank_sdp,
solve_sdp_fusion,
solve_sdp_cvxpy,
solve_sdp_fusion,
solve_sdp_mosek,
)

root_dir = os.path.abspath(os.path.dirname(__file__) + "/../")
Expand Down
17 changes: 7 additions & 10 deletions cert_tools/admm_clique.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import cvxpy as cp
import numpy as np
import scipy.sparse as sp

from cert_tools.base_clique import BaseClique
from cert_tools.hom_qcqp import HomQCQP
from cert_tools.sdp_solvers import adjust_Q
Expand Down Expand Up @@ -81,19 +80,17 @@ def create_admm_cliques_from_problem(problem: HomQCQP, variable=["x_", "z_"]):
problem.var_sizes, fixed=["h"], variable=variable
)
problem.clique_decomposition(clique_data=clique_data)

eq_list = problem.get_consistency_constraints()
problem.consistency_constraints()

Q_dict = problem.decompose_matrix(problem.C, method="split")
A_dict_list = [problem.decompose_matrix(A, method="first") for A in problem.As]
A_dict_list = [(problem.assign_matrix(A), A) for A in problem.As]
# A_dict_list = [problem.decompose_matrix(A, method="first") for A in problem.As]
admm_cliques = []
for clique in problem.cliques:
Constraints = [problem.get_homog_constraint(clique.var_sizes)]
for A_dict in A_dict_list:
if clique.index in A_dict.keys():
Constraints.append(
(A_dict[clique.index].get_matrix(clique.var_sizes), 0.0)
)
for idx, A in A_dict_list:
if clique.index in idx:
Constraints.append((A.get_matrix(clique.var_sizes), 0.0))
admm_clique = ADMMClique(
Q=Q_dict[clique.index].get_matrix(clique.var_sizes),
Constraints=Constraints,
Expand All @@ -107,7 +104,7 @@ def create_admm_cliques_from_problem(problem: HomQCQP, variable=["x_", "z_"]):
# F @ vech(Xk) + G @ vech(Xl) = 0
F_dict = dict()
G_dict = dict()
for k, l, Ak, Al in eq_list:
for k, l, Ak, Al in problem.Es:
if k == clique.index:
if l in F_dict:
# TODO(FD) we currently need to use the full matrix and not just the upper half
Expand Down
7 changes: 5 additions & 2 deletions cert_tools/admm_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from multiprocessing import Pipe, Process

import cvxpy as cp
import matplotlib.pylab as plt # for debugging
import mosek.fusion as fu
import numpy as np

from cert_tools.admm_clique import ADMMClique, update_rho
from cert_tools.fusion_tools import mat_fusion, read_costs_from_mosek
from cert_tools.sdp_solvers import (
Expand Down Expand Up @@ -229,6 +229,9 @@ def solve_inner_sdp_fusion(
"cost": cost,
"msg": f"solved with status {M.getProblemStatus()}",
}
else:
raise ValueError(f"Problem status is: {M.getProblemStatus()}")

return X, info


Expand All @@ -251,7 +254,7 @@ def solve_inner_sdp(

if use_fusion:
# for debugging only
err = clique.F @ clique.X.flatten() + clique.g
# err = clique.F @ clique.X.flatten() + clique.g
# print(f"current error: {err}")

return solve_inner_sdp_fusion(
Expand Down
15 changes: 7 additions & 8 deletions cert_tools/hom_qcqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, homog_var="h"):
self.dim = 0 # total size of variables
self.C = None # cost matrix
self.As = None # list of constraints
self.Es = None # list of consistency constraints
self.asg = Graph() # Aggregate sparsity graph
self.cliques = [] # List of clique objects
self.order = [] # Elimination ordering
Expand Down Expand Up @@ -391,7 +392,7 @@ def get_standard_form(self, vec_order="C"):
b = svec(C.toarray(), vec_order)
return P, q, A, b

def get_consistency_constraints(self):
def consistency_constraints(self, constrain_only_h_row=CONSTRAIN_ONLY_H_ROW):
"""Return a list of constraints that enforce equalities between
clique variables. List consist of 4-tuples: (k, l, A_k, A_l)
where k and l are the indices of the cliques for which the equality is
Expand All @@ -403,7 +404,7 @@ def get_consistency_constraints(self):
PolyMatrix module.
"""
# Lopp through edges in the junction tree
eq_list = []
self.Es = []
for l, clique_l in enumerate(self.cliques):
# Get parent clique object and separator set
k = clique_l.parent
Expand All @@ -418,7 +419,7 @@ def get_consistency_constraints(self):
size_l = clique_l.size

# Define constraint matrices only in one row
if CONSTRAIN_ONLY_H_ROW:
if constrain_only_h_row:
assert "h" in sepset

hom_k = int(clique_k._get_indices(var_list="h"))
Expand All @@ -442,7 +443,7 @@ def get_consistency_constraints(self):
(-vals, (rows_l, cols_l)),
(size_l, size_l),
)
eq_list.append((k, l, A_k, A_l))
self.Es.append((k, l, A_k, A_l))

A_k = sp.coo_matrix(
([1.0], ([hom_k], [hom_k])),
Expand All @@ -452,7 +453,7 @@ def get_consistency_constraints(self):
([-1.0], ([hom_l], [hom_l])),
(size_l, size_l),
)
eq_list.append((k, l, A_k, A_l))
self.Es.append((k, l, A_k, A_l))
continue

# Define sparse constraint matrices for each element in the seperator overlap
Expand Down Expand Up @@ -480,9 +481,7 @@ def get_consistency_constraints(self):
(vals_l, (rows_l, cols_l)),
(size_l, size_l),
)
eq_list.append((k, l, A_k, A_l))

return eq_list
self.Es.append((k, l, A_k, A_l))

def assign_matrix(self, pmat: PolyMatrix):
"""Assign a matrix to the clique that it corresponds to.
Expand Down
31 changes: 15 additions & 16 deletions cert_tools/sparse_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

from poly_matrix import PolyMatrix

CONSTRAIN_ALL_OVERLAP = False

TOL = 1e-5


Expand Down Expand Up @@ -428,15 +426,17 @@ def get_decomp_fusion_expr(pmat_in, decomp_method="split"):
# CLIQUE CONSISTENCY EQUALITIES
if verbose:
print("Generating overlap consistency constraints")
clq_constrs = problem.get_consistency_constraints()

if problem.Es is None:
problem.consistency_constraints()
# TEST reduce number of clique
if reduce_constrs is not None:
n_constrs = int(reduce_constrs * len(clq_constrs))
clq_constrs = random.sample(clq_constrs, n_constrs)
n_constrs = int(reduce_constrs * len(problem.Es))
clq_constrs = random.sample(problem.Es, n_constrs)
if verbose:
print("Adding overlap consistency constraints to problem")
cnt = 0
for k, l, A_k, A_l in clq_constrs:
for k, l, A_k, A_l in problem.Es:
# Convert sparse array to fusion sparse matrix
A_k_fusion = sparse_to_fusion(A_k)
A_l_fusion = sparse_to_fusion(A_l)
Expand Down Expand Up @@ -484,16 +484,15 @@ def get_decomp_fusion_expr(pmat_in, decomp_method="split"):

# EXTRACT SOLN
status = M.getProblemStatus()
if status == fu.ProblemStatus.PrimalAndDualFeasible:
# Get MOSEK cost
cost = M.primalObjValue()
clq_list = [cvar.level().reshape(cvar.shape) for cvar in cvars]
dual = [cvar.dual().reshape(cvar.shape) for cvar in cvars]
info["success"] = True
info["dual"] = dual
info["cost"] = cost
else:
print("Solve Failed - Mosek Status: " + str(status))
if status != fu.ProblemStatus.PrimalAndDualFeasible:
print("Warning: solve failed -- mosek status: " + str(status))

cost = M.primalObjValue()
clq_list = [cvar.level().reshape(cvar.shape) for cvar in cvars]
dual = [cvar.dual().reshape(cvar.shape) for cvar in cvars]
info["success"] = True
info["dual"] = dual
info["cost"] = cost

return clq_list, info

Expand Down

0 comments on commit 8048f56

Please sign in to comment.