Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: struct touching #22

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,24 +263,6 @@ def parse_Attribute(self):
if addr.value == "address": # for `self.code`
return IRnode.from_list(["~selfcode"], typ=BytesT(0))
return IRnode.from_list(["~extcode", addr], typ=BytesT(0))
# self.x: global attribute
elif (varinfo := self.expr._expr_info.var_info) is not None:
if varinfo.is_constant:
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

location = data_location_to_address_space(
varinfo.location, self.context.is_ctor_context
)

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

# Reserved keywords
elif (
Expand Down Expand Up @@ -336,17 +318,37 @@ def parse_Attribute(self):
"chain.id is unavailable prior to istanbul ruleset", self.expr
)
return IRnode.from_list(["chainid"], typ=UINT256_T)

# Other variables
else:
sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

# self.x: global attribute
if (varinfo := self.expr._expr_info.var_info) is not None:
if varinfo.is_constant:
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

location = data_location_to_address_space(
varinfo.location, self.context.is_ctor_context
)

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

def parse_Subscript(self):
sub = Expr(self.expr.value, self.context).ir_node
Expand Down
65 changes: 45 additions & 20 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,26 @@ def is_constant(self):
return res


@dataclass(frozen=True)
class VariableAccess:
variable: VarInfo
attrs: tuple[str, ...]

@classmethod
def from_base(cls, access: "VariableAccess"):
return cls(**access.__dict__)


@dataclass(frozen=True)
class VariableRead(VariableAccess):
pass


@dataclass(frozen=True)
class VariableWrite(VariableAccess):
pass


