Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AST Node Classes: Fix mutable default arguments and enhance r… #277

Merged
merged 4 commits into from
Mar 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 95 additions & 69 deletions crosstl/backend/Vulkan/VulkanAst.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,129 @@
from typing import Any


class ASTNode:
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

def __str__(self) -> str:
return self.__repr__()

class TernaryOpNode:
def __init__(self, condition, true_expr, false_expr):

class TernaryOpNode(ASTNode):
def __init__(self, condition: ASTNode, true_expr: ASTNode, false_expr: ASTNode):
self.condition = condition
self.true_expr = true_expr
self.false_expr = false_expr

def __repr__(self):
return f"TernaryOpNode(condition={self.condition}, true_expr={self.true_expr}, false_expr={self.false_expr})"
def __repr__(self) -> str:
return (
f"TernaryOpNode(condition={self.condition}, "
f"true_expr={self.true_expr}, false_expr={self.false_expr})"
)


class ShaderNode:
class ShaderNode(ASTNode):
def __init__(
self,
spirv_version,
descriptor_sets,
shader_stages,
functions,
spirv_version: str,
descriptor_sets: list,
shader_stages: list,
functions: list,
):
self.spirv_version = spirv_version
self.descriptor_sets = descriptor_sets
self.shader_stages = shader_stages
self.functions = functions

def __repr__(self):
return f"ShaderNode(spirv_version={self.spirv_version}, descriptor_sets={self.descriptor_sets}, shader_stages={self.shader_stages}, functions={self.functions})"
def __repr__(self) -> str:
return (
f"ShaderNode(spirv_version={self.spirv_version}, "
f"descriptor_sets={self.descriptor_sets}, "
f"shader_stages={self.shader_stages}, functions={self.functions})"
)


class IfNode(ASTNode):
def __init__(
self,
if_condition,
if_body,
else_if_conditions=[],
else_if_bodies=[],
else_body=None,
if_condition: ASTNode,
if_body: list,
else_if_conditions: list = None,
else_if_bodies: list = None,
else_body: list = None,
):
self.if_condition = if_condition
self.if_body = if_body
self.else_if_conditions = else_if_conditions
self.else_if_bodies = else_if_bodies
self.else_if_conditions = (
else_if_conditions if else_if_conditions is not None else []
)
self.else_if_bodies = else_if_bodies if else_if_bodies is not None else []
self.else_body = else_body

def __repr__(self):
return f"IfNode(if_condition={self.if_condition}, if_body={self.if_body}, else_if_conditions={self.else_if_conditions}, else_if_bodies={self.else_if_bodies}, else_body={self.else_body})"
def __repr__(self) -> str:
return (
f"IfNode(if_condition={self.if_condition}, if_body={self.if_body}, "
f"else_if_conditions={self.else_if_conditions}, "
f"else_if_bodies={self.else_if_bodies}, else_body={self.else_body})"
)


class ForNode(ASTNode):
def __init__(self, init, condition, update, body):
def __init__(self, init: ASTNode, condition: ASTNode, update: ASTNode, body: list):
self.init = init
self.condition = condition
self.update = update
self.body = body

def __repr__(self):
return f"ForNode(init={self.init}, condition={self.condition}, update={self.update}, body={self.body})"
def __repr__(self) -> str:
return (
f"ForNode(init={self.init}, condition={self.condition}, "
f"update={self.update}, body={self.body})"
)


class ReturnNode(ASTNode):
def __init__(self, value):
def __init__(self, value: Any):
self.value = value

def __repr__(self):
def __repr__(self) -> str:
return f"ReturnNode(value={self.value})"


class FunctionCallNode(ASTNode):
def __init__(self, name, args):
def __init__(self, name: str, args: list):
self.name = name
self.args = args

def __repr__(self):
def __repr__(self) -> str:
return f"FunctionCallNode(name={self.name}, args={self.args})"


class BinaryOpNode(ASTNode):
def __init__(self, left, op, right):
def __init__(self, left: ASTNode, op: str, right: ASTNode):
self.left = left
self.op = op
self.right = right

def __repr__(self):
def __repr__(self) -> str:
return f"BinaryOpNode(left={self.left}, op={self.op}, right={self.right})"


class UnaryOpNode(ASTNode):
def __init__(self, op, operand):
def __init__(self, op: str, operand: ASTNode):
self.op = op
self.operand = operand

def __repr__(self):
def __repr__(self) -> str:
return f"UnaryOpNode(operator={self.op}, operand={self.operand})"


class DescriptorSetNode(ASTNode):
def __init__(self, set_number, bindings):
def __init__(self, set_number: int, bindings: list):
self.set_number = set_number
self.bindings = bindings

