From bd69c6609b85637e11f969305d2b1ba770898f8f Mon Sep 17 00:00:00 2001 From: Manuel Saelices Date: Sun, 24 Sep 2023 00:17:56 +0200 Subject: [PATCH] Allow to define the float precision when converting types --- README.md | 4 +++- py2mojo/converters/assignment.py | 2 +- py2mojo/converters/functiondef.py | 4 ++-- py2mojo/helpers.py | 6 ++++-- py2mojo/main.py | 6 ++++++ py2mojo/rules.py | 2 ++ tests/test_assignment.py | 6 ++++++ tests/test_functiondef.py | 8 ++++++++ 8 files changed, 32 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d2bbb15..3eaa0c2 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,8 @@ You can read the usage by running `py2mojo --help`: ```bash ❯ py2mojo --help -usage: py2mojo [-h] [--inplace] [--extension {mojo,🔥}] [--convert-def-to-fn | --no-convert-def-to-fn] [--convert-class-to-struct | --no-convert-class-to-struct] filenames [filenames ...] +usage: py2mojo [-h] [--inplace] [--extension {mojo,🔥}] [--convert-def-to-fn | --no-convert-def-to-fn] [--convert-class-to-struct | --no-convert-class-to-struct] [--float-precision {32,64}] + filenames [filenames ...] positional arguments: filenames @@ -27,6 +28,7 @@ options: --extension {mojo,🔥} File extension of the generated files --convert-def-to-fn, --no-convert-def-to-fn --convert-class-to-struct, --no-convert-class-to-struct + --float-precision {32,64} ``` Examples: diff --git a/py2mojo/converters/assignment.py b/py2mojo/converters/assignment.py index 8cfd2fe..1aacb86 100644 --- a/py2mojo/converters/assignment.py +++ b/py2mojo/converters/assignment.py @@ -23,7 +23,7 @@ def _replace_assignment(tokens: list[Token], i: int, rules: RuleSet, new_type: s def convert_assignment(node: ast.AnnAssign, rules: RuleSet) -> Iterable: """Convert an assignment to a mojo assignment.""" curr_type = get_annotation_type(node.annotation) - new_type = get_mojo_type(curr_type) + new_type = get_mojo_type(curr_type, rules) if not new_type: return diff --git a/py2mojo/converters/functiondef.py b/py2mojo/converters/functiondef.py index cb5ffc6..ea7435d 100644 --- a/py2mojo/converters/functiondef.py +++ b/py2mojo/converters/functiondef.py @@ -69,7 +69,7 @@ def convert_functiondef(node: ast.FunctionDef, rules: RuleSet = 0) -> Iterable: node, 'For converting a "def" function to "fn", the declaration needs to be fully type annotated' ) curr_type = get_annotation_type(arg.annotation) - new_type = get_mojo_type(curr_type) + new_type = get_mojo_type(curr_type, rules) if not new_type: continue @@ -85,7 +85,7 @@ def convert_functiondef(node: ast.FunctionDef, rules: RuleSet = 0) -> Iterable: if node.returns: curr_type = get_annotation_type(node.returns) - new_type = get_mojo_type(curr_type) + new_type = get_mojo_type(curr_type, rules) if not new_type: return diff --git a/py2mojo/helpers.py b/py2mojo/helpers.py index 4af4359..76e0545 100644 --- a/py2mojo/helpers.py +++ b/py2mojo/helpers.py @@ -6,6 +6,8 @@ from rich.text import Text from tokenize_rt import UNIMPORTANT_WS, Offset, Token +from .rules import RuleSet + def ast_to_offset(node: ast.expr | ast.stmt) -> Offset: return Offset(node.lineno, node.col_offset) @@ -103,11 +105,11 @@ def get_annotation_type(node: ast.AST) -> str: return curr_type -def get_mojo_type(curr_type: str) -> str: +def get_mojo_type(curr_type: str, rules: RuleSet) -> str: """Returns the corresponding Mojo type for the given Python type.""" patterns = [ (re.compile(r'int'), 'Int'), - (re.compile(r'float'), 'Float64'), + (re.compile(r'float'), f'Float{rules.float_precision}'), ] prev_type = '' diff --git a/py2mojo/main.py b/py2mojo/main.py index 5d04c0b..1647685 100644 --- a/py2mojo/main.py +++ b/py2mojo/main.py @@ -105,6 +105,12 @@ def main(argv: Sequence[str] | None = None) -> int: default=True, action=argparse.BooleanOptionalAction, ) + parser.add_argument( + '--float-precision', + default=32, + type=int, + choices=[32, 64], + ) args = parser.parse_args(argv) for filename in args.filenames: diff --git a/py2mojo/rules.py b/py2mojo/rules.py index 0d4f432..2fdc9f4 100644 --- a/py2mojo/rules.py +++ b/py2mojo/rules.py @@ -6,10 +6,12 @@ class RuleSet: convert_def_to_fn: bool = False convert_class_to_struct: bool = False + float_precision: int = 64 def get_rules(args: argparse.Namespace) -> RuleSet: return RuleSet( convert_def_to_fn=args.convert_def_to_fn, convert_class_to_struct=args.convert_class_to_struct, + float_precision=args.float_precision, ) diff --git a/tests/test_assignment.py b/tests/test_assignment.py index 0bc9f81..5db81c1 100644 --- a/tests/test_assignment.py +++ b/tests/test_assignment.py @@ -1,4 +1,5 @@ from helpers import validate +from py2mojo.rules import RuleSet def test_assignment_with_basic_types(): @@ -10,6 +11,11 @@ def test_assignment_with_basic_types(): 'x: float = 10.5', 'var x: Float64 = 10.5', ) + validate( + 'x: float = 10.5', + 'var x: Float32 = 10.5', + rules=RuleSet(float_precision=32), + ) validate( 'x: str = "foo"', 'var x: str = "foo"', diff --git a/tests/test_functiondef.py b/tests/test_functiondef.py index 7f1f049..15ce397 100644 --- a/tests/test_functiondef.py +++ b/tests/test_functiondef.py @@ -62,6 +62,14 @@ def test_functiondef_with_list_types(python_type, mojo_type): ) +def test_functiondef_with_float_in_precision(): + validate( + 'def add(x: float, y: float) -> float: return x + y', + 'fn add(x: Float32, y: Float32) -> Float32: return x + y', + rules=RuleSet(convert_def_to_fn=True, float_precision=32), + ) + + def test_functiondef_inside_classes(): validate( '''