Skip to content

Commit

Permalink
separate strict and static analysis
Browse files Browse the repository at this point in the history
Summary:
Only runs strict analysis when the strict flag is set to true,
which only happens when there is `import __static__` or force strict is
set.

This should allow us to remove the cases of `from __strict__ import allow_side_effects`.

Reviewed By: carljm

Differential Revision: D49736502

fbshipit-source-id: e9ee49cefb02bf4dfc53e3c1282abbf733d2c8e7
  • Loading branch information
pilleye authored and facebook-github-bot committed Nov 10, 2023
1 parent d5f7bef commit 4aa9d7d
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 29 deletions.
4 changes: 2 additions & 2 deletions CinderX/test_cinderx/test_compiler/test_static/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,12 @@ def compile_strict(
):
compiler = self.get_strict_compiler(enable_patching=enable_patching)

code, is_valid_strict = compiler.load_compiled_module_from_source(
code, is_valid_strict, _is_static = compiler.load_compiled_module_from_source(
self.clean_code(codestr),
f"{modname}.py",
modname,
optimize,
override_flags=override_flags,
override_flags=override_flags or Flags(is_strict=True),
)
assert is_valid_strict
return code
Expand Down
2 changes: 1 addition & 1 deletion CinderX/test_cinderx/test_compiler/test_strict/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def check_and_compile(
):
compiler = Compiler([], "", [], [])
source = inspect.cleandoc("\n" + source)
code, is_valid_strict = compiler.load_compiled_module_from_source(
code, is_valid_strict, _is_static = compiler.load_compiled_module_from_source(
source, f"{modname}.py", modname, optimize
)
assert is_valid_strict
Expand Down
27 changes: 20 additions & 7 deletions CinderX/test_cinderx/test_compiler/test_strict/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from compiler.strict.compiler import StrictModuleError
from compiler.strict.loader import (
_MAGIC_LEN,
_MAGIC_NONSTRICT,
_MAGIC_STRICT,
_MAGIC_NEITHER_STRICT_NOR_STATIC,
_MAGIC_STRICT_OR_STATIC,
install,
StrictModule,
StrictModuleTestingPatchProxy,
Expand Down Expand Up @@ -357,7 +357,7 @@ def test_magic_number(self) -> None:
mod = self.sbx.strict_import("a")

with open(mod.__cached__, "rb") as fh:
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_STRICT)
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_STRICT_OR_STATIC)

BAD_MAGIC = (65535).to_bytes(2, "little") + b"\r\n"

Expand All @@ -368,15 +368,15 @@ def test_magic_number(self) -> None:
mod2 = self.sbx.strict_import("a")

with open(mod2.__cached__, "rb") as fh:
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_STRICT)
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_STRICT_OR_STATIC)

def test_magic_number_non_strict(self) -> None:
"""Extra magic number is written to pycs, and validated."""
self.sbx.write_file("a.py", "x=2")
mod = self.sbx.strict_import("a")

with open(mod.__cached__, "rb") as fh:
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_NONSTRICT)
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_NEITHER_STRICT_NOR_STATIC)

BAD_MAGIC = (65535).to_bytes(2, "little") + b"\r\n"

Expand All @@ -387,7 +387,7 @@ def test_magic_number_non_strict(self) -> None:
mod2 = self.sbx.strict_import("a")

with open(mod2.__cached__, "rb") as fh:
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_NONSTRICT)
self.assertEqual(fh.read(_MAGIC_LEN), _MAGIC_NEITHER_STRICT_NOR_STATIC)

