Skip to content

Commit

Permalink
Add generalized signature function to analyzer
Browse files Browse the repository at this point in the history
The analyzer passes AST nodes to the make signature function in order to
generate a signature for the node in the dependency graph.

This commit adds a generalized signature function that is used to
generate signatures for all possible AST nodes. `None` is returned for
AST that is irrelevant for the dependency graph (e.g. aggregates).

Adding tests for all possible AST nodes that can be analyzed.
  • Loading branch information
stephanzwicknagl committed Jun 20, 2024
1 parent 893d4e1 commit 9bce464
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 34 deletions.
96 changes: 71 additions & 25 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
parse_string,
ASTType,
AST,
Literal as astLiteral,
SymbolicAtom as astSymbolicAtom
)
from viasp.shared.util import hash_transformation_rules

Expand All @@ -26,27 +28,48 @@
def is_fact(rule, dependencies):
return len(rule.body) == 0 and not len(dependencies)

def make_signature_from_terms(term) -> Optional[Tuple[str, int]]:
if term.ast_type == ASTType.SymbolicTerm:
return term.symbol.name, 0
elif term.ast_type == ASTType.Variable:
return (term.name, 0)
elif term.ast_type == ASTType.UnaryOperation:
return make_signature_from_terms(term.argument)
elif term.ast_type == ASTType.BinaryOperation:
return None
elif term.ast_type == ASTType.Interval:
return None
elif term.ast_type == ASTType.Function:
return (term.name, len(term.arguments))
elif term.ast_type == ASTType.Pool:
return make_signature_from_terms(term.arguments[0])
raise ValueError(f"Could not make signature of {term}.")

def make_signature(ast: Union[ast.Literal, ast.ConditionalLiteral]) -> Optional[Tuple[str, int]]: # type: ignore
"""
Is used to create a signature for a literal or conditional literal for placing it in the dependency graph.
`None` is returned for types of literals that are unsupported or neglected in the dependency graph.
"""
if ast.ast_type == ASTType.Literal:
atom = ast.atom
else:
return None

def make_signature(literal: ast.Literal) -> Tuple[str, int]: # type: ignore
if literal.atom.ast_type in [
ASTType.BodyAggregate, ASTType.BooleanConstant
]:
return literal, 0
unpacked = literal.atom.symbol
if unpacked.ast_type in [ASTType.Variable, ASTType.Function]:
return (
unpacked.name,
len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0,
)
if unpacked.ast_type == ASTType.SymbolicTerm:
return (
unpacked.symbol.name,
len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0,
)
if unpacked.ast_type == ASTType.Pool:
unpacked = unpacked.arguments[0]
return (unpacked.name, len(unpacked.arguments))
raise ValueError(f"Could not make signature of {literal}")
if atom.ast_type == ASTType.BodyAggregate:
return None
elif atom.ast_type == ASTType.BooleanConstant:
return None
elif atom.ast_type == ASTType.Comparison:
return None
elif atom.ast_type == ASTType.Aggregate:
return None
elif atom.ast_type == ASTType.TheoryAtom:
raise ValueError(f"Could not make signature of {ast}.")
elif atom.ast_type == ASTType.SymbolicAtom:
term = atom.symbol
return make_signature_from_terms(term)

raise ValueError(f"Could not make signature of {ast}, {ast.ast_type}.")


def filter_body_arithmetic(elem: ast.Literal): # type: ignore
Expand Down Expand Up @@ -222,7 +245,7 @@ def get_conflict_free_iterindex(self):
transformations.
"""
return self._get_conflict_free_version_of_name("n")

def get_conflict_free_derivable(self):
"""
For use in generation of subgraphs at recursive
Expand Down Expand Up @@ -369,7 +392,8 @@ def register_rule_conditions(
conditions: List[ast.Literal]) -> None: # type: ignore
for c in conditions:
c_sig = make_signature(c)
self.conditions[c_sig].add(rule)
if c_sig is not None:
self.conditions[c_sig].add(rule)

