Skip to content

Commit

Permalink
- enable configuration of default transforms for measures
Browse files Browse the repository at this point in the history
- add '1' SQL transform, aka "any"
  • Loading branch information
danfrankcb committed Aug 30, 2018
1 parent d00a0e7 commit 8009d35
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 49 deletions.
49 changes: 24 additions & 25 deletions mensor/backends/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class SQLDialect(object):
'sum': lambda x: "SUM({})".format(x),
'mean': lambda x: "AVG({})".format(x),
'sos': lambda x: "SUM(POW({}, 2))".format(x),
'count': lambda x: "COUNT({})".format(x)
'count': lambda x: "COUNT({})".format(x),
'1': lambda x: "1",
}

TEMPLATE_BASE = textwrap.dedent("""
Expand Down Expand Up @@ -187,7 +188,7 @@ def dialect(self):

def query(self, sql):
print(sql)
raise NotImplementedError("This SQLExecutor goes no further.")
raise NotImplementedError("DebugSQLExecutor prints SQL but cannot execute.")


class SQLMeasureProvider(MeasureProvider):
Expand All @@ -197,15 +198,14 @@ class SQLMeasureProvider(MeasureProvider):

@classmethod
def _on_registered(cls, key):
for agg in ['sum', 'mean', 'sos', 'count']:
for agg in ['sum', 'mean', 'sos', 'count', '1']:
global_stats_registry.aggregations.register(
name=agg,
backend=key,
agg=eval("lambda field, dialect: dialect.AGG_METHODS['{}'](field)".format(agg), {}, {})
)

def __init__(self, *args, sql=None, executor=None, **kwargs):

if not executor:
executor = DebugSQLExecutor()
elif isinstance(executor, str):
Expand Down Expand Up @@ -256,14 +256,13 @@ def get_sql(self, *args, **kwargs):

def _get_ir(self, unit_type, measures, segment_by, where, joins, stats_registry, stats, covariates, **opts):
field_map = self._field_map(unit_type, measures, segment_by, joins)
rebase_agg = not unit_type.is_unique
sql = self._template_environment.get_template(self.dialect.TEMPLATE_BASE).render(
_sql=self._sql(unit_type=unit_type, measures=measures, segment_by=segment_by, where=where, joins=joins, stats=stats, covariates=covariates, **opts),
field_map=field_map,
provider=self,
table_name=self._table_name(unit_type),
dimensions=self._get_dimensions_sql(field_map, segment_by),
measures=self._get_measures_sql(field_map, unit_type, measures, rebase_agg, stats_registry, stats, covariates),
measures=self._get_measures_sql(field_map, unit_type, measures, stats_registry, stats, covariates),
groupby=self._get_groupby_sql(field_map, segment_by),
joins=joins,
constraints=self._get_where_sql(field_map, where),
Expand Down Expand Up @@ -325,29 +324,29 @@ def _get_dimensions_sql(self, field_map, dimensions):
)
return dims

def _get_measures_sql(self, field_map, unit_type, measures, rebase_agg, stats_registry, stats, covariates):
def _get_measures_sql(self, field_map, unit_type, measures, stats_registry, stats, covariates):
aggs = []

rebase_agg = not unit_type.is_unique
if rebase_agg and stats:
raise NotImplementedError("Computing stats and rebasing units simultaneously has not been implemented for the SQL backend.")
else:
for measure in measures:
if not measure.private:
for fieldname, transforms in measure.get_fields(unit_type=unit_type, stats=stats, stats_registry=stats_registry, rebase_agg=rebase_agg).items():

field = '1' if measure == 'count' else field_map['measures'][measure.via_name]
if transforms.get('pre_agg'):
field = transforms['pre_agg'](field, self.dialect)
field = transforms['agg'](field, self.dialect)
if transforms.get('post_agg'):
field = transforms['post_agg'](field, self.dialect)

aggs.append(
'{col_op} AS {f}'.format(
col_op=field,
f=self._col(fieldname),
)

for measure in measures:
if not measure.private:
for fieldname, transforms in measure.get_fields(unit_type=unit_type, stats=stats, stats_registry=stats_registry, rebase_agg=rebase_agg).items():

field = '1' if measure == 'count' else field_map['measures'][measure.via_name]
if transforms.get('pre_agg'):
field = transforms['pre_agg'](field, self.dialect)
field = transforms['agg'](field, self.dialect)
if transforms.get('post_agg'):
field = transforms['post_agg'](field, self.dialect)

aggs.append(
'{col_op} AS {f}'.format(
col_op=field,
f=self._col(fieldname),
)
)

return aggs

Expand Down
8 changes: 7 additions & 1 deletion mensor/measures/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,14 @@ def _features_lookup(self, unit_type, kind, attr_filter=None):
mask = None
if kind in ('foreign_key', 'reverse_foreign_key') and avail_unit_type == feature.name:
mask = unit_type.name
feature_attrs = feature.attrs
feature_attrs.update({
'unit_type': unit_type,
'mask': mask,
'kind': kind,
})
features.append(
_ResolvedFeature(feature.name, providers=[d.provider for d in instances], unit_type=unit_type, mask=mask, kind=kind)
_ResolvedFeature(providers=[d.provider for d in instances], **feature_attrs)
)
return features

Expand Down
50 changes: 27 additions & 23 deletions mensor/measures/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def __init__(self, name, unit_type=None, via=None, external=False, private=False
if self.ALLOW_ALL_ATTRIBUTES or attr in self.EXTRA_ATTRIBUTES:
setattr(self, attr, value)
else:
raise KeyError("No such attribute {}.".format(attr))
raise AttributeError(
"Cannot initialize {}<{}> with attribute {}.".format(self.__class__.__name__, self.name, attr)
)

def __getattr__(self, name):
if name.startswith('_'):
Expand Down Expand Up @@ -446,10 +448,7 @@ def transforms(self):
@transforms.setter
def transforms(self, transforms):
# TODO: Check structure of transforms dict
if not transforms:
self._transforms = {}
else:
self._transforms = transforms
self._transforms = {} if not transforms else transforms

@property
def as_external(self):
Expand Down Expand Up @@ -692,8 +691,8 @@ def desc(self):

class _Dimension(_ProvidedFeature):

def __init__(self, name, expr=None, default=None, desc=None, shared=False, partition=False, requires_constraint=False, provider=None):
_ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider)
def __init__(self, name, expr=None, default=None, desc=None, shared=False, partition=False, requires_constraint=False, provider=None, **attrs):
_ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider, **attrs)
if not shared and partition:
raise ValueError("Partitions must be shared.")
self.partition = partition
Expand Down Expand Up @@ -777,20 +776,25 @@ def matches(self, unit_type, reverse=False):
class _Measure(_ProvidedFeature):

def __init__(self, name, expr=None, default=None, desc=None,
distribution='normal', shared=False, provider=None):
_ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider)
distribution='normal', shared=False, provider=None, **attrs):
_ProvidedFeature.__init__(
self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider,
**attrs
)
self.distribution = distribution

def transforms_for_unit_type(self, unit_type, stats_registry=None):
transforms = {
transforms = { # defaults
'pre_agg': None,
'agg': 'sum',
'post_agg': None,
'pre_rebase_agg': None,
'rebase_agg': 'sum',
'post_rebase_agg': None
}

if isinstance(self.transforms, dict):
transforms.update(self.transforms.get('_default', {}))
transforms.update(self.transforms.get(unit_type, {}))

backend_aggs = stats_registry.aggregations.for_provider(self.provider)
Expand Down Expand Up @@ -828,14 +832,15 @@ def get_fields(self, unit_type=None, stats=True, rebase_agg=False, stats_registr
"""
assert stats_registry is not None
assert not (rebase_agg and stats)

if for_pandas:
from mensor.backends.pandas import PandasMeasureProvider
provider = PandasMeasureProvider
else:
provider = self.provider

transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry)
if stats:
transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry)
return OrderedDict([
(
(
Expand All @@ -850,18 +855,17 @@ def get_fields(self, unit_type=None, stats=True, rebase_agg=False, stats_registr
)
for field_name, agg_method in stats_registry.distribution_for_provider(self.distribution, provider).items()
])
else:
transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry)
return OrderedDict([
(
'{fieldname}|raw'.format(fieldname=self.fieldname(role=None, unit_type=unit_type if not rebase_agg else None)),
{
'agg': transforms['rebase_agg'] if rebase_agg else transforms['agg'],
'pre_agg': transforms['pre_rebase_agg'] if rebase_agg else transforms['pre_agg'],
'post_agg': transforms['post_rebase_agg'] if rebase_agg else transforms['post_agg'],
}
)
])

return OrderedDict([
(
'{fieldname}|raw'.format(fieldname=self.fieldname(role=None, unit_type=unit_type if not rebase_agg else None)),
{
'agg': transforms['rebase_agg'] if rebase_agg else transforms['agg'],
'pre_agg': transforms['pre_rebase_agg'] if rebase_agg else transforms['pre_agg'],
'post_agg': transforms['post_rebase_agg'] if rebase_agg else transforms['post_agg'],
}
)
])

@classmethod
def get_all_fields(self, measures, unit_type=None, stats=True, rebase_agg=False, stats_registry=None, for_pandas=False):
Expand Down

0 comments on commit 8009d35

Please sign in to comment.