From d8353ae3e302064a98e5a938777637adac476cb3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 14:16:39 -0500 Subject: [PATCH 1/3] fix: struct touching --- vyper/codegen/expr.py | 61 +++++++++++++++++-------------- vyper/semantics/analysis/base.py | 56 +++++++++++++++++++--------- vyper/semantics/analysis/local.py | 12 +++--- vyper/semantics/analysis/utils.py | 17 ++++++--- 4 files changed, 89 insertions(+), 57 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 335cfefb87..2aed6af4b2 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -37,6 +37,7 @@ VyperException, tag_exceptions, ) +from vyper.semantics.analysis.base import VarAttributeInfo from vyper.semantics.types import ( AddressT, BoolT, @@ -263,24 +264,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif (varinfo := self.expr._expr_info.var_info) is not None: - if varinfo.is_constant: - return Expr.parse_value_expr(varinfo.decl_node.value, self.context) - - location = data_location_to_address_space( - varinfo.location, self.context.is_ctor_context - ) - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -336,17 +319,39 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global attribute + if (varinfo := self.expr._expr_info.var_info) is not None and not isinstance( + varinfo, VarAttributeInfo + ): + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) + + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation="self." + self.expr.attr, + ) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + # contract type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e59132bd82..b1ab262b9b 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -194,6 +194,29 @@ def is_constant(self): return res +@dataclass(kw_only=True) +class VarAttributeInfo(VarInfo): + attr: str + parent: VarInfo + + def __hash__(self): + return super().__hash__() + + @classmethod + def from_varinfo(cls, varinfo: VarInfo, attr: str, typ: VyperType): + location = varinfo.location + modifiability = varinfo.modifiability + return cls( + typ=typ, location=location, modifiability=modifiability, attr=attr, parent=varinfo + ) + + +@dataclass +class AttributeInfo: + attr: str + expr_info: "ExprInfo" + + @dataclass class ExprInfo: """ @@ -205,9 +228,7 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - - # the chain of attribute parents for this expr - attribute_chain: list["ExprInfo"] = field(default_factory=list) + attribute_chain: list[AttributeInfo] = field(default_factory=list) def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -216,6 +237,8 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") + self.attribute_chain = self.attribute_chain or [] + self._writes: OrderedSet[VarInfo] = OrderedSet() self._reads: OrderedSet[VarInfo] = OrderedSet() @@ -223,41 +246,40 @@ def __post_init__(self): # e.x. `x` will return varinfo for `x` # `module.foo` will return varinfo for `module.foo` # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_root_varinfo(self) -> Optional[VarInfo]: - for expr_info in self.attribute_chain + [self]: - if expr_info.var_info is not None and not isinstance(expr_info.typ, SelfT): - return expr_info.var_info + def get_closest_varinfo(self) -> Optional[VarInfo]: + for attr_info in reversed(self.attribute_chain + [self]): + var_info = getattr(attr_info, "expr_info", attr_info).var_info # type: ignore + if var_info is not None and not isinstance(var_info, SelfT): + return var_info return None @classmethod - def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, - attribute_chain=attribute_chain or [], + **kwargs, ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo": modifiability = Modifiability.RUNTIME_CONSTANT if module_info.ownership >= ModuleOwnership.USES: modifiability = Modifiability.MODIFIABLE return cls( - module_info.module_t, - module_info=module_info, - modifiability=modifiability, - attribute_chain=attribute_chain or [], + module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs ) - def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": + def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} - if attribute_chain is not None: - fields["attribute_chain"] = attribute_chain + for t in to_copy: + assert t not in kwargs + fields.update(kwargs) return self.__class__(typ=typ, **fields) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0ee03fc9bd..a2ff81a363 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -329,16 +329,14 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - var_info = info.get_root_varinfo() - assert var_info is not None - - info._writes.add(var_info) + assert (varinfo := info.get_closest_varinfo()) is not None + info._writes.add(varinfo) def _check_module_use(self, target: vy_ast.ExprNode): module_infos = [] for t in get_expr_info(target).attribute_chain: - if t.module_info is not None: - module_infos.append(t.module_info) + if t.expr_info.module_info is not None: + module_infos.append(t.expr_info.module_info) if len(module_infos) == 0: return @@ -444,7 +442,7 @@ def _analyse_list_iter(self, iter_node, target_type): # get the root varinfo from iter_val in case we need to peer # through folded constants info = get_expr_info(iter_val) - return info.get_root_varinfo() + return info.get_closest_varinfo() def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..ea0d6bc98d 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,13 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ( + ExprInfo, + Modifiability, + ModuleInfo, + VarAttributeInfo, + VarInfo, +) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -84,12 +90,11 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # propagate the parent exprinfo members down into the new expr # note: Attribute(expr value, identifier attr) - name = node.attr info = self.get_expr_info(node.value, is_callable=is_callable) - attribute_chain = info.attribute_chain + [info] - t = info.typ.get_member(name, node) + attr = node.attr + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): @@ -99,7 +104,9 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) # it's something else, like my_struct.foo - return info.copy_with_type(t, attribute_chain=attribute_chain) + assert (varinfo := info.var_info) is not None + child_varinfo = VarAttributeInfo.from_varinfo(varinfo=varinfo, attr=attr, typ=t) + return ExprInfo.from_varinfo(child_varinfo) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): From 15111ccd17fddd79f6f74d9407100ed759611af5 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 15:13:45 -0500 Subject: [PATCH 2/3] yeet VarAttributeInfo --- vyper/codegen/expr.py | 4 +--- vyper/semantics/analysis/base.py | 20 ++++---------------- vyper/semantics/analysis/utils.py | 6 +++--- 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 2aed6af4b2..c504145a57 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -323,9 +323,7 @@ def parse_Attribute(self): # Other variables # self.x: global attribute - if (varinfo := self.expr._expr_info.var_info) is not None and not isinstance( - varinfo, VarAttributeInfo - ): + if (varinfo := self.expr._expr_info.var_info) is not None: if varinfo.is_constant: return Expr.parse_value_expr(varinfo.decl_node.value, self.context) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index b1ab262b9b..47095ae271 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -193,22 +193,10 @@ def is_constant(self): assert res == (self.modifiability == Modifiability.CONSTANT) return res - -@dataclass(kw_only=True) -class VarAttributeInfo(VarInfo): - attr: str - parent: VarInfo - - def __hash__(self): - return super().__hash__() - - @classmethod - def from_varinfo(cls, varinfo: VarInfo, attr: str, typ: VyperType): - location = varinfo.location - modifiability = varinfo.modifiability - return cls( - typ=typ, location=location, modifiability=modifiability, attr=attr, parent=varinfo - ) +@dataclass +class VariableAccess: + variable: VarInfo + attrs: tuple[str] @dataclass diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ea0d6bc98d..fbde793e6a 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -91,7 +91,8 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # note: Attribute(expr value, identifier attr) info = self.get_expr_info(node.value, is_callable=is_callable) - attribute_chain = info.attribute_chain + [info] + attr_info = AttributeInfo(attr, info) + attribute_chain = info.attribute_chain + [attr_info] attr = node.attr t = info.typ.get_member(attr, node) @@ -111,8 +112,7 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - attribute_chain = info.attribute_chain + [info] - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t) return ExprInfo(t) From 7a7547fa9b7c9eb7a4e7ce7b42a28d3b7029b900 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 15:13:45 -0500 Subject: [PATCH 3/3] yeet VarAttributeInfo --- vyper/codegen/expr.py | 1 - vyper/semantics/analysis/base.py | 49 +++++++++++++++++++----------- vyper/semantics/analysis/local.py | 50 ++++++++++++++++++++----------- vyper/semantics/analysis/utils.py | 23 +++++--------- vyper/semantics/types/function.py | 7 +++-- 5 files changed, 75 insertions(+), 55 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index c504145a57..9c7f11dcb3 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -37,7 +37,6 @@ VyperException, tag_exceptions, ) -from vyper.semantics.analysis.base import VarAttributeInfo from vyper.semantics.types import ( AddressT, BoolT, diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 47095ae271..81f543cb17 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -193,16 +193,25 @@ def is_constant(self): assert res == (self.modifiability == Modifiability.CONSTANT) return res -@dataclass + +@dataclass(frozen=True) class VariableAccess: variable: VarInfo - attrs: tuple[str] + attrs: tuple[str, ...] + @classmethod + def from_base(cls, access: "VariableAccess"): + return cls(**access.__dict__) -@dataclass -class AttributeInfo: - attr: str - expr_info: "ExprInfo" + +@dataclass(frozen=True) +class VariableRead(VariableAccess): + pass + + +@dataclass(frozen=True) +class VariableWrite(VariableAccess): + pass @dataclass @@ -216,7 +225,8 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - attribute_chain: list[AttributeInfo] = field(default_factory=list) + attribute_chain: list["ExprInfo"] = field(default_factory=list) + attr: Optional[str] = None def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -227,18 +237,24 @@ def __post_init__(self): self.attribute_chain = self.attribute_chain or [] - self._writes: OrderedSet[VarInfo] = OrderedSet() - self._reads: OrderedSet[VarInfo] = OrderedSet() + self._writes: OrderedSet[VariableWrite] = OrderedSet() + self._reads: OrderedSet[VariableRead] = OrderedSet() # find exprinfo in the attribute chain which has a varinfo # e.x. `x` will return varinfo for `x` # `module.foo` will return varinfo for `module.foo` - # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_closest_varinfo(self) -> Optional[VarInfo]: - for attr_info in reversed(self.attribute_chain + [self]): - var_info = getattr(attr_info, "expr_info", attr_info).var_info # type: ignore - if var_info is not None and not isinstance(var_info, SelfT): - return var_info + # `self.my_struct.x.y` will return varinfo for `self.my_struct.x.y` + def get_variable_access(self) -> Optional[VariableAccess]: + chain = self.attribute_chain + [self] + for i, expr_info in enumerate(chain): + varinfo = expr_info.var_info + if varinfo is not None and not isinstance(varinfo, SelfT): + attrs = [] + for expr_info in chain[i:]: + if expr_info.attr is None: + break + attrs.append(expr_info.attr) + return VariableAccess(varinfo, tuple(attrs)) return None @classmethod @@ -269,5 +285,4 @@ def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": fields = {k: getattr(self, k) for k in to_copy} for t in to_copy: assert t not in kwargs - fields.update(kwargs) - return self.__class__(typ=typ, **fields) + return self.__class__(typ=typ, **fields, **kwargs) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a2ff81a363..f19dc3e4e7 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -20,7 +20,14 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo +from vyper.semantics.analysis.base import ( + Modifiability, + ModuleOwnership, + VariableAccess, + VariableRead, + VariableWrite, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -182,7 +189,7 @@ def __init__( self.func = fn_node._metadata["func_type"] self.expr_visitor = ExprVisitor(self) - self.loop_variables: list[Optional[VarInfo]] = [] + self.loop_variables: list[Optional[VariableAccess]] = [] def analyze(self): if self.func.analysed: @@ -221,8 +228,8 @@ def analyze(self): self.expr_visitor.visit(kwarg.default_value, kwarg.typ) @contextlib.contextmanager - def enter_for_loop(self, varinfo: Optional[VarInfo]): - self.loop_variables.append(varinfo) + def enter_for_loop(self, varaccess: Optional[VariableAccess]): + self.loop_variables.append(varaccess) try: yield finally: @@ -329,14 +336,19 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - assert (varinfo := info.get_closest_varinfo()) is not None - info._writes.add(varinfo) + base_var = target + while isinstance(base_var, vy_ast.Subscript): + base_var = base_var.value + + base_info = get_expr_info(base_var) + assert (var_access := base_info.get_variable_access()) is not None + info._writes.add(VariableWrite.from_base(var_access)) def _check_module_use(self, target: vy_ast.ExprNode): module_infos = [] for t in get_expr_info(target).attribute_chain: - if t.expr_info.module_info is not None: - module_infos.append(t.expr_info.module_info) + if t.module_info is not None: + module_infos.append(t.module_info) if len(module_infos) == 0: return @@ -442,7 +454,7 @@ def _analyse_list_iter(self, iter_node, target_type): # get the root varinfo from iter_val in case we need to peer # through folded constants info = get_expr_info(iter_val) - return info.get_closest_varinfo() + return info.get_variable_access() def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): @@ -450,13 +462,13 @@ def visit_For(self, node): target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) - iter_varinfo = None + iter_var = None if isinstance(node.iter, vy_ast.Call): self._analyse_range_iter(node.iter, target_type) else: - iter_varinfo = self._analyse_list_iter(node.iter, target_type) + iter_var = self._analyse_list_iter(node.iter, target_type) - with self.namespace.enter_scope(), self.enter_for_loop(iter_varinfo): + with self.namespace.enter_scope(), self.enter_for_loop(iter_var): target_name = node.target.target.id # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( @@ -551,9 +563,9 @@ def visit(self, node, typ): # log variable accesses. # (note writes will get logged as both read+write) - varinfo = info.var_info - if varinfo is not None: - info._reads.add(varinfo) + var_access = info.get_variable_access() + if var_access is not None: + info._reads.add(VariableRead.from_base(var_access)) if self.function_analyzer: for s in self.function_analyzer.loop_variables: @@ -561,14 +573,16 @@ def visit(self, node, typ): continue if s in info._writes: + print(s, info._writes) msg = "Cannot modify loop variable" - if s.decl_node is not None: + var = s.variable + if var.decl_node is not None: msg += f" `{s.decl_node.target.id}`" - raise ImmutableViolation(msg, s.decl_node, node) + raise ImmutableViolation(msg, var.decl_node, node) variable_accesses = info._writes | info._reads for s in variable_accesses: - if s.is_module_variable(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node) self.func.mark_variable_writes(info._writes) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index fbde793e6a..64af036242 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,13 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ( - ExprInfo, - Modifiability, - ModuleInfo, - VarAttributeInfo, - VarInfo, -) +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -91,23 +85,20 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # note: Attribute(expr value, identifier attr) info = self.get_expr_info(node.value, is_callable=is_callable) - attr_info = AttributeInfo(attr, info) - attribute_chain = info.attribute_chain + [attr_info] - attr = node.attr + + attribute_chain = info.attribute_chain + [info] + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain, attr=attr) if isinstance(t, ModuleInfo): - return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain, attr=attr) - # it's something else, like my_struct.foo - assert (varinfo := info.var_info) is not None - child_varinfo = VarAttributeInfo.from_varinfo(varinfo=varinfo, attr=attr, typ=t) - return ExprInfo.from_varinfo(child_varinfo) + return info.copy_with_type(t, attribute_chain=attribute_chain, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 1b612c9b81..bb65a10ad9 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -21,7 +21,8 @@ Modifiability, ModuleInfo, StateMutability, - VarInfo, + VariableRead, + VariableWrite, VarOffset, ) from vyper.semantics.analysis.utils import ( @@ -119,10 +120,10 @@ def __init__( self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # writes to variables from this function - self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + self._variable_writes: OrderedSet[VariableWrite] = OrderedSet() # reads of variables from this function - self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + self._variable_reads: OrderedSet[VariableRead] = OrderedSet() # list of modules used (accessed state) by this function self._used_modules: OrderedSet[ModuleInfo] = OrderedSet()