From 9bce46432345a6292cfd259bb9aac1c4cffdf5f7 Mon Sep 17 00:00:00 2001 From: Stephan Zwicknagl Date: Thu, 20 Jun 2024 13:18:18 -1000 Subject: [PATCH] Add generalized signature function to analyzer The analyzer passes AST nodes to the make signature function in order to generate a signature for the node in the dependency graph. This commit adds a generalized signature function that is used to generate signatures for all possible AST nodes. `None` is returned for AST that is irrelevant for the dependency graph (e.g. aggregates). Adding tests for all possible AST nodes that can be analyzed. --- backend/src/viasp/asp/reify.py | 96 +++++++++++++++++++++++++--------- backend/test/test_analyzer.py | 90 +++++++++++++++++++++++++++---- 2 files changed, 152 insertions(+), 34 deletions(-) diff --git a/backend/src/viasp/asp/reify.py b/backend/src/viasp/asp/reify.py index fd192919..8c654c69 100644 --- a/backend/src/viasp/asp/reify.py +++ b/backend/src/viasp/asp/reify.py @@ -9,6 +9,8 @@ parse_string, ASTType, AST, + Literal as astLiteral, + SymbolicAtom as astSymbolicAtom ) from viasp.shared.util import hash_transformation_rules @@ -26,27 +28,48 @@ def is_fact(rule, dependencies): return len(rule.body) == 0 and not len(dependencies) +def make_signature_from_terms(term) -> Optional[Tuple[str, int]]: + if term.ast_type == ASTType.SymbolicTerm: + return term.symbol.name, 0 + elif term.ast_type == ASTType.Variable: + return (term.name, 0) + elif term.ast_type == ASTType.UnaryOperation: + return make_signature_from_terms(term.argument) + elif term.ast_type == ASTType.BinaryOperation: + return None + elif term.ast_type == ASTType.Interval: + return None + elif term.ast_type == ASTType.Function: + return (term.name, len(term.arguments)) + elif term.ast_type == ASTType.Pool: + return make_signature_from_terms(term.arguments[0]) + raise ValueError(f"Could not make signature of {term}.") + +def make_signature(ast: Union[ast.Literal, ast.ConditionalLiteral]) -> Optional[Tuple[str, int]]: # type: ignore + """ + Is used to create a signature for a literal or conditional literal for placing it in the dependency graph. + `None` is returned for types of literals that are unsupported or neglected in the dependency graph. + """ + if ast.ast_type == ASTType.Literal: + atom = ast.atom + else: + return None -def make_signature(literal: ast.Literal) -> Tuple[str, int]: # type: ignore - if literal.atom.ast_type in [ - ASTType.BodyAggregate, ASTType.BooleanConstant - ]: - return literal, 0 - unpacked = literal.atom.symbol - if unpacked.ast_type in [ASTType.Variable, ASTType.Function]: - return ( - unpacked.name, - len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0, - ) - if unpacked.ast_type == ASTType.SymbolicTerm: - return ( - unpacked.symbol.name, - len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0, - ) - if unpacked.ast_type == ASTType.Pool: - unpacked = unpacked.arguments[0] - return (unpacked.name, len(unpacked.arguments)) - raise ValueError(f"Could not make signature of {literal}") + if atom.ast_type == ASTType.BodyAggregate: + return None + elif atom.ast_type == ASTType.BooleanConstant: + return None + elif atom.ast_type == ASTType.Comparison: + return None + elif atom.ast_type == ASTType.Aggregate: + return None + elif atom.ast_type == ASTType.TheoryAtom: + raise ValueError(f"Could not make signature of {ast}.") + elif atom.ast_type == ASTType.SymbolicAtom: + term = atom.symbol + return make_signature_from_terms(term) + + raise ValueError(f"Could not make signature of {ast}, {ast.ast_type}.") def filter_body_arithmetic(elem: ast.Literal): # type: ignore @@ -222,7 +245,7 @@ def get_conflict_free_iterindex(self): transformations. """ return self._get_conflict_free_version_of_name("n") - + def get_conflict_free_derivable(self): """ For use in generation of subgraphs at recursive @@ -369,7 +392,8 @@ def register_rule_conditions( conditions: List[ast.Literal]) -> None: # type: ignore for c in conditions: c_sig = make_signature(c) - self.conditions[c_sig].add(rule) + if c_sig is not None: + self.conditions[c_sig].add(rule) def register_rule_dependencies( self, @@ -381,15 +405,18 @@ def register_rule_dependencies( for (cond, pos_cond) in deps.values(): for c in filter(filter_body_arithmetic, cond): c_sig = make_signature(c) - self.conditions[c_sig].add(rule) + if c_sig is not None: + self.conditions[c_sig].add(rule) for c in filter(filter_body_arithmetic, pos_cond): c_sig = make_signature(c) - self.positive_conditions[c_sig].add(rule) + if c_sig is not None: + self.positive_conditions[c_sig].add(rule) for v in deps.keys(): if v.ast_type == ASTType.Literal and v.atom.ast_type != ASTType.BooleanConstant: v_sig = make_signature(v) - self.dependants[v_sig].add(rule) + if v_sig is not None: + self.dependants[v_sig].add(rule) def get_body_aggregate_elements(self, body: Sequence[AST]) -> List[AST]: body_aggregate_elements: List[AST] = [] @@ -911,3 +938,22 @@ def reify_recursion_transformation(transformation: Transformation, for rule in transformation.rules.ast: result.extend(cast(Iterable[AST], visitor.visit(rule))) return result + + +class LiteralsCollector(Transformer): + + def visit_Literal( + self, + literal: ast.Literal, # type: ignore + **kwargs: Any) -> AST: + literals: List[AST] = kwargs.get("literals", []) + literals.append(literal) + return literal.update(**self.visit_children(literal, **kwargs)) + + +def collect_literals(program: str): + visitor = LiteralsCollector() + literals = [] + parse_string(program, + lambda rule: visitor.visit(rule, literals=literals) and None) + return literals diff --git a/backend/test/test_analyzer.py b/backend/test/test_analyzer.py index d640e6eb..341f6d14 100644 --- a/backend/test/test_analyzer.py +++ b/backend/test/test_analyzer.py @@ -2,7 +2,7 @@ from clingo.ast import AST from viasp.asp.ast_types import (SUPPORTED_TYPES, UNSUPPORTED_TYPES) -from viasp.asp.reify import ProgramAnalyzer +from viasp.asp.reify import ProgramAnalyzer, collect_literals, make_signature from viasp.shared.util import hash_transformation_rules from viasp.server.database import GraphAccessor, get_or_create_encoding_id @@ -348,7 +348,7 @@ def test_body_conditional_literal_sorted_in_show_term(app_context): rules = ["hc(U,V) :- edge(U,V).", "#show allnodes : node(X): hc(_,X)."] program = """ node(1..2). edge(1,2). edge(2,1). - """ + "\n".join(rules) + """ + "\n".join(rules) transformer = ProgramAnalyzer() result = transformer.sort_program(program) assert len(result) == len(rules) @@ -424,10 +424,82 @@ def test_loop_recursion_gets_recognized(app_context): )) in recursive_rules, "Hash is determined by transformatinos." -def test_signature_literal(app_context): - test_string = """ - a :- Lit(b). - -h(R,T) :- b. - lit(c) :- b. - """ - pass \ No newline at end of file +def test_signature_pool(): + pool = """holds(X) :- map(a(X);a(X+1)).""" + literals = collect_literals(pool) + assert len(literals) == 2 + assert make_signature(literals[0]) == ('holds', 1) + assert make_signature(literals[1]) == ('map', 1) + + +def test_signature_boolean_constant(): + boolean_constant = """a:- #true.""" + literals = collect_literals(boolean_constant) + assert make_signature(literals[0]) == ('a', 0) + assert make_signature(literals[1]) == None + +def test_signature_theory_atom(): + theory_atom = """b:- &diff { T1-T2 } <= -D.""" + literals = collect_literals(theory_atom) + assert make_signature(literals[0]) == ('b', 0) + with pytest.raises(ValueError) as e_info: + make_signature(literals[1]) + +def test_signature_aggregate(): + aggregate = """a:- 1{b(X):c(X)}.""" + literals = collect_literals(aggregate) + assert make_signature(literals[0]) == ('a', 0) + assert make_signature(literals[1]) == None # ? + assert make_signature(literals[2]) == ('b', 1) + assert make_signature(literals[3]) == ('c', 1) + +def test_signature_body_aggregate(): + body_aggregate= """a:- 1=#sum{b(X):c(X)}.""" + literals = collect_literals(body_aggregate) + assert make_signature(literals[0]) == ('a', 0) + assert make_signature(literals[1]) == None + assert make_signature(literals[2]) == ('c', 1) + +def test_signature_comparison(): + comparison = """a:- Z