def test_strict_loader_toggle(self) -> None:
"""Repeat imports with strict module loader toggled off/on/off work correctly."""
Expand Down Expand Up @@ -462,6 +462,7 @@ def test_cross_module_static(self) -> None:
self.sbx.write_file(
"astatic.py",
"""
import __strict__
import __static__
class C:
def f(self) -> int:
Expand All @@ -471,6 +472,7 @@ def f(self) -> int:
self.sbx.write_file(
"bstatic.py",
"""
import __strict__
import __static__
from astatic import C
def f() -> int:
Expand All @@ -492,6 +494,7 @@ def test_cross_module_static_typestub(self) -> None:
self.sbx.write_file(
"math.pyi",
"""
import __strict__
import __static__
def gcd(a: int, b: int) -> int:
Expand All @@ -501,6 +504,7 @@ def gcd(a: int, b: int) -> int:
self.sbx.write_file(
"bstatic.py",
"""
import __strict__
import __static__
from math import gcd
def e() -> int:
Expand All @@ -524,7 +528,7 @@ def gcd(a: int, b: int) -> int:
self.sbx.write_file(
"bstatic.py",
"""
import __static__
import __strict__
from math import gcd
def e() -> int:
return gcd(15, 25)
Expand All @@ -540,6 +544,7 @@ def test_cross_module_static_typestub_ensure_types_untrusted(self) -> None:
self.sbx.write_file(
"math.pyi",
"""
import __strict__
import __static__
def gcd(a: int, b: int) -> int:
Expand All @@ -549,6 +554,7 @@ def gcd(a: int, b: int) -> int:
self.sbx.write_file(
"bstatic.py",
"""
import __strict__
import __static__
from math import gcd
def e() -> int:
Expand All @@ -565,6 +571,7 @@ def test_cross_module_static_typestub_missing(self) -> None:
self.sbx.write_file(
"astatic.py",
"""
import __strict__
import __static__
from math import gcd
def e() -> int:
Expand Down Expand Up @@ -2030,6 +2037,7 @@ def test_static_python(self) -> None:
self.sbx.write_file(
"a.py",
"""
import __strict__
import __static__
from typing import Optional
class C:
Expand All @@ -2045,6 +2053,7 @@ def test_static_python_del_builtin(self) -> None:
self.sbx.write_file(
"a.py",
"""
import __strict__
import __static__
for int in [1, 2]:
pass
Expand All @@ -2060,6 +2069,7 @@ def test_static_python_import_from_fixed_module(self) -> None:
self.sbx.write_file(
"a.py",
"""
import __strict__
import __static__
from typing import List
""",
Expand All @@ -2071,6 +2081,7 @@ def test_static_python_final_globals_patch(self) -> None:
self.sbx.write_file(
"a.py",
"""
import __strict__
import __static__
from typing import Final
Expand Down Expand Up @@ -2530,6 +2541,7 @@ def maybe_throw(c):
raise ImportError("this module fails to import")
"""
modcode = """
import __strict__
import __static__
from __strict__ import allow_side_effects
from flag import maybe_throw
Expand All @@ -2541,6 +2553,7 @@ class C:
"""
othercode = """
from __future__ import annotations
import __strict__
import __static__
from __strict__ import allow_side_effects
Expand Down
2 changes: 1 addition & 1 deletion Lib/compiler/strict/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# Increment this whenever we change the output of the strict modules
# interpreter. It must stay below 32768 (15 bits), because we use the high bit
# to encode strictness of the module.
MAGIC_NUMBER = 49
MAGIC_NUMBER = 50


DEFAULT_STUB_PATH = os.path.dirname(__file__) + "/stubs"
Expand Down
9 changes: 4 additions & 5 deletions Lib/compiler/strict/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,20 @@ def load_compiled_module_from_source(

if not flags.is_static and not flags.is_strict:
code = self._compile_basic(name, pyast, filename, optimize)
return (code, False)
return (code, False, False)

# TODO: Remove the check when static is enabled in the next diff to isolate errors
is_valid_strict = False
if flags.is_strict or flags.is_static:
if flags.is_strict:
is_valid_strict = self._strict_analyze(
source, flags, symbols, filename, name, submodule_search_locations
)

if flags.is_static:
code = self._compile_static(pyast, symbols, filename, name, optimize)
return (code, is_valid_strict)
return (code, is_valid_strict, True)
else:
code = self._compile_strict(pyast, symbols, filename, name, optimize)
return (code, is_valid_strict)
return (code, is_valid_strict, False)

def _get_source(
self,
Expand Down
32 changes: 19 additions & 13 deletions Lib/compiler/strict/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@
Compiler = Compiler


_MAGIC_STRICT: bytes = (MAGIC_NUMBER + 2**15).to_bytes(2, "little") + b"\r\n"
_MAGIC_STRICT_OR_STATIC: bytes = (MAGIC_NUMBER + 2**15).to_bytes(
2, "little"
) + b"\r\n"
# We don't actually need to increment anything here, because the strict modules
# AST rewrite has no impact on pycs for non-strict modules. So we just always
# use two zero bytes. This simplifies generating "fake" strict pycs for
# known-not-to-be-strict third-party modules.
_MAGIC_NONSTRICT: bytes = (0).to_bytes(2, "little") + b"\r\n"
_MAGIC_LEN: int = len(_MAGIC_STRICT)
_MAGIC_NEITHER_STRICT_NOR_STATIC: bytes = (0).to_bytes(2, "little") + b"\r\n"
_MAGIC_LEN: int = len(_MAGIC_STRICT_OR_STATIC)


@final
Expand Down Expand Up @@ -177,7 +179,7 @@ def __del__(self) -> None:


class StrictSourceFileLoader(SourceFileLoader):
strict: bool = False
strict_or_static: bool = False
compiler: Optional[Compiler] = None
module: Optional[ModuleType] = None

Expand Down Expand Up @@ -261,10 +263,10 @@ def get_data(self, path: bytes | str) -> bytes:
if is_pyc:
self.bytecode_found = True
magic = data[:_MAGIC_LEN]
if magic == _MAGIC_NONSTRICT:
self.strict = False
elif magic == _MAGIC_STRICT:
self.strict = True
if magic == _MAGIC_NEITHER_STRICT_NOR_STATIC:
self.strict_or_static = False
elif magic == _MAGIC_STRICT_OR_STATIC:
self.strict_or_static = True
else:
# This is a bit ugly: OSError is the only kind of error that
# get_code() ignores from get_data(). But this is way better
Expand All @@ -279,7 +281,11 @@ def set_data(self, path: bytes | str, data: bytes, *, _mode=0o666) -> None:
assert isinstance(path, str)
if path.endswith(tuple(BYTECODE_SUFFIXES)):
path = add_strict_tag(path, self.enable_patching)
magic = _MAGIC_STRICT if self.strict else _MAGIC_NONSTRICT
magic = (
_MAGIC_STRICT_OR_STATIC
if self.strict_or_static
else _MAGIC_NEITHER_STRICT_NOR_STATIC
)
data = magic + data
return super().set_data(path, data, _mode=_mode)

Expand Down Expand Up @@ -317,7 +323,7 @@ def source_to_code(
# Let the ast transform attempt to validate the strict module. This
# will return an unmodified module if import __strict__ isn't
# actually at the top-level
code, is_valid_strict = self.ensure_compiler(
code, is_valid_strict, is_static = self.ensure_compiler(
self.import_path,
self.stub_path,
self.allow_list_prefix,
Expand All @@ -333,11 +339,11 @@ def source_to_code(
submodule_search_locations,
override_flags=Flags(is_strict=force),
)
self.strict = is_valid_strict
self.strict_or_static = is_valid_strict or is_static
assert code is not None
return code
self.strict = False

self.strict_or_static = False
return code

def exec_module(self, module: ModuleType) -> None:
Expand All @@ -364,7 +370,7 @@ def exec_module(self, module: ModuleType) -> None:
if cached and spec and spec.cached:
spec.cached = cached

if self.strict:
if self.strict_or_static:
if spec is None:
raise ImportError(f"Missing module spec for {module.__name__}")

Expand Down

0 comments on commit 4aa9d7d

Please sign in to comment.