def register_rule_dependencies(
self,
Expand All @@ -381,15 +405,18 @@ def register_rule_dependencies(
for (cond, pos_cond) in deps.values():
for c in filter(filter_body_arithmetic, cond):
c_sig = make_signature(c)
self.conditions[c_sig].add(rule)
if c_sig is not None:
self.conditions[c_sig].add(rule)
for c in filter(filter_body_arithmetic, pos_cond):
c_sig = make_signature(c)
self.positive_conditions[c_sig].add(rule)
if c_sig is not None:
self.positive_conditions[c_sig].add(rule)

for v in deps.keys():
if v.ast_type == ASTType.Literal and v.atom.ast_type != ASTType.BooleanConstant:
v_sig = make_signature(v)
self.dependants[v_sig].add(rule)
if v_sig is not None:
self.dependants[v_sig].add(rule)

def get_body_aggregate_elements(self, body: Sequence[AST]) -> List[AST]:
body_aggregate_elements: List[AST] = []
Expand Down Expand Up @@ -911,3 +938,22 @@ def reify_recursion_transformation(transformation: Transformation,
for rule in transformation.rules.ast:
result.extend(cast(Iterable[AST], visitor.visit(rule)))
return result


class LiteralsCollector(Transformer):

def visit_Literal(
self,
literal: ast.Literal, # type: ignore
**kwargs: Any) -> AST:
literals: List[AST] = kwargs.get("literals", [])
literals.append(literal)
return literal.update(**self.visit_children(literal, **kwargs))


def collect_literals(program: str):
visitor = LiteralsCollector()
literals = []
parse_string(program,
lambda rule: visitor.visit(rule, literals=literals) and None)
return literals
90 changes: 81 additions & 9 deletions backend/test/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from clingo.ast import AST

from viasp.asp.ast_types import (SUPPORTED_TYPES, UNSUPPORTED_TYPES)
from viasp.asp.reify import ProgramAnalyzer
from viasp.asp.reify import ProgramAnalyzer, collect_literals, make_signature
from viasp.shared.util import hash_transformation_rules
from viasp.server.database import GraphAccessor, get_or_create_encoding_id

Expand Down Expand Up @@ -348,7 +348,7 @@ def test_body_conditional_literal_sorted_in_show_term(app_context):
rules = ["hc(U,V) :- edge(U,V).", "#show allnodes : node(X): hc(_,X)."]
program = """
node(1..2). edge(1,2). edge(2,1).
""" + "\n".join(rules)
""" + "\n".join(rules)
transformer = ProgramAnalyzer()
result = transformer.sort_program(program)
assert len(result) == len(rules)
Expand Down Expand Up @@ -424,10 +424,82 @@ def test_loop_recursion_gets_recognized(app_context):
)) in recursive_rules, "Hash is determined by transformatinos."


def test_signature_literal(app_context):
test_string = """
a :- Lit(b).
-h(R,T) :- b.
lit(c) :- b.
"""
pass
def test_signature_pool():
pool = """holds(X) :- map(a(X);a(X+1))."""
literals = collect_literals(pool)
assert len(literals) == 2
assert make_signature(literals[0]) == ('holds', 1)
assert make_signature(literals[1]) == ('map', 1)


def test_signature_boolean_constant():
boolean_constant = """a:- #true."""
literals = collect_literals(boolean_constant)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == None

def test_signature_theory_atom():
theory_atom = """b:- &diff { T1-T2 } <= -D."""
literals = collect_literals(theory_atom)
assert make_signature(literals[0]) == ('b', 0)
with pytest.raises(ValueError) as e_info:
make_signature(literals[1])

def test_signature_aggregate():
aggregate = """a:- 1{b(X):c(X)}."""
literals = collect_literals(aggregate)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == None # ?
assert make_signature(literals[2]) == ('b', 1)
assert make_signature(literals[3]) == ('c', 1)

def test_signature_body_aggregate():
body_aggregate= """a:- 1=#sum{b(X):c(X)}."""
literals = collect_literals(body_aggregate)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == None
assert make_signature(literals[2]) == ('c', 1)

def test_signature_comparison():
comparison = """a:- Z<X+Y."""
literals = collect_literals(comparison)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == None # ?

def test_signature_unary_operation():
unary_operation = """a:- -b."""
literals = collect_literals(unary_operation)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == ('b', 0)

unary_operation = """-a:- b."""
literals = collect_literals(unary_operation)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == ('b', 0)

def test_signature_function():
function = """a:- b."""
literals = collect_literals(function)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == ('b', 0)

def test_signature_function_with_variable():
function_with_variable = """a:- b(X,Y,Z)."""
literals = collect_literals(function_with_variable)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == ('b', 3)

def test_signature_function_with_interval():
function_with_interval = """b(X) :- a(1..2)."""
literals = collect_literals(function_with_interval)
assert make_signature(literals[0]) == ('b', 1)
assert make_signature(literals[1]) == ('a', 1)


def test_signature_conditional_literal():
conditional_literal = """a:- b(X):c(X)."""
literals = collect_literals(conditional_literal)
assert make_signature(literals[0]) == ('a', 0)
assert make_signature(literals[1]) == ('b', 1)
assert make_signature(literals[2]) == ('c', 1)
# signature of the conditional literal itself

0 comments on commit 9bce464

Please sign in to comment.