Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ports review #7

Merged
merged 3 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/andromede/expression/context_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def variable(self, node: VariableNode) -> ExpressionNode:
def parameter(self, node: ParameterNode) -> ExpressionNode:
return ComponentParameterNode(self.component_id, node.name)

def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
raise ValueError(
"This expression has already been associated to another component."
)

def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
raise ValueError(
"This expression has already been associated to another component."
)
Expand Down
12 changes: 6 additions & 6 deletions src/andromede/expression/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ class CopyVisitor(ExpressionVisitorOperations[ExpressionNode]):
Simply copies the whole AST.
"""

def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
return ComponentParameterNode(node.component_id, node.name)

def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
return ComponentVariableNode(node.component_id, node.name)

def literal(self, node: LiteralNode) -> ExpressionNode:
return LiteralNode(node.value)

Expand All @@ -63,6 +57,12 @@ def variable(self, node: VariableNode) -> ExpressionNode:
def parameter(self, node: ParameterNode) -> ExpressionNode:
return ParameterNode(node.name)

def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
return ComponentVariableNode(node.component_id, node.name)

def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
return ComponentParameterNode(node.component_id, node.name)

def copy_expression_range(
self, expression_range: ExpressionRange
) -> ExpressionRange:
Expand Down
12 changes: 6 additions & 6 deletions src/andromede/expression/degree.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ class ExpressionDegreeVisitor(ExpressionVisitor[int]):
Computes degree of expression with respect to variables.
"""

def comp_parameter(self, node: ComponentParameterNode) -> int:
return 0

def comp_variable(self, node: ComponentVariableNode) -> int:
return 1

def literal(self, node: LiteralNode) -> int:
return 0

Expand Down Expand Up @@ -78,6 +72,12 @@ def variable(self, node: VariableNode) -> int:
def parameter(self, node: ParameterNode) -> int:
return 0

def comp_variable(self, node: ComponentVariableNode) -> int:
return 1

def comp_parameter(self, node: ComponentParameterNode) -> int:
return 0

def time_operator(self, node: TimeOperatorNode) -> int:
if node.name in ["TimeShift", "TimeEvaluation"]:
return visit(node.operand, self)
Expand Down
12 changes: 6 additions & 6 deletions src/andromede/expression/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ class EvaluationVisitor(ExpressionVisitorOperations[float]):

context: ValueProvider

def comp_parameter(self, node: ComponentParameterNode) -> float:
return self.context.get_component_parameter_value(node.component_id, node.name)

def comp_variable(self, node: ComponentVariableNode) -> float:
return self.context.get_component_variable_value(node.component_id, node.name)

def literal(self, node: LiteralNode) -> float:
return node.value

Expand All @@ -120,6 +114,12 @@ def variable(self, node: VariableNode) -> float:
def parameter(self, node: ParameterNode) -> float:
return self.context.get_parameter_value(node.name)

def comp_parameter(self, node: ComponentParameterNode) -> float:
return self.context.get_component_parameter_value(node.component_id, node.name)

def comp_variable(self, node: ComponentVariableNode) -> float:
return self.context.get_component_variable_value(node.component_id, node.name)

def time_operator(self, node: TimeOperatorNode) -> float:
raise NotImplementedError()

Expand Down
25 changes: 25 additions & 0 deletions src/andromede/expression/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def param(name: str) -> ParameterNode:

@dataclass(frozen=True, eq=False)
class ComponentParameterNode(ExpressionNode):
"""
Represents one parameter of one component.

When building actual equations for a system,
we need to associated each parameter to its
actual component, at some point.
"""

component_id: str
name: str

Expand All @@ -189,6 +197,14 @@ def comp_param(component_id: str, name: str) -> ComponentParameterNode:

@dataclass(frozen=True, eq=False)
class ComponentVariableNode(ExpressionNode):
"""
Represents one variable of one component.

When building actual equations for a system,
we need to associated each variable to its
actual component, at some point.
"""

component_id: str
name: str

