Skip to content

Commit

Permalink
add more robust support for multi connections
Browse files Browse the repository at this point in the history
  • Loading branch information
pblankley committed Aug 21, 2024
1 parent 9616a20 commit e91f8c8
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 16 deletions.
11 changes: 9 additions & 2 deletions metrics_layer/cli/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,13 @@ def seed(self, auto_tag_searchable_fields: bool = False):
dumper.dump_yaml_file(project_data, zenlytic_project_path)

def get_model_name(self, current_models: list):
if len(current_models) > 0:
if len(current_models) == 1:
return current_models[0].name
elif len(current_models) > 1:
for model in current_models:
if self.connection and model.connection == self.connection.name:
return model.name
raise ValueError("Multiple models found, but none match the connection name")
return self.default_model_name

def make_models(self):
Expand Down Expand Up @@ -588,7 +593,9 @@ def table_query(self):
return query + ";" if self.connection.type not in Definitions.no_semicolon_warehouses else query

def run_query(self, query: str):
if self.run_query_override:
if self.run_query_override and self.connection is not None:
return self.run_query_override(query, connection_name=self.connection.name)
elif self.run_query_override and self.connection is None:
return self.run_query_override(query)
return self.metrics_layer.run_query(
query, self.connection, run_pre_queries=False, start_warehouse=True
Expand Down
12 changes: 9 additions & 3 deletions metrics_layer/core/model/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,11 +2775,17 @@ def join_graphs(self):
base = list(set.intersection(*base_collection))

if self.is_cumulative():
return base
return self._wrap_with_model_name(base)

edges = self.view.project.join_graph.merged_results_graph(self.view.model).in_edges(self.id())
extended = [f"merged_result_{mr}" for mr, _ in edges]

if self.loses_join_ability_with_other_views():
return extended
return list(sorted(base + extended))
return self._wrap_with_model_name(extended)
return self._wrap_with_model_name(list(sorted(base + extended)))

def _wrap_with_model_name(self, join_graphs: list):
if len(models := self.view.project.models()) > 1:
model_index = [m.name for m in models].index(self.view.model.name)
return [f"m{model_index}_{jg}" for jg in join_graphs]
return join_graphs
10 changes: 7 additions & 3 deletions metrics_layer/core/model/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,13 @@ def remove_field(self, field_name: str, view_name: str, refresh_cache: bool = Tr
def timezone(self):
if self._timezone:
return self._timezone
for m in self.models():
if m.timezone:
return m.timezone
timezones = list(set(m.timezone for m in self.models() if m.timezone))
if len(timezones) == 1:
return timezones[0]
elif len(timezones) > 1:
raise QueryError(
"Multiple timezones found in models, please specify only one timezone across models"
)
return None

@property
Expand Down
2 changes: 2 additions & 0 deletions metrics_layer/core/sql/arbitrary_merge_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
self.project = project
self.connections = connections
self.connection = None
self.model = None
# All queries are merged queries (obviously)
self.query_kind = QueryKindTypes.merged
self.kwargs = kwargs
Expand Down Expand Up @@ -76,6 +77,7 @@ def get_query(self, semicolon: bool = True):
)

mapping_lookup = self._mapping_lookup
self.model = resolver.model
clean_where = [{**w, "field": mapping_lookup.get(w["field"].lower(), w["field"])} for w in self.where]
clean_having = [
{**h, "field": mapping_lookup.get(h["field"].lower(), h["field"])} for h in self.having
Expand Down
19 changes: 11 additions & 8 deletions metrics_layer/core/sql/resolve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import defaultdict
from collections import Counter, defaultdict
from copy import deepcopy
from typing import List, Union

Expand Down Expand Up @@ -413,7 +413,7 @@ def _get_model_for_query(self, model_name: str = None, metrics: list = [], dimen
return self._derive_model(metrics, dimensions)

def _derive_model(self, metrics: list, dimensions: list):
all_model_names = []
all_model_names, mapping_model_names = [], []
models = self.project.models()
for f in metrics + dimensions:
try:
Expand All @@ -423,18 +423,21 @@ def _derive_model(self, metrics: list, dimensions: list):
for model in models:
try:
self.project.get_mapped_field(f, model=model)
all_model_names.append(model.name)
break
mapping_model_names.append(model.name)
except Exception:
pass

all_model_names = list(set(all_model_names))

if len(all_model_names) == 0:
if len(all_model_names) == 0 and len(mapping_model_names) > 0:
# In a case that there are no models in the query, we'll just use the first model
# in the project. This case should be limited to only mapping-only queries, so this is safe.
return self.project.models()[0]
elif len(all_model_names) == 1:
model_counts = Counter(mapping_model_names)
sorted_models = [m for m, _ in model_counts.most_common()]
return self.project.get_model(sorted_models[0])
elif len(all_model_names) == 1 and (
len(mapping_model_names) == 0
or (len(mapping_model_names) > 0 and all_model_names[0] in mapping_model_names)
):
return self.project.get_model(list(all_model_names)[0])
else:
raise QueryError(
Expand Down

0 comments on commit e91f8c8

Please sign in to comment.