diff --git a/mensor/backends/sql.py b/mensor/backends/sql.py index f1dcca2..a3790a9 100644 --- a/mensor/backends/sql.py +++ b/mensor/backends/sql.py @@ -26,7 +26,8 @@ class SQLDialect(object): 'sum': lambda x: "SUM({})".format(x), 'mean': lambda x: "AVG({})".format(x), 'sos': lambda x: "SUM(POW({}, 2))".format(x), - 'count': lambda x: "COUNT({})".format(x) + 'count': lambda x: "COUNT({})".format(x), + '1': lambda x: "1", } TEMPLATE_BASE = textwrap.dedent(""" @@ -187,7 +188,7 @@ def dialect(self): def query(self, sql): print(sql) - raise NotImplementedError("This SQLExecutor goes no further.") + raise NotImplementedError("DebugSQLExecutor prints SQL but cannot execute.") class SQLMeasureProvider(MeasureProvider): @@ -197,7 +198,7 @@ class SQLMeasureProvider(MeasureProvider): @classmethod def _on_registered(cls, key): - for agg in ['sum', 'mean', 'sos', 'count']: + for agg in ['sum', 'mean', 'sos', 'count', '1']: global_stats_registry.aggregations.register( name=agg, backend=key, @@ -205,7 +206,6 @@ def _on_registered(cls, key): ) def __init__(self, *args, sql=None, executor=None, **kwargs): - if not executor: executor = DebugSQLExecutor() elif isinstance(executor, str): @@ -256,14 +256,13 @@ def get_sql(self, *args, **kwargs): def _get_ir(self, unit_type, measures, segment_by, where, joins, stats_registry, stats, covariates, **opts): field_map = self._field_map(unit_type, measures, segment_by, joins) - rebase_agg = not unit_type.is_unique sql = self._template_environment.get_template(self.dialect.TEMPLATE_BASE).render( _sql=self._sql(unit_type=unit_type, measures=measures, segment_by=segment_by, where=where, joins=joins, stats=stats, covariates=covariates, **opts), field_map=field_map, provider=self, table_name=self._table_name(unit_type), dimensions=self._get_dimensions_sql(field_map, segment_by), - measures=self._get_measures_sql(field_map, unit_type, measures, rebase_agg, stats_registry, stats, covariates), + measures=self._get_measures_sql(field_map, unit_type, measures, stats_registry, stats, covariates), groupby=self._get_groupby_sql(field_map, segment_by), joins=joins, constraints=self._get_where_sql(field_map, where), @@ -325,29 +324,29 @@ def _get_dimensions_sql(self, field_map, dimensions): ) return dims - def _get_measures_sql(self, field_map, unit_type, measures, rebase_agg, stats_registry, stats, covariates): + def _get_measures_sql(self, field_map, unit_type, measures, stats_registry, stats, covariates): aggs = [] - + rebase_agg = not unit_type.is_unique if rebase_agg and stats: raise NotImplementedError("Computing stats and rebasing units simultaneously has not been implemented for the SQL backend.") - else: - for measure in measures: - if not measure.private: - for fieldname, transforms in measure.get_fields(unit_type=unit_type, stats=stats, stats_registry=stats_registry, rebase_agg=rebase_agg).items(): - - field = '1' if measure == 'count' else field_map['measures'][measure.via_name] - if transforms.get('pre_agg'): - field = transforms['pre_agg'](field, self.dialect) - field = transforms['agg'](field, self.dialect) - if transforms.get('post_agg'): - field = transforms['post_agg'](field, self.dialect) - - aggs.append( - '{col_op} AS {f}'.format( - col_op=field, - f=self._col(fieldname), - ) + + for measure in measures: + if not measure.private: + for fieldname, transforms in measure.get_fields(unit_type=unit_type, stats=stats, stats_registry=stats_registry, rebase_agg=rebase_agg).items(): + + field = '1' if measure == 'count' else field_map['measures'][measure.via_name] + if transforms.get('pre_agg'): + field = transforms['pre_agg'](field, self.dialect) + field = transforms['agg'](field, self.dialect) + if transforms.get('post_agg'): + field = transforms['post_agg'](field, self.dialect) + + aggs.append( + '{col_op} AS {f}'.format( + col_op=field, + f=self._col(fieldname), ) + ) return aggs diff --git a/mensor/measures/registry.py b/mensor/measures/registry.py index dbd8fca..21dfffb 100644 --- a/mensor/measures/registry.py +++ b/mensor/measures/registry.py @@ -196,8 +196,14 @@ def _features_lookup(self, unit_type, kind, attr_filter=None): mask = None if kind in ('foreign_key', 'reverse_foreign_key') and avail_unit_type == feature.name: mask = unit_type.name + feature_attrs = feature.attrs + feature_attrs.update({ + 'unit_type': unit_type, + 'mask': mask, + 'kind': kind, + }) features.append( - _ResolvedFeature(feature.name, providers=[d.provider for d in instances], unit_type=unit_type, mask=mask, kind=kind) + _ResolvedFeature(providers=[d.provider for d in instances], **feature_attrs) ) return features diff --git a/mensor/measures/types.py b/mensor/measures/types.py index d741816..0ab71ac 100644 --- a/mensor/measures/types.py +++ b/mensor/measures/types.py @@ -380,7 +380,9 @@ def __init__(self, name, unit_type=None, via=None, external=False, private=False if self.ALLOW_ALL_ATTRIBUTES or attr in self.EXTRA_ATTRIBUTES: setattr(self, attr, value) else: - raise KeyError("No such attribute {}.".format(attr)) + raise AttributeError( + "Cannot initialize {}<{}> with attribute {}.".format(self.__class__.__name__, self.name, attr) + ) def __getattr__(self, name): if name.startswith('_'): @@ -446,10 +448,7 @@ def transforms(self): @transforms.setter def transforms(self, transforms): # TODO: Check structure of transforms dict - if not transforms: - self._transforms = {} - else: - self._transforms = transforms + self._transforms = {} if not transforms else transforms @property def as_external(self): @@ -692,8 +691,8 @@ def desc(self): class _Dimension(_ProvidedFeature): - def __init__(self, name, expr=None, default=None, desc=None, shared=False, partition=False, requires_constraint=False, provider=None): - _ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider) + def __init__(self, name, expr=None, default=None, desc=None, shared=False, partition=False, requires_constraint=False, provider=None, **attrs): + _ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider, **attrs) if not shared and partition: raise ValueError("Partitions must be shared.") self.partition = partition @@ -777,12 +776,15 @@ def matches(self, unit_type, reverse=False): class _Measure(_ProvidedFeature): def __init__(self, name, expr=None, default=None, desc=None, - distribution='normal', shared=False, provider=None): - _ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider) + distribution='normal', shared=False, provider=None, **attrs): + _ProvidedFeature.__init__( + self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider, + **attrs + ) self.distribution = distribution def transforms_for_unit_type(self, unit_type, stats_registry=None): - transforms = { + transforms = { # defaults 'pre_agg': None, 'agg': 'sum', 'post_agg': None, @@ -790,7 +792,9 @@ def transforms_for_unit_type(self, unit_type, stats_registry=None): 'rebase_agg': 'sum', 'post_rebase_agg': None } + if isinstance(self.transforms, dict): + transforms.update(self.transforms.get('_default', {})) transforms.update(self.transforms.get(unit_type, {})) backend_aggs = stats_registry.aggregations.for_provider(self.provider) @@ -828,14 +832,15 @@ def get_fields(self, unit_type=None, stats=True, rebase_agg=False, stats_registr """ assert stats_registry is not None assert not (rebase_agg and stats) + if for_pandas: from mensor.backends.pandas import PandasMeasureProvider provider = PandasMeasureProvider else: provider = self.provider + transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry) if stats: - transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry) return OrderedDict([ ( ( @@ -850,18 +855,17 @@ def get_fields(self, unit_type=None, stats=True, rebase_agg=False, stats_registr ) for field_name, agg_method in stats_registry.distribution_for_provider(self.distribution, provider).items() ]) - else: - transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry) - return OrderedDict([ - ( - '{fieldname}|raw'.format(fieldname=self.fieldname(role=None, unit_type=unit_type if not rebase_agg else None)), - { - 'agg': transforms['rebase_agg'] if rebase_agg else transforms['agg'], - 'pre_agg': transforms['pre_rebase_agg'] if rebase_agg else transforms['pre_agg'], - 'post_agg': transforms['post_rebase_agg'] if rebase_agg else transforms['post_agg'], - } - ) - ]) + + return OrderedDict([ + ( + '{fieldname}|raw'.format(fieldname=self.fieldname(role=None, unit_type=unit_type if not rebase_agg else None)), + { + 'agg': transforms['rebase_agg'] if rebase_agg else transforms['agg'], + 'pre_agg': transforms['pre_rebase_agg'] if rebase_agg else transforms['pre_agg'], + 'post_agg': transforms['post_rebase_agg'] if rebase_agg else transforms['post_agg'], + } + ) + ]) @classmethod def get_all_fields(self, measures, unit_type=None, stats=True, rebase_agg=False, stats_registry=None, for_pandas=False):