Expand Down Expand Up @@ -321,6 +337,15 @@ def expression_range(


class InstancesTimeIndex:
"""
Defines a set of time indices on which a time operator operates.

In particular, it defines time indices created by the shift operator.

The actual indices can either be defined as a time range defined by
2 expression, or as a list of expressions.
"""

expressions: Union[List[ExpressionNode], ExpressionRange]

def __init__(
Expand Down
28 changes: 16 additions & 12 deletions src/andromede/expression/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,6 @@ class TimeScenarioIndexingVisitor(ExpressionVisitor[IndexingStructure]):

context: IndexingStructureProvider

def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure:
return self.context.get_component_parameter_structure(
node.component_id, node.name
)

def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure:
return self.context.get_component_variable_structure(
node.component_id, node.name
)

def literal(self, node: LiteralNode) -> IndexingStructure:
return IndexingStructure(False, False)

Expand Down Expand Up @@ -109,6 +99,16 @@ def parameter(self, node: ParameterNode) -> IndexingStructure:
scenario = self.context.get_parameter_structure(node.name).scenario == True
return IndexingStructure(time, scenario)

def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure:
return self.context.get_component_variable_structure(
node.component_id, node.name
)

def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure:
return self.context.get_component_parameter_structure(
node.component_id, node.name
)

def time_operator(self, node: TimeOperatorNode) -> IndexingStructure:
time_operator_cls = getattr(andromede.expression.time_operator, node.name)
if time_operator_cls.rolling():
Expand All @@ -126,10 +126,14 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> IndexingStructure:
return IndexingStructure(visit(node.operand, self).time, False)

def port_field(self, node: PortFieldNode) -> IndexingStructure:
raise ValueError("Should be instantiated before computing indexing structure.")
raise ValueError(
"Port fields must be resolved before computing indexing structure."
)

def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStructure:
raise ValueError("Should be instantiated before computing indexing structure.")
raise ValueError(
"Port fields aggregators must be resolved before computing indexing structure."
)


def compute_indexation(
Expand Down
10 changes: 5 additions & 5 deletions src/andromede/expression/port_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ class PortResolver(CopyVisitor):
their corresponding expression.
"""

ports_expressions: Dict[PortFieldKey, List[ExpressionNode]]
component_id: str
ports_expressions: Dict[PortFieldKey, List[ExpressionNode]]

def port_field(self, node: PortFieldNode) -> ExpressionNode:
expression = self.ports_expressions[
expressions = self.ports_expressions[
PortFieldKey(
self.component_id, PortFieldId(node.port_name, node.field_name)
)
]
if len(expression) != 1:
if len(expressions) != 1:
raise ValueError(
f"Invalid number of expression for port : {node.port_name}"
)
else:
return expression[0]
return expressions[0]

def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode:
if node.aggregator != "PortSum":
Expand All @@ -80,4 +80,4 @@ def resolve_port(
component_id: str,
ports_expressions: Dict[PortFieldKey, List[ExpressionNode]],
) -> ExpressionNode:
return visit(expression, PortResolver(ports_expressions, component_id))
return visit(expression, PortResolver(component_id, ports_expressions))
12 changes: 6 additions & 6 deletions src/andromede/expression/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ class PrinterVisitor(ExpressionVisitor[str]):
TODO: remove parenthis where not necessary.
"""

def comp_parameter(self, node: ComponentParameterNode) -> str:
return f"{node.component_id}.{node.name}"

def comp_variable(self, node: ComponentVariableNode) -> str:
return f"{node.component_id}.{node.name}"

def literal(self, node: LiteralNode) -> str:
return str(node.value)

Expand Down Expand Up @@ -98,6 +92,12 @@ def variable(self, node: VariableNode) -> str:
def parameter(self, node: ParameterNode) -> str:
return node.name

def comp_variable(self, node: ComponentVariableNode) -> str:
return f"{node.component_id}.{node.name}"

def comp_parameter(self, node: ComponentParameterNode) -> str:
return f"{node.component_id}.{node.name}"

# TODO: Add pretty print for node.instances_index
def time_operator(self, node: TimeOperatorNode) -> str:
return f"({visit(node.operand, self)}.{str(node.name)}({node.instances_index}))"
Expand Down
14 changes: 7 additions & 7 deletions src/andromede/expression/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,31 @@ def parameter(self, node: ParameterNode) -> T:
...

@abstractmethod
def time_operator(self, node: TimeOperatorNode) -> T:
def comp_parameter(self, node: ComponentParameterNode) -> T:
...

@abstractmethod
def time_aggregator(self, node: TimeAggregatorNode) -> T:
def comp_variable(self, node: ComponentVariableNode) -> T:
...

@abstractmethod
def scenario_operator(self, node: ScenarioOperatorNode) -> T:
def time_operator(self, node: TimeOperatorNode) -> T:
...

@abstractmethod
def port_field(self, node: PortFieldNode) -> T:
def time_aggregator(self, node: TimeAggregatorNode) -> T:
...

@abstractmethod
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T:
def scenario_operator(self, node: ScenarioOperatorNode) -> T:
...

@abstractmethod
def comp_parameter(self, node: ComponentParameterNode) -> T:
def port_field(self, node: PortFieldNode) -> T:
...

@abstractmethod
def comp_variable(self, node: ComponentVariableNode) -> T:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T:
...


Expand Down
Loading
Loading