def __repr__(self):
def __repr__(self) -> str:
return (
f"DescriptorSetNode(set_number={self.set_number}, bindings={self.bindings})"
)
Expand All @@ -109,12 +132,12 @@ def __repr__(self):
class LayoutNode(ASTNode):
def __init__(
self,
bindings,
push_constant,
layout_type,
data_type,
variable_name,
struct_fields,
bindings: list,
push_constant: bool,
layout_type: str,
data_type: str,
variable_name: str,
struct_fields: list,
):
self.bindings = bindings
self.push_constant = push_constant
Expand All @@ -123,7 +146,7 @@ def __init__(
self.variable_name = variable_name
self.struct_fields = struct_fields

def __repr__(self):
def __repr__(self) -> str:
return (
f"LayoutNode(bindings={self.bindings}, push_constant={self.push_constant}, "
f"layout_type={self.layout_type}, data_type={self.data_type}, "
Expand All @@ -132,127 +155,130 @@ def __repr__(self):


class UniformNode(ASTNode):
def __init__(self, name, var_type, value=None):
def __init__(self, name: str, var_type: str, value: Any = None):
self.name = name
self.var_type = var_type
self.value = value

def __repr__(self):
def __repr__(self) -> str:
return f"UniformNode(name={self.name}, var_type={self.var_type}, value={self.value})"


class ShaderStageNode(ASTNode):
def __init__(self, stage, entry_point):
def __init__(self, stage: str, entry_point: str):
self.stage = stage
self.entry_point = entry_point

def __repr__(self):
def __repr__(self) -> str:
return f"ShaderStageNode(stage={self.stage}, entry_point={self.entry_point})"


class PushConstantNode(ASTNode):
def __init__(self, size, values):
def __init__(self, size: int, values: list):
self.size = size
self.values = values

def __repr__(self):
def __repr__(self) -> str:
return f"PushConstantNode(size={self.size}, values={self.values})"


class StructNode(ASTNode):
def __init__(self, name, members):
def __init__(self, name: str, members: list):
self.name = name
self.members = members

def __repr__(self):
def __repr__(self) -> str:
return f"StructNode(name={self.name}, members={self.members})"


class FunctionNode(ASTNode):
def __init__(self, name, return_type, parameters, body):
def __init__(self, name: str, return_type: str, parameters: list, body: list):
self.name = name
self.return_type = return_type
self.parameters = parameters
self.body = body

def __repr__(self):
return f"FunctionNode(name={self.name}, return_type={self.return_type}, parameters={self.parameters}, body={self.body})"
def __repr__(self) -> str:
return (
f"FunctionNode(name={self.name}, return_type={self.return_type}, "
f"parameters={self.parameters}, body={self.body})"
)


class MemberAccessNode(ASTNode):
def __init__(self, object, member):
def __init__(self, object: ASTNode, member: str):
self.object = object
self.member = member

def __repr__(self):
def __repr__(self) -> str:
return f"MemberAccessNode(object={self.object}, member={self.member})"


class VariableNode(ASTNode):
def __init__(self, name, var_type):
def __init__(self, name: str, var_type: str):
self.name = name
self.var_type = var_type

def __repr__(self):
def __repr__(self) -> str:
return f"VariableNode(name={self.name}, var_type={self.var_type})"


class SwitchNode(ASTNode):
def __init__(self, expression, cases):
def __init__(self, expression: ASTNode, cases: list):
self.expression = expression
self.cases = cases

def __repr__(self):
def __repr__(self) -> str:
return f"SwitchNode(expression={self.expression}, cases={self.cases})"


class CaseNode(ASTNode):
def __init__(self, value, body):
def __init__(self, value: Any, body: list):
self.value = value
self.body = body

def __repr__(self):
def __repr__(self) -> str:
return f"CaseNode(value={self.value}, body={self.body})"


class DefaultNode(ASTNode):
def __init__(self, statements):
def __init__(self, statements: list):
self.statements = statements

def __repr__(self):
def __repr__(self) -> str:
return f"DefaultNode(statements={self.statements})"


class WhileNode(ASTNode):
def __init__(self, condition, body):
def __init__(self, condition: ASTNode, body: list):
self.condition = condition
self.body = body

def __repr__(self):
def __repr__(self) -> str:
return f"WhileNode(condition={self.condition}, body={self.body})"


class DoWhileNode(ASTNode):
def __init__(self, body, condition):
def __init__(self, body: list, condition: ASTNode):
self.body = body
self.condition = condition

def __repr__(self):
def __repr__(self) -> str:
return f"DoWhileNode(body={self.body}, condition={self.condition})"


class AssignmentNode(ASTNode):
def __init__(self, name, value):
def __init__(self, name: str, value: Any):
self.name = name
self.value = value

def __repr__(self):
def __repr__(self) -> str:
return f"AssignmentNode(name={self.name}, value={self.value})"


class BreakNode(ASTNode):
def __init__(self):
pass

def __repr__(self):
def __repr__(self) -> str:
return f"BreakNode()"
Loading