diff --git a/metrics_layer/cli/seeding.py b/metrics_layer/cli/seeding.py index 381f011..29ac4c5 100644 --- a/metrics_layer/cli/seeding.py +++ b/metrics_layer/cli/seeding.py @@ -266,8 +266,13 @@ def seed(self, auto_tag_searchable_fields: bool = False): dumper.dump_yaml_file(project_data, zenlytic_project_path) def get_model_name(self, current_models: list): - if len(current_models) > 0: + if len(current_models) == 1: return current_models[0].name + elif len(current_models) > 1: + for model in current_models: + if self.connection and model.connection == self.connection.name: + return model.name + raise ValueError("Multiple models found, but none match the connection name") return self.default_model_name def make_models(self): @@ -588,7 +593,9 @@ def table_query(self): return query + ";" if self.connection.type not in Definitions.no_semicolon_warehouses else query def run_query(self, query: str): - if self.run_query_override: + if self.run_query_override and self.connection is not None: + return self.run_query_override(query, connection_name=self.connection.name) + elif self.run_query_override and self.connection is None: return self.run_query_override(query) return self.metrics_layer.run_query( query, self.connection, run_pre_queries=False, start_warehouse=True diff --git a/metrics_layer/core/model/field.py b/metrics_layer/core/model/field.py index a09fc90..d553722 100644 --- a/metrics_layer/core/model/field.py +++ b/metrics_layer/core/model/field.py @@ -2775,11 +2775,17 @@ def join_graphs(self): base = list(set.intersection(*base_collection)) if self.is_cumulative(): - return base + return self._wrap_with_model_name(base) edges = self.view.project.join_graph.merged_results_graph(self.view.model).in_edges(self.id()) extended = [f"merged_result_{mr}" for mr, _ in edges] if self.loses_join_ability_with_other_views(): - return extended - return list(sorted(base + extended)) + return self._wrap_with_model_name(extended) + return self._wrap_with_model_name(list(sorted(base + extended))) + + def _wrap_with_model_name(self, join_graphs: list): + if len(models := self.view.project.models()) > 1: + model_index = [m.name for m in models].index(self.view.model.name) + return [f"m{model_index}_{jg}" for jg in join_graphs] + return join_graphs diff --git a/metrics_layer/core/model/project.py b/metrics_layer/core/model/project.py index 771a4ae..9ea90cb 100644 --- a/metrics_layer/core/model/project.py +++ b/metrics_layer/core/model/project.py @@ -116,9 +116,13 @@ def remove_field(self, field_name: str, view_name: str, refresh_cache: bool = Tr def timezone(self): if self._timezone: return self._timezone - for m in self.models(): - if m.timezone: - return m.timezone + timezones = list(set(m.timezone for m in self.models() if m.timezone)) + if len(timezones) == 1: + return timezones[0] + elif len(timezones) > 1: + raise QueryError( + "Multiple timezones found in models, please specify only one timezone across models" + ) return None @property diff --git a/metrics_layer/core/sql/arbitrary_merge_resolve.py b/metrics_layer/core/sql/arbitrary_merge_resolve.py index afed927..20b06e9 100644 --- a/metrics_layer/core/sql/arbitrary_merge_resolve.py +++ b/metrics_layer/core/sql/arbitrary_merge_resolve.py @@ -37,6 +37,7 @@ def __init__( self.project = project self.connections = connections self.connection = None + self.model = None # All queries are merged queries (obviously) self.query_kind = QueryKindTypes.merged self.kwargs = kwargs @@ -76,6 +77,7 @@ def get_query(self, semicolon: bool = True): ) mapping_lookup = self._mapping_lookup + self.model = resolver.model clean_where = [{**w, "field": mapping_lookup.get(w["field"].lower(), w["field"])} for w in self.where] clean_having = [ {**h, "field": mapping_lookup.get(h["field"].lower(), h["field"])} for h in self.having diff --git a/metrics_layer/core/sql/resolve.py b/metrics_layer/core/sql/resolve.py index 07670b7..5a058da 100644 --- a/metrics_layer/core/sql/resolve.py +++ b/metrics_layer/core/sql/resolve.py @@ -1,4 +1,4 @@ -from collections import defaultdict +from collections import Counter, defaultdict from copy import deepcopy from typing import List, Union @@ -413,7 +413,7 @@ def _get_model_for_query(self, model_name: str = None, metrics: list = [], dimen return self._derive_model(metrics, dimensions) def _derive_model(self, metrics: list, dimensions: list): - all_model_names = [] + all_model_names, mapping_model_names = [], [] models = self.project.models() for f in metrics + dimensions: try: @@ -423,18 +423,21 @@ def _derive_model(self, metrics: list, dimensions: list): for model in models: try: self.project.get_mapped_field(f, model=model) - all_model_names.append(model.name) - break + mapping_model_names.append(model.name) except Exception: pass all_model_names = list(set(all_model_names)) - - if len(all_model_names) == 0: + if len(all_model_names) == 0 and len(mapping_model_names) > 0: # In a case that there are no models in the query, we'll just use the first model # in the project. This case should be limited to only mapping-only queries, so this is safe. - return self.project.models()[0] - elif len(all_model_names) == 1: + model_counts = Counter(mapping_model_names) + sorted_models = [m for m, _ in model_counts.most_common()] + return self.project.get_model(sorted_models[0]) + elif len(all_model_names) == 1 and ( + len(mapping_model_names) == 0 + or (len(mapping_model_names) > 0 and all_model_names[0] in mapping_model_names) + ): return self.project.get_model(list(all_model_names)[0]) else: raise QueryError(