Skip to content

Commit

Permalink
Merge branch 'WIP-draggable_rules' into dev
Browse files Browse the repository at this point in the history
Merge draggable-rules into dev
  • Loading branch information
stephanzwicknagl committed Feb 2, 2024
2 parents a183d3d + edc5eed commit 034e620
Show file tree
Hide file tree
Showing 83 changed files with 3,447 additions and 1,805 deletions.
3 changes: 1 addition & 2 deletions backend/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ packages = find:
python_requires = >=3.7
install_requires =
networkx>=2.4
flask==2.2.0
Werkzeug==2.2.2
flask>=2.2.0
clingo>=5.6.0
flask-cors>=3.0
requests>=2.26.0
Expand Down
90 changes: 53 additions & 37 deletions backend/src/viasp/asp/justify.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
"""This module is concerned with finding reasons for why a stable model is found."""
from collections import defaultdict
from typing import List, Collection, Dict, Iterable, Union
from typing import List, Collection, Dict, Iterable, Union, cast

import networkx as nx

from clingo import Control, Symbol, Model, ast

from clingo.ast import AST, Function
from clingo.ast import AST
from networkx import DiGraph

from .reify import ProgramAnalyzer, has_an_interval
from .recursion import RecursionReasoner
from .utils import insert_atoms_into_nodes, identify_reasons, calculate_spacing_factor
from .utils import insert_atoms_into_nodes, identify_reasons, harmonize_uuids, calculate_spacing_factor
from ..shared.model import Node, Transformation, SymbolIdentifier
from ..shared.simple_logging import info, warn
from ..shared.util import pairwise, get_leafs_from_graph


def stringify_fact(fact: Function) -> str:
def stringify_fact(fact: Symbol) -> str:
return f"{str(fact)}."


