Skip to content

Commit

Permalink
Feature/fix count non additive metrics (#191)
Browse files Browse the repository at this point in the history
* fix counts for non additive dims

* fix week datatype bigquery issue

* bump version to 0.12.12

---------

Co-authored-by: Paul Blankley <paul@zenlytic.com>
  • Loading branch information
pblankley and pblankley authored May 9, 2024
1 parent 24f17fd commit 589ee68
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 12 deletions.
14 changes: 9 additions & 5 deletions metrics_layer/core/model/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def sql(self):

else_0 = False
if non_additive_dimension := self.non_additive_dimension:
else_0 = True
# We need to do else 0 if it's a numeric operation like sum, average, etc
# But we need to do else null if it is a non numeric op like count, count_distinct
else_0 = self.type not in {ZenlyticType.count, ZenlyticType.count_distinct}
if isinstance(self.non_additive_dimension, dict):
filters_to_apply += [
{
Expand Down Expand Up @@ -558,8 +560,6 @@ def strict_replaced_query(self):

def _needs_symmetric_aggregate(self, functional_pk: MetricsLayerBase):
if functional_pk:
if functional_pk == Definitions.does_not_exist:
return True
try:
field_pk_id = self.view.primary_key.id()
except AttributeError:
Expand All @@ -569,7 +569,9 @@ def _needs_symmetric_aggregate(self, functional_pk: MetricsLayerBase):
"Define the primary key by adding primary_key: yes to the field "
"that is the primary key of the table."
)
different_functional_pk = field_pk_id != functional_pk.id()
different_functional_pk = (
functional_pk == Definitions.does_not_exist or field_pk_id != functional_pk.id()
)
else:
different_functional_pk = False
return different_functional_pk
Expand Down Expand Up @@ -1103,7 +1105,9 @@ def apply_dimension_group_time_sql(self, sql: str, query_type: str):
f"CAST(DATETIME_TRUNC(CAST({s} AS DATETIME), HOUR) AS {self.datatype.upper()})"
),
"date": lambda s, qt: f"CAST(DATE_TRUNC(CAST({s} AS DATE), DAY) AS {self.datatype.upper()})",
"week": self._week_dimension_group_time_sql,
"week": lambda s, qt: (
f"CAST({self._week_dimension_group_time_sql(s, qt)} AS {self.datatype.upper()})"
),
"month": lambda s, qt: ( # noqa
f"CAST(DATE_TRUNC(CAST({s} AS DATE), MONTH) AS {self.datatype.upper()})"
),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "metrics_layer"
version = "0.12.11"
version = "0.12.12"
description = "The open source metrics layer."
authors = ["Paul Blankley <paul@zenlytic.com>"]
keywords = ["Metrics Layer", "Business Intelligence", "Analytics"]
Expand Down
16 changes: 16 additions & 0 deletions tests/config/metrics_layer_config/views/mrr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ fields:
type: count
sql: ${parent_account_id}

- name: accounts_beginning_of_month
field_type: measure
type: count
sql: ${parent_account_id}
non_additive_dimension:
name: record_raw
window_choice: min

- name: accounts_end_of_month
field_type: measure
type: count_distinct
sql: ${parent_account_id}
non_additive_dimension:
name: record_raw
window_choice: max

- name: mrr_end_of_month
field_type: measure
type: sum
Expand Down
35 changes: 33 additions & 2 deletions tests/test_join_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pytest

from metrics_layer.core.exceptions import JoinError, QueryError
Expand Down Expand Up @@ -68,6 +70,36 @@ def test_query_no_join_average_distinct(connection):
assert query == correct


@pytest.mark.query
@pytest.mark.parametrize("field", ["order_lines.order_week", "orders.order_week"])
def test_query_bigquery_week_filter_type_conversion(connection, field):
query = connection.get_sql_query(
metrics=["total_item_revenue"],
dimensions=["channel"],
where=[
{
"field": field,
"expression": "greater_than",
"value": datetime(year=2021, month=8, day=4),
}
],
query_type="BIGQUERY",
)

cast_as = "DATE" if "order_lines.order_week" == field else "TIMESTAMP"
sql_field = "order_lines.order_date" if "order_lines.order_week" == field else "orders.order_date"
join = ""
if "orders" in field:
join = "LEFT JOIN analytics.orders orders ON order_lines.order_unique_id=orders.id "
correct = (
"SELECT order_lines.sales_channel as order_lines_channel,SUM(order_lines.revenue) as"
f" order_lines_total_item_revenue FROM analytics.order_line_items order_lines {join}WHERE"
f" CAST(DATE_TRUNC(CAST({sql_field} AS DATE), WEEK) AS {cast_as})>{cast_as}('2021-08-04 00:00:00')"
" GROUP BY order_lines_channel;"
)
assert query == correct


@pytest.mark.query
def test_query_single_join(connection):
query = connection.get_sql_query(metrics=["total_item_revenue"], dimensions=["channel", "new_vs_repeat"])
Expand Down Expand Up @@ -231,7 +263,6 @@ def test_query_single_join_metric_with_sub_field(connection):
assert query == correct


# TODO need one like this with order lines and rainfall
@pytest.mark.query
def test_query_single_join_with_forced_additional_join(connection):
query = connection.get_sql_query(
Expand All @@ -249,7 +280,7 @@ def test_query_single_join_with_forced_additional_join(connection):
"(country_detail.rain) IS NOT NULL THEN country_detail.country ELSE NULL END), "
"0)) as country_detail_avg_rainfall FROM analytics.discount_detail discount_detail "
"LEFT JOIN analytics_live.discounts discounts ON discounts.discount_id=discount_detail.discount_id "
"AND DATE_TRUNC(CAST(discounts.order_date AS DATE), WEEK) is not null LEFT JOIN "
"AND CAST(DATE_TRUNC(CAST(discounts.order_date AS DATE), WEEK) AS TIMESTAMP) is not null LEFT JOIN "
"(SELECT * FROM ANALYTICS.COUNTRY_DETAIL) as country_detail "
"ON discounts.country=country_detail.country "
"GROUP BY discount_detail_discount_promo_name;"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_listing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
@pytest.mark.project
def test_list_metrics(connection):
metrics = connection.list_metrics()
assert len(metrics) == 55
assert len(metrics) == 57

metrics = connection.list_metrics(view_name="order_lines", names_only=True)
assert len(metrics) == 11
Expand Down
20 changes: 20 additions & 0 deletions tests/test_non_additive_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@ def test_mrr_non_additive_dimension_no_group_by_max(connection, metric_suffix):
assert query == correct


@pytest.mark.query
@pytest.mark.parametrize("metric_suffix", ["end_of_month", "beginning_of_month"])
def test_mrr_non_additive_dimension_no_group_by_counts(connection, metric_suffix):
query = connection.get_sql_query(metrics=[f"accounts_{metric_suffix}"])

func = "MAX" if metric_suffix == "end_of_month" else "MIN"
agg = "COUNT(DISTINCT(" if metric_suffix == "end_of_month" else "COUNT("
close = "))" if metric_suffix == "end_of_month" else ")"
correct = (
f"WITH cte_accounts_{metric_suffix}_record_raw AS (SELECT {func}(mrr.record_date) as"
f" mrr_{func.lower()}_record_raw FROM analytics.mrr_by_customer mrr ORDER BY"
f" mrr_{func.lower()}_record_raw DESC) SELECT {agg}case when"
f" mrr.record_date=cte_accounts_{metric_suffix}_record_raw.mrr_{func.lower()}_record_raw then"
f" mrr.parent_account_id end{close} as mrr_accounts_{metric_suffix} FROM analytics.mrr_by_customer"
f" mrr JOIN cte_accounts_{metric_suffix}_record_raw ON 1=1 ORDER BY"
f" mrr_accounts_{metric_suffix} DESC;"
)
assert query == correct


@pytest.mark.query
def test_mrr_non_additive_dimension_no_group_by_multi_cte(connection):
query = connection.get_sql_query(
Expand Down
9 changes: 6 additions & 3 deletions tests/test_simple_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ def test_simple_query_dimension_group_timezone(connections, field: str, group: s
" AS TIMESTAMP) AS DATE), DAY) AS TIMESTAMP)"
),
"week": ( # noqa
"CAST(DATE_TRUNC(CAST(CAST(DATETIME(CAST(simple.order_date AS TIMESTAMP), 'America/New_York')"
" AS TIMESTAMP) AS DATE) + 1, WEEK) - 1 AS TIMESTAMP)"
"CAST(CAST(DATE_TRUNC(CAST(CAST(DATETIME(CAST(simple.order_date AS TIMESTAMP),"
" 'America/New_York') AS TIMESTAMP) AS DATE) + 1, WEEK) - 1 AS TIMESTAMP) AS TIMESTAMP)"
),
}
where = (
Expand Down Expand Up @@ -849,7 +849,10 @@ def test_simple_query_dimension_group(connections, group: str, query_type: str):
"minute": "CAST(DATETIME_TRUNC(CAST(simple.order_date AS DATETIME), MINUTE) AS TIMESTAMP)",
"hour": "CAST(DATETIME_TRUNC(CAST(simple.order_date AS DATETIME), HOUR) AS TIMESTAMP)",
"date": "CAST(DATE_TRUNC(CAST(simple.order_date AS DATE), DAY) AS TIMESTAMP)",
"week": "CAST(DATE_TRUNC(CAST(simple.order_date AS DATE) + 1, WEEK) - 1 AS TIMESTAMP)",
"week": (
"CAST(CAST(DATE_TRUNC(CAST(simple.order_date AS DATE) + 1, WEEK) - 1 AS TIMESTAMP) AS"
" TIMESTAMP)"
),
"month": "CAST(DATE_TRUNC(CAST(simple.order_date AS DATE), MONTH) AS TIMESTAMP)",
"quarter": "CAST(DATE_TRUNC(CAST(simple.order_date AS DATE), QUARTER) AS TIMESTAMP)",
"year": "CAST(DATE_TRUNC(CAST(simple.order_date AS DATE), YEAR) AS TIMESTAMP)",
Expand Down

0 comments on commit 589ee68

Please sign in to comment.