@dataclass
class ExprInfo:
"""
Expand All @@ -205,9 +225,8 @@ class ExprInfo:
module_info: Optional[ModuleInfo] = None
location: DataLocation = DataLocation.UNSET
modifiability: Modifiability = Modifiability.MODIFIABLE

# the chain of attribute parents for this expr
attribute_chain: list["ExprInfo"] = field(default_factory=list)
attr: Optional[str] = None

def __post_init__(self):
should_match = ("typ", "location", "modifiability")
Expand All @@ -216,48 +235,54 @@ def __post_init__(self):
if getattr(self.var_info, attr) != getattr(self, attr):
raise CompilerPanic("Bad analysis: non-matching {attr}: {self}")

self._writes: OrderedSet[VarInfo] = OrderedSet()
self._reads: OrderedSet[VarInfo] = OrderedSet()
self.attribute_chain = self.attribute_chain or []

self._writes: OrderedSet[VariableWrite] = OrderedSet()
self._reads: OrderedSet[VariableRead] = OrderedSet()

# find exprinfo in the attribute chain which has a varinfo
# e.x. `x` will return varinfo for `x`
# `module.foo` will return varinfo for `module.foo`
# `self.my_struct.x.y` will return varinfo for `self.my_struct`
def get_root_varinfo(self) -> Optional[VarInfo]:
for expr_info in self.attribute_chain + [self]:
if expr_info.var_info is not None and not isinstance(expr_info.typ, SelfT):
return expr_info.var_info
# `self.my_struct.x.y` will return varinfo for `self.my_struct.x.y`
def get_variable_access(self) -> Optional[VariableAccess]:
chain = self.attribute_chain + [self]
for i, expr_info in enumerate(chain):
varinfo = expr_info.var_info
if varinfo is not None and not isinstance(varinfo, SelfT):
attrs = []
for expr_info in chain[i:]:
if expr_info.attr is None:
break
attrs.append(expr_info.attr)
return VariableAccess(varinfo, tuple(attrs))
return None

@classmethod
def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo":
def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo":
return cls(
var_info.typ,
var_info=var_info,
location=var_info.location,
modifiability=var_info.modifiability,
attribute_chain=attribute_chain or [],
**kwargs,
)

@classmethod
def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo":
def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo":
modifiability = Modifiability.RUNTIME_CONSTANT
if module_info.ownership >= ModuleOwnership.USES:
modifiability = Modifiability.MODIFIABLE

return cls(
module_info.module_t,
module_info=module_info,
modifiability=modifiability,
attribute_chain=attribute_chain or [],
module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs
)

def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo":
def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo":
"""
Return a copy of the ExprInfo but with the type set to something else
"""
to_copy = ("location", "modifiability")
fields = {k: getattr(self, k) for k in to_copy}
if attribute_chain is not None:
fields["attribute_chain"] = attribute_chain
return self.__class__(typ=typ, **fields)
for t in to_copy:
assert t not in kwargs
return self.__class__(typ=typ, **fields, **kwargs)
46 changes: 29 additions & 17 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
VariableDeclarationException,
VyperException,
)
from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo
from vyper.semantics.analysis.base import (
Modifiability,
ModuleOwnership,
VariableAccess,
VariableRead,
VariableWrite,
VarInfo,
)
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.utils import (
get_common_types,
Expand Down Expand Up @@ -182,7 +189,7 @@ def __init__(
self.func = fn_node._metadata["func_type"]
self.expr_visitor = ExprVisitor(self)

self.loop_variables: list[Optional[VarInfo]] = []
self.loop_variables: list[Optional[VariableAccess]] = []

def analyze(self):
if self.func.analysed:
Expand Down Expand Up @@ -221,8 +228,8 @@ def analyze(self):
self.expr_visitor.visit(kwarg.default_value, kwarg.typ)

@contextlib.contextmanager
def enter_for_loop(self, varinfo: Optional[VarInfo]):
self.loop_variables.append(varinfo)
def enter_for_loop(self, varaccess: Optional[VariableAccess]):
self.loop_variables.append(varaccess)
try:
yield
finally:
Expand Down Expand Up @@ -329,10 +336,13 @@ def _handle_modification(self, target: vy_ast.ExprNode):
if info.modifiability == Modifiability.CONSTANT:
raise ImmutableViolation("Constant value cannot be written to.")

var_info = info.get_root_varinfo()
assert var_info is not None
base_var = target
while isinstance(base_var, vy_ast.Subscript):
base_var = base_var.value

info._writes.add(var_info)
base_info = get_expr_info(base_var)
assert (var_access := base_info.get_variable_access()) is not None
info._writes.add(VariableWrite.from_base(var_access))

def _check_module_use(self, target: vy_ast.ExprNode):
module_infos = []
Expand Down Expand Up @@ -444,21 +454,21 @@ def _analyse_list_iter(self, iter_node, target_type):
# get the root varinfo from iter_val in case we need to peer
# through folded constants
info = get_expr_info(iter_val)
return info.get_root_varinfo()
return info.get_variable_access()

def visit_For(self, node):
if not isinstance(node.target.target, vy_ast.Name):
raise StructureException("Invalid syntax for loop iterator", node.target.target)

target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY)

iter_varinfo = None
iter_var = None
if isinstance(node.iter, vy_ast.Call):
self._analyse_range_iter(node.iter, target_type)
else:
iter_varinfo = self._analyse_list_iter(node.iter, target_type)
iter_var = self._analyse_list_iter(node.iter, target_type)

with self.namespace.enter_scope(), self.enter_for_loop(iter_varinfo):
with self.namespace.enter_scope(), self.enter_for_loop(iter_var):
target_name = node.target.target.id
# maybe we should introduce a new Modifiability: LOOP_VARIABLE
self.namespace[target_name] = VarInfo(
Expand Down Expand Up @@ -553,24 +563,26 @@ def visit(self, node, typ):

# log variable accesses.
# (note writes will get logged as both read+write)
varinfo = info.var_info
if varinfo is not None:
info._reads.add(varinfo)
var_access = info.get_variable_access()
if var_access is not None:
info._reads.add(VariableRead.from_base(var_access))

if self.function_analyzer:
for s in self.function_analyzer.loop_variables:
if s is None:
continue

if s in info._writes:
print(s, info._writes)
msg = "Cannot modify loop variable"
if s.decl_node is not None:
var = s.variable
if var.decl_node is not None:
msg += f" `{s.decl_node.target.id}`"
raise ImmutableViolation(msg, s.decl_node, node)
raise ImmutableViolation(msg, var.decl_node, node)

variable_accesses = info._writes | info._reads
for s in variable_accesses:
if s.is_module_variable():
if s.variable.is_module_variable():
self.function_analyzer._check_module_use(node)

self.func.mark_variable_writes(info._writes)
Expand Down
14 changes: 6 additions & 8 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,28 +84,26 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex
# propagate the parent exprinfo members down into the new expr
# note: Attribute(expr value, identifier attr)

name = node.attr
info = self.get_expr_info(node.value, is_callable=is_callable)
attr = node.attr

attribute_chain = info.attribute_chain + [info]

t = info.typ.get_member(name, node)
t = info.typ.get_member(attr, node)

# it's a top-level variable
if isinstance(t, VarInfo):
return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain)
return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain, attr=attr)

if isinstance(t, ModuleInfo):
return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain)
return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain, attr=attr)

# it's something else, like my_struct.foo
return info.copy_with_type(t, attribute_chain=attribute_chain)
return info.copy_with_type(t, attribute_chain=attribute_chain, attr=attr)

# If it's a Subscript, propagate the subscriptable varinfo
if isinstance(node, vy_ast.Subscript):
info = self.get_expr_info(node.value)
attribute_chain = info.attribute_chain + [info]
return info.copy_with_type(t, attribute_chain=attribute_chain)
return info.copy_with_type(t)

return ExprInfo(t)

Expand Down
7 changes: 4 additions & 3 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
Modifiability,
ModuleInfo,
StateMutability,
VarInfo,
VariableRead,
VariableWrite,
VarOffset,
)
from vyper.semantics.analysis.utils import (
Expand Down Expand Up @@ -119,10 +120,10 @@ def __init__(
self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet()

# writes to variables from this function
self._variable_writes: OrderedSet[VarInfo] = OrderedSet()
self._variable_writes: OrderedSet[VariableWrite] = OrderedSet()

# reads of variables from this function
self._variable_reads: OrderedSet[VarInfo] = OrderedSet()
self._variable_reads: OrderedSet[VariableRead] = OrderedSet()

# list of modules used (accessed state) by this function
self._used_modules: OrderedSet[ModuleInfo] = OrderedSet()
Expand Down
Loading