diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 416b255fb763f..69cbe8d823450 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -262,6 +262,7 @@ "databricks-sdk>=0.9.0", "pyspark~=3.3.0", "requests", + "databricks-sql-connector", } mysql = sql_common | {"pymysql>=1.0.2"} diff --git a/metadata-ingestion/tests/performance/bigquery/bigquery_events.py b/metadata-ingestion/tests/performance/bigquery/bigquery_events.py index d9b5571a8015f..0e0bfe78c260f 100644 --- a/metadata-ingestion/tests/performance/bigquery/bigquery_events.py +++ b/metadata-ingestion/tests/performance/bigquery/bigquery_events.py @@ -2,7 +2,7 @@ import random import uuid from collections import defaultdict -from typing import Dict, Iterable, List, cast +from typing import Dict, Iterable, List, Set from typing_extensions import get_args @@ -15,7 +15,7 @@ ) from datahub.ingestion.source.bigquery_v2.bigquery_config import BigQueryV2Config from datahub.ingestion.source.bigquery_v2.usage import OPERATION_STATEMENT_TYPES -from tests.performance.data_model import Query, StatementType, Table, View +from tests.performance.data_model import Query, StatementType, Table # https://cloud.google.com/bigquery/docs/reference/auditlogs/rest/Shared.Types/BigQueryAuditMetadata.TableDataRead.Reason READ_REASONS = [ @@ -86,7 +86,7 @@ def generate_events( ref_from_table(parent, table_to_project) for field in query.fields_accessed if field.table.is_view() - for parent in cast(View, field.table).parents + for parent in field.table.upstreams ) ), referencedViews=referencedViews, @@ -96,7 +96,7 @@ def generate_events( query_on_view=True if referencedViews else False, ) ) - table_accesses = defaultdict(set) + table_accesses: Dict[BigQueryTableRef, Set[str]] = defaultdict(set) for field in query.fields_accessed: if not field.table.is_view(): table_accesses[ref_from_table(field.table, table_to_project)].add( @@ -104,7 +104,7 @@ def generate_events( ) else: # assuming that same fields are accessed in parent tables - for parent in cast(View, field.table).parents: + for parent in field.table.upstreams: table_accesses[ref_from_table(parent, table_to_project)].add( field.column ) diff --git a/metadata-ingestion/tests/performance/data_generation.py b/metadata-ingestion/tests/performance/data_generation.py index 67b156896909a..9b80d6260d408 100644 --- a/metadata-ingestion/tests/performance/data_generation.py +++ b/metadata-ingestion/tests/performance/data_generation.py @@ -8,16 +8,16 @@ This is a work in progress, built piecemeal as needed. """ import random -import uuid +from abc import ABCMeta, abstractmethod +from collections import OrderedDict from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Iterable, List, TypeVar, Union, cast +from typing import Collection, Iterable, List, Optional, TypeVar, Union, cast from faker import Faker from tests.performance.data_model import ( Column, - ColumnMapping, ColumnType, Container, FieldAccess, @@ -40,17 +40,46 @@ "UNKNOWN", ] +ID_COLUMN = "id" # Use to allow joins between all tables + + +class Distribution(metaclass=ABCMeta): + @abstractmethod + def _sample(self) -> int: + raise NotImplementedError + + def sample( + self, *, floor: Optional[int] = None, ceiling: Optional[int] = None + ) -> int: + value = self._sample() + if floor is not None: + value = max(value, floor) + if ceiling is not None: + value = min(value, ceiling) + return value + @dataclass(frozen=True) -class NormalDistribution: +class NormalDistribution(Distribution): mu: float sigma: float - def sample(self) -> int: + def _sample(self) -> int: return int(random.gauss(mu=self.mu, sigma=self.sigma)) - def sample_with_floor(self, floor: int = 1) -> int: - return max(int(random.gauss(mu=self.mu, sigma=self.sigma)), floor) + +@dataclass(frozen=True) +class LomaxDistribution(Distribution): + """See https://en.wikipedia.org/wiki/Lomax_distribution. + + Equivalent to pareto(scale, shape) - scale; scale * beta_prime(1, shape) + """ + + scale: float + shape: float + + def _sample(self) -> int: + return int(self.scale * (random.paretovariate(self.shape) - 1)) @dataclass @@ -72,9 +101,9 @@ def generate_data( num_containers: Union[List[int], int], num_tables: int, num_views: int, - columns_per_table: NormalDistribution = NormalDistribution(5, 2), - parents_per_view: NormalDistribution = NormalDistribution(2, 1), - view_definition_length: NormalDistribution = NormalDistribution(150, 50), + columns_per_table: Distribution = NormalDistribution(5, 2), + parents_per_view: Distribution = NormalDistribution(2, 1), + view_definition_length: Distribution = NormalDistribution(150, 50), time_range: timedelta = timedelta(days=14), ) -> SeedMetadata: # Assemble containers @@ -85,43 +114,32 @@ def generate_data( for i, num_in_layer in enumerate(num_containers): layer = [ Container( - f"{i}-container-{j}", + f"{_container_type(i)}_{j}", parent=random.choice(containers[-1]) if containers else None, ) for j in range(num_in_layer) ] containers.append(layer) - # Assemble tables + # Assemble tables and views, lineage, and definitions tables = [ - Table( - f"table-{i}", - container=random.choice(containers[-1]), - columns=[ - f"column-{j}-{uuid.uuid4()}" - for j in range(columns_per_table.sample_with_floor()) - ], - column_mapping=None, - ) - for i in range(num_tables) + _generate_table(i, containers[-1], columns_per_table) for i in range(num_tables) ] views = [ View( - f"view-{i}", - container=random.choice(containers[-1]), - columns=[ - f"column-{j}-{uuid.uuid4()}" - for j in range(columns_per_table.sample_with_floor()) - ], - column_mapping=None, - definition=f"{uuid.uuid4()}-{'*' * view_definition_length.sample_with_floor(10)}", - parents=random.sample(tables, parents_per_view.sample_with_floor()), + **{ # type: ignore + **_generate_table(i, containers[-1], columns_per_table).__dict__, + "name": f"view_{i}", + "definition": f"--{'*' * view_definition_length.sample(floor=0)}", + }, ) for i in range(num_views) ] - for table in tables + views: - _generate_column_mapping(table) + for view in views: + view.upstreams = random.sample(tables, k=parents_per_view.sample(floor=1)) + + generate_lineage(tables, views) now = datetime.now(tz=timezone.utc) return SeedMetadata( @@ -133,6 +151,33 @@ def generate_data( ) +def generate_lineage( + tables: Collection[Table], + views: Collection[Table], + # Percentiles: 75th=0, 80th=1, 95th=2, 99th=4, 99.99th=15 + upstream_distribution: Distribution = LomaxDistribution(scale=3, shape=5), +) -> None: + num_upstreams = [upstream_distribution.sample(ceiling=100) for _ in tables] + # Prioritize tables with a lot of upstreams themselves + factor = 1 + len(tables) // 10 + table_weights = [1 + (num_upstreams[i] * factor) for i in range(len(tables))] + view_weights = [1] * len(views) + + # TODO: Python 3.9 use random.sample with counts + sample = [] + for table, weight in zip(tables, table_weights): + for _ in range(weight): + sample.append(table) + for view, weight in zip(views, view_weights): + for _ in range(weight): + sample.append(view) + for i, table in enumerate(tables): + table.upstreams = random.sample( # type: ignore + sample, + k=num_upstreams[i], + ) + + def generate_queries( seed_metadata: SeedMetadata, num_selects: int, @@ -146,12 +191,12 @@ def generate_queries( ) -> Iterable[Query]: faker = Faker() query_texts = [ - faker.paragraph(query_length.sample_with_floor(30) // 30) + faker.paragraph(query_length.sample(floor=30) // 30) for _ in range(num_unique_queries) ] all_tables = seed_metadata.tables + seed_metadata.views - users = [f"user-{i}@xyz.com" for i in range(num_users)] + users = [f"user_{i}@xyz.com" for i in range(num_users)] for i in range(num_selects): # Pure SELECT statements tables = _sample_list(all_tables, tables_per_select) all_columns = [ @@ -191,21 +236,43 @@ def generate_queries( ) -def _generate_column_mapping(table: Table) -> ColumnMapping: - d = {} - for column in table.columns: - d[column] = Column( - name=column, +def _container_type(i: int) -> str: + if i == 0: + return "database" + elif i == 1: + return "schema" + else: + return f"{i}container" + + +def _generate_table( + i: int, parents: List[Container], columns_per_table: Distribution +) -> Table: + num_columns = columns_per_table.sample(floor=1) + + columns = OrderedDict({ID_COLUMN: Column(ID_COLUMN, ColumnType.INTEGER, False)}) + for j in range(num_columns): + name = f"column_{j}" + columns[name] = Column( + name=name, type=random.choice(list(ColumnType)), nullable=random.random() < 0.1, # Fixed 10% chance for now ) - table.column_mapping = d - return d + return Table( + f"table_{i}", + container=random.choice(parents), + columns=columns, + upstreams=[], + ) def _sample_list(lst: List[T], dist: NormalDistribution, floor: int = 1) -> List[T]: - return random.sample(lst, min(dist.sample_with_floor(floor), len(lst))) + return random.sample(lst, min(dist.sample(floor=floor), len(lst))) def _random_time_between(start: datetime, end: datetime) -> datetime: return start + timedelta(seconds=(end - start).total_seconds() * random.random()) + + +if __name__ == "__main__": + z = generate_data(10, 1000, 10) diff --git a/metadata-ingestion/tests/performance/data_model.py b/metadata-ingestion/tests/performance/data_model.py index 9425fa827070e..728bb6ddde215 100644 --- a/metadata-ingestion/tests/performance/data_model.py +++ b/metadata-ingestion/tests/performance/data_model.py @@ -1,7 +1,9 @@ -from dataclasses import dataclass +import typing +from collections import OrderedDict +from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from typing_extensions import Literal @@ -37,29 +39,63 @@ class ColumnType(str, Enum): @dataclass class Column: name: str - type: ColumnType - nullable: bool + type: ColumnType = ColumnType.STRING + nullable: bool = False ColumnRef = str ColumnMapping = Dict[ColumnRef, Column] -@dataclass +@dataclass(init=False) class Table: name: str container: Container - columns: List[ColumnRef] - column_mapping: Optional[ColumnMapping] + columns: typing.OrderedDict[ColumnRef, Column] = field(repr=False) + upstreams: List["Table"] = field(repr=False) + + def __init__( + self, + name: str, + container: Container, + columns: Union[List[str], Dict[str, Column]], + upstreams: List["Table"], + ): + self.name = name + self.container = container + self.upstreams = upstreams + if isinstance(columns, list): + self.columns = OrderedDict((col, Column(col)) for col in columns) + elif isinstance(columns, dict): + self.columns = OrderedDict(columns) + + @property + def name_components(self) -> List[str]: + lst = [self.name] + container: Optional[Container] = self.container + while container: + lst.append(container.name) + container = container.parent + return lst[::-1] def is_view(self) -> bool: return False -@dataclass +@dataclass(init=False) class View(Table): definition: str - parents: List[Table] + + def __init__( + self, + name: str, + container: Container, + columns: Union[List[str], Dict[str, Column]], + upstreams: List["Table"], + definition: str, + ): + super().__init__(name, container, columns, upstreams) + self.definition = definition def is_view(self) -> bool: return True diff --git a/metadata-ingestion/tests/performance/databricks/generator.py b/metadata-ingestion/tests/performance/databricks/generator.py new file mode 100644 index 0000000000000..29df325d856a1 --- /dev/null +++ b/metadata-ingestion/tests/performance/databricks/generator.py @@ -0,0 +1,177 @@ +import logging +import random +import string +from concurrent.futures import ThreadPoolExecutor, wait +from datetime import datetime +from typing import Callable, List, TypeVar, Union +from urllib.parse import urlparse + +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import DatabricksError +from databricks.sdk.service.catalog import ColumnTypeName +from performance.data_generation import Distribution, LomaxDistribution, SeedMetadata +from performance.data_model import ColumnType, Container, Table, View +from performance.databricks.unity_proxy_mock import _convert_column_type +from sqlalchemy import create_engine + +from datahub.ingestion.source.sql.sql_config import make_sqlalchemy_uri + +logger = logging.getLogger(__name__) +T = TypeVar("T") + +MAX_WORKERS = 200 + + +class DatabricksDataGenerator: + def __init__(self, host: str, token: str, warehouse_id: str): + self.client = WorkspaceClient(host=host, token=token) + self.warehouse_id = warehouse_id + url = make_sqlalchemy_uri( + scheme="databricks", + username="token", + password=token, + at=urlparse(host).netloc, + db=None, + uri_opts={"http_path": f"/sql/1.0/warehouses/{warehouse_id}"}, + ) + engine = create_engine( + url, connect_args={"timeout": 600}, pool_size=MAX_WORKERS + ) + self.connection = engine.connect() + + def clear_data(self, seed_metadata: SeedMetadata) -> None: + for container in seed_metadata.containers[0]: + try: + self.client.catalogs.delete(container.name, force=True) + except DatabricksError: + pass + + def create_data( + self, + seed_metadata: SeedMetadata, + # Percentiles: 1st=0, 10th=7, 25th=21, 50th=58, 75th=152, 90th=364, 99th=2063, 99.99th=46316 + num_rows_distribution: Distribution = LomaxDistribution(scale=100, shape=1.5), + ) -> None: + """Create data in Databricks based on SeedMetadata.""" + for container in seed_metadata.containers[0]: + self._create_catalog(container) + for container in seed_metadata.containers[1]: + self._create_schema(container) + + _thread_pool_execute("create tables", seed_metadata.tables, self._create_table) + _thread_pool_execute("create views", seed_metadata.views, self._create_view) + _thread_pool_execute( + "populate tables", + seed_metadata.tables, + lambda t: self._populate_table( + t, num_rows_distribution.sample(ceiling=1_000_000) + ), + ) + _thread_pool_execute( + "create table lineage", seed_metadata.tables, self._create_table_lineage + ) + + def _create_catalog(self, catalog: Container) -> None: + try: + self.client.catalogs.get(catalog.name) + except DatabricksError: + self.client.catalogs.create(catalog.name) + + def _create_schema(self, schema: Container) -> None: + try: + self.client.schemas.get(f"{schema.parent.name}.{schema.name}") + except DatabricksError: + self.client.schemas.create(schema.name, schema.parent.name) + + def _create_table(self, table: Table) -> None: + try: + self.client.tables.delete(".".join(table.name_components)) + except DatabricksError: + pass + + columns = ", ".join( + f"{name} {_convert_column_type(column.type).value}" + for name, column in table.columns.items() + ) + self._execute_sql(f"CREATE TABLE {_quote_table(table)} ({columns})") + self._assert_table_exists(table) + + def _create_view(self, view: View) -> None: + self._execute_sql(_generate_view_definition(view)) + self._assert_table_exists(view) + + def _assert_table_exists(self, table: Table) -> None: + self.client.tables.get(".".join(table.name_components)) + + def _populate_table(self, table: Table, num_rows: int) -> None: + values = [ + ", ".join( + str(_generate_value(column.type)) for column in table.columns.values() + ) + for _ in range(num_rows) + ] + values_str = ", ".join(f"({value})" for value in values) + self._execute_sql(f"INSERT INTO {_quote_table(table)} VALUES {values_str}") + + def _create_table_lineage(self, table: Table) -> None: + for upstream in table.upstreams: + self._execute_sql(_generate_insert_lineage(table, upstream)) + + def _execute_sql(self, sql: str) -> None: + print(sql) + self.connection.execute(sql) + + +def _thread_pool_execute(desc: str, lst: List[T], fn: Callable[[T], None]) -> None: + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + futures = [executor.submit(fn, item) for item in lst] + wait(futures) + for future in futures: + try: + future.result() + except Exception as e: + logger.error(f"Error executing '{desc}': {e}", exc_info=True) + + +def _generate_value(t: ColumnType) -> Union[int, float, str, bool]: + ctn = _convert_column_type(t) + if ctn == ColumnTypeName.INT: + return random.randint(-(2**31), 2**31 - 1) + elif ctn == ColumnTypeName.DOUBLE: + return random.uniform(-(2**31), 2**31 - 1) + elif ctn == ColumnTypeName.STRING: + return ( + "'" + "".join(random.choice(string.ascii_letters) for _ in range(8)) + "'" + ) + elif ctn == ColumnTypeName.BOOLEAN: + return random.choice([True, False]) + elif ctn == ColumnTypeName.TIMESTAMP: + return random.randint(0, int(datetime.now().timestamp())) + else: + raise NotImplementedError(f"Unsupported type {ctn}") + + +def _generate_insert_lineage(table: Table, upstream: Table) -> str: + select = [] + for column in table.columns.values(): + matching_cols = [c for c in upstream.columns.values() if c.type == column.type] + if matching_cols: + upstream_col = random.choice(matching_cols) + select.append(f"{upstream_col.name} AS {column.name}") + else: + select.append(f"{_generate_value(column.type)} AS {column.name}") + + return f"INSERT INTO {_quote_table(table)} SELECT {', '.join(select)} FROM {_quote_table(upstream)}" + + +def _generate_view_definition(view: View) -> str: + from_statement = f"FROM {_quote_table(view.upstreams[0])} t0" + join_statement = " ".join( + f"JOIN {_quote_table(upstream)} t{i+1} ON t0.id = t{i+1}.id" + for i, upstream in enumerate(view.upstreams[1:]) + ) + return f"CREATE VIEW {_quote_table(view)} AS SELECT * {from_statement} {join_statement} {view.definition}" + + +def _quote_table(table: Table) -> str: + return ".".join(f"`{component}`" for component in table.name_components) diff --git a/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py b/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py index 593163e12bf0a..ee1caf6783ec1 100644 --- a/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py +++ b/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py @@ -88,22 +88,21 @@ def schemas(self, catalog: Catalog) -> Iterable[Schema]: def tables(self, schema: Schema) -> Iterable[Table]: for table in self._schema_to_table[schema.name]: columns = [] - if table.column_mapping: - for i, col_name in enumerate(table.columns): - column = table.column_mapping[col_name] - columns.append( - Column( - id=column.name, - name=column.name, - type_name=self._convert_column_type(column.type), - type_text=column.type.value, - nullable=column.nullable, - position=i, - comment=None, - type_precision=0, - type_scale=0, - ) + for i, col_name in enumerate(table.columns): + column = table.columns[col_name] + columns.append( + Column( + id=column.name, + name=column.name, + type_name=_convert_column_type(column.type), + type_text=column.type.value, + nullable=column.nullable, + position=i, + comment=None, + type_precision=0, + type_scale=0, ) + ) yield Table( id=f"{schema.id}.{table.name}", @@ -145,7 +144,7 @@ def query_history( yield Query( query_id=str(i), query_text=query.text, - statement_type=self._convert_statement_type(query.type), + statement_type=_convert_statement_type(query.type), start_time=query.timestamp, end_time=query.timestamp, user_id=hash(query.actor), @@ -160,24 +159,24 @@ def table_lineage(self, table: Table) -> None: def get_column_lineage(self, table: Table) -> None: pass - @staticmethod - def _convert_column_type(t: ColumnType) -> ColumnTypeName: - if t == ColumnType.INTEGER: - return ColumnTypeName.INT - elif t == ColumnType.FLOAT: - return ColumnTypeName.DOUBLE - elif t == ColumnType.STRING: - return ColumnTypeName.STRING - elif t == ColumnType.BOOLEAN: - return ColumnTypeName.BOOLEAN - elif t == ColumnType.DATETIME: - return ColumnTypeName.TIMESTAMP - else: - raise ValueError(f"Unknown column type: {t}") - - @staticmethod - def _convert_statement_type(t: StatementType) -> QueryStatementType: - if t == "CUSTOM" or t == "UNKNOWN": - return QueryStatementType.OTHER - else: - return QueryStatementType[t] + +def _convert_column_type(t: ColumnType) -> ColumnTypeName: + if t == ColumnType.INTEGER: + return ColumnTypeName.INT + elif t == ColumnType.FLOAT: + return ColumnTypeName.DOUBLE + elif t == ColumnType.STRING: + return ColumnTypeName.STRING + elif t == ColumnType.BOOLEAN: + return ColumnTypeName.BOOLEAN + elif t == ColumnType.DATETIME: + return ColumnTypeName.TIMESTAMP + else: + raise ValueError(f"Unknown column type: {t}") + + +def _convert_statement_type(t: StatementType) -> QueryStatementType: + if t == "CUSTOM" or t == "UNKNOWN": + return QueryStatementType.OTHER + else: + return QueryStatementType[t] diff --git a/metadata-ingestion/tests/unit/test_bigquery_source.py b/metadata-ingestion/tests/unit/test_bigquery_source.py index 4cfa5c48d2377..3cdb73d77d0a1 100644 --- a/metadata-ingestion/tests/unit/test_bigquery_source.py +++ b/metadata-ingestion/tests/unit/test_bigquery_source.py @@ -324,7 +324,7 @@ def test_get_projects_list_failure( {"project_id_pattern": {"deny": ["^test-project$"]}} ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) - caplog.records.clear() + caplog.clear() with caplog.at_level(logging.ERROR): projects = source._get_projects() assert len(caplog.records) == 1 diff --git a/metadata-ingestion/tests/unit/test_bigquery_usage.py b/metadata-ingestion/tests/unit/test_bigquery_usage.py index 1eb5d8b00e27c..c0055763bc15b 100644 --- a/metadata-ingestion/tests/unit/test_bigquery_usage.py +++ b/metadata-ingestion/tests/unit/test_bigquery_usage.py @@ -1,7 +1,7 @@ import logging import random from datetime import datetime, timedelta, timezone -from typing import Iterable, cast +from typing import Iterable from unittest.mock import MagicMock, patch import pytest @@ -45,15 +45,16 @@ ACTOR_2, ACTOR_2_URN = "b@acryl.io", "urn:li:corpuser:b" DATABASE_1 = Container("database_1") DATABASE_2 = Container("database_2") -TABLE_1 = Table("table_1", DATABASE_1, ["id", "name", "age"], None) -TABLE_2 = Table("table_2", DATABASE_1, ["id", "table_1_id", "value"], None) +TABLE_1 = Table("table_1", DATABASE_1, columns=["id", "name", "age"], upstreams=[]) +TABLE_2 = Table( + "table_2", DATABASE_1, columns=["id", "table_1_id", "value"], upstreams=[] +) VIEW_1 = View( name="view_1", container=DATABASE_1, columns=["id", "name", "total"], definition="VIEW DEFINITION 1", - parents=[TABLE_1, TABLE_2], - column_mapping=None, + upstreams=[TABLE_1, TABLE_2], ) ALL_TABLES = [TABLE_1, TABLE_2, VIEW_1] @@ -842,6 +843,7 @@ def test_usage_counts_no_columns( ) ), ] + caplog.clear() with caplog.at_level(logging.WARNING): workunits = usage_extractor._get_workunits_internal( events, [TABLE_REFS[TABLE_1.name]] @@ -938,7 +940,7 @@ def test_operational_stats( ).to_urn("PROD") for field in query.fields_accessed if field.table.is_view() - for parent in cast(View, field.table).parents + for parent in field.table.upstreams ) ), ),