Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

format_to_parts #131

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions fluent.runtime/fluent/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import absolute_import, unicode_literals

import six

import babel
import babel.numbers
import babel.plural
Expand All @@ -9,7 +11,7 @@

from .builtins import BUILTINS
from .prepare import Compiler
from .resolver import ResolverEnvironment, CurrentEnvironment
from .resolver import ResolverEnvironment, CurrentEnvironment, TextElement
from .utils import native_to_fluent
from .fallback import FluentLocalization, AbstractResourceLoader, FluentResourceLoader

Expand Down Expand Up @@ -86,7 +88,10 @@ def _lookup(self, entry_id, term=False):
self._compiled[compiled_id] = self._compiler(entry)
return self._compiled[compiled_id]

def format_pattern(self, pattern, args=None):
def format_to_parts(self, pattern, errors, args=None):
if isinstance(pattern, TextElement):
yield pattern.value
return
if args is not None:
fluent_args = {
argname: native_to_fluent(argvalue)
Expand All @@ -95,17 +100,26 @@ def format_pattern(self, pattern, args=None):
else:
fluent_args = {}

errors = []
env = ResolverEnvironment(context=self,
current=CurrentEnvironment(args=fluent_args),
errors=errors)
for part in pattern(env):
yield part

def format_pattern(self, pattern, args=None):
errors = []
try:
result = pattern(env)
result = ''.join(self.format_part(part) for part in self.format_to_parts(pattern, errors, args=args))
except ValueError as e:
errors.append(e)
result = '{???}'
return [result, errors]

def format_part(self, fluentish):
if isinstance(fluentish, six.string_types):
return fluentish
return fluentish.format(self._babel_locale)

def _get_babel_locale(self):
for l in self.locales:
try:
Expand Down
2 changes: 1 addition & 1 deletion fluent.runtime/fluent/runtime/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def compile_Pattern(self, _, elements, **kwargs):
if len(elements) == 1:
return elements[0]
return resolver.TextElement(
''.join(child(None) for child in elements)
''.join(part for child in elements for part in child(None))
)
121 changes: 63 additions & 58 deletions fluent.runtime/fluent/runtime/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
"""


# Prevent expansion of too long placeables, for memory DOS protection
MAX_PART_LENGTH = 2500


@attr.s
class CurrentEnvironment(object):
# The parts of ResolverEnvironment that we want to mutate (and restore)
Expand Down Expand Up @@ -85,6 +81,19 @@ class BaseResolver(object):
def __call__(self, env):
raise NotImplementedError

def to_type(self, env):
"""
Resolve this to a FluentType.

For simple resolvers of one part, this retains the type of that part.
For multi-part resolvers, format each part to a string and concat
the formatted parts.
"""
parts = list(self(env))
if len(parts) == 1:
return parts[0]
return ''.join(env.context.format_part(part) for part in parts)


class Literal(BaseResolver):
pass
Expand Down Expand Up @@ -124,56 +133,43 @@ def __init__(self, *args, **kwargs):
def __call__(self, env):
if self in env.active_patterns:
env.errors.append(FluentCyclicReferenceError("Cyclic reference"))
return FluentNone()
yield FluentNone()
return
env.active_patterns.add(self)
elements = self.elements
remaining_parts = self.MAX_PARTS - env.part_count
if len(self.elements) > remaining_parts:
env.active_patterns.remove(self)
raise ValueError("Too many parts in message (> {0}), "
"aborting.".format(self.MAX_PARTS))
retval = ''.join(
resolve(element(env), env) for element in elements
)
env.part_count += len(elements)
for element in self.elements:
for part in element(env):
yield part
env.part_count += 1
if env.part_count > self.MAX_PARTS:
raise ValueError("Too many parts in message (> {0}), "
"aborting.".format(self.MAX_PARTS))
env.active_patterns.remove(self)
return retval


def resolve(fluentish, env):
if isinstance(fluentish, FluentType):
return fluentish.format(env.context._babel_locale)
if isinstance(fluentish, six.string_types):
if len(fluentish) > MAX_PART_LENGTH:
raise ValueError(
"Too many characters in placeable "
"({}, max allowed is {})".format(len(fluentish), Pattern.MAX_PARTS)
)
return fluentish


class TextElement(FTL.TextElement, Literal):
def __call__(self, env):
return self.value
yield self.value


class Placeable(FTL.Placeable, BaseResolver):
class NeverIsolatingPlaceable(FTL.Placeable, BaseResolver):
def __call__(self, env):
inner = resolve(self.expression(env), env)
if not env.context.use_isolating:
return inner
return "\u2068" + inner + "\u2069"
for part in self.expression(env):
yield part


class NeverIsolatingPlaceable(FTL.Placeable, BaseResolver):
class Placeable(NeverIsolatingPlaceable):
def __call__(self, env):
inner = resolve(self.expression(env), env)
return inner
if env.context.use_isolating:
yield "\u2068"
for part in self.expression(env):
yield part
if env.context.use_isolating:
yield "\u2069"


class StringLiteral(FTL.StringLiteral, Literal):
def __call__(self, env):
return self.parse()['value']
yield self.parse()['value']


class NumberLiteral(FTL.NumberLiteral, BaseResolver):
Expand All @@ -185,7 +181,7 @@ def __init__(self, value, **kwargs):
self.value = FluentInt(self.value)

def __call__(self, env):
return self.value
yield self.value


class EntryReference(BaseResolver):
Expand All @@ -196,11 +192,12 @@ def __call__(self, env):
pattern = entry.attributes[self.attribute.name]
else:
pattern = entry.value
return pattern(env)
for part in pattern(env):
yield part
except LookupError:
ref_id = reference_to_id(self)
env.errors.append(unknown_reference_error_obj(ref_id))
return FluentNone('{{{}}}'.format(ref_id))
yield FluentNone('{{{}}}'.format(ref_id))


class MessageReference(FTL.MessageReference, EntryReference):
Expand All @@ -213,11 +210,12 @@ def __call__(self, env):
if self.arguments.positional:
env.errors.append(FluentFormatError("Ignored positional arguments passed to term '{0}'"
.format(reference_to_id(self))))
kwargs = {kwarg.name.name: kwarg.value(env) for kwarg in self.arguments.named}
kwargs = {kwarg.name.name: kwarg.value.to_type(env) for kwarg in self.arguments.named}
else:
kwargs = None
with env.modified_for_term_reference(args=kwargs):
return super(TermReference, self).__call__(env)
for part in super(TermReference, self).__call__(env):
yield part


class VariableReference(FTL.VariableReference, BaseResolver):
Expand All @@ -229,13 +227,16 @@ def __call__(self, env):
if env.current.error_for_missing_arg:
env.errors.append(
FluentReferenceError("Unknown external: {0}".format(name)))
return FluentNone(name)
yield FluentNone(name)
return

if isinstance(arg_val, (FluentType, six.text_type)):
return arg_val
env.errors.append(TypeError("Unsupported external type: {0}, {1}"
.format(name, type(arg_val))))
return FluentNone(name)
yield arg_val
else:
env.errors.append(TypeError(
"Unsupported external type: {0}, {1}".format(name, type(arg_val))
))
yield FluentNone(name)


class Attribute(FTL.Attribute, BaseResolver):
Expand All @@ -244,8 +245,9 @@ class Attribute(FTL.Attribute, BaseResolver):

class SelectExpression(FTL.SelectExpression, BaseResolver):
def __call__(self, env):
key = self.selector(env)
return self.select_from_select_expression(env, key=key)
key = self.selector.to_type(env)
for part in self.select_from_select_expression(env, key=key):
yield part

def select_from_select_expression(self, env, key):
default = None
Expand All @@ -254,7 +256,7 @@ def select_from_select_expression(self, env, key):
if variant.default:
default = variant

if match(key, variant.key(env), env):
if match(key, variant.key, env):
found = variant
break

Expand Down Expand Up @@ -283,12 +285,14 @@ def match(val1, val2, env):


class Variant(FTL.Variant, BaseResolver):
pass
def __init__(self, key, value, default=False, **kwargs):
key = key.to_type(None)
super(Variant, self).__init__(key, value, default=default, **kwargs)


class Identifier(FTL.Identifier, BaseResolver):
def __call__(self, env):
return self.name
yield self.name


class CallArguments(FTL.CallArguments, BaseResolver):
Expand All @@ -297,21 +301,22 @@ class CallArguments(FTL.CallArguments, BaseResolver):

class FunctionReference(FTL.FunctionReference, BaseResolver):
def __call__(self, env):
args = [arg(env) for arg in self.arguments.positional]
kwargs = {kwarg.name.name: kwarg.value(env) for kwarg in self.arguments.named}
args = [arg.to_type(env) for arg in self.arguments.positional]
kwargs = {kwarg.name.name: kwarg.value.to_type(env) for kwarg in self.arguments.named}
function_name = self.id.name
try:
function = env.context._functions[function_name]
except LookupError:
env.errors.append(FluentReferenceError("Unknown function: {0}"
.format(function_name)))
return FluentNone(function_name + "()")
yield FluentNone(function_name + "()")
return

try:
return function(*args, **kwargs)
yield function(*args, **kwargs)
except Exception as e:
env.errors.append(e)
return FluentNone(function_name + "()")
yield FluentNone(function_name + "()")


class NamedArgument(FTL.NamedArgument, BaseResolver):
Expand Down
2 changes: 1 addition & 1 deletion fluent.runtime/tests/test_bomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_max_length_protection(self):
val, errs = self.ctx.format_pattern(self.ctx.get_message('lolz').value)
self.assertEqual(val, '{???}')
self.assertNotEqual(len(errs), 0)
self.assertIn('Too many characters', str(errs[-1]))
self.assertIn('Too many parts', str(errs[-1]))

def test_max_expansions_protection(self):
# Without protection, emptylolz will take a really long time to
Expand Down