diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml index 4cb2d74..b06405b 100644 --- a/.github/ISSUE_TEMPLATE/bug.yml +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -57,7 +57,7 @@ body: id: os attributes: label: Operating System - description: Which operating system do you have UCX installed on? + description: Which operating system do you have DQX installed on? options: - macOS - Linux diff --git a/README.md b/README.md index 5a2b328..ffa1f74 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Simplified Data Quality checking at Scale for PySpark Workloads on streaming and * [Uninstall DQX from the Databricks workspace](#uninstall-dqx-from-the-databricks-workspace) * [How to use it](#how-to-use-it) * [Demos](#demos) - * [Data Profiling](#data-profiling) + * [Data Profiling and Quality Rules Generation](#data-profiling-and-quality-rules-generation) * [In Python](#in-python) * [Using CLI](#using-cli) * [Validating quality rules (checks)](#validating-quality-rules--checks-) @@ -125,6 +125,7 @@ and other configuration options. The cli command will install the following components in the workspace: - A Python [wheel file](https://peps.python.org/pep-0427/) with the library packaged. - DQX configuration file ('config.yml'). +- Profiling workflow for generating quality rule candidates. - Quality dashboard for monitoring to display information about the data quality issues. DQX configuration file can contain multiple run configurations defining specific set of input, output and quarantine locations etc. @@ -136,7 +137,7 @@ run_config: - name: default checks_file: checks.yml curated_location: main.dqx.curated - input_locations: main.dqx.input + input_location: main.dqx.input output_location: main.dqx.output profile_summary_stats_file: profile_summary_stats.yml quarantine_location: main.dqx.quarantine @@ -152,6 +153,11 @@ by setting 'DQX_FORCE_INSTALL' environment variable. The following options are a * `DQX_FORCE_INSTALL=global databricks labs install dqx`: will force the installation to be for root only (`/Applications/dqx`) * `DQX_FORCE_INSTALL=user databricks labs install dqx`: will force the installation to be for user only (`/Users//.dqx`) +To list all installed dqx workflows in the workspace and their latest run state, execute the following command: +```commandline +databricks labs dqx workflows +``` + ### Install the tool on the Databricks cluster After you install the tool on the workspace, you need to install the DQX package on a Databricks cluster. @@ -212,7 +218,7 @@ you can upload the following notebooks in the Databricks workspace to try it out * [DQX Demo Notebook](demos/dqx_demo.py) - demonstrates how to use DQX for data quality checks. * [DQX DLT Demo Notebook](demos/dqx_dlt_demo.py) - demonstrates how to use DQX with Delta Live Tables (DLT). -## Data Profiling +## Data Profiling and Quality Rules Generation Data profiling is run to profile the input data and generate quality rule candidates with summary statistics. The generated rules/checks are input for the quality checking (see [Adding quality checks to the application](#adding-quality-checks-to-the-application)). @@ -246,19 +252,27 @@ dlt_expectations = dlt_generator.generate_dlt_rules(profiles) ### Using CLI You must install DQX in the workspace before (see [installation](#installation-in-a-databricks-workspace)). +As part of the installation, profiler workflow is installed. It can be run manually in the workspace UI or using the CLI as below. -Run profiling job: +Run profiler workflow: ```commandline databricks labs dqx profile --run-config "default" ``` -If run config is not provided, the "default" run config will be used. The run config is used to select specific run configuration from 'config.yml'. +You will find the generated quality rule candidates and summary statistics in the installation folder as defined in the run config. +If run config is not provided, the "default" run config will be used. The run config is used to select specific run configuration from the 'config.yml'. -The following DQX configuration from 'config.yml' will be used by default: +The following DQX configuration from 'config.yml' are used: - 'input_location': input data as a path or a table. -- 'input_format': input data format. -- 'checks_file': relative location of the generated quality rule candidates (default: `checks.yml`). Can be json or yaml file. -- 'profile_summary_stats_file': relative location of the summary statistics (default: `profile_summary.yml`). Can be json or yaml file. +- 'input_format': input data format. Required if input data is a path. +- 'checks_file': relative location of the generated quality rule candidates (default: `checks.yml`). +- 'profile_summary_stats_file': relative location of the summary statistics (default: `profile_summary.yml`). + +Logs are be printed in the console and saved in the installation folder. +To show the saved logs from the latest profiler workflow run, visit the Databricks workspace UI or execute the following command: +```commandline +databricks labs dqx logs --workflow profiler +``` ## Validating quality rules (checks) diff --git a/demos/dqx_demo.py b/demos/dqx_demo.py index 7597826..ef37ddd 100644 --- a/demos/dqx_demo.py +++ b/demos/dqx_demo.py @@ -7,7 +7,7 @@ # MAGIC %md # MAGIC ### Installation DQX in the workspace # MAGIC -# MAGIC Install DQX in the workspace as per the instructions [here](https://github.com/databrickslabs/dqx?tab=readme-ov-file#installation). +# MAGIC Install DQX in the workspace (default user installation) as per the instructions [here](https://github.com/databrickslabs/dqx?tab=readme-ov-file#installation). # MAGIC Use default filename for data quality rules. # COMMAND ---------- @@ -17,11 +17,13 @@ # COMMAND ---------- -import subprocess +import glob +import os user_name = spark.sql('select current_user() as user').collect()[0]['user'] -pip_install_path = f"/Workspace/Users/{user_name}/.dqx/wheels/databricks_labs_dqx-*.whl" -%pip install {pip_install_path} +dqx_wheel_files = glob.glob(f"/Workspace/Users/{user_name}/.dqx/wheels/databricks_labs_dqx-*.whl") +dqx_latest_wheel = max(dqx_wheel_files, key=os.path.getctime) +%pip install {dqx_latest_wheel} # COMMAND ---------- diff --git a/demos/dqx_dlt_demo.py b/demos/dqx_dlt_demo.py index e24f0d7..5ad718c 100644 --- a/demos/dqx_dlt_demo.py +++ b/demos/dqx_dlt_demo.py @@ -1,11 +1,15 @@ # Databricks notebook source -# 1. Install DQX in the workspace as per the instructions here: https://github.com/databrickslabs/dqx?tab=readme-ov-file#installation +# 1. Install DQX in the workspace (default user installation) as per the instructions here: https://github.com/databrickslabs/dqx?tab=readme-ov-file#installation # Use default filename for data quality rules. # 2. Install DQX in the cluster +import glob +import os + user_name = "marcin.wojtyczka@databricks.com" # cannot dynamically retrieve user name as "System-User" is always returned: spark.sql('select current_user() as user').collect()[0]['user'] -pip_install_path = f"/Workspace/Users/{user_name}/.dqx/wheels/databricks_labs_dqx-*.whl" -%pip install {pip_install_path} +dqx_wheel_files = glob.glob(f"/Workspace/Users/{user_name}/.dqx/wheels/databricks_labs_dqx-*.whl") +dqx_latest_wheel = max(dqx_wheel_files, key=os.path.getctime) +%pip install {dqx_latest_wheel} # COMMAND ---------- diff --git a/labs.yml b/labs.yml index 584a14d..2b1ab71 100644 --- a/labs.yml +++ b/labs.yml @@ -1,10 +1,10 @@ --- name: dqx -description: Common libraries for Databricks Labs +description: Data Quality Framework for PySpark Workloads install: - script: src/databricks/labs/dqx/install.py + script: src/databricks/labs/dqx/installer/install.py uninstall: - script: src/databricks/labs/dqx/uninstall.py + script: src/databricks/labs/dqx/installer/uninstall.py entrypoint: src/databricks/labs/dqx/cli.py min_python: 3.10 commands: @@ -23,4 +23,20 @@ commands: description: Run config to use table_template: |- {{range .}}{{.error}} - {{end}} \ No newline at end of file + {{end}} + - name: profile + description: Profile input data and generate quality rule (checks) candidates + flags: + - name: run-config + description: (Optional) Selects run configuration from installation config. If not provided, use the "default" run configuration. + - name: workflows + description: Show deployed workflows and their latest run state + table_template: |- + Workflow\tWorkflow ID\tState\tStarted + {{range .}}{{.workflow}}\t{{.workflow_id}}\t{{.state}}\t{{.started}} + {{end}} + - name: logs + description: Show logs from the latest job run + flags: + - name: workflow + description: Name of the workflow to show logs for, e.g. profiler \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0ce7c8a..254e69b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,9 @@ dependencies = ["databricks-labs-blueprint>=0.9.1,<0.10", "databricks-sdk~=0.30", "databricks-labs-lsql>=0.5,<0.13"] +[project.entry-points.databricks] +runtime = "databricks.labs.dqx.runtime:main" + [project.urls] Issues = "https://github.com/databrickslabs/dqx/issues" Source = "https://github.com/databrickslabs/dqx" @@ -76,8 +79,8 @@ path = ".venv" [tool.hatch.envs.default.scripts] test = "pytest -n 10 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20" -coverage = "pytest -n 10 --cov src tests/ --timeout 240 --cov-report=html --durations 20" -integration = "pytest -n 10 --timeout 240 --cov src tests/integration --durations 20" +coverage = "pytest -n 10 --cov src tests/ --timeout 480 --cov-report=html --durations 20" +integration = "pytest -n 10 --timeout 480 --cov src tests/integration --durations 20" fmt = ["black . --extend-exclude 'demos/'", "ruff check . --fix", "mypy . --exclude 'demos/*'", @@ -539,7 +542,9 @@ disable = [ "consider-using-assignment-expr", "logging-fstring-interpolation", "consider-using-any-or-all", - "unnecessary-default-type-args" + "unnecessary-default-type-args", + "mock-no-usage", + "broad-exception-caught", ] # Enable the message, report, category or checker with the given id(s). You can diff --git a/src/databricks/labs/dqx/__init__.py b/src/databricks/labs/dqx/__init__.py index 54db6de..df97799 100644 --- a/src/databricks/labs/dqx/__init__.py +++ b/src/databricks/labs/dqx/__init__.py @@ -14,8 +14,8 @@ r"(?:\+(?P[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" ) -# Add ucx/ for projects depending on ucx as a library +# Add dqx/ for projects depending on dqx as a library ua.with_extra("dqx", __version__) -# Add ucx/ for re-packaging of ucx, where product name is omitted +# Add dqx/ for re-packaging of dqx, where product name is omitted ua.with_product("dqx", __version__) diff --git a/src/databricks/labs/dqx/cli.py b/src/databricks/labs/dqx/cli.py index 34c7bf7..ad82e5b 100644 --- a/src/databricks/labs/dqx/cli.py +++ b/src/databricks/labs/dqx/cli.py @@ -82,5 +82,46 @@ def validate_checks( return errors_list +@dqx.command +def profile(w: WorkspaceClient, *, run_config: str = "default", ctx: WorkspaceContext | None = None) -> None: + """ + Profile input data and generate quality rule (checks) candidates. + + :param w: The WorkspaceClient instance to use for accessing the workspace. + :param run_config: The name of the run configuration to use. + :param ctx: The WorkspaceContext instance to use for accessing the workspace. + """ + ctx = ctx or WorkspaceContext(w) + ctx.deployed_workflows.run_workflow("profiler", run_config) + + +@dqx.command +def workflows(w: WorkspaceClient, *, ctx: WorkspaceContext | None = None): + """ + Show deployed workflows and their state + + :param w: The WorkspaceClient instance to use for accessing the workspace. + :param ctx: The WorkspaceContext instance to use for accessing the workspace. + """ + ctx = ctx or WorkspaceContext(w) + logger.info("Fetching deployed jobs...") + latest_job_status = ctx.deployed_workflows.latest_job_status() + print(json.dumps(latest_job_status)) + return latest_job_status + + +@dqx.command +def logs(w: WorkspaceClient, *, workflow: str | None = None, ctx: WorkspaceContext | None = None): + """ + Show logs of the latest job run. + + :param w: The WorkspaceClient instance to use for accessing the workspace. + :param workflow: The name of the workflow to show logs for. + :param ctx: The WorkspaceContext instance to use for accessing the workspace + """ + ctx = ctx or WorkspaceContext(w) + ctx.deployed_workflows.relay_logs(workflow) + + if __name__ == "__main__": dqx() diff --git a/src/databricks/labs/dqx/config.py b/src/databricks/labs/dqx/config.py index dabc806..e7baa05 100644 --- a/src/databricks/labs/dqx/config.py +++ b/src/databricks/labs/dqx/config.py @@ -10,12 +10,14 @@ class RunConfig: """Configuration class for the data quality checks""" name: str = "default" # name of the run configuration - input_locations: str | None = None # input data path or a table + input_location: str | None = None # input data path or a table input_format: str | None = "delta" # input data format output_table: str | None = None # output data table quarantine_table: str | None = None # quarantined data table checks_file: str | None = "checks.yml" # file containing quality rules / checks profile_summary_stats_file: str | None = "profile_summary_stats.yml" # file containing profile summary statistics + override_clusters: dict[str, str] | None = None + spark_conf: dict[str, str] | None = None @dataclass diff --git a/src/databricks/labs/dqx/contexts/application.py b/src/databricks/labs/dqx/contexts/application.py index 2b973d0..255a349 100644 --- a/src/databricks/labs/dqx/contexts/application.py +++ b/src/databricks/labs/dqx/contexts/application.py @@ -1,12 +1,12 @@ import abc import logging -from datetime import timedelta from functools import cached_property from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.installer import InstallState from databricks.labs.blueprint.tui import Prompts from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 +from databricks.labs.dqx.installer.workflows_installer import DeployedWorkflows from databricks.sdk import WorkspaceClient from databricks.labs.dqx.config import WorkspaceConfig @@ -57,10 +57,6 @@ def installation(self): def config(self) -> WorkspaceConfig: return self.installation.load(WorkspaceConfig) - @cached_property - def verify_timeout(self): - return timedelta(minutes=2) - @cached_property def wheels(self): return WheelsV2(self.installation, self.product_info) @@ -69,12 +65,14 @@ def wheels(self): def install_state(self): return InstallState.from_installation(self.installation) + @cached_property + def deployed_workflows(self) -> DeployedWorkflows: + return DeployedWorkflows(self.workspace_client, self.install_state) + class CliContext(GlobalContext, abc.ABC): """ Abstract base class for global context, providing common properties and methods for workspace management. - - :param named_parameters: Optional dictionary of named parameters. """ @cached_property diff --git a/src/databricks/labs/dqx/contexts/workflow_task.py b/src/databricks/labs/dqx/contexts/workflow_task.py deleted file mode 100644 index 8d8e372..0000000 --- a/src/databricks/labs/dqx/contexts/workflow_task.py +++ /dev/null @@ -1,91 +0,0 @@ -from functools import cached_property -from pathlib import Path - -from databricks.labs.blueprint.installation import Installation -from databricks.labs.lsql.backends import RuntimeBackend, SqlBackend -from databricks.sdk import WorkspaceClient, core -from databricks.labs.dqx.contexts.application import GlobalContext -from databricks.labs.dqx.config import WorkspaceConfig -from databricks.labs.dqx.__about__ import __version__ - - -class RuntimeContext(GlobalContext): - """ - Returns the WorkspaceClient instance. - - :return: The WorkspaceClient instance. - """ - - @cached_property - def _config_path(self) -> Path: - config = self.named_parameters.get("config") - if not config: - raise ValueError("config flag is required") - return Path(config) - - @cached_property - def config(self) -> WorkspaceConfig: - """ - Loads and returns the workspace configuration. - - :return: The WorkspaceConfig instance. - """ - return Installation.load_local(WorkspaceConfig, self._config_path) - - @cached_property - def connect_config(self) -> core.Config: - """ - Returns the connection configuration. - - :return: The core.Config instance. - :raises AssertionError: If the connect configuration is not provided. - """ - connect = self.config.connect - assert connect, "connect is required" - return connect - - @cached_property - def workspace_client(self) -> WorkspaceClient: - """ - Returns the WorkspaceClient instance. - - :return: The WorkspaceClient instance. - """ - return WorkspaceClient(config=self.connect_config, product='dqx', product_version=__version__) - - @cached_property - def sql_backend(self) -> SqlBackend: - """ - Returns the SQL backend for the runtime. - - :return: The SqlBackend instance. - """ - return RuntimeBackend(debug_truncate_bytes=self.connect_config.debug_truncate_bytes) - - @cached_property - def installation(self) -> Installation: - """ - Returns the installation instance for the runtime. - - :return: The Installation instance. - """ - install_folder = self._config_path.parent.as_posix().removeprefix("/Workspace") - return Installation(self.workspace_client, "dqx", install_folder=install_folder) - - @cached_property - def workspace_id(self) -> int: - """ - Returns the workspace ID. - - :return: The workspace ID as an integer. - """ - return self.workspace_client.get_workspace_id() - - @cached_property - def parent_run_id(self) -> int: - """ - Returns the parent run ID. - - :return: The parent run ID as an integer. - """ - return int(self.named_parameters["parent_run_id"]) diff --git a/src/databricks/labs/dqx/contexts/workflows.py b/src/databricks/labs/dqx/contexts/workflows.py new file mode 100644 index 0000000..5328fed --- /dev/null +++ b/src/databricks/labs/dqx/contexts/workflows.py @@ -0,0 +1,81 @@ +from functools import cached_property +from pathlib import Path +from pyspark.sql import SparkSession + +from databricks.labs.blueprint.installation import Installation +from databricks.sdk import WorkspaceClient, core +from databricks.labs.dqx.contexts.application import GlobalContext +from databricks.labs.dqx.config import WorkspaceConfig, RunConfig +from databricks.labs.dqx.__about__ import __version__ +from databricks.labs.dqx.profiler.generator import DQGenerator +from databricks.labs.dqx.profiler.profiler import DQProfiler +from databricks.labs.dqx.profiler.runner import ProfilerRunner + + +class RuntimeContext(GlobalContext): + + @cached_property + def _config_path(self) -> Path: + config = self.named_parameters.get("config") + if not config: + raise ValueError("config flag is required") + return Path(config) + + @cached_property + def config(self) -> WorkspaceConfig: + """Loads and returns the workspace configuration.""" + return Installation.load_local(WorkspaceConfig, self._config_path) + + @cached_property + def run_config(self) -> RunConfig: + """Loads and returns the run configuration.""" + run_config_name = self.named_parameters.get("run_config_name") + if not run_config_name: + raise ValueError("Run config flag is required") + return self.config.get_run_config(run_config_name) + + @cached_property + def connect_config(self) -> core.Config: + """ + Returns the connection configuration. + + :return: The core.Config instance. + :raises AssertionError: If the connect configuration is not provided. + """ + connect = self.config.connect + assert connect, "connect is required" + return connect + + @cached_property + def workspace_client(self) -> WorkspaceClient: + """Returns the WorkspaceClient instance.""" + return WorkspaceClient( + config=self.connect_config, product=self.product_info.product_name(), product_version=__version__ + ) + + @cached_property + def installation(self) -> Installation: + """Returns the installation instance for the runtime.""" + install_folder = self._config_path.parent.as_posix().removeprefix("/Workspace") + return Installation(self.workspace_client, self.product_info.product_name(), install_folder=install_folder) + + @cached_property + def workspace_id(self) -> int: + """Returns the workspace ID.""" + return self.workspace_client.get_workspace_id() + + @cached_property + def parent_run_id(self) -> int: + """Returns the parent run ID.""" + return int(self.named_parameters["parent_run_id"]) + + @cached_property + def profiler(self) -> ProfilerRunner: + """Returns the ProfilerRunner instance.""" + spark_session = SparkSession.builder.getOrCreate() + profiler = DQProfiler(self.workspace_client) + generator = DQGenerator(self.workspace_client) + + return ProfilerRunner( + self.workspace_client, spark_session, installation=self.installation, profiler=profiler, generator=generator + ) diff --git a/src/databricks/labs/dqx/contexts/workspace_cli.py b/src/databricks/labs/dqx/contexts/workspace_cli.py index fc49cd5..b1b8345 100644 --- a/src/databricks/labs/dqx/contexts/workspace_cli.py +++ b/src/databricks/labs/dqx/contexts/workspace_cli.py @@ -19,9 +19,5 @@ def __init__(self, ws: WorkspaceClient, named_parameters: dict[str, str] | None @cached_property def workspace_client(self) -> WorkspaceClient: - """ - Returns the WorkspaceClient instance. - - :return: The WorkspaceClient instance. - """ + """Returns the WorkspaceClient instance.""" return self._ws diff --git a/src/databricks/labs/dqx/engine.py b/src/databricks/labs/dqx/engine.py index 1efb0b9..8801b0d 100644 --- a/src/databricks/labs/dqx/engine.py +++ b/src/databricks/labs/dqx/engine.py @@ -1,14 +1,14 @@ +import logging import os import functools as ft import inspect import itertools -import json -import logging from pathlib import Path from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum from typing import Any +import yaml import pyspark.sql.functions as F from pyspark.sql import Column, DataFrame @@ -94,11 +94,12 @@ def __post_init__(self): def rule_criticality(self) -> str: """Returns criticality of the check. - :return: string describing criticality - `warn` or `error`. Raises exception if it's something else + :return: string describing criticality - `warn` or `error`. + :raises ValueError: if criticality is invalid. """ criticality = self.criticality - if criticality not in {Criticality.WARN.value and criticality, Criticality.ERROR.value}: - criticality = Criticality.ERROR.value + if criticality not in {Criticality.WARN.value, Criticality.ERROR.value}: + raise ValueError(f"Invalid criticality value: {criticality}") return criticality @@ -588,5 +589,5 @@ def _deserialize_dicts(cls, checks: list[dict[str, str]]) -> list[dict]: for item in checks: for key, value in item.items(): if value.startswith("{") and value.endswith("}"): - item[key] = json.loads(value.replace("'", '"')) + item[key] = yaml.safe_load(value.replace("'", '"')) return checks diff --git a/src/databricks/labs/dqx/installer/__init__.py b/src/databricks/labs/dqx/installer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/databricks/labs/dqx/install.py b/src/databricks/labs/dqx/installer/install.py similarity index 88% rename from src/databricks/labs/dqx/install.py rename to src/databricks/labs/dqx/installer/install.py index 5d2f7c5..93b9964 100644 --- a/src/databricks/labs/dqx/install.py +++ b/src/databricks/labs/dqx/installer/install.py @@ -1,6 +1,5 @@ import logging import os -import re import webbrowser from functools import cached_property from requests.exceptions import ConnectionError as RequestsConnectionError @@ -16,29 +15,18 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.core import with_user_agent_extra from databricks.sdk.errors import InvalidParameterValue, NotFound, PermissionDenied +from databricks.labs.dqx.installer.workflows_installer import WorkflowsDeployment +from databricks.labs.dqx.runtime import Workflows from databricks.labs.dqx.__about__ import __version__ from databricks.labs.dqx.config import WorkspaceConfig, RunConfig from databricks.labs.dqx.contexts.workspace_cli import WorkspaceContext - +from databricks.labs.dqx.utils import extract_major_minor logger = logging.getLogger(__name__) with_user_agent_extra("cmd", "install") -def extract_major_minor(version_string: str): - """ - Extracts the major and minor version from a version string. - - :param version_string: The version string to extract from. - :return: The major.minor version as a string, or None if not found. - """ - match = re.search(r"(\d+\.\d+)", version_string) - if match: - return match.group(1) - return None - - class WorkspaceInstaller(WorkspaceContext): """ Installer for DQX workspace. @@ -58,6 +46,8 @@ def __init__(self, ws: WorkspaceClient, environ: dict[str, str] | None = None): msg = "WorkspaceInstaller is not supposed to be executed in Databricks Runtime" raise SystemExit(msg) + self._tasks = Workflows.all().tasks() + @cached_property def upgrades(self): """ @@ -85,29 +75,35 @@ def installation(self): def run( self, default_config: WorkspaceConfig | None = None, - config: WorkspaceConfig | None = None, ) -> WorkspaceConfig: """ Runs the installation process. :param default_config: Optional default configuration. - :param config: Optional configuration to use. :return: The final WorkspaceConfig used for the installation. :raises ManyError: If multiple errors occur during installation. :raises TimeoutError: If a timeout occurs during installation. """ logger.info(f"Installing DQX v{self.product_info.version()}") try: - if config is None: - config = self.configure(default_config) - if self._is_testing(): - return config + config = self.configure(default_config) + workflows_deployment = WorkflowsDeployment( + config, + config.get_run_config().name, + self.installation, + self.install_state, + self.workspace_client, + self.wheels, + self.product_info, + self._tasks, + ) workspace_installation = WorkspaceInstallation( config, self.installation, self.install_state, self.workspace_client, + workflows_deployment, self.prompts, self.product_info, ) @@ -132,9 +128,9 @@ def _prompt_for_new_installation(self) -> WorkspaceConfig: logger.info("Please answer a couple of questions to configure DQX") log_level = self.prompts.question("Log level", default="INFO").upper() - input_locations = self.prompts.question( - "Provide locations for the input data " - "as a path or table in the UC fully qualified format `..`)", + input_location = self.prompts.question( + "Provide location for the input data " + "as a path or table in the UC fully qualified format `catalog.schema.table`)", default="skipped", valid_regex=r"^\w.+$", ) @@ -146,13 +142,14 @@ def _prompt_for_new_installation(self) -> WorkspaceConfig: ) output_table = self.prompts.question( - "Provide output table in the UC fully qualified format `..
`", + "Provide output table in the UC fully qualified format `catalog.schema.table`", default="skipped", valid_regex=r"^\w.+$", ) quarantine_table = self.prompts.question( - "Provide quarantined table in the UC fully qualified format `..
`", + "Provide quarantined table in the UC fully qualified format `catalog.schema.table` " + "(use output table if skipped)", default="skipped", valid_regex=r"^\w.+$", ) @@ -174,7 +171,7 @@ def _prompt_for_new_installation(self) -> WorkspaceConfig: log_level=log_level, run_configs=[ RunConfig( - input_locations=input_locations, + input_location=input_location, input_format=input_format, output_table=output_table, quarantine_table=quarantine_table, @@ -269,12 +266,14 @@ def __init__( installation: Installation, install_state: InstallState, ws: WorkspaceClient, + workflows_installer: WorkflowsDeployment, prompts: Prompts, product_info: ProductInfo, ): self._config = config self._installation = installation self._install_state = install_state + self._workflows_installer = workflows_installer self._ws = ws self._prompts = prompts self._product_info = product_info @@ -292,13 +291,20 @@ def current(cls, ws: WorkspaceClient): installation = product_info.current_installation(ws) install_state = InstallState.from_installation(installation) config = installation.load(WorkspaceConfig) + run_config_name = config.get_run_config().name prompts = Prompts() + wheels = product_info.wheels(ws) + tasks = Workflows.all().tasks() + workflows_installer = WorkflowsDeployment( + config, run_config_name, installation, install_state, ws, wheels, product_info, tasks + ) return cls( config, installation, install_state, ws, + workflows_installer, prompts, product_info, ) @@ -326,6 +332,9 @@ def _upload_wheel(self) -> None: wheel_path = self._wheels.upload_to_wsfs() logger.info(f"Wheel uploaded to /Workspace{wheel_path}") + def _remove_jobs(self): + self._workflows_installer.remove_jobs() + def run(self) -> bool: """ Runs the workflow installation. @@ -333,7 +342,7 @@ def run(self) -> bool: :return: True if the installation finished successfully, False otherwise. """ logger.info(f"Installing DQX v{self._product_info.version()}") - install_tasks = [self._upload_wheel] + install_tasks = [self._workflows_installer.create_jobs] Threads.strict("installing components", install_tasks) logger.info("Installation completed successfully!") @@ -344,7 +353,7 @@ def uninstall(self): Uninstalls DQX from the workspace, including project folder, dashboards, and jobs. """ if self._prompts and not self._prompts.confirm( - "Do you want to uninstall DQX from the workspace too, this would " + "Do you want to uninstall DQX from the workspace? this would " "remove dqx project folder, dashboards, and jobs" ): return @@ -356,6 +365,7 @@ def uninstall(self): logger.error(f"Check if {self._installation.install_folder()} is present") return + self._remove_jobs() self._installation.remove() logger.info("Uninstalling DQX complete") diff --git a/src/databricks/labs/dqx/installer/logs.py b/src/databricks/labs/dqx/installer/logs.py new file mode 100644 index 0000000..a77a738 --- /dev/null +++ b/src/databricks/labs/dqx/installer/logs.py @@ -0,0 +1,195 @@ +import contextlib +import datetime as dt +import logging +import os +import re +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from logging.handlers import TimedRotatingFileHandler +from pathlib import Path +from typing import TextIO + +from databricks.labs.blueprint.logger import install_logger + +from databricks.sdk.retries import retried + +from databricks.labs.dqx.__about__ import __version__ + +logger = logging.getLogger(__name__) + + +@dataclass +class LogRecord: + timestamp: int + job_id: int + job_name: str + task_name: str + job_run_id: int + level: str + component: str + message: str + + +@dataclass +class PartialLogRecord: + """The information found within a log file record.""" + + time: dt.time + level: str + component: str + message: str + + +def peak_multi_line_message(log: TextIO, pattern: re.Pattern) -> tuple[str, re.Match | None, str]: + """ + A single log record message may span multiple log lines. In this case, the regex on + subsequent lines do not match. + + Args: + log (TextIO): The log file IO. + pattern (re.Pattern): The regex pattern for a log line. + """ + multi_line_message = "" + line = log.readline() + match = pattern.match(line) + while len(line) > 0 and match is None: + multi_line_message += "\n" + line.rstrip() + line = log.readline() + match = pattern.match(line) + return line, match, multi_line_message + + +def parse_logs(log: TextIO) -> Iterator[PartialLogRecord]: + """Parse the logs to retrieve values for PartialLogRecord fields. + + Args: + log (TextIO): The log file IO. + """ + time_format = "%H:%M:%S" + # This regex matches the log format defined in databricks.labs.dqx.installer.logs.TaskLogger + log_format = r"(\d+:\d+:\d+)\s(\w+)\s\[(.+)\]\s\{\w+\}\s(.+)" + pattern = re.compile(log_format) + + line = log.readline() + match = pattern.match(line) + if match is None: + logger.warning(f"Logs do not match expected format ({log_format}): {line}") + return + while len(line) > 0: + assert match is not None + time, *groups, message = match.groups() + + next_line, next_match, multi_line_message = peak_multi_line_message(log, pattern) + + time = dt.datetime.strptime(time, time_format).time() + # Mypy can't determine length of regex expressions + partial_log_record = PartialLogRecord(time, *groups, message + multi_line_message) # type: ignore + + yield partial_log_record + + line, match = next_line, next_match + + +class TaskLogger(contextlib.AbstractContextManager): + # files are available in the workspace only once their handlers are closed, + # so we rotate files log every minute to make them available for download. + # + # See https://docs.python.org/3/library/logging.handlers.html#logging.handlers.TimedRotatingFileHandler + # See https://docs.python.org/3/howto/logging-cookbook.html + + def __init__( + self, + install_dir: Path, + workflow: str, + job_id: str, + task_name: str, + job_run_id: str, + log_level="INFO", + attempt: str = "0", + ): + self._log_level = log_level + self._workflow = workflow + self._job_id = job_id + self._job_run_id = job_run_id + self._databricks_logger = logging.getLogger("databricks") + self._app_logger = logging.getLogger("databricks.labs.dqx") + self._log_path = self._get_log_path(install_dir, workflow, job_run_id, attempt) + self.log_file = self._log_path / f"{task_name}.log" + self._app_logger.info(f"DQX v{__version__} After workflow finishes, see debug logs at {self.log_file}") + + @classmethod + def _get_log_path(cls, install_dir: Path, workflow: str, workflow_run_id: str | int, attempt: str | int) -> Path: + return install_dir / "logs" / workflow / f"run-{workflow_run_id}-{attempt}" + + def __repr__(self): + return self.log_file.as_posix() + + def __enter__(self): + self._log_path.mkdir(parents=True, exist_ok=True) + self._init_debug_logfile() + self._init_run_readme() + self._databricks_logger.setLevel(logging.DEBUG) + self._app_logger.setLevel(logging.DEBUG) + console_handler = install_logger(self._log_level) + self._databricks_logger.removeHandler(console_handler) + self._databricks_logger.addHandler(self._file_handler) + return self + + def __exit__(self, _t, error, _tb): + if error: + log_file_for_cli = str(self.log_file).removeprefix("/Workspace") + cli_command = f"databricks workspace export /{log_file_for_cli}" + self._app_logger.error(f"Execute `{cli_command}` locally to troubleshoot with more details. {error}") + self._databricks_logger.debug("Task crash details", exc_info=error) + self._file_handler.flush() + self._file_handler.close() + + def _init_debug_logfile(self): + log_format = "%(asctime)s %(levelname)s [%(name)s] {%(threadName)s} %(message)s" + log_formatter = logging.Formatter(fmt=log_format, datefmt="%H:%M:%S") + self._file_handler = TimedRotatingFileHandler(self.log_file.as_posix(), when="M", interval=1) + self._file_handler.setFormatter(log_formatter) + self._file_handler.setLevel(logging.DEBUG) + + def _init_run_readme(self): + log_readme = self._log_path.joinpath("README.md") + if log_readme.exists(): + return + # this may race when run from multiple tasks, therefore it must be multiprocess safe + with self._exclusive_open(str(log_readme), mode="w") as f: + f.write(f"# Logs for the DQX {self._workflow} workflow\n") + f.write("This folder contains DQX log files.\n\n") + f.write(f"See the [{self._workflow} workflow](/#job/{self._job_id}) and ") + f.write(f"[run #{self._job_run_id}](/#job/{self._job_id}/run/{self._job_run_id})\n") + + @classmethod + @contextmanager + def _exclusive_open(cls, filename: str, **kwargs): + """Open a file with exclusive access across multiple processes. + Requires write access to the directory containing the file. + + Arguments are the same as the built-in open. + + Returns a context manager that closes the file and releases the lock. + """ + lockfile_name = filename + ".lock" + lockfile = cls._create_lock(lockfile_name) + + try: + with open(filename, encoding="utf8", **kwargs) as f: + yield f + finally: + try: + os.close(lockfile) + finally: + os.unlink(lockfile_name) + + @staticmethod + @retried(on=[FileExistsError], timeout=timedelta(seconds=5)) + def _create_lock(lockfile_name): + while True: # wait until the lock file can be opened + f = os.open(lockfile_name, os.O_CREAT | os.O_EXCL) + break + return f diff --git a/src/databricks/labs/dqx/installer/mixins.py b/src/databricks/labs/dqx/installer/mixins.py new file mode 100644 index 0000000..01e7cf8 --- /dev/null +++ b/src/databricks/labs/dqx/installer/mixins.py @@ -0,0 +1,26 @@ +import logging +import os + +from databricks.labs.blueprint.installation import Installation +from databricks.sdk import WorkspaceClient + +from databricks.labs.dqx.config import WorkspaceConfig + +logger = logging.getLogger(__name__) + + +class InstallationMixin: + def __init__(self, config: WorkspaceConfig, installation: Installation, ws: WorkspaceClient): + self._config = config + self._installation = installation + self._ws = ws + + def _name(self, name: str) -> str: + prefix = os.path.basename(self._installation.install_folder()).removeprefix('.') + return f"[{prefix.upper()}] {name}" + + @property + def _my_username(self): + if not hasattr(self, "_me"): + self._me = self._ws.current_user.me() + return self._me.user_name diff --git a/src/databricks/labs/dqx/uninstall.py b/src/databricks/labs/dqx/installer/uninstall.py similarity index 82% rename from src/databricks/labs/dqx/uninstall.py rename to src/databricks/labs/dqx/installer/uninstall.py index 9dae109..b0f428a 100644 --- a/src/databricks/labs/dqx/uninstall.py +++ b/src/databricks/labs/dqx/installer/uninstall.py @@ -3,10 +3,11 @@ from databricks.sdk import WorkspaceClient from databricks.labs.dqx.__about__ import __version__ -from databricks.labs.dqx.install import WorkspaceInstallation +from databricks.labs.dqx.installer.install import WorkspaceInstallation logger = logging.getLogger(__name__) + if __name__ == "__main__": logger.setLevel("INFO") ws = WorkspaceClient(product="dqx", product_version=__version__) diff --git a/src/databricks/labs/dqx/installer/workflow_task.py b/src/databricks/labs/dqx/installer/workflow_task.py new file mode 100644 index 0000000..3fc415a --- /dev/null +++ b/src/databricks/labs/dqx/installer/workflow_task.py @@ -0,0 +1,82 @@ +import logging +from collections.abc import Callable, Iterable +from dataclasses import dataclass + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.lsql.backends import SqlBackend +from databricks.sdk import WorkspaceClient + +from databricks.labs.dqx.config import WorkspaceConfig +from databricks.labs.dqx.utils import remove_extra_indentation + +logger = logging.getLogger(__name__) + +_TASKS: dict[str, "Task"] = {} + + +@dataclass +class Task: + workflow: str + name: str + doc: str + fn: Callable[[WorkspaceConfig, WorkspaceClient, SqlBackend, Installation], None] + depends_on: list[str] | None = None + job_cluster: str = "main" + + def dependencies(self): + """List of dependencies""" + if not self.depends_on: + return [] + return self.depends_on + + +class Workflow: + def __init__(self, name: str): + self._name = name + + @property + def name(self): + """Name of the workflow""" + return self._name + + def tasks(self) -> Iterable[Task]: + """List of tasks""" + # return __task__ from every method in this class that has this attribute + for attr in dir(self): + if attr.startswith("_"): # skip private methods + continue + fn = getattr(self, attr) + if hasattr(fn, "__task__"): + yield fn.__task__ + + +def workflow_task(fn=None, *, depends_on=None, job_cluster=Task.job_cluster) -> Callable[[Callable], Callable]: + def register(func): + """Register a task""" + if not func.__doc__: + raise SyntaxError(f"{func.__name__} must have some doc comment") + deps = [] + this_class = func.__qualname__.split('.')[0] + if depends_on is not None: + if not isinstance(depends_on, list): + msg = "depends_on has to be a list" + raise SyntaxError(msg) + for dep in depends_on: + other_class, task_name = dep.__qualname__.split('.') + if other_class != this_class: + continue + deps.append(task_name) + func.__task__ = Task( + workflow='', + name=func.__name__, + doc=remove_extra_indentation(func.__doc__), + fn=func, + depends_on=deps, + job_cluster=job_cluster, + ) + return func + + if fn is None: + return register + register(fn) + return fn diff --git a/src/databricks/labs/dqx/installer/workflows_installer.py b/src/databricks/labs/dqx/installer/workflows_installer.py new file mode 100644 index 0000000..e9affe7 --- /dev/null +++ b/src/databricks/labs/dqx/installer/workflows_installer.py @@ -0,0 +1,638 @@ +from __future__ import annotations + +import logging +import os.path +import re +import sys +from collections.abc import Iterator +from dataclasses import replace +from datetime import datetime, timedelta, timezone +from io import StringIO +from pathlib import Path +from typing import Any + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.parallel import ManyError +from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ( + Aborted, + AlreadyExists, + BadRequest, + Cancelled, + DataLoss, + DeadlineExceeded, + InternalError, + InvalidParameterValue, + NotFound, + OperationFailed, + PermissionDenied, + RequestLimitExceeded, + ResourceAlreadyExists, + ResourceConflict, + ResourceDoesNotExist, + ResourceExhausted, + TemporarilyUnavailable, + TooManyRequests, + Unauthenticated, + Unknown, +) +from databricks.sdk.retries import retried +from databricks.sdk.service import compute, jobs +from databricks.sdk.service.jobs import Run +from databricks.sdk.service.workspace import ObjectType + +import databricks +from databricks.labs.dqx.config import WorkspaceConfig +from databricks.labs.dqx.installer.workflow_task import Task +from databricks.labs.dqx.installer.mixins import InstallationMixin +from databricks.labs.dqx.installer.logs import PartialLogRecord, parse_logs + + +logger = logging.getLogger(__name__) + +TEST_RESOURCE_PURGE_TIMEOUT = timedelta(hours=1) +TEST_NIGHTLY_CI_RESOURCES_PURGE_TIMEOUT = timedelta(hours=3) # Buffer for debugging nightly integration test runs +EXTRA_TASK_PARAMS = { + "job_id": "{{job_id}}", + "run_id": "{{run_id}}", + "start_time": "{{job.start_time.iso_datetime}}", + "attempt": "{{job.repair_count}}", + "parent_run_id": "{{parent_run_id}}", +} + + +class DeployedWorkflows: + def __init__(self, ws: WorkspaceClient, install_state: InstallState): + self._ws = ws + self._install_state = install_state + + def run_workflow( + self, + workflow: str, + run_config_name: str, + max_wait: timedelta = timedelta(minutes=20), + ) -> int: + # this dunder variable is hiding this method from tracebacks, making it cleaner + # for the user to see the actual error without too much noise. + __tracebackhide__ = True + logger.debug(__tracebackhide__) + + job_id = int(self._install_state.jobs[workflow]) + logger.debug(f"starting {workflow} workflow: {self._ws.config.host}#job/{job_id}") + job_initial_run = self._ws.jobs.run_now(job_id, python_named_params={"run_config_name": run_config_name}) + run_id = job_initial_run.run_id + run_url = f"{self._ws.config.host}#job/{job_id}/runs/{run_id}" + logger.info(f"Started {workflow} workflow: {run_url}") + + try: + logger.debug(f"Waiting for completion of {workflow} workflow: {run_url}") + job_run = self._ws.jobs.wait_get_run_job_terminated_or_skipped(run_id=run_id, timeout=max_wait) + self._log_completed_job(workflow, run_id, job_run) + logger.info('---------- REMOTE LOGS --------------') + self._relay_logs(workflow, run_id) + logger.info('---------- END REMOTE LOGS ----------') + return run_id + except TimeoutError: + logger.warning(f"Timeout while waiting for {workflow} workflow to complete: {run_url}") + logger.info('---------- REMOTE LOGS --------------') + self._relay_logs(workflow, run_id) + logger.info('------ END REMOTE LOGS (SO FAR) -----') + raise + except OperationFailed as err: + logger.info('---------- REMOTE LOGS --------------') + self._relay_logs(workflow, run_id) + logger.info('---------- END REMOTE LOGS ----------') + job_run = self._ws.jobs.get_run(run_id) + raise self._infer_error_from_job_run(job_run) from err + + @staticmethod + def _log_completed_job(step: str, run_id: int, job_run: Run) -> None: + if job_run.state: + result_state = job_run.state.result_state or "N/A" + state_message = job_run.state.state_message + state_description = f"{result_state} ({state_message})" if state_message else f"{result_state}" + logger.info(f"Completed {step} workflow run {run_id} with state: {state_description}") + else: + logger.warning(f"Completed {step} workflow run {run_id} but end state is unknown.") + if job_run.start_time or job_run.end_time: + start_time = ( + datetime.fromtimestamp(job_run.start_time / 1000, tz=timezone.utc) if job_run.start_time else None + ) + end_time = datetime.fromtimestamp(job_run.end_time / 1000, tz=timezone.utc) if job_run.end_time else None + if job_run.run_duration: + duration = timedelta(milliseconds=job_run.run_duration) + elif start_time and end_time: + duration = end_time - start_time + else: + duration = None + logger.info( + f"Completed {step} workflow run {run_id} duration: {duration or 'N/A'} ({start_time or 'N/A'} thru {end_time or 'N/A'})" + ) + + def latest_job_status(self) -> list[dict]: + latest_status = [] + for job, job_id in self._install_state.jobs.items(): + job_state = None + start_time = None + try: + job_runs = list(self._ws.jobs.list_runs(job_id=int(job_id), limit=1)) + except InvalidParameterValue as e: + logger.warning(f"skipping {job}: {e}") + continue + if job_runs: + state = job_runs[0].state + if state and state.result_state: + job_state = state.result_state.name + elif state and state.life_cycle_state: + job_state = state.life_cycle_state.name + if job_runs[0].start_time: + start_time = job_runs[0].start_time / 1000 + latest_status.append( + { + "workflow": job, + "workflow_id": job_id, + "state": "UNKNOWN" if not (job_runs and job_state) else job_state, + "started": ( + "" if not (job_runs and start_time) else self._readable_timedelta(start_time) + ), + } + ) + return latest_status + + def relay_logs(self, workflow: str | None = None): + latest_run = None + if not workflow: + runs = [] + for step in self._install_state.jobs: + try: + _, latest_run = self._latest_job_run(step) + runs.append((step, latest_run)) + except InvalidParameterValue: + continue + if not runs: + logger.warning("No jobs to relay logs for") + return + runs = sorted(runs, key=lambda x: x[1].start_time, reverse=True) + workflow, latest_run = runs[0] + if not latest_run: + assert workflow is not None + _, latest_run = self._latest_job_run(workflow) + self._relay_logs(workflow, latest_run.run_id) + + def _relay_logs(self, workflow, run_id): + for record in self._fetch_last_run_attempt_logs(workflow, run_id): + task_logger = logging.getLogger(record.component) + MaxedStreamHandler.install_handler(task_logger) + task_logger.setLevel(logger.getEffectiveLevel()) + log_level = logging.getLevelName(record.level) + task_logger.log(log_level, record.message) + MaxedStreamHandler.uninstall_handlers() + + def _fetch_last_run_attempt_logs(self, workflow: str, run_id: str) -> Iterator[PartialLogRecord]: + """Fetch the logs for the last run attempt.""" + run_folders = self._get_log_run_folders(workflow, run_id) + if not run_folders: + return + # sort folders based on the last repair attempt + last_attempt = sorted(run_folders, key=lambda _: int(_.split('-')[-1]), reverse=True)[0] + for object_info in self._ws.workspace.list(last_attempt): + if not object_info.path: + continue + if '.log' not in object_info.path: + continue + task_name = os.path.basename(object_info.path).split('.')[0] + with self._ws.workspace.download(object_info.path) as raw_file: + text_io = StringIO(raw_file.read().decode()) + for record in parse_logs(text_io): + yield replace(record, component=f'{record.component}:{task_name}') + + def _get_log_run_folders(self, workflow: str, run_id: str) -> list[str]: + """Get the log run folders. + + The log run folders are located in the installation folder under the logs directory. Each job has a log run + folder for each run id. Multiple runs occur for repair runs. + """ + log_path = f"{self._install_state.install_folder()}/logs/{workflow}" + try: + # Ensure any exception is triggered early. + log_path_objects = list(self._ws.workspace.list(log_path)) + except ResourceDoesNotExist: + logger.warning(f"Cannot fetch logs as folder {log_path} does not exist") + return [] + run_folders = [] + for run_folder in log_path_objects: + if not run_folder.path or run_folder.object_type != ObjectType.DIRECTORY: + continue + if f"run-{run_id}-" not in run_folder.path: + continue + run_folders.append(run_folder.path) + return run_folders + + @staticmethod + def _readable_timedelta(epoch): + when = datetime.utcfromtimestamp(epoch) + duration = datetime.now() - when + data = {} + data["days"], remaining = divmod(duration.total_seconds(), 86_400) + data["hours"], remaining = divmod(remaining, 3_600) + data["minutes"], data["seconds"] = divmod(remaining, 60) + + time_parts = ((name, round(value)) for (name, value) in data.items()) + time_parts = [f"{value} {name[:-1] if value == 1 else name}" for name, value in time_parts if value > 0] + if len(time_parts) > 0: + time_parts.append("ago") + if time_parts: + return " ".join(time_parts) + return "less than 1 second ago" + + def _latest_job_run(self, workflow: str): + job_id = self._install_state.jobs.get(workflow) + if not job_id: + raise InvalidParameterValue("job does not exists hence skipping repair") + job_runs = list(self._ws.jobs.list_runs(job_id=job_id, limit=1)) + if not job_runs: + raise InvalidParameterValue("job is not initialized yet. Can't trigger repair run now") + latest_job_run = job_runs[0] + return job_id, latest_job_run + + def _infer_error_from_job_run(self, job_run) -> Exception: + errors: list[Exception] = [] + timeouts: list[DeadlineExceeded] = [] + assert job_run.tasks is not None + for run_task in job_run.tasks: + error = self._infer_error_from_task_run(run_task) + if not error: + continue + if isinstance(error, DeadlineExceeded): + timeouts.append(error) + continue + errors.append(error) + assert job_run.state is not None + assert job_run.state.state_message is not None + if len(errors) == 1: + return errors[0] + all_errors = errors + timeouts + if len(all_errors) == 0: + return Unknown(job_run.state.state_message) + return ManyError(all_errors) + + def _infer_error_from_task_run(self, run_task: jobs.RunTask) -> Exception | None: + if not run_task.state: + return None + if run_task.state.result_state == jobs.RunResultState.TIMEDOUT: + msg = f"{run_task.task_key}: The run was stopped after reaching the timeout" + return DeadlineExceeded(msg) + if run_task.state.result_state != jobs.RunResultState.FAILED: + return None + assert run_task.run_id is not None + run_output = self._ws.jobs.get_run_output(run_task.run_id) + if not run_output: + msg = f'No run output. {run_task.state.state_message}' + return InternalError(msg) + if logger.isEnabledFor(logging.DEBUG): + if run_output.error_trace: + sys.stderr.write(run_output.error_trace) + if not run_output.error: + msg = f'No error in run output. {run_task.state.state_message}' + return InternalError(msg) + return self._infer_task_exception(f"{run_task.task_key}: {run_output.error}") + + @staticmethod + def _infer_task_exception(haystack: str) -> Exception: + needles: list[type[Exception]] = [ + BadRequest, + Unauthenticated, + PermissionDenied, + NotFound, + ResourceConflict, + TooManyRequests, + Cancelled, + databricks.sdk.errors.NotImplemented, + InternalError, + TemporarilyUnavailable, + DeadlineExceeded, + InvalidParameterValue, + ResourceDoesNotExist, + Aborted, + AlreadyExists, + ResourceAlreadyExists, + ResourceExhausted, + RequestLimitExceeded, + Unknown, + DataLoss, + ValueError, + KeyError, + ] + constructors: dict[re.Pattern, type[Exception]] = { + re.compile(r".*\[TimeoutException] (.*)"): TimeoutError, + } + for klass in needles: + constructors[re.compile(f".*{klass.__name__}: (.*)")] = klass + for pattern, klass in constructors.items(): + match = pattern.match(haystack) + if match: + return klass(match.group(1)) + return Unknown(haystack) + + +class WorkflowsDeployment(InstallationMixin): + def __init__( + self, + config: WorkspaceConfig, + run_config_name: str, + installation: Installation, + install_state: InstallState, + ws: WorkspaceClient, + wheels: WheelsV2, + product_info: ProductInfo, + tasks: list[Task], + ): + self._config = config + self._run_config = self._config.get_run_config(run_config_name) + self._installation = installation + self._ws = ws + self._install_state = install_state + self._wheels = wheels + self._product_info = product_info + self._tasks = tasks + self._this_file = Path(__file__) + super().__init__(config, installation, ws) + + def create_jobs(self) -> None: + remote_wheels = self._upload_wheel() + desired_workflows = {task.workflow for task in self._tasks} + + for workflow_name in desired_workflows: + settings = self._job_settings(workflow_name, remote_wheels) + if self._run_config.override_clusters: + settings = self._apply_cluster_overrides( + settings, + self._run_config.override_clusters, + ) + self._deploy_workflow(workflow_name, settings) + + self.remove_jobs(keep=desired_workflows) + self._install_state.save() + + def remove_jobs(self, *, keep: set[str] | None = None) -> None: + for workflow_name, job_id in self._install_state.jobs.items(): + if keep and workflow_name in keep: + continue + try: + if not self._is_managed_job_failsafe(int(job_id)): + logger.warning(f"Corrupt installation state. Skipping job_id={job_id} as it is not managed by DQX") + continue + logger.info(f"Removing job_id={job_id}, as it is no longer needed") + self._ws.jobs.delete(job_id) + except InvalidParameterValue: + logger.warning(f"step={workflow_name} does not exist anymore for some reason") + continue + + def _is_testing(self): + return self._product_info.product_name() != "dqx" + + @staticmethod + def _is_nightly(): + ci_env = os.getenv("TEST_NIGHTLY") + return ci_env is not None and ci_env.lower() == "true" + + @classmethod + def _get_test_purge_time(cls) -> str: + # Duplicate of mixins.fixtures.get_test_purge_time(); we don't want to import pytest as a transitive dependency. + timeout = TEST_NIGHTLY_CI_RESOURCES_PURGE_TIMEOUT if cls._is_nightly() else TEST_RESOURCE_PURGE_TIMEOUT + now = datetime.now(timezone.utc) + purge_deadline = now + timeout + # Round UP to the next hour boundary: that is when resources will be deleted. + purge_hour = purge_deadline + (datetime.min.replace(tzinfo=timezone.utc) - purge_deadline) % timedelta(hours=1) + return purge_hour.strftime("%Y%m%d%H") + + def _is_managed_job_failsafe(self, job_id: int) -> bool: + try: + return self._is_managed_job(job_id) + except ResourceDoesNotExist: + return False + except InvalidParameterValue: + return False + + def _is_managed_job(self, job_id: int) -> bool: + job = self._ws.jobs.get(job_id) + if not job.settings or not job.settings.tasks: + return False + for task in job.settings.tasks: + if task.python_wheel_task and task.python_wheel_task.package_name == "databricks_labs_dqx": + return True + return False + + @property + def _config_file(self): + return f"{self._installation.install_folder()}/config.yml" + + def _job_cluster_spark_conf(self, cluster_key: str): + conf_from_installation = self._run_config.spark_conf if self._run_config.spark_conf else {} + if cluster_key == "main": + spark_conf = { + "spark.databricks.cluster.profile": "singleNode", + "spark.master": "local[*]", + } + return spark_conf | conf_from_installation + return conf_from_installation + + # Workflow creation might fail on an InternalError with no message + @retried(on=[InternalError], timeout=timedelta(minutes=2)) + def _deploy_workflow(self, step_name: str, settings): + if step_name in self._install_state.jobs: + try: + job_id = int(self._install_state.jobs[step_name]) + logger.info(f"Updating configuration for step={step_name} job_id={job_id}") + return self._ws.jobs.reset(job_id, jobs.JobSettings(**settings)) + except InvalidParameterValue: + del self._install_state.jobs[step_name] + logger.warning(f"step={step_name} does not exist anymore for some reason") + return self._deploy_workflow(step_name, settings) + logger.info(f"Creating new job configuration for step={step_name}") + new_job = self._ws.jobs.create(**settings) + assert new_job.job_id is not None + self._install_state.jobs[step_name] = str(new_job.job_id) + return None + + @staticmethod + def _library_dep_order(library: str): + match library: + case library if 'sdk' in library: + return 0 + case library if 'blueprint' in library: + return 1 + case _: + return 2 + + def _upload_wheel(self): + wheel_paths = [] + with self._wheels: + wheel_paths.sort(key=WorkflowsDeployment._library_dep_order) + wheel_paths.append(self._wheels.upload_to_wsfs()) + wheel_paths = [f"/Workspace{wheel}" for wheel in wheel_paths] + return wheel_paths + + @staticmethod + def _apply_cluster_overrides( + settings: dict[str, Any], + overrides: dict[str, str], + ) -> dict: + settings["job_clusters"] = [_ for _ in settings["job_clusters"] if _.job_cluster_key not in overrides] + for job_task in settings["tasks"]: + if job_task.job_cluster_key is None: + continue + if job_task.job_cluster_key in overrides: + job_task.existing_cluster_id = overrides[job_task.job_cluster_key] + job_task.job_cluster_key = None + job_task.libraries = None + return settings + + def _job_settings(self, step_name: str, remote_wheels: list[str]) -> dict[str, Any]: + email_notifications = None + if not self._is_testing() and "@" in self._my_username: + # set email notifications only if we're running the real installation and not tests + email_notifications = jobs.JobEmailNotifications( + on_success=[self._my_username], on_failure=[self._my_username] + ) + + job_tasks = [] + job_clusters: set[str] = {Task.job_cluster} + for task in self._tasks: + if task.workflow != step_name: + continue + job_clusters.add(task.job_cluster) + job_tasks.append(self._job_task(task, remote_wheels)) + + version = self._product_info.version() + version = version if not self._ws.config.is_gcp else version.replace("+", "-") + tags = {"version": f"v{version}"} + if self._is_testing(): + # add RemoveAfter tag for test job cleanup + date_to_remove = self._get_test_purge_time() + tags.update({"RemoveAfter": date_to_remove}) + return { + "name": self._name(step_name), + "tags": tags, + "job_clusters": self._job_clusters(job_clusters), + "email_notifications": email_notifications, + "tasks": job_tasks, + } + + def _job_task(self, task: Task, remote_wheels: list[str]) -> jobs.Task: + jobs_task = jobs.Task( + task_key=task.name, + job_cluster_key=task.job_cluster, + depends_on=[jobs.TaskDependency(task_key=d) for d in task.dependencies()], + ) + return self._job_wheel_task(jobs_task, task.workflow, remote_wheels) + + def _job_wheel_task(self, jobs_task: jobs.Task, workflow: str, remote_wheels: list[str]) -> jobs.Task: + libraries = [] + for wheel in remote_wheels: + libraries.append(compute.Library(whl=wheel)) + named_parameters = { + "config": f"/Workspace{self._config_file}", + "run_config_name": self._run_config.name, + "workflow": workflow, + "task": jobs_task.task_key, + } + return replace( + jobs_task, + libraries=libraries, + python_wheel_task=jobs.PythonWheelTask( + package_name="databricks_labs_dqx", + entry_point="runtime", # [project.entry-points.databricks] in pyproject.toml + named_parameters=named_parameters | EXTRA_TASK_PARAMS, + ), + ) + + def _job_clusters(self, job_clusters: set[str]): + clusters = [] + if "main" in job_clusters: + latest_lts_dbr = self._ws.clusters.select_spark_version(latest=True, long_term_support=True) + node_type_id = self._ws.clusters.select_node_type( + local_disk=True, min_memory_gb=16, min_cores=4, photon_worker_capable=True + ) + clusters = [ + jobs.JobCluster( + job_cluster_key="main", + new_cluster=compute.ClusterSpec( + spark_version=latest_lts_dbr, + node_type_id=node_type_id, + data_security_mode=compute.DataSecurityMode.SINGLE_USER, + spark_conf=self._job_cluster_spark_conf("main"), + custom_tags={"ResourceClass": "SingleNode"}, + num_workers=0, + ), + ) + ] + return clusters + + +class MaxedStreamHandler(logging.StreamHandler): + + MAX_STREAM_SIZE = 2**20 - 2**6 # 1 Mb minus some buffer + _installed_handlers: dict[str, tuple[logging.Logger, MaxedStreamHandler]] = {} + _sent_bytes = 0 + + @classmethod + def install_handler(cls, logger_: logging.Logger): + if logger_.handlers: + # already installed ? + installed = next((h for h in logger_.handlers if isinstance(h, MaxedStreamHandler)), None) + if installed: + return + # any handler to override ? + handler = next((h for h in logger_.handlers if isinstance(h, logging.StreamHandler)), None) + if handler: + to_install = MaxedStreamHandler(handler) + cls._installed_handlers[logger_.name] = (logger_, to_install) + logger_.removeHandler(handler) + logger_.addHandler(to_install) + return + if logger_.parent: + cls.install_handler(logger_.parent) + if logger_.root: + cls.install_handler(logger_.root) + + @classmethod + def uninstall_handlers(cls): + for logger_, handler in cls._installed_handlers.values(): + logger_.removeHandler(handler) + logger_.addHandler(handler.original_handler) + cls._installed_handlers.clear() + cls._sent_bytes = 0 + + def __init__(self, original_handler: logging.StreamHandler): + super().__init__() + self._original_handler = original_handler + + @property + def original_handler(self): + return self._original_handler + + def emit(self, record): + try: + msg = self.format(record) + self.terminator + if self._prevent_overflow(msg): + return + self.stream.write(msg) + self.flush() + except RecursionError: # See issue 36272 + raise + # the below is copied from Python source + # so ensuring not to break the logging logic + except Exception: + self.handleError(record) + + def _prevent_overflow(self, msg: str): + data = msg.encode("utf-8") + if self._sent_bytes + len(data) > self.MAX_STREAM_SIZE: + # ensure readers are aware of why the logs are incomplete + self.stream.write(f"MAX LOGS SIZE REACHED: {self._sent_bytes} bytes!!!") + self.flush() + return True + return False diff --git a/src/databricks/labs/dqx/profiler/dlt_generator.py b/src/databricks/labs/dqx/profiler/dlt_generator.py index 4572187..248ddb4 100644 --- a/src/databricks/labs/dqx/profiler/dlt_generator.py +++ b/src/databricks/labs/dqx/profiler/dlt_generator.py @@ -169,7 +169,6 @@ def _generate_dlt_rules_sql(self, rules: list[DQProfile], action: str | None = N if expr == "": logger.info("Empty expression was generated for rule '{nm}' for column '{cl}'") continue - # TODO: generate constraint name in lower_case, etc. dlt_rule = f"CONSTRAINT {col_name}_{rule_name} EXPECT ({expr}){act_str}" dlt_rules.append(dlt_rule) diff --git a/src/databricks/labs/dqx/profiler/generator.py b/src/databricks/labs/dqx/profiler/generator.py index 92e1d65..2c51b51 100644 --- a/src/databricks/labs/dqx/profiler/generator.py +++ b/src/databricks/labs/dqx/profiler/generator.py @@ -1,6 +1,7 @@ import logging from databricks.labs.dqx.base import DQEngineBase +from databricks.labs.dqx.engine import DQEngine from databricks.labs.dqx.profiler.common import val_maybe_to_str from databricks.labs.dqx.profiler.profiler import DQProfile @@ -31,6 +32,9 @@ def generate_dq_rules(self, rules: list[DQProfile] | None = None, level: str = " if expr: dq_rules.append(expr) + status = DQEngine.validate_checks(dq_rules) + assert not status.has_errors + return dq_rules @staticmethod @@ -62,6 +66,9 @@ def dq_generate_min_max(col_name: str, level: str = "error", **params: dict): min_limit = params.get("min") max_limit = params.get("max") + if not isinstance(min_limit, int) or not isinstance(max_limit, int): + return None # TODO handle timestamp and dates: https://github.com/databrickslabs/dqx/issues/71 + if min_limit is not None and max_limit is not None: return { "check": { @@ -114,8 +121,7 @@ def dq_generate_is_not_null(col_name: str, level: str = "error", **params: dict) :param params: Additional parameters. :return: A dictionary representing the data quality rule. """ - if params: - pass + params = params or {} return { "check": {"function": "is_not_null", "arguments": {"col_name": col_name}}, "name": f"{col_name}_is_null", diff --git a/src/databricks/labs/dqx/profiler/runner.py b/src/databricks/labs/dqx/profiler/runner.py new file mode 100644 index 0000000..a842d41 --- /dev/null +++ b/src/databricks/labs/dqx/profiler/runner.py @@ -0,0 +1,76 @@ +from typing import Any +import logging +import yaml +from pyspark.sql import SparkSession + +from databricks.labs.dqx.utils import read_input_data +from databricks.labs.dqx.profiler.generator import DQGenerator +from databricks.labs.dqx.profiler.profiler import DQProfiler +from databricks.sdk import WorkspaceClient +from databricks.labs.blueprint.installation import Installation + + +logger = logging.getLogger(__name__) + + +class ProfilerRunner: + """Runs the DQX profiler on the input data and saves the generated checks and profile summary stats.""" + + def __init__( + self, + ws: WorkspaceClient, + spark: SparkSession, + installation: Installation, + profiler: DQProfiler, + generator: DQGenerator, + ): + self.spark = spark + self.ws = ws + self.installation = installation + self.profiler = profiler + self.generator = generator + + def run( + self, + input_location: str | None, + input_format: str | None = None, + ) -> tuple[list[dict], dict[str, Any]]: + """ + Run the DQX profiler on the input data and return the generated checks and profile summary stats. + + :param input_location: The location of the input data. + :param input_format: The format of the input data. + :return: A tuple containing the generated checks and profile summary statistics. + """ + df = read_input_data(self.spark, input_location, input_format) + summary_stats, profiles = self.profiler.profile(df) + checks = self.generator.generate_dq_rules(profiles) # use default criticality level "error" + logger.info(f"Generated checks:\n{checks}") + logger.info(f"Generated summary statistics:\n{summary_stats}") + return checks, summary_stats + + def save( + self, + checks: list[dict], + summary_stats: dict[str, Any], + checks_file: str | None, + profile_summary_stats_file: str | None, + ) -> None: + """ + Save the generated checks and profile summary statistics to the specified files. + + :param checks: The generated checks. + :param summary_stats: The profile summary statistics. + :param checks_file: The file to save the checks to. + :param profile_summary_stats_file: The file to save the profile summary statistics to. + """ + if not checks_file: + raise ValueError("Check file not configured") + if not profile_summary_stats_file: + raise ValueError("Profile summary stats file not configured") + + install_folder = self.installation.install_folder() + logger.info(f"Uploading checks to {install_folder}/{checks_file}") + self.installation.upload(checks_file, yaml.safe_dump(checks).encode('utf-8')) + logger.info(f"Uploading profile summary stats to {install_folder}/{profile_summary_stats_file}") + self.installation.upload(profile_summary_stats_file, yaml.dump(summary_stats).encode('utf-8')) diff --git a/src/databricks/labs/dqx/profiler/workflow.py b/src/databricks/labs/dqx/profiler/workflow.py new file mode 100644 index 0000000..e7b357f --- /dev/null +++ b/src/databricks/labs/dqx/profiler/workflow.py @@ -0,0 +1,26 @@ +import logging + +from databricks.labs.dqx.contexts.workflows import RuntimeContext +from databricks.labs.dqx.installer.workflow_task import Workflow, workflow_task + + +logger = logging.getLogger(__name__) + + +class ProfilerWorkflow(Workflow): + def __init__(self): + super().__init__('profiler') + + @workflow_task + def profile(self, ctx: RuntimeContext): + """ + Profile the input data and save the generated checks and profile summary stats. + + :param ctx: Runtime context. + """ + run_config = ctx.run_config + checks, profile_summary_stats = ctx.profiler.run( + run_config.input_location, + run_config.input_format, + ) + ctx.profiler.save(checks, profile_summary_stats, run_config.checks_file, run_config.profile_summary_stats_file) diff --git a/src/databricks/labs/dqx/runtime.py b/src/databricks/labs/dqx/runtime.py new file mode 100644 index 0000000..b910c28 --- /dev/null +++ b/src/databricks/labs/dqx/runtime.py @@ -0,0 +1,97 @@ +import dataclasses +import logging +import os +import sys +from pathlib import Path + +from databricks.sdk.config import with_user_agent_extra + +from databricks.labs.dqx.__about__ import __version__ +from databricks.labs.dqx.profiler.workflow import ProfilerWorkflow +from databricks.labs.dqx.contexts.workflows import RuntimeContext +from databricks.labs.dqx.installer.workflow_task import Task, Workflow +from databricks.labs.dqx.installer.logs import TaskLogger + +logger = logging.getLogger(__name__) + + +class Workflows: + def __init__(self, workflows: list[Workflow]): + self._tasks: list[Task] = [] + self._workflows: dict[str, Workflow] = {} + for workflow in workflows: + self._workflows[workflow.name] = workflow + for task_definition in workflow.tasks(): + # Add the workflow name to the task definition, because we cannot access + # the workflow name from the method decorator + with_workflow = dataclasses.replace(task_definition, workflow=workflow.name) + self._tasks.append(with_workflow) + + @classmethod + def all(cls): + """Return all workflows.""" + return cls( + [ + ProfilerWorkflow(), + ] + ) + + def tasks(self) -> list[Task]: + """Return all tasks.""" + return self._tasks + + def trigger(self, *argv): + """Trigger a workflow.""" + named_parameters = self._parse_args(*argv) + config_path = Path(named_parameters["config"]) + ctx = RuntimeContext(named_parameters) + install_dir = config_path.parent + task_name = named_parameters.get("task", "not specified") + workflow_name = named_parameters.get("workflow", "not specified") + attempt = named_parameters.get("attempt", "0") + if workflow_name not in self._workflows: + msg = f'Workflow "{workflow_name}" not found. Valid workflows are: {", ".join(self._workflows.keys())}' + raise KeyError(msg) + workflow = self._workflows[workflow_name] + + # both CLI commands and workflow names appear in telemetry under `cmd` + with_user_agent_extra("cmd", workflow_name) + # `{{parent_run_id}}` is the run of entire workflow, whereas `{{run_id}}` is the run of a task + job_run_id = named_parameters.get("parent_run_id", "unknown_run_id") + job_id = named_parameters.get("job_id", "unknown_job_id") + with TaskLogger( + install_dir, + workflow=workflow_name, + job_id=job_id, + task_name=task_name, + job_run_id=job_run_id, + log_level=ctx.config.log_level, + attempt=attempt, + ) as task_logger: + dqx_logger = logging.getLogger("databricks.labs.dqx") + dqx_logger.info(f"DQX v{__version__} After workflow finishes, see debug logs at {task_logger}") + current_task = getattr(workflow, task_name) + current_task(ctx) + return None + + @staticmethod + def _parse_args(*argv) -> dict[str, str]: + """Parse command line arguments""" + args = dict(a[2:].split("=") for a in argv if a[0:2] == "--") + if "config" not in args: + msg = "no --config specified" + raise KeyError(msg) + return args + + +def main(*argv): + """Main entry point.""" + if len(argv) == 0: + argv = sys.argv + Workflows.all().trigger(*argv) + + +if __name__ == "__main__": + if "DATABRICKS_RUNTIME_VERSION" not in os.environ: + raise SystemExit("Only intended to run in Databricks Runtime") + main(*sys.argv) diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index 47486f5..7d2c26d 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -1,4 +1,10 @@ +import re from pyspark.sql import Column +from pyspark.sql import SparkSession + + +STORAGE_PATH_PATTERN = re.compile(r"^(/|s3:/|abfss:/|gs:/)") +UNITY_CATALOG_TABLE_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$") def get_column_name(col: Column) -> str: @@ -12,3 +18,56 @@ def get_column_name(col: Column) -> str: :return: Col name alias as str """ return str(col).removeprefix("Column<'").removesuffix("'>").split(" AS ")[-1] + + +def read_input_data(spark: SparkSession, input_location: str | None, input_format: str | None): + """ + Reads input data from the specified location and format. + + :param spark: SparkSession + :param input_location: The input data location. + :param input_format: The input data format. + """ + if not input_location: + raise ValueError("Input location not configured") + + if UNITY_CATALOG_TABLE_PATTERN.match(input_location): + return spark.read.table(input_location) # must provide 3-level Unity Catalog namespace + + if STORAGE_PATH_PATTERN.match(input_location): + if not input_format: + raise ValueError("Input format not configured") + return spark.read.format(str(input_format)).load(input_location) + + raise ValueError( + f"Invalid input location. It must be Unity Catalog table / view or storage location, " f"given {input_location}" + ) + + +def remove_extra_indentation(doc: str) -> str: + """ + Remove extra indentation from docstring. + + :param doc: Docstring + """ + lines = doc.splitlines() + stripped = [] + for line in lines: + if line.startswith(" " * 4): + stripped.append(line[4:]) + else: + stripped.append(line) + return "\n".join(stripped) + + +def extract_major_minor(version_string: str): + """ + Extracts the major and minor version from a version string. + + :param version_string: The version string to extract from. + :return: The major.minor version as a string, or None if not found. + """ + match = re.search(r"(\d+\.\d+)", version_string) + if match: + return match.group(1) + return None diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d71b522..733fd63 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,5 +1,6 @@ import os import logging +import threading from pathlib import Path from collections.abc import Callable, Generator from functools import cached_property @@ -7,21 +8,26 @@ from unittest.mock import patch import pytest from databricks.labs.pytester.fixtures.baseline import factory -from databricks.labs.dqx.contexts.workflow_task import RuntimeContext +from databricks.labs.dqx.contexts.workflows import RuntimeContext from databricks.labs.dqx.__about__ import __version__ from databricks.sdk.service.workspace import ImportFormat from databricks.sdk import WorkspaceClient from databricks.labs.blueprint.wheels import ProductInfo from databricks.labs.dqx.config import WorkspaceConfig, RunConfig from databricks.labs.blueprint.installation import Installation, MockInstallation -from databricks.labs.dqx.install import WorkspaceInstaller, WorkspaceInstallation +from databricks.labs.dqx.installer.install import WorkspaceInstaller, WorkspaceInstallation from databricks.labs.blueprint.tui import MockPrompts +from databricks.labs.dqx.runtime import Workflows +from databricks.labs.dqx.installer.workflow_task import Task +from databricks.labs.dqx.installer.workflows_installer import WorkflowsDeployment + logging.getLogger("tests").setLevel("DEBUG") logging.getLogger("databricks.labs.dqx").setLevel("DEBUG") logger = logging.getLogger(__name__) +_lock = threading.Lock() @pytest.fixture @@ -137,6 +143,23 @@ def config(self) -> WorkspaceConfig: def product_info(self): return ProductInfo.for_testing(WorkspaceConfig) + @cached_property + def tasks(self) -> list[Task]: + return Workflows.all().tasks() + + @cached_property + def workflows_deployment(self) -> WorkflowsDeployment: + return WorkflowsDeployment( + self.config, + self.config.get_run_config().name, + self.installation, + self.install_state, + self.workspace_client, + self.product_info.wheels(self.workspace_client), + self.product_info, + self.tasks, + ) + @cached_property def prompts(self): return MockPrompts( @@ -159,6 +182,7 @@ def workspace_installation(self) -> WorkspaceInstallation: self.installation, self.install_state, self.workspace_client, + self.workflows_deployment, self.prompts, self.product_info, ) @@ -183,3 +207,39 @@ def installation_ctx( def webbrowser_open(): with patch("webbrowser.open") as mock_open: yield mock_open + + +@pytest.fixture +def setup_workflows(installation_ctx: MockInstallationContext, make_schema, make_table): + """ + Setup the workflows for the tests + + Existing cluster can be used by adding: + run_config.override_clusters = {Task.job_cluster: installation_ctx.workspace_client.config.cluster_id} + """ + # install dqx in the workspace + installation_ctx.workspace_installation.run() + + # prepare test data + catalog_name = "main" + schema = make_schema(catalog_name=catalog_name) + table = make_table( + catalog_name=catalog_name, + schema_name=schema.name, + ctas="SELECT * FROM VALUES (1, 'a'), (2, 'b'), (3, NULL) AS data(id, name)", + ) + + # update input location + config = installation_ctx.config + run_config = config.get_run_config() + run_config.input_location = table.full_name + installation_ctx.installation.save(installation_ctx.config) + + yield installation_ctx, run_config + + +def contains_expected_workflows(workflows, state): + for workflow in workflows: + if all(item in workflow.items() for item in state.items()): + return True + return False diff --git a/tests/integration/test_apply_checks.py b/tests/integration/test_apply_checks.py index fe5f281..388dd72 100644 --- a/tests/integration/test_apply_checks.py +++ b/tests/integration/test_apply_checks.py @@ -1,5 +1,6 @@ from pathlib import Path import pyspark.sql.functions as F +import pytest from pyspark.sql import Column from chispa.dataframe_comparer import assert_df_equality # type: ignore from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition @@ -90,6 +91,20 @@ def test_apply_checks(ws, spark): assert_df_equality(checked, expected, ignore_nullable=True) +def test_apply_checks_invalid_criticality(ws, spark): + dq_engine = DQEngine(ws) + test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA) + + checks = [ + DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")), + DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")), + DQRule(name="col_c_is_null_or_empty", criticality="invalid", check=is_not_null_and_not_empty("c")), + ] + + with pytest.raises(ValueError, match="Invalid criticality value: invalid"): + dq_engine.apply_checks(test_df, checks) + + def test_apply_checks_with_autogenerated_col_names(ws, spark): dq_engine = DQEngine(ws) test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA) diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index df969a3..1d35c58 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -1,11 +1,15 @@ import logging from dataclasses import dataclass + import yaml +from integration.conftest import contains_expected_workflows import pytest -from databricks.labs.dqx.cli import open_remote_config, installations, validate_checks +from databricks.labs.dqx.cli import open_remote_config, installations, validate_checks, profile, workflows, logs from databricks.labs.dqx.config import WorkspaceConfig from databricks.sdk.errors import NotFound +from databricks.labs.dqx.engine import DQEngine + logger = logging.getLogger(__name__) @@ -38,7 +42,7 @@ def test_installations_output_serde_error(ws, installation_ctx): @dataclass class InvalidConfig: __version__ = WorkspaceConfig.__version__ - fake: str | None = "fake" + fake = "fake" installation_ctx.installation.save(InvalidConfig(), filename=WorkspaceConfig.__file__) output = installations( @@ -50,13 +54,13 @@ class InvalidConfig: def test_validate_checks(ws, make_workspace_file, installation_ctx): installation_ctx.installation.save(installation_ctx.config) checks = [{"criticality": "warn", "check": {"function": "is_not_null", "arguments": {"col_name": "a"}}}] - run_config_name = "default" - run_config = installation_ctx.config.get_run_config(run_config_name) + + run_config = installation_ctx.config.get_run_config() checks_file = f"{installation_ctx.installation.install_folder()}/{run_config.checks_file}" make_workspace_file(path=checks_file, content=yaml.dump(checks)) errors_list = validate_checks( - installation_ctx.workspace_client, run_config=run_config_name, ctx=installation_ctx.workspace_installer + installation_ctx.workspace_client, run_config=run_config.name, ctx=installation_ctx.workspace_installer ) assert not errors_list @@ -85,11 +89,10 @@ def test_validate_checks_when_given_invalid_checks(ws, make_workspace_file, inst def test_validate_checks_invalid_run_config(ws, installation_ctx): installation_ctx.installation.save(installation_ctx.config) - run_config_name = "unavailable" with pytest.raises(ValueError, match="No run configurations available"): validate_checks( - installation_ctx.workspace_client, run_config=run_config_name, ctx=installation_ctx.workspace_installer + installation_ctx.workspace_client, run_config="unavailable", ctx=installation_ctx.workspace_installer ) @@ -98,3 +101,53 @@ def test_validate_checks_when_checks_file_missing(ws, installation_ctx): with pytest.raises(NotFound, match="Checks file checks.yml missing"): validate_checks(installation_ctx.workspace_client, ctx=installation_ctx.workspace_installer) + + +def test_profiler(ws, setup_workflows, caplog): + installation_ctx, run_config = setup_workflows + + profile(installation_ctx.workspace_client, run_config=run_config.name, ctx=installation_ctx.workspace_installer) + + checks = DQEngine(ws).load_checks_from_installation( + run_config_name=run_config.name, assume_user=True, product_name=installation_ctx.installation.product() + ) + assert checks, "Checks were not loaded correctly" + + install_folder = installation_ctx.installation.install_folder() + status = ws.workspace.get_status(f"{install_folder}/{run_config.profile_summary_stats_file}") + assert status, f"Profile summary stats file {run_config.profile_summary_stats_file} does not exist." + + with caplog.at_level(logging.INFO): + logs(installation_ctx.workspace_client, ctx=installation_ctx.workspace_installer) + + assert "Completed profiler workflow run" in caplog.text + + +def test_profiler_when_run_config_missing(ws, installation_ctx): + installation_ctx.workspace_installation.run() + + with pytest.raises(ValueError, match="No run configurations available"): + installation_ctx.deployed_workflows.run_workflow("profiler", run_config_name="unavailable") + + +def test_workflows(ws, installation_ctx): + installation_ctx.workspace_installation.run() + installed_workflows = workflows(installation_ctx.workspace_client, ctx=installation_ctx.workspace_installer) + + expected_workflows_state = [{'workflow': 'profiler', 'state': 'UNKNOWN', 'started': ''}] + for state in expected_workflows_state: + assert contains_expected_workflows(installed_workflows, state) + + +def test_workflows_not_installed(ws, installation_ctx): + installed_workflows = workflows(installation_ctx.workspace_client, ctx=installation_ctx.workspace_installer) + assert not installed_workflows + + +def test_logs(ws, installation_ctx, caplog): + installation_ctx.workspace_installation.run() + + with caplog.at_level(logging.INFO): + logs(installation_ctx.workspace_client, ctx=installation_ctx.workspace_installer) + + assert "No jobs to relay logs for" in caplog.text diff --git a/tests/integration/test_functions.py b/tests/integration/test_functions.py index f3b8b70..d7e180c 100644 --- a/tests/integration/test_functions.py +++ b/tests/integration/test_functions.py @@ -1,5 +1,3 @@ -from datetime import datetime - import pyspark.sql.functions as F from chispa.dataframe_comparer import assert_df_equality # type: ignore from databricks.labs.dqx.col_functions import ( @@ -202,7 +200,7 @@ def test_col_not_in_near_future(spark): def test_is_col_older_than_n_days_cur(spark): schema_dates = "a: string" - cur_date = datetime.now().strftime("%Y-%m-%d") + cur_date = spark.sql("SELECT current_date() AS current_date").collect()[0]['current_date'].strftime("%Y-%m-%d") test_df = spark.createDataFrame([["2023-01-10"], [None]], schema_dates) diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index 7ae2219..7234d60 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -1,13 +1,22 @@ import logging -from unittest.mock import patch +from unittest.mock import patch, create_autospec import pytest + +from integration.conftest import contains_expected_workflows import databricks -from databricks.labs.blueprint.installation import Installation +from databricks.labs.dqx.installer.workflows_installer import WorkflowsDeployment +from databricks.labs.blueprint.installation import Installation, MockInstallation +from databricks.labs.blueprint.wheels import WheelsV2 +from databricks.labs.dqx.installer.workflow_task import Task +from databricks.labs.blueprint.installer import InstallState from databricks.labs.blueprint.tui import MockPrompts from databricks.labs.blueprint.wheels import ProductInfo -from databricks.labs.dqx.config import WorkspaceConfig -from databricks.labs.dqx.install import WorkspaceInstaller +from databricks.labs.dqx.config import WorkspaceConfig, RunConfig +from databricks.labs.dqx.installer.install import WorkspaceInstaller from databricks.sdk.errors import NotFound +from databricks.sdk.service.jobs import CreateResponse +from databricks.sdk import WorkspaceClient + logger = logging.getLogger(__name__) @@ -63,6 +72,9 @@ def test_fresh_global_config_installation(ws, installation_ctx): installation_ctx.installation = Installation.assume_global(ws, product_name) installation_ctx.installation.save(installation_ctx.config) assert installation_ctx.workspace_installation.folder == f"/Shared/{product_name}" + assert installation_ctx.workspace_installer.installation + assert installation_ctx.workspace_installation.current(ws) + assert installation_ctx.workspace_installation.config == installation_ctx.config def test_fresh_user_config_installation(ws, installation_ctx): @@ -73,9 +85,20 @@ def test_fresh_user_config_installation(ws, installation_ctx): ) +def test_complete_installation(ws, installation_ctx): + installation_ctx.workspace_installer.run(installation_ctx.config) + assert installation_ctx.workspace_installer.installation + assert installation_ctx.deployed_workflows.latest_job_status() + + def test_installation(ws, installation_ctx): installation_ctx.workspace_installation.run() + workflows = installation_ctx.deployed_workflows.latest_job_status() + expected_workflows_state = [{'workflow': 'profiler', 'state': 'UNKNOWN', 'started': ''}] + assert ws.workspace.get_status(installation_ctx.workspace_installation.folder) + for state in expected_workflows_state: + assert contains_expected_workflows(workflows, state) def test_uninstallation(ws, installation_ctx): @@ -99,7 +122,22 @@ def test_global_installation_on_existing_global_install(ws, installation_ctx): ) installation_ctx.__dict__.pop("workspace_installer") installation_ctx.__dict__.pop("prompts") - installation_ctx.workspace_installer.configure() + + config = installation_ctx.workspace_installer.configure() + config.connect = None + assert config == WorkspaceConfig( + log_level='INFO', + run_configs=[ + RunConfig( + input_location="skipped", + input_format="delta", + output_table="skipped", + quarantine_table="skipped", + checks_file="checks.yml", + profile_summary_stats_file="profile_summary_stats.yml", + ) + ], + ) def test_user_installation_on_existing_global_install(ws, new_installation, make_random): @@ -193,3 +231,46 @@ def test_compare_remote_local_install_versions(ws, installation_ctx): installation_ctx.__dict__.pop("workspace_installer") installation_ctx.__dict__.pop("prompts") installation_ctx.workspace_installer.configure() + + +def test_installation_stores_install_state_keys(ws, installation_ctx): + """The installation should store the keys in the installation state.""" + expected_keys = ["jobs"] + installation_ctx.workspace_installation.run() + # Refresh the installation state since the installation context uses `@cached_property` + install_state = InstallState.from_installation(installation_ctx.installation) + for key in expected_keys: + assert hasattr(install_state, key), f"Missing key in install state: {key}" + assert getattr(install_state, key), f"Installation state is empty: {key}" + + +def side_effect_remove_after_in_tags_settings(**settings) -> CreateResponse: + tags = settings.get("tags", {}) + _ = tags["RemoveAfter"] # KeyError side effect + return CreateResponse(job_id=1) + + +def test_workflows_deployment_creates_jobs_with_remove_after_tag(): + ws = create_autospec(WorkspaceClient) + ws.jobs.create.side_effect = side_effect_remove_after_in_tags_settings + config = WorkspaceConfig([RunConfig()]) + mock_installation = MockInstallation() + install_state = InstallState.from_installation(mock_installation) + wheels = create_autospec(WheelsV2) + product_info = ProductInfo.for_testing(WorkspaceConfig) + tasks = [Task("workflow", "task", "docs", lambda *_: None)] + workflows_deployment = WorkflowsDeployment( + config, + config.get_run_config().name, + mock_installation, + install_state, + ws, + wheels, + product_info, + tasks=tasks, + ) + try: + workflows_deployment.create_jobs() + except KeyError as e: + assert False, f"RemoveAfter tag not present: {e}" + wheels.assert_not_called() diff --git a/tests/integration/test_profiler.py b/tests/integration/test_profiler.py index 23d7adc..dcc6f9e 100644 --- a/tests/integration/test_profiler.py +++ b/tests/integration/test_profiler.py @@ -7,6 +7,7 @@ def test_profiler(spark, ws): inp_schema = T.StructType( [ T.StructField("t1", T.IntegerType()), + T.StructField("t2", T.StringType()), T.StructField( "s1", T.StructType( @@ -25,6 +26,7 @@ def test_profiler(spark, ws): [ [ 1, + " test ", { "ns1": datetime.fromisoformat("2023-01-08T10:00:11+00:00"), "s2": {"ns2": "test", "ns3": date.fromisoformat("2023-01-08")}, @@ -32,6 +34,7 @@ def test_profiler(spark, ws): ], [ 2, + "test2", { "ns1": datetime.fromisoformat("2023-01-07T10:00:11+00:00"), "s2": {"ns2": "test2", "ns3": date.fromisoformat("2023-01-07")}, @@ -39,6 +42,7 @@ def test_profiler(spark, ws): ], [ 3, + None, { "ns1": datetime.fromisoformat("2023-01-06T10:00:11+00:00"), "s2": {"ns2": "test", "ns3": date.fromisoformat("2023-01-06")}, @@ -56,6 +60,7 @@ def test_profiler(spark, ws): DQProfile( name="min_max", column="t1", description="Real min/max values were used", parameters={"min": 1, "max": 3} ), + DQProfile(name='is_not_null_or_empty', column='t2', description=None, parameters={'trim_strings': True}), DQProfile(name="is_not_null", column="s1.ns1", description=None, parameters=None), DQProfile( name="min_max", @@ -77,6 +82,98 @@ def test_profiler(spark, ws): assert rules == expected_rules +def test_profiler_non_default_profile_options(spark, ws): + inp_schema = T.StructType( + [ + T.StructField("t1", T.IntegerType()), + T.StructField("t2", T.StringType()), + T.StructField( + "s1", + T.StructType( + [ + T.StructField("ns1", T.TimestampType()), + T.StructField( + "s2", + T.StructType([T.StructField("ns2", T.StringType()), T.StructField("ns3", T.DateType())]), + ), + ] + ), + ), + ] + ) + inp_df = spark.createDataFrame( + [ + [ + 1, + " test ", + { + "ns1": datetime.fromisoformat("2023-01-08T10:00:11+00:00"), + "s2": {"ns2": "test", "ns3": date.fromisoformat("2023-01-08")}, + }, + ], + [ + 2, + " ", + { + "ns1": datetime.fromisoformat("2023-01-07T10:00:11+00:00"), + "s2": {"ns2": "test2", "ns3": date.fromisoformat("2023-01-07")}, + }, + ], + [ + 3, + None, + { + "ns1": datetime.fromisoformat("2023-01-06T10:00:11+00:00"), + "s2": {"ns2": "test", "ns3": date.fromisoformat("2023-01-06")}, + }, + ], + ], + schema=inp_schema, + ) + + profiler = DQProfiler(ws) + + profiler.default_profile_options = { + "round": False, + "max_in_count": 1, + "distinct_ratio": 0.01, + "max_null_ratio": 0.01, # Generate is_null if we have less than 1 percent of nulls + "remove_outliers": False, + "outlier_columns": ["t1", "s1"], # remove outliers in all columns of appropriate type + "num_sigmas": 1, # number of sigmas to use when remove_outliers is True + "trim_strings": False, # trim whitespace from strings + "max_empty_ratio": 0.01, + } + + stats, rules = profiler.profile(inp_df) + + expected_rules = [ + DQProfile(name="is_not_null", column="t1", description=None, parameters=None), + DQProfile( + name="min_max", column="t1", description="Real min/max values were used", parameters={"min": 1, "max": 3} + ), + DQProfile(name='is_not_null_or_empty', column='t2', description=None, parameters={'trim_strings': False}), + DQProfile(name="is_not_null", column="s1.ns1", description=None, parameters=None), + DQProfile( + name="min_max", + column="s1.ns1", + description="Real min/max values were used", + parameters={'max': datetime(2023, 1, 8, 11, 0, 11), 'min': datetime(2023, 1, 6, 11, 0, 11)}, + ), + DQProfile(name="is_not_null", column="s1.s2.ns2", description=None, parameters=None), + DQProfile(name="is_not_null", column="s1.s2.ns3", description=None, parameters=None), + DQProfile( + name="min_max", + column="s1.s2.ns3", + description="Real min/max values were used", + parameters={"min": date(2023, 1, 6), "max": date(2023, 1, 8)}, + ), + ] + print(stats) + assert len(stats.keys()) > 0 + assert rules == expected_rules + + def test_profiler_empty_df(spark, ws): test_df = spark.createDataFrame([], "data: string") diff --git a/tests/integration/test_profiler_runner.py b/tests/integration/test_profiler_runner.py new file mode 100644 index 0000000..65c32b2 --- /dev/null +++ b/tests/integration/test_profiler_runner.py @@ -0,0 +1,112 @@ +import sys +import pytest + +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.profiler.generator import DQGenerator +from databricks.labs.dqx.profiler.profiler import DQProfiler +from databricks.labs.dqx.profiler.runner import ProfilerRunner +from databricks.labs.dqx.profiler.workflow import ProfilerWorkflow + + +def test_profiler_runner_save_raise_error_when_check_file_missing(ws, spark, installation_ctx): + profiler = DQProfiler(ws) + generator = DQGenerator(ws) + runner = ProfilerRunner(ws, spark, installation_ctx.installation, profiler, generator) + + checks = [] + summary_stats = {} + checks_file = None + profile_summary_stats_file = "profile_summary_stats.yml" + + with pytest.raises(ValueError, match="Check file not configured"): + runner.save(checks, summary_stats, checks_file, profile_summary_stats_file) + + +def test_profiler_runner_save_raise_error_when_profile_summary_stats_file_missing(ws, spark, installation_ctx): + profiler = DQProfiler(ws) + generator = DQGenerator(ws) + runner = ProfilerRunner(ws, spark, installation_ctx.installation, profiler, generator) + + checks = [] + summary_stats = {} + checks_file = "checks.yml" + profile_summary_stats_file = None + + with pytest.raises(ValueError, match="Profile summary stats file not configured"): + runner.save(checks, summary_stats, checks_file, profile_summary_stats_file) + + +def test_profiler_runner_raise_error_when_profile_summary_stats_file_missing(ws, spark, installation_ctx): + profiler = DQProfiler(ws) + generator = DQGenerator(ws) + runner = ProfilerRunner(ws, spark, installation_ctx.installation, profiler, generator) + + checks = [ + { + "name": "col_a_is_null_or_empty", + "criticality": "error", + "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "a"}}, + }, + ] + summary_stats = { + 'a': { + 'count': 3, + 'mean': 2.0, + 'stddev': 1.0, + 'min': 1, + '25%': 1, + '50%': 2, + '75%': 3, + 'max': 3, + 'count_non_null': 3, + 'count_null': 0, + } + } + checks_file = "checks.yml" + profile_summary_stats_file = "profile_summary_stats.yml" + + runner.save(checks, summary_stats, checks_file, profile_summary_stats_file) + installation_ctx.installation.install_folder() + + install_folder = installation_ctx.installation.install_folder() + checks_file_status = ws.workspace.get_status(f"{install_folder}/{checks_file}") + assert checks_file_status, f"Checks not uploaded to {install_folder}/{checks_file}." + + summary_stats_file_status = ws.workspace.get_status(f"{install_folder}/{profile_summary_stats_file}") + assert ( + summary_stats_file_status + ), f"Profile summary stats not uploaded to {install_folder}/{profile_summary_stats_file}." + + +def test_profiler_runner(ws, spark, installation_ctx, make_schema, make_table, make_random): + profiler = DQProfiler(ws) + generator = DQGenerator(ws) + runner = ProfilerRunner(ws, spark, installation_ctx.installation, profiler, generator) + + # prepare test data + catalog_name = "main" + schema = make_schema(catalog_name=catalog_name) + table = make_table( + catalog_name=catalog_name, + schema_name=schema.name, + ctas="SELECT * FROM VALUES (1, 'a'), (2, 'b'), (3, NULL) AS data(id, name)", + ) + + checks, summary_stats = runner.run(input_location=table.full_name) + + assert checks, "Checks were not generated correctly" + assert summary_stats, "Profile summary stats were not generated correctly" + + +def test_profiler_workflow(ws, spark, setup_workflows): + installation_ctx, run_config = setup_workflows + + sys.modules["pyspark.sql.session"] = spark + ctx = installation_ctx.replace(run_config=run_config) + + ProfilerWorkflow().profile(ctx) # type: ignore + + checks = DQEngine(ws).load_checks_from_installation( + run_config_name=run_config.name, assume_user=True, product_name=installation_ctx.installation.product() + ) + assert checks, "Checks were not loaded correctly" diff --git a/tests/integration/test_profiler_workflow.py b/tests/integration/test_profiler_workflow.py new file mode 100644 index 0000000..a5d76ff --- /dev/null +++ b/tests/integration/test_profiler_workflow.py @@ -0,0 +1,47 @@ +from datetime import timedelta + +import pytest + +from databricks.labs.dqx.engine import DQEngine + + +def test_profiler_workflow_e2e_when_missing_input_location_in_config(ws, setup_workflows): + installation_ctx, run_config = setup_workflows + + config = installation_ctx.config + run_config = config.get_run_config() + run_config.input_location = "invalid" + installation_ctx.installation.save(installation_ctx.config) + + with pytest.raises(ValueError) as failure: + installation_ctx.deployed_workflows.run_workflow("profiler", run_config.name) + + assert "Invalid input location." in str(failure.value) + + install_folder = installation_ctx.installation.install_folder() + workflow_run_logs = list(ws.workspace.list(f"{install_folder}/logs")) + assert len(workflow_run_logs) == 1 + + +def test_profiler_workflow_e2e_when_timeout(ws, setup_workflows): + installation_ctx, run_config = setup_workflows + + with pytest.raises(TimeoutError) as failure: + installation_ctx.deployed_workflows.run_workflow("profiler", run_config.name, max_wait=timedelta(seconds=0)) + + assert "timed out" in str(failure.value) + + +def test_profiler_workflow_e2e(ws, setup_workflows): + installation_ctx, run_config = setup_workflows + + installation_ctx.deployed_workflows.run_workflow("profiler", run_config.name) + + checks = DQEngine(ws).load_checks_from_installation( + run_config_name=run_config.name, assume_user=True, product_name=installation_ctx.installation.product() + ) + assert checks, "Checks were not loaded correctly" + + install_folder = installation_ctx.installation.install_folder() + status = ws.workspace.get_status(f"{install_folder}/{run_config.profile_summary_stats_file}") + assert status, f"Profile summary stats file {run_config.profile_summary_stats_file} does not exist." diff --git a/tests/integration/test_rules_generator.py b/tests/integration/test_rules_generator.py index 43c1b71..22c3a9a 100644 --- a/tests/integration/test_rules_generator.py +++ b/tests/integration/test_rules_generator.py @@ -64,22 +64,6 @@ def test_generate_dq_rules(ws): "name": "rate_code_id_isnt_in_range", "criticality": "error", }, - { - "check": { - "function": "not_less_than", - "arguments": {"col_name": "product_launch_date", "val": "2020-01-01"}, - }, - "name": "product_launch_date_not_less_than", - "criticality": "error", - }, - { - "check": { - "function": "not_greater_than", - "arguments": {"col_name": "product_expiry_ts", "val": "2020-01-01T00:00:00.000000"}, - }, - "name": "product_expiry_ts_not_greater_than", - "criticality": "error", - }, ] assert expectations == expected @@ -117,22 +101,6 @@ def test_generate_dq_rules_warn(ws): "name": "rate_code_id_isnt_in_range", "criticality": "warn", }, - { - "check": { - "function": "not_less_than", - "arguments": {"col_name": "product_launch_date", "val": "2020-01-01"}, - }, - "name": "product_launch_date_not_less_than", - "criticality": "warn", - }, - { - "check": { - "function": "not_greater_than", - "arguments": {"col_name": "product_expiry_ts", "val": "2020-01-01T00:00:00.000000"}, - }, - "name": "product_expiry_ts_not_greater_than", - "criticality": "warn", - }, ] assert expectations == expected @@ -141,3 +109,9 @@ def test_generate_dq_rules_logging(ws, caplog): generator = DQGenerator(ws) generator.generate_dq_rules(test_rules) assert "No rule 'is_random' for column 'vendor_id'. skipping..." in caplog.text + + +def test_generate_dq_no_rules(ws): + generator = DQGenerator(ws) + expectations = generator.generate_dq_rules(None, level="warn") + assert not expectations diff --git a/tests/integration/test_runtime_context.py b/tests/integration/test_runtime_context.py new file mode 100644 index 0000000..32cae64 --- /dev/null +++ b/tests/integration/test_runtime_context.py @@ -0,0 +1,61 @@ +import os +import base64 +import yaml +import pytest +from databricks.labs.dqx.config import WorkspaceConfig +from databricks.labs.dqx.contexts.workflows import RuntimeContext + + +@pytest.fixture +def save_local(ws, make_random): + temp_files = [] + + def _save_local(config_path): + temp_file = f"{make_random}.yml" + export = ws.workspace.export(config_path) + content = base64.b64decode(export.content).decode('utf-8') + yaml_content = yaml.safe_load(content) + with open(temp_file, 'w', encoding="utf-8") as local_file: + yaml.dump(yaml_content, local_file) + temp_files.append(temp_file) + return temp_file + + yield _save_local + + for temp_file in temp_files: + if os.path.exists(temp_file): + os.remove(temp_file) + + +def test_runtime_config(ws, installation_ctx, save_local): + installation_ctx.installation.save(installation_ctx.config) + run_config = installation_ctx.config.get_run_config() + + install_config_path = f"{installation_ctx.installation.install_folder()}/{WorkspaceConfig.__file__}" + local_config_path = save_local(install_config_path) + + runtime_context = RuntimeContext(named_parameters={"config": local_config_path, "run_config_name": run_config.name}) + + actual_config = runtime_context.config + actual_run_config = runtime_context.run_config + + assert actual_config + assert actual_config.get_run_config() == run_config + assert actual_run_config + assert actual_run_config == run_config + assert runtime_context.connect_config + assert runtime_context.workspace_client + assert runtime_context.workspace_id + assert runtime_context.installation + + +def test_runtime_config_when_missing_run_config(): + runtime_context = RuntimeContext(named_parameters={"config": "temp"}) + with pytest.raises(ValueError, match="Run config flag is required"): + run_config = runtime_context.run_config + assert not run_config + + +def test_runtime_parent_run_id(): + runtime_context = RuntimeContext(named_parameters={"parent_run_id": "1"}) + assert runtime_context.parent_run_id == 1 diff --git a/tests/test_data/checks.json b/tests/test_data/checks.json index e86a580..0300fb9 100644 --- a/tests/test_data/checks.json +++ b/tests/test_data/checks.json @@ -1,5 +1,5 @@ [ {"criticality":"error","check":{"function":"is_not_null","arguments":{"col_names":["col1","col2"]}}}, - {"name":"col_col3_is_null_or_empty","criticality":"error","check":{"function":"is_not_null_and_not_empty","arguments":{"col_name":"col3"}}}, + {"name":"col_col3_is_null_or_empty","criticality":"error","check":{"function":"is_not_null_and_not_empty","arguments":{"col_name":"col3", "trim_strings": true}}}, {"criticality":"warn","check":{"function":"value_is_in_list","arguments":{"col_name":"col4","allowed":[1,2]}}} ] \ No newline at end of file diff --git a/tests/test_data/checks.yml b/tests/test_data/checks.yml index cd5d54d..d9e0fc4 100644 --- a/tests/test_data/checks.yml +++ b/tests/test_data/checks.yml @@ -12,6 +12,7 @@ function: is_not_null_and_not_empty arguments: col_name: col3 + trim_strings: true - criticality: warn check: function: value_is_in_list diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..e0c0577 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,8 @@ +from unittest.mock import Mock +from pyspark.sql import SparkSession +import pytest + + +@pytest.fixture +def spark(): + return Mock(spec=SparkSession) diff --git a/tests/unit/load_checks_from_local_file.py b/tests/unit/load_checks_from_local_file.py index b7b3280..fcd91a7 100644 --- a/tests/unit/load_checks_from_local_file.py +++ b/tests/unit/load_checks_from_local_file.py @@ -11,7 +11,7 @@ { "name": "col_col3_is_null_or_empty", "criticality": "error", - "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "col3"}}, + "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "col3", "trim_strings": True}}, }, { "criticality": "warn", @@ -30,7 +30,6 @@ def test_load_check_from_local_file_json(): def test_load_check_from_local_file_yml(): file = BASE_PATH + "/test_data/checks.yml" checks = DQEngine.load_checks_from_local_file(file) - assert checks == EXPECTED_CHECKS, "The loaded checks do not match the expected checks." diff --git a/tests/unit/test_installer.py b/tests/unit/test_installer.py new file mode 100644 index 0000000..2977ba0 --- /dev/null +++ b/tests/unit/test_installer.py @@ -0,0 +1,51 @@ +from unittest.mock import patch, MagicMock +import pytest +from databricks.labs.dqx.installer.install import WorkspaceInstaller, ManyError +from databricks.sdk import WorkspaceClient + + +def test_installer_executed_outside_workspace(): + mock_ws_client = MagicMock(spec=WorkspaceClient) + with pytest.raises(SystemExit) as exc_info: + WorkspaceInstaller(mock_ws_client, environ={"DATABRICKS_RUNTIME_VERSION": "7.3"}) + assert str(exc_info.value) == "WorkspaceInstaller is not supposed to be executed in Databricks Runtime" + + +def test_configure_raises_timeout_error(): + mock_configure = MagicMock(side_effect=TimeoutError("Mocked timeout error")) + mock_ws_client = MagicMock(spec=WorkspaceClient) + installer = WorkspaceInstaller(mock_ws_client) + + with patch.object(installer, 'configure', mock_configure): + with pytest.raises(TimeoutError) as exc_info: + installer.configure() + + assert str(exc_info.value) == "Mocked timeout error" + + +def test_configure_raises_single_error(): + single_error = ValueError("Single error") + mock_configure = MagicMock(side_effect=ManyError([single_error])) + mock_ws_client = MagicMock(spec=WorkspaceClient) + installer = WorkspaceInstaller(mock_ws_client) + + with patch.object(installer, 'configure', mock_configure): + with pytest.raises(ManyError) as exc_info: + installer.configure() + + assert exc_info.value.errs == [single_error] + + +def test_configure_raises_many_errors(): + first_error = ValueError("First error") + second_error = ValueError("Second error") + errors = [first_error, second_error] + mock_configure = MagicMock(side_effect=ManyError(errors)) + mock_ws_client = MagicMock(spec=WorkspaceClient) + installer = WorkspaceInstaller(mock_ws_client) + + with patch.object(installer, 'configure', mock_configure): + with pytest.raises(ManyError) as exc_info: + installer.configure() + + assert exc_info.value.errs == errors diff --git a/tests/unit/test_logs.py b/tests/unit/test_logs.py new file mode 100644 index 0000000..b9f37cd --- /dev/null +++ b/tests/unit/test_logs.py @@ -0,0 +1,60 @@ +from pathlib import Path +from typing import TextIO +import re +from unittest.mock import create_autospec +from databricks.labs.dqx.installer.logs import TaskLogger, parse_logs, peak_multi_line_message + + +def test_task_logger_initialization(): + install_dir = Path("/fake/install/dir") + workflow = "test_workflow" + job_id = "123" + task_name = "test_task" + job_run_id = "456" + log_level = "DEBUG" + attempt = "1" + + task_logger = TaskLogger(install_dir, workflow, job_id, task_name, job_run_id, log_level, attempt) + + assert task_logger.log_file == install_dir / "logs" / workflow / f"run-{job_run_id}-{attempt}" / f"{task_name}.log" + + +def test_parse_invalie_logs(): + log_content = "invalid format" + log_file = create_autospec(TextIO) + log_file.readline.side_effect = log_content.splitlines(keepends=True) + [''] + assert not list(parse_logs(log_file)) + + +def test_parse_logs(): + log_content = """12:00:00 INFO [component] {thread} message +12:00:01 ERROR [component] {thread} another message +""" + log_file = create_autospec(TextIO) + log_file.readline.side_effect = log_content.splitlines(keepends=True) + [''] + parsed_logs = list(parse_logs(log_file)) + + assert len(parsed_logs) == 2 + assert parsed_logs[0].time.strftime("%H:%M:%S") == "12:00:00" + assert parsed_logs[0].level == "INFO" + assert parsed_logs[0].component == "component" + assert parsed_logs[0].message == "message" + assert parsed_logs[1].time.strftime("%H:%M:%S") == "12:00:01" + assert parsed_logs[1].level == "ERROR" + assert parsed_logs[1].component == "component" + assert parsed_logs[1].message == "another message" + + +def test_peak_multi_line_message(): + log_content = """message part 2 +12:00:00 INFO [component] {thread} message part 1 +""" + log_file = create_autospec(TextIO) + log_file.readline.side_effect = log_content.splitlines(keepends=True) + [''] + pattern = re.compile(r"(\d+:\d+:\d+)\s(\w+)\s\[(.+)\]\s\{\w+\}\s(.+)") + + line, match, multi_line_message = peak_multi_line_message(log_file, pattern) + + assert line == "12:00:00 INFO [component] {thread} message part 1\n" + assert match is not None + assert multi_line_message == "\nmessage part 2" diff --git a/tests/unit/test_profiler_common.py b/tests/unit/test_profiler_common.py new file mode 100644 index 0000000..2141fb3 --- /dev/null +++ b/tests/unit/test_profiler_common.py @@ -0,0 +1,56 @@ +import datetime +from databricks.labs.dqx.profiler.common import val_maybe_to_str, val_to_str + + +def test_val_to_str(): + # Test with datetime + date_time_val = datetime.datetime(2023, 10, 1, 12, 0, 0) + assert val_to_str(date_time_val) == "'2023-10-01T12:00:00.000000'" + assert val_to_str(date_time_val, include_sql_quotes=False) == "2023-10-01T12:00:00.000000" + + # Test with date + date_val = datetime.date(2023, 10, 1) + assert val_to_str(date_val) == "'2023-10-01'" + assert val_to_str(date_val, include_sql_quotes=False) == "2023-10-01" + + # Test with int + assert val_to_str(123) == "123" + + # Test with float + assert val_to_str(123.45) == "123.45" + + # Test with string + assert val_to_str("test") == "'test'" + assert val_to_str("test", include_sql_quotes=False) == "test" + + # Test with string containing special characters + assert val_to_str("test'string") == "'test\\'string'" + assert val_to_str("test\\string") == "'test\\\\string'" + + # Test with None + assert val_to_str(None) == "'None'" + assert val_to_str(None, include_sql_quotes=False) == "None" + + +def test_val_maybe_to_str(): + # Test with datetime + date_time_val = datetime.datetime(2023, 10, 1, 12, 0, 0) + assert val_maybe_to_str(date_time_val) == "'2023-10-01T12:00:00.000000'" + assert val_maybe_to_str(date_time_val, include_sql_quotes=False) == "2023-10-01T12:00:00.000000" + + # Test with date + date_val = datetime.date(2023, 10, 1) + assert val_maybe_to_str(date_val) == "'2023-10-01'" + assert val_maybe_to_str(date_val, include_sql_quotes=False) == "2023-10-01" + + # Test with int + assert val_maybe_to_str(123) == 123 + + # Test with float + assert val_maybe_to_str(123.45) == 123.45 + + # Test with string + assert val_maybe_to_str("test") == "test" + + # Test with None + assert val_maybe_to_str(None) is None diff --git a/tests/unit/test_runtime.py b/tests/unit/test_runtime.py new file mode 100644 index 0000000..232250f --- /dev/null +++ b/tests/unit/test_runtime.py @@ -0,0 +1,24 @@ +import sys +import pytest +from databricks.labs.dqx.runtime import main + + +def test_runtime_raises_key_error(): + with pytest.raises(KeyError, match=r'Workflow "invalid_workflow" not found.'): + main("--workflow=invalid_workflow", "--config=config_path") + + +def test_runtime_no_config(): + with pytest.raises(KeyError, match='no --config specified'): + main("--workflow=invalid_workflow") + + +def test_runtime_missing_config(): + with pytest.raises(FileNotFoundError, match='config_path'): + main("--workflow=profiler", "--config=config_path") + + +def test_runtime_args_provided_as_sys_args(): + with pytest.raises(FileNotFoundError, match='config_path'): + sys.argv = [__file__, "--workflow=profiler", "--config=config_path"] + main() diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index a932cbc..4ec59af 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,6 +1,6 @@ import pyspark.sql.functions as F - -from databricks.labs.dqx.utils import get_column_name +import pytest +from databricks.labs.dqx.utils import read_input_data, get_column_name, remove_extra_indentation, extract_major_minor def test_get_column_name(): @@ -25,3 +25,101 @@ def test_get_col_name_longer(): col = F.col("local") actual = get_column_name(col) assert actual == "local" + + +def test_read_input_data_unity_catalog_table(spark): + input_location = "catalog.schema.table" + input_format = None + spark.read.table.return_value = "dataframe" + + result = read_input_data(spark, input_location, input_format) + + spark.read.table.assert_called_once_with(input_location) + assert result == "dataframe" + + +def test_read_input_data_storage_path(spark): + input_location = "s3://bucket/path" + input_format = "delta" + spark.read.format.return_value.load.return_value = "dataframe" + + result = read_input_data(spark, input_location, input_format) + + spark.read.format.assert_called_once_with(input_format) + spark.read.format.return_value.load.assert_called_once_with(input_location) + assert result == "dataframe" + + +def test_read_input_data_workspace_file(spark): + input_location = "/folder/path" + input_format = "delta" + spark.read.format.return_value.load.return_value = "dataframe" + + result = read_input_data(spark, input_location, input_format) + + spark.read.format.assert_called_once_with(input_format) + spark.read.format.return_value.load.assert_called_once_with(input_location) + assert result == "dataframe" + + +def test_read_input_data_no_input_location(spark): + with pytest.raises(ValueError, match="Input location not configured"): + read_input_data(spark, None, None) + + +def test_read_input_data_no_input_format(spark): + input_location = "s3://bucket/path" + input_format = None + + with pytest.raises(ValueError, match="Input format not configured"): + read_input_data(spark, input_location, input_format) + + +def test_read_invalid_input_location(spark): + input_location = "invalid/location" + input_format = None + + with pytest.raises(ValueError, match="Invalid input location."): + read_input_data(spark, input_location, input_format) + + +def test_remove_extra_indentation_no_indentation(): + doc = "This is a test docstring." + expected = "This is a test docstring." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_with_indentation(): + doc = " This is a test docstring with indentation." + expected = "This is a test docstring with indentation." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_mixed_indentation(): + doc = " This is a test docstring with indentation.\nThis line has no indentation." + expected = "This is a test docstring with indentation.\nThis line has no indentation." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_multiple_lines(): + doc = " Line one.\n Line two.\n Line three." + expected = "Line one.\nLine two.\nLine three." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_empty_string(): + doc = "" + expected = "" + assert remove_extra_indentation(doc) == expected + + +def test_extract_major_minor(): + assert extract_major_minor("1.2.3") == "1.2" + assert extract_major_minor("10.20.30") == "10.20" + assert extract_major_minor("v1.2.3") == "1.2" + assert extract_major_minor("version 1.2.3") == "1.2" + assert extract_major_minor("1.2") == "1.2" + assert extract_major_minor("1.2.3.4") == "1.2" + assert extract_major_minor("no version") is None + assert extract_major_minor("") is None + assert extract_major_minor("1") is None diff --git a/tests/unit/test_workflow_instaler.py b/tests/unit/test_workflow_instaler.py new file mode 100644 index 0000000..8712e34 --- /dev/null +++ b/tests/unit/test_workflow_instaler.py @@ -0,0 +1,40 @@ +from unittest.mock import patch, create_autospec +from datetime import timedelta, datetime, timezone +from databricks.labs.dqx.installer.workflows_installer import DeployedWorkflows +from databricks.sdk.service.jobs import Run, RunState, RunResultState +from databricks.sdk import WorkspaceClient +from databricks.labs.blueprint.installer import InstallState + + +def test_run_workflow(): + mock_ws = create_autospec(WorkspaceClient) + mock_install_state = create_autospec(InstallState) + mock_install_state.jobs = {'test_workflow': '123'} + + mock_run = create_autospec(Run) + mock_run.run_id = 456 + mock_run.state = RunState(result_state=RunResultState.SUCCESS, state_message="Completed successfully") + mock_run.start_time = datetime.now(tz=timezone.utc).timestamp() * 1000 + mock_run.end_time = datetime.now(tz=timezone.utc).timestamp() * 1000 + mock_run.run_duration = 1000 + + with ( + patch.object(mock_ws.jobs, 'run_now', return_value=mock_run), + patch.object(mock_ws.jobs, 'wait_get_run_job_terminated_or_skipped', return_value=mock_run), + ): + deployed_workflows = DeployedWorkflows(mock_ws, mock_install_state) + run_id = deployed_workflows.run_workflow('test_workflow', 'test_run_config') + + assert run_id == 456 + mock_ws.jobs.run_now.assert_called_once_with(123, python_named_params={'run_config_name': 'test_run_config'}) + mock_ws.jobs.wait_get_run_job_terminated_or_skipped.assert_called_once_with( + run_id=456, timeout=timedelta(minutes=20) + ) + + assert mock_run.state.result_state == RunResultState.SUCCESS + assert mock_run.state.state_message == "Completed successfully" + assert mock_run.start_time is not None + assert mock_run.end_time is not None + assert mock_run.run_duration == 1000 + assert mock_ws.jobs.run_now.called + assert mock_ws.jobs.wait_get_run_job_terminated_or_skipped.called diff --git a/tests/unit/test_workflow_task.py b/tests/unit/test_workflow_task.py new file mode 100644 index 0000000..ea7932b --- /dev/null +++ b/tests/unit/test_workflow_task.py @@ -0,0 +1,76 @@ +import pytest +from databricks.labs.dqx.installer.workflow_task import workflow_task, Task, Workflow + + +def test_dependencies(): + task_with_deps = Task( + workflow="test_workflow", + name="test_task", + doc="Test task with dependencies", + fn=lambda x: x, + depends_on=["task1", "task2"], + ) + + task_without_deps = Task( + workflow="test_workflow", name="test_task", doc="Test task without dependencies", fn=lambda x: x + ) + + assert task_with_deps.dependencies() == ["task1", "task2"] + assert task_without_deps.dependencies() == [] + + +def test_workflow_task_decorator(): + @workflow_task + def sample_task(): + """Sample task""" + + assert hasattr(sample_task, "__task__") + task = sample_task.__task__ + assert task.name == "sample_task" + assert task.doc == "Sample task" + assert task.fn + assert task.depends_on == [] + assert task.job_cluster == "main" + + +def test_workflow_task_register_task_without_doc(): + with pytest.raises(SyntaxError, match="must have some doc comment"): + + @workflow_task + def task_without_docstring(): + pass + + +def test_workflow_task_raises_syntax_error_for_depends_on(): + with pytest.raises(SyntaxError, match="depends_on has to be a list"): + + @workflow_task(depends_on="not_a_list") + def task_with_invalid_depends_on(): + """Task with invalid depends_on""" + + +class WorkflowTest(Workflow): + @workflow_task + def dependency_task(self): + """Dependency task""" + + @workflow_task(depends_on=[dependency_task]) + def main_task(self): + """Main task""" + + +def test_workflow_task_decorator_with_dependencies(): + main_task = WorkflowTest("test").main_task + assert hasattr(main_task, "__task__") + task = main_task.__task__ + assert task.name == "main_task" + assert task.doc == "Main task" + assert task.fn + assert task.depends_on == ["dependency_task"] + assert task.job_cluster == "main" + + +def test_workflow_task_returns_register(): + decorator = workflow_task() + assert callable(decorator) + assert decorator.__name__ == "register"