Skip to content

Commit

Permalink
Allow to define the float precision when converting types
Browse files Browse the repository at this point in the history
  • Loading branch information
msaelices committed Sep 23, 2023
1 parent 1992d05 commit bd69c66
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 6 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ You can read the usage by running `py2mojo --help`:

```bash
❯ py2mojo --help
usage: py2mojo [-h] [--inplace] [--extension {mojo,🔥}] [--convert-def-to-fn | --no-convert-def-to-fn] [--convert-class-to-struct | --no-convert-class-to-struct] filenames [filenames ...]
usage: py2mojo [-h] [--inplace] [--extension {mojo,🔥}] [--convert-def-to-fn | --no-convert-def-to-fn] [--convert-class-to-struct | --no-convert-class-to-struct] [--float-precision {32,64}]
filenames [filenames ...]

positional arguments:
filenames
Expand All @@ -27,6 +28,7 @@ options:
--extension {mojo,🔥} File extension of the generated files
--convert-def-to-fn, --no-convert-def-to-fn
--convert-class-to-struct, --no-convert-class-to-struct
--float-precision {32,64}
```

Examples:
Expand Down
2 changes: 1 addition & 1 deletion py2mojo/converters/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _replace_assignment(tokens: list[Token], i: int, rules: RuleSet, new_type: s
def convert_assignment(node: ast.AnnAssign, rules: RuleSet) -> Iterable:
"""Convert an assignment to a mojo assignment."""
curr_type = get_annotation_type(node.annotation)
new_type = get_mojo_type(curr_type)
new_type = get_mojo_type(curr_type, rules)
if not new_type:
return

Expand Down
4 changes: 2 additions & 2 deletions py2mojo/converters/functiondef.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def convert_functiondef(node: ast.FunctionDef, rules: RuleSet = 0) -> Iterable:
node, 'For converting a "def" function to "fn", the declaration needs to be fully type annotated'
)
curr_type = get_annotation_type(arg.annotation)
new_type = get_mojo_type(curr_type)
new_type = get_mojo_type(curr_type, rules)

if not new_type:
continue
Expand All @@ -85,7 +85,7 @@ def convert_functiondef(node: ast.FunctionDef, rules: RuleSet = 0) -> Iterable:

if node.returns:
curr_type = get_annotation_type(node.returns)
new_type = get_mojo_type(curr_type)
new_type = get_mojo_type(curr_type, rules)
if not new_type:
return

Expand Down
6 changes: 4 additions & 2 deletions py2mojo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from rich.text import Text
from tokenize_rt import UNIMPORTANT_WS, Offset, Token

from .rules import RuleSet


def ast_to_offset(node: ast.expr | ast.stmt) -> Offset:
return Offset(node.lineno, node.col_offset)
Expand Down Expand Up @@ -103,11 +105,11 @@ def get_annotation_type(node: ast.AST) -> str:
return curr_type


def get_mojo_type(curr_type: str) -> str:
def get_mojo_type(curr_type: str, rules: RuleSet) -> str:
"""Returns the corresponding Mojo type for the given Python type."""
patterns = [
(re.compile(r'int'), 'Int'),
(re.compile(r'float'), 'Float64'),
(re.compile(r'float'), f'Float{rules.float_precision}'),
]

prev_type = ''
Expand Down
6 changes: 6 additions & 0 deletions py2mojo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def main(argv: Sequence[str] | None = None) -> int:
default=True,
action=argparse.BooleanOptionalAction,
)
parser.add_argument(
'--float-precision',
default=32,
type=int,
choices=[32, 64],
)
args = parser.parse_args(argv)

for filename in args.filenames:
Expand Down
2 changes: 2 additions & 0 deletions py2mojo/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
class RuleSet:
convert_def_to_fn: bool = False
convert_class_to_struct: bool = False
float_precision: int = 64


def get_rules(args: argparse.Namespace) -> RuleSet:
return RuleSet(
convert_def_to_fn=args.convert_def_to_fn,
convert_class_to_struct=args.convert_class_to_struct,
float_precision=args.float_precision,
)
6 changes: 6 additions & 0 deletions tests/test_assignment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from helpers import validate
from py2mojo.rules import RuleSet


def test_assignment_with_basic_types():
Expand All @@ -10,6 +11,11 @@ def test_assignment_with_basic_types():
'x: float = 10.5',
'var x: Float64 = 10.5',
)
validate(
'x: float = 10.5',
'var x: Float32 = 10.5',
rules=RuleSet(float_precision=32),
)
validate(
'x: str = "foo"',
'var x: str = "foo"',
Expand Down
8 changes: 8 additions & 0 deletions tests/test_functiondef.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def test_functiondef_with_list_types(python_type, mojo_type):
)


def test_functiondef_with_float_in_precision():
validate(
'def add(x: float, y: float) -> float: return x + y',
'fn add(x: Float32, y: Float32) -> Float32: return x + y',
rules=RuleSet(convert_def_to_fn=True, float_precision=32),
)


def test_functiondef_inside_classes():
validate(
'''
Expand Down

0 comments on commit bd69c66

Please sign in to comment.