def get_h_symbols_from_model(wrapped_stable_model: Iterable[Symbol],
def get_h_symbols_from_model(wrapped_stable_model: Iterable[str],
transformed_prg: Collection[Union[str, AST]],
facts: List[Symbol],
constants: List[Symbol],
Expand Down Expand Up @@ -60,39 +60,49 @@ def get_facts(original_program) -> Collection[Symbol]:


def collect_h_symbols_and_create_nodes(h_symbols: Collection[Symbol], relevant_indices, pad: bool, supernode_symbols: frozenset = frozenset([])) -> List[Node]:
tmp_symbol: Dict[int, List[SymbolIdentifier]] = defaultdict(list)
tmp_reason: Dict[int, Dict[Symbol, List[Symbol]]] = defaultdict(dict)
tmp_symbol: Dict[int, List[Symbol]] = defaultdict(list)
tmp_symbol_identifier: Dict[int, List[SymbolIdentifier]] = defaultdict(list)
tmp_reason: Dict[int, Dict[str, List[Symbol]]] = defaultdict(dict)
for sym in h_symbols:
rule_nr, symbol, reasons = sym.arguments
tmp_symbol[rule_nr.number].append(symbol)
tmp_reason[rule_nr.number][str(symbol)] = reasons.arguments
for rule_nr in tmp_symbol.keys():
tmp_symbol[rule_nr] = set(tmp_symbol[rule_nr])
tmp_symbol[rule_nr] = map(lambda symbol: next(filter(
tmp_symbol[rule_nr] = list(tmp_symbol[rule_nr])
tmp_symbol_identifier[rule_nr] = list(map(lambda symbol: next(filter(
lambda supernode_symbol: supernode_symbol==symbol, supernode_symbols)) if
symbol in supernode_symbols else
SymbolIdentifier(symbol),tmp_symbol[rule_nr])
SymbolIdentifier(symbol),tmp_symbol[rule_nr]))
if pad:
h_symbols = [
Node(frozenset(tmp_symbol[rule_nr]), rule_nr, reason=tmp_reason[rule_nr]) if rule_nr in tmp_symbol else Node(frozenset(), rule_nr) for
rule_nr in relevant_indices]
h_nodes: List[Node] = [
Node(frozenset(tmp_symbol_identifier[rule_nr]),
rule_nr,
reason=tmp_reason[rule_nr])
if rule_nr in tmp_symbol
else Node(frozenset(), rule_nr)
for rule_nr in relevant_indices]
else:
h_symbols = [
Node(frozenset(tmp_symbol[rule_nr]), rule_nr, reason=tmp_reason[rule_nr]) if rule_nr in tmp_symbol else Node(frozenset(), rule_nr)
h_nodes: List[Node] = [
Node(frozenset(tmp_symbol_identifier[rule_nr]),
rule_nr,
reason=tmp_reason[rule_nr])
if rule_nr in tmp_symbol
else Node(frozenset(), rule_nr)
for rule_nr in range(1, max(tmp_symbol.keys(), default=-1) + 1)]

return h_symbols
return h_nodes


def make_reason_path_from_facts_to_stable_model(wrapped_stable_model,
rule_mapping: Dict[int, Union[AST, str]],
fact_node: Node, h_syms,
recursive_transformations:frozenset,
rule_mapping: Dict[int, Transformation],
fact_node: Node,
h_symbols: List[Symbol],
recursive_transformations:set,
h="h",
analyzer: ProgramAnalyzer = ProgramAnalyzer(),
pad=True) \
-> nx.DiGraph:
h_syms = collect_h_symbols_and_create_nodes(h_syms, rule_mapping.keys(), pad)
h_syms: List[Node] = collect_h_symbols_and_create_nodes(h_symbols, rule_mapping.keys(), pad)
h_syms.sort(key=lambda node: node.rule_nr)
h_syms.insert(0, fact_node)

Expand Down Expand Up @@ -132,27 +142,29 @@ def make_transformation_mapping(transformations: Iterable[Transformation]):


def append_noops(result_graph: DiGraph, analyzer: ProgramAnalyzer):
next_transformation_id = max(t.id for t in analyzer.get_sorted_program()) + 1
next_transformation_id = max(t.id for t in next(analyzer.get_sorted_program())) + 1
leaves = list(get_leafs_from_graph(result_graph))
leaf: Node
for leaf in leaves:
noop_node = Node(frozenset(), next_transformation_id, leaf.atoms)
result_graph.add_edge(leaf, noop_node,
transformation=Transformation(next_transformation_id,
[str(pt) for pt in analyzer.pass_through]))
tuple(analyzer.pass_through)))


def build_graph(wrapped_stable_models: Collection[str], transformed_prg: Collection[AST],
def build_graph(wrapped_stable_models: List[List[str]],
transformed_prg: Collection[AST],
sorted_program: List[Transformation],
analyzer: ProgramAnalyzer,
recursion_transformations: frozenset) -> nx.DiGraph:
recursion_transformations: set) -> nx.DiGraph:
paths: List[nx.DiGraph] = []
facts = analyzer.get_facts()
conflict_free_h = analyzer.get_conflict_free_h()
conflict_free_h_showTerm = analyzer.get_conflict_free_h_showTerm()
identifiable_facts = map(SymbolIdentifier,facts)
sorted_program = analyzer.get_sorted_program()
identifiable_facts = list(map(SymbolIdentifier, facts))
mapping = make_transformation_mapping(sorted_program)
fact_node = Node(frozenset(identifiable_facts), -1, frozenset(identifiable_facts))
fact_node = Node(frozenset(identifiable_facts), -1,
frozenset(identifiable_facts))
if not len(mapping):
info(f"Program only contains facts. {fact_node}")
single_node_graph = nx.DiGraph()
Expand All @@ -163,15 +175,17 @@ def build_graph(wrapped_stable_models: Collection[str], transformed_prg: Collect
analyzer.get_constants(),
conflict_free_h,
conflict_free_h_showTerm)
new_path = make_reason_path_from_facts_to_stable_model(model, mapping, fact_node, h_symbols, recursion_transformations, conflict_free_h, analyzer)
new_path = make_reason_path_from_facts_to_stable_model(
model, mapping, fact_node, h_symbols, recursion_transformations,
conflict_free_h, analyzer)
paths.append(new_path)

result_graph = nx.DiGraph()
result_graph.update(join_paths_with_facts(paths))
calculate_spacing_factor(result_graph)
identify_reasons(result_graph)
if analyzer.pass_through:
append_noops(result_graph, analyzer)
calculate_spacing_factor(result_graph)
identify_reasons(result_graph)
return result_graph


Expand All @@ -191,7 +205,7 @@ def filter_body_aggregates(element: AST):


def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,
transformation: Union[AST, str], conflict_free_h: str,
transformation: Transformation, conflict_free_h: str,
analyzer: ProgramAnalyzer) -> Union[bool, nx.DiGraph]:
"""
Get a recursion explanation for the given facts and the recursive transformation.
Expand All @@ -207,14 +221,14 @@ def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,

init = [fact.symbol for fact in facts]
justification_program = ""
model_str: str = analyzer.get_conflict_free_model()
n_str: str = analyzer.get_conflict_free_iterindex()
model_str: str = analyzer.get_conflict_free_model() if analyzer else "model"
n_str: str = analyzer.get_conflict_free_iterindex() if analyzer else "n"

for rule in transformation.rules:
deps = defaultdict(list)
loc = rule.location

_ = analyzer.visit(rule.head, deps=deps)
_ = analyzer.visit(rule.head, deps=deps) # type: ignore
if not deps:
deps[rule.head] = []
for dependant, conditions in deps.items():
Expand All @@ -231,8 +245,8 @@ def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,
False))
dependant = ast.Literal(loc, ast.Sign.NoSign, symbol)

new_body: List[ast.Literal] = []
reason_literals: List[ast.Literal] = []
new_body: List[ast.Literal] = [] # type: ignore
reason_literals: List[ast.Literal] = [] # type: ignore
_ = analyzer.visit_sequence(
rule.body, reasons=reason_literals, new_body=new_body, rename_variables=False)
loc_fun = ast.Function(loc, n_str, [], False)
Expand All @@ -257,7 +271,9 @@ def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,
new_body = [x for i, x in enumerate(
new_body) if x not in new_body[:i]]
# rename variables inside body aggregates
new_body = analyzer.visit_sequence(new_body, rename_variables=True)
new_body = list(
analyzer.visit_sequence(cast(ast.ASTSequence, new_body),
rename_variables=True))
new_body = [ast.Function(loc, model_str, [bb], 0) for bb in filter(filter_body_aggregates,new_body)]
new_body.append(ast.Function(loc, f"not {model_str}", [dependant], 0))
justification_program += "\n".join(map(str, (ast.Rule(rule.location, new_head, new_body)
Expand Down
Loading

0 comments on commit 034e620

Please sign in to comment.