diff --git a/papermill/tests/test_translators.py b/papermill/tests/test_translators.py index 0edc1f07..7be403c5 100644 --- a/papermill/tests/test_translators.py +++ b/papermill/tests/test_translators.py @@ -628,3 +628,41 @@ def test_translate_comment_sh(test_input, expected): ) def test_translate_codify_sh(parameters, expected): assert translators.BashTranslator.codify(parameters) == expected + + +# Stata section +@pytest.mark.parametrize( + "test_input,expected", + [ + ("foo", """`"foo"'"""), + ("foo bar", """`"foo bar"'"""), + ('foo"bar', """`"foo"bar"'"""), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, '1'), + (False, '0'), + (None, '""'), + ], +) +def test_translate_type_stata(test_input, expected): + assert translators.StataTranslator.translate(test_input) == expected + + +@pytest.mark.parametrize("test_input,expected", [("", '*'), ("foo", '* foo'), ("['best effort']", "* ['best effort']")]) +def test_translate_comment_stata(test_input, expected): + assert translators.StataTranslator.comment(test_input) == expected + + +@pytest.mark.parametrize( + "parameters,expected", + [ + ({"foo": "bar"}, '''* Parameters\nglobal foo = `"bar"'\n'''), + ({"foo": True}, '* Parameters\nglobal foo = 1\n'), + ({"foo": 5}, '* Parameters\nglobal foo = 5\n'), + ({"foo": 1.1}, '* Parameters\nglobal foo = 1.1\n'), + ], +) +def test_translate_codify_stata(parameters, expected): + assert translators.StataTranslator.codify(parameters) == expected diff --git a/papermill/translators.py b/papermill/translators.py index 1cb43d89..d04b305e 100644 --- a/papermill/translators.py +++ b/papermill/translators.py @@ -545,6 +545,31 @@ def assign(cls, name, str_val): return f'{name}={str_val}' +class StataTranslator(Translator): + @classmethod + def translate_escaped_str(cls, str_val): + if isinstance(str_val, str): + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') + return f"""`"{str_val}"'""" + + @classmethod + def translate_none(cls, val): + return '""' + + @classmethod + def translate_bool(cls, val): + return '1' if val else '0' + + @classmethod + def comment(cls, cmt_str): + return f'* {cmt_str}'.strip() + + @classmethod + def assign(cls, name, str_val): + return f'global {name} = {str_val}' + + # Instantiate a PapermillIO instance and register Handlers. papermill_translators = PapermillTranslators() papermill_translators.register("python", PythonTranslator) @@ -559,6 +584,7 @@ def assign(cls, name, str_val): papermill_translators.register("sparkkernel", ScalaTranslator) papermill_translators.register("sparkrkernel", RTranslator) papermill_translators.register("bash", BashTranslator) +papermill_translators.register("stata", StataTranslator) def translate_parameters(kernel_name, language, parameters, comment='Parameters'):