diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 335cfefb87..9c7f11dcb3 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -263,24 +263,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif (varinfo := self.expr._expr_info.var_info) is not None: - if varinfo.is_constant: - return Expr.parse_value_expr(varinfo.decl_node.value, self.context) - - location = data_location_to_address_space( - varinfo.location, self.context.is_ctor_context - ) - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -336,17 +318,37 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global attribute + if (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) + + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation="self." + self.expr.attr, + ) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + # contract type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e59132bd82..81f543cb17 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -194,6 +194,26 @@ def is_constant(self): return res +@dataclass(frozen=True) +class VariableAccess: + variable: VarInfo + attrs: tuple[str, ...] + + @classmethod + def from_base(cls, access: "VariableAccess"): + return cls(**access.__dict__) + + +@dataclass(frozen=True) +class VariableRead(VariableAccess): + pass + + +@dataclass(frozen=True) +class VariableWrite(VariableAccess): + pass + + @dataclass class ExprInfo: """ @@ -205,9 +225,8 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - - # the chain of attribute parents for this expr attribute_chain: list["ExprInfo"] = field(default_factory=list) + attr: Optional[str] = None def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -216,48 +235,54 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") - self._writes: OrderedSet[VarInfo] = OrderedSet() - self._reads: OrderedSet[VarInfo] = OrderedSet() + self.attribute_chain = self.attribute_chain or [] + + self._writes: OrderedSet[VariableWrite] = OrderedSet() + self._reads: OrderedSet[VariableRead] = OrderedSet() # find exprinfo in the attribute chain which has a varinfo # e.x. `x` will return varinfo for `x` # `module.foo` will return varinfo for `module.foo` - # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_root_varinfo(self) -> Optional[VarInfo]: - for expr_info in self.attribute_chain + [self]: - if expr_info.var_info is not None and not isinstance(expr_info.typ, SelfT): - return expr_info.var_info + # `self.my_struct.x.y` will return varinfo for `self.my_struct.x.y` + def get_variable_access(self) -> Optional[VariableAccess]: + chain = self.attribute_chain + [self] + for i, expr_info in enumerate(chain): + varinfo = expr_info.var_info + if varinfo is not None and not isinstance(varinfo, SelfT): + attrs = [] + for expr_info in chain[i:]: + if expr_info.attr is None: + break + attrs.append(expr_info.attr) + return VariableAccess(varinfo, tuple(attrs)) return None @classmethod - def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, - attribute_chain=attribute_chain or [], + **kwargs, ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo": modifiability = Modifiability.RUNTIME_CONSTANT if module_info.ownership >= ModuleOwnership.USES: modifiability = Modifiability.MODIFIABLE return cls( - module_info.module_t, - module_info=module_info, - modifiability=modifiability, - attribute_chain=attribute_chain or [], + module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs ) - def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": + def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} - if attribute_chain is not None: - fields["attribute_chain"] = attribute_chain - return self.__class__(typ=typ, **fields) + for t in to_copy: + assert t not in kwargs + return self.__class__(typ=typ, **fields, **kwargs) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0ee03fc9bd..f19dc3e4e7 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -20,7 +20,14 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo +from vyper.semantics.analysis.base import ( + Modifiability, + ModuleOwnership, + VariableAccess, + VariableRead, + VariableWrite, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -182,7 +189,7 @@ def __init__( self.func = fn_node._metadata["func_type"] self.expr_visitor = ExprVisitor(self) - self.loop_variables: list[Optional[VarInfo]] = [] + self.loop_variables: list[Optional[VariableAccess]] = [] def analyze(self): if self.func.analysed: @@ -221,8 +228,8 @@ def analyze(self): self.expr_visitor.visit(kwarg.default_value, kwarg.typ) @contextlib.contextmanager - def enter_for_loop(self, varinfo: Optional[VarInfo]): - self.loop_variables.append(varinfo) + def enter_for_loop(self, varaccess: Optional[VariableAccess]): + self.loop_variables.append(varaccess) try: yield finally: @@ -329,10 +336,13 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - var_info = info.get_root_varinfo() - assert var_info is not None + base_var = target + while isinstance(base_var, vy_ast.Subscript): + base_var = base_var.value - info._writes.add(var_info) + base_info = get_expr_info(base_var) + assert (var_access := base_info.get_variable_access()) is not None + info._writes.add(VariableWrite.from_base(var_access)) def _check_module_use(self, target: vy_ast.ExprNode): module_infos = [] @@ -444,7 +454,7 @@ def _analyse_list_iter(self, iter_node, target_type): # get the root varinfo from iter_val in case we need to peer # through folded constants info = get_expr_info(iter_val) - return info.get_root_varinfo() + return info.get_variable_access() def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): @@ -452,13 +462,13 @@ def visit_For(self, node): target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) - iter_varinfo = None + iter_var = None if isinstance(node.iter, vy_ast.Call): self._analyse_range_iter(node.iter, target_type) else: - iter_varinfo = self._analyse_list_iter(node.iter, target_type) + iter_var = self._analyse_list_iter(node.iter, target_type) - with self.namespace.enter_scope(), self.enter_for_loop(iter_varinfo): + with self.namespace.enter_scope(), self.enter_for_loop(iter_var): target_name = node.target.target.id # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( @@ -553,9 +563,9 @@ def visit(self, node, typ): # log variable accesses. # (note writes will get logged as both read+write) - varinfo = info.var_info - if varinfo is not None: - info._reads.add(varinfo) + var_access = info.get_variable_access() + if var_access is not None: + info._reads.add(VariableRead.from_base(var_access)) if self.function_analyzer: for s in self.function_analyzer.loop_variables: @@ -563,14 +573,16 @@ def visit(self, node, typ): continue if s in info._writes: + print(s, info._writes) msg = "Cannot modify loop variable" - if s.decl_node is not None: + var = s.variable + if var.decl_node is not None: msg += f" `{s.decl_node.target.id}`" - raise ImmutableViolation(msg, s.decl_node, node) + raise ImmutableViolation(msg, var.decl_node, node) variable_accesses = info._writes | info._reads for s in variable_accesses: - if s.is_module_variable(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node) self.func.mark_variable_writes(info._writes) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..64af036242 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -84,28 +84,26 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # propagate the parent exprinfo members down into the new expr # note: Attribute(expr value, identifier attr) - name = node.attr info = self.get_expr_info(node.value, is_callable=is_callable) + attr = node.attr attribute_chain = info.attribute_chain + [info] - t = info.typ.get_member(name, node) + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain, attr=attr) if isinstance(t, ModuleInfo): - return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain, attr=attr) - # it's something else, like my_struct.foo - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t, attribute_chain=attribute_chain, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - attribute_chain = info.attribute_chain + [info] - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t) return ExprInfo(t) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 1b612c9b81..bb65a10ad9 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -21,7 +21,8 @@ Modifiability, ModuleInfo, StateMutability, - VarInfo, + VariableRead, + VariableWrite, VarOffset, ) from vyper.semantics.analysis.utils import ( @@ -119,10 +120,10 @@ def __init__( self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # writes to variables from this function - self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + self._variable_writes: OrderedSet[VariableWrite] = OrderedSet() # reads of variables from this function - self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + self._variable_reads: OrderedSet[VariableRead] = OrderedSet() # list of modules used (accessed state) by this function self._used_modules: OrderedSet[ModuleInfo] = OrderedSet()