From 60418497a655bd16dc915194f01b0170aabcbaee Mon Sep 17 00:00:00 2001 From: Ranjeet Date: Thu, 3 Oct 2024 14:24:46 +0530 Subject: [PATCH] feat: refactored adding support for RETL and Profiles --- .github/workflows/test.yaml | 9 +- .gitignore | 3 +- Makefile | 2 +- examples/profiles_sample_dag.py | 34 ++ ...{retl_sample.dag.py => retl_sample_dag.py} | 29 +- examples/sample_dag.py | 24 - pyproject.toml | 32 ++ requirements.txt | 5 +- rudder_airflow_provider/__init__.py | 6 +- rudder_airflow_provider/hooks/rudderstack.py | 386 ++++++++++---- .../operators/rudderstack.py | 150 ++++-- .../test/hooks/test_rudderstack_hook.py | 495 ++++++++++-------- .../operators/test_rudderstack_operator.py | 128 +++-- rudder_airflow_provider/version.py | 2 +- setup.cfg | 38 -- 15 files changed, 848 insertions(+), 495 deletions(-) create mode 100644 examples/profiles_sample_dag.py rename examples/{retl_sample.dag.py => retl_sample_dag.py} (54%) delete mode 100644 examples/sample_dag.py delete mode 100644 setup.cfg diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3dff387..1a41112 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,6 +19,13 @@ jobs: - name: Install dependencies run: | pip3 install -r requirements.txt - - name: Test with unittest + - name: Test with pytest run: | + pip3 install pytest-cov make test + - name: Upload Coverage to Codecov + uses: codecov/codecov-action@v4 + with: + fail_ci_if_error: true + files: ./coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index ec58a85..1ed4cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.egg-info **/__pycache__ .vscode +.venv/ # Distribution / packaging @@ -36,4 +37,4 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ -cover/ +cover/ \ No newline at end of file diff --git a/Makefile b/Makefile index cf74765..e34bdcc 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,3 @@ .PHONY: test test: - python3 -m unittest discover -s rudder_airflow_provider/test + pytest --cov=rudder_airflow_provider rudder_airflow_provider/test --cov-report=xml diff --git a/examples/profiles_sample_dag.py b/examples/profiles_sample_dag.py new file mode 100644 index 0000000..d0937e2 --- /dev/null +++ b/examples/profiles_sample_dag.py @@ -0,0 +1,34 @@ +from datetime import datetime, timedelta + +from airflow import DAG + +from rudder_airflow_provider.operators.rudderstack import RudderstackProfilesOperator + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +with DAG( + "rudderstack-profiles-sample", + default_args=default_args, + description="A simple tutorial DAG", + schedule_interval=timedelta(days=1), + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["rs-profiles"], +) as dag: + # profile_id is template field + rs_operator = RudderstackProfilesOperator( + profile_id="{{ var.value.profile_id }}", + task_id="", + connection_id="", + ) + +if __name__ == "__main__": + dag.test() diff --git a/examples/retl_sample.dag.py b/examples/retl_sample_dag.py similarity index 54% rename from examples/retl_sample.dag.py rename to examples/retl_sample_dag.py index c817f44..61126e8 100644 --- a/examples/retl_sample.dag.py +++ b/examples/retl_sample_dag.py @@ -5,30 +5,31 @@ from rudder_airflow_provider.operators.rudderstack import RudderstackRETLOperator default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'email': ['airflow@example.com'], - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5) + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), } -with DAG('rudderstack-sample', +with DAG( + "rudderstack-sample", default_args=default_args, - description='A simple tutorial DAG', + description="A simple tutorial DAG", schedule_interval=timedelta(days=1), start_date=datetime(2021, 1, 1), catchup=False, - tags=['rs']) as dag: + tags=["rs"], +) as dag: # retl_connection_id, sync_type are template fields rs_operator = RudderstackRETLOperator( retl_connection_id="{{ var.value.retl_connection_id }}", - task_id='', - connection_id='', + task_id="", + connection_id="", sync_type="{{ var.value.sync_type }}", - wait_for_completion=True ) if __name__ == "__main__": - dag.test() \ No newline at end of file + dag.test() diff --git a/examples/sample_dag.py b/examples/sample_dag.py deleted file mode 100644 index 5240aa4..0000000 --- a/examples/sample_dag.py +++ /dev/null @@ -1,24 +0,0 @@ -from datetime import datetime, timedelta - -from airflow import DAG - -from rudder_airflow_provider.operators.rudderstack import RudderstackOperator - -default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'email': ['airflow@example.com'], - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5) -} - -with DAG('rudderstack-sample', - default_args=default_args, - description='A simple tutorial DAG', - schedule_interval=timedelta(days=1), - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['rs']) as dag: - rs_operator = RudderstackOperator(source_id='', task_id='') \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 374b58c..6ec9fd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,3 +4,35 @@ requires = [ "wheel" ] build-backend = "setuptools.build_meta" + +[build-system] +requires = [ + "setuptools >= 61.0", + "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "rudderstack-airflow-provider" +version = "2.0.0" +readme = "README.md" +license = {file = "LICENSE"} +description = "Apache airflow provider for managing Reverse ETL syncs and Profiles runs in RudderStack." +keywords = [ "airflow", "orchestration", "rudderstack"] +classifiers = [ + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", +] +dependencies = [ + "apache-airflow", + "pytest", + "requests", + "responses", + "setuptools" +] +requires-python = ">= 3.6" + +[tool.setuptools.packages.find] +exclude = *test* + +[project.entry-points.apache_airflow_provider] +provider_info = "sample_provider.__init__:get_provider_info" diff --git a/requirements.txt b/requirements.txt index 4d34b38..6602c44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -apache-airflow == 2.8.0 -requests == 2.28.2 +apache-airflow == 2.10.0 +requests == 2.32.3 setuptools == 65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability +pytest==7.3.1 diff --git a/rudder_airflow_provider/__init__.py b/rudder_airflow_provider/__init__.py index d59897f..96fd90c 100644 --- a/rudder_airflow_provider/__init__.py +++ b/rudder_airflow_provider/__init__.py @@ -1,6 +1,6 @@ def get_provider_info(): return { - 'package-name': 'rudderstack-airflow-provider', - 'name': 'rudderstack-airflow-provider', - 'description': 'Apache airflow provider for Rudderstack' + "package-name": "rudderstack-airflow-provider", + "name": "rudderstack-airflow-provider", + "description": "Apache airflow provider for Rudderstack", } diff --git a/rudder_airflow_provider/hooks/rudderstack.py b/rudder_airflow_provider/hooks/rudderstack.py index a5081cf..d6c08e9 100644 --- a/rudder_airflow_provider/hooks/rudderstack.py +++ b/rudder_airflow_provider/hooks/rudderstack.py @@ -1,129 +1,297 @@ -from enum import Enum import logging import time - +import datetime +from typing import Any, Dict, Mapping, Optional +from urllib.parse import urljoin from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook import requests -STATUS_FINISHED = 'finished' -STATUS_POLL_INTERVAL = 10 +DEFAULT_POLL_INTERVAL_SECONDS = 10 +DEFAULT_REQUEST_MAX_RETRIES = 3 +DEFAULT_RETRY_DELAY = 1 +DEFAULT_REQUEST_TIMEOUT = 30 +DEFAULT_RUDDERSTACK_API_ENDPOINT = "https://api.rudderstack.com" + + +class RETLSyncStatus: + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + + +class RETLSyncType: + INCREMENTAL = "incremental" + FULL = "full" -class RETLSyncStatus(Enum): - ''' - Enum for retl sync status - ''' - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' -class RudderstackHook(HttpHook): - ''' - Hook for rudderstack public API - ''' +class ProfilesRunStatus: + RUNNING = "running" + FINISHED = "finished" + FAILED = "failed" - def __init__(self, connection_id: str) -> None: + +class BaseRudderStackHook(HttpHook): + """ + BaseRudderStackHook to interact with RudderStack API. + :params connection_id: `Conn ID` of the Connection to be used to configure this hook. + :params request_retry_delay: Time (in seconds) to wait between each request retry.. + :params request_timeout: Time (in seconds) after which the requests to RudderStack are declared timed out. + :params request_max_retries: The maximum number of times requests to the RudderStack API should be retried before failng. + """ + + def __init__( + self, + connection_id: str, + request_retry_delay: int = 1, + request_timeout: int = 30, + request_max_retries: int = 3, + ) -> None: self.connection_id = connection_id + self.request_retry_delay = request_retry_delay + self.request_timeout = request_timeout + self.request_max_retries = request_max_retries super().__init__(http_conn_id=self.connection_id) - def trigger_sync(self, source_id:str) -> str | None: - ''' - trigger sync for a source - ''' - self.method = 'POST' - sync_endpoint = f"/v2/sources/{source_id}/start" - headers = self.get_request_headers() - logging.info('triggering sync for sourceId: %s, endpoint: %s', - source_id, sync_endpoint) - resp = self.run(endpoint=sync_endpoint, headers=headers, - extra_options={"check_response": False}) - if resp.status_code in (200, 204, 201): - logging.info('Job triggered for sourceId: %s', source_id) - return resp.json().get('runId') - elif resp.status_code == 409: - logging.info('Job is already running for sourceId: %s', source_id) - else: - raise AirflowException(f"Error while starting sync for sourceId: {source_id}, response: {resp.status_code}") - - def poll_for_status(self, source_id, run_id: str): - ''' - polls for sync status - ''' - status_endpoint = f"/v2/sources/{source_id}/runs/{run_id}/status" - headers = self.get_request_headers() - while True: - self.method = 'GET' - resp = self.run(endpoint=status_endpoint, headers=headers).json() - job_status = resp['status'] - logging.info('sync status for sourceId: %s, runId: %s, status: %s', - source_id, run_id, job_status) - - if job_status == STATUS_FINISHED: - if resp.get('error'): - raise AirflowException( - f"sync for sourceId: {source_id} failed with error: {resp['error']}") - - logging.info('sync finished for sourceId: %s, runId: %s', source_id, run_id) - break - time.sleep(STATUS_POLL_INTERVAL) - - def get_request_headers(self) -> dict: - access_token = self.get_access_token() - return { - 'authorization': f"Bearer {access_token}", - 'Content-Type': 'application/json' - } - - def get_access_token(self) -> str: - ''' - returns rudderstack access token - ''' + def _get_access_token(self) -> str: + """ + returns rudderstack access token + """ conn = self.get_connection(self.connection_id) return conn.password - def get_api_base_url(self): - ''' - returns base api url - ''' + def _get_api_base_url(self): + """ + returns base api url + """ conn = self.get_connection(self.connection_id) return conn.host - def trigger_retl_sync(self, retl_connection_id, sync_type: str): - ''' - trigger sync for a retl source - ''' - base_url = self.get_api_base_url().rstrip('/') - sync_endpoint = f"/v2/retl-connections/{retl_connection_id}/start" - headers = self.get_request_headers() - logging.info('triggering sync for retl connection, endpoint: %s', sync_endpoint) - sync_req_data = { "syncType": sync_type } - resp = requests.post(f"{base_url}{sync_endpoint}", json=sync_req_data, headers=headers) - if resp.status_code == 200: - logging.info('Job triggered for retl connection, syncId: %s', resp.json().get("syncId")) - return resp.json().get('syncId') - else: - error = resp.json().get('error', None) - raise AirflowException(f"Error while starting sync for retl, response: {resp.status_code}, error: {error}") - - - def poll_retl_sync_status(self, retl_connection_id, sync_id: str): - ''' - polls for retl sync status - ''' - base_url = self.get_api_base_url().rstrip('/') + def _get_request_headers(self) -> dict: + """ + Returns the request headers to be used by the hook. + """ + access_token = self._get_access_token() + return { + "authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + def make_request( + self, + endpoint: str, + method: str = "POST", + data: Optional[Mapping[str, object]] = None, + ): + """Prepares and makes request to RudderStack API endpoint. + + Args: + method (str): The http method to be used for this request (e.g. "GET", "POST"). + endpoint (str): The RudderStack API endpoint to send request to. + data (Optional(Mapping)): Data to pass in request to the API endpoint. + + Returns: + Dict[str, Any]: Parsed json data from the response for this request. + """ + url = urljoin(self._get_api_base_url(), endpoint) + headers = self._get_request_headers() + num_retries = 0 + while True: + try: + request_args: Dict[str, Any] = dict( + method=method, + url=url, + headers=headers, + timeout=self.request_timeout, + ) + if data is not None: + request_args["json"] = data + response = requests.request(**request_args) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logging.error(f"Request to url: {url} failed: {e}") + if num_retries == self.request_max_retries: + break + num_retries += 1 + time.sleep(self.request_retry_delay) + + raise AirflowException( + f"Exceeded max number of retries for connectionId: {self.connection_id}" + ) + + +class RudderStackRETLHook(BaseRudderStackHook): + """ + RudderStackRETLHook to interact with RudderStack RETL API. + :params connection_id: `Conn ID` of the Connection to be used to configure this hook. + :params request_retry_delay: Time (in seconds) to wait between each request retry.. + :params request_timeout: Time (in seconds) after which the requests to RudderStack are declared timed out. + :params request_max_retries: The maximum number of times requests to the RudderStack API should be retried before failng. + """ + + def __init__( + self, + connection_id: str, + request_retry_delay: int = DEFAULT_RETRY_DELAY, + request_timeout: int = DEFAULT_REQUEST_TIMEOUT, + request_max_retries: int = DEFAULT_REQUEST_MAX_RETRIES, + poll_timeout: float = None, + poll_interval: float = DEFAULT_POLL_INTERVAL_SECONDS, + ) -> None: + super().__init__( + connection_id, request_retry_delay, request_timeout, request_max_retries + ) + self.poll_timeout = poll_timeout + self.poll_interval = poll_interval + + def start_sync(self, retl_connection_id, sync_type: Optional[str] = None) -> str: + """Triggers a sync and returns runId if successful, else raises Failure. + + Args: + retl_connection_id (str): connetionId for an RETL sync. + sync_type (str): (optional) full or incremental. Default is None. + + Returns: + sync_id of the sync started. + """ + if sync_type is not None and sync_type not in [ + RETLSyncType.INCREMENTAL, + RETLSyncType.FULL, + ]: + raise AirflowException(f"Invalid sync type: {sync_type}") + if not retl_connection_id: + raise AirflowException("retl_connection_id is required") + self.log.info("Triggering sync for retl connection id: %s", retl_connection_id) + + data = {} + if sync_type is not None: + data = {"syncType": sync_type} + return self.make_request( + endpoint=f"/v2/retl-connections/{retl_connection_id}/start", + data=data, + )["syncId"] + + def poll_sync(self, retl_connection_id, sync_id: str) -> Dict[str, Any]: + """Polls for completion of a sync. If poll_timeout is set, raises Failure after timeout. + + Args: + retl_connection_id (str): connetionId for an RETL sync. + sync_type (str): (optional) full or incremental. Default is None. + Returns: + Dict[str, Any]: Parsed json output from syncs endpoint. + """ + if not retl_connection_id: + raise AirflowException( + "retl_connection_id is required to poll status of sync run" + ) + if not sync_id: + raise AirflowException("sync_id is required to poll status of sync run") + status_endpoint = f"/v2/retl-connections/{retl_connection_id}/syncs/{sync_id}" - headers = self.get_request_headers() + poll_start = datetime.datetime.now() + while True: + resp = self.make_request(endpoint=status_endpoint, method="GET") + sync_status = resp["status"] + self.log.info( + f"Polled status for syncId: {sync_id} for retl connection: {retl_connection_id}, status: {sync_status}" + ) + if sync_status == RETLSyncStatus.SUCCEEDED: + self._log.info( + f"Sync finished for retl connection: {retl_connection_id}, syncId: {sync_id}" + ) + return resp + elif sync_status == RETLSyncStatus.FAILED: + error_msg = resp.get("error", None) + raise AirflowException( + f"Sync for retl connection: {retl_connection_id}, syncId: {sync_id} failed with error: {error_msg}" + ) + if ( + self.poll_timeout + and datetime.datetime.now() + > poll_start + datetime.timedelta(seconds=self.poll_timeout) + ): + raise AirflowException( + f"Polling for syncId: {sync_id} for retl connection: {retl_connection_id} timed out" + ) + time.sleep(self.poll_interval) + + +class RudderStackProfilesHook(BaseRudderStackHook): + """ + RudderStackRETLHook to interact with RudderStack RETL API. + :params connection_id: `Conn ID` of the Connection to be used to configure this hook. + :params request_retry_delay: Time (in seconds) to wait between each request retry.. + :params request_timeout: Time (in seconds) after which the requests to RudderStack are declared timed out. + :params request_max_retries: The maximum number of times requests to the RudderStack API should be retried before failng. + """ + + def __init__( + self, + connection_id: str, + request_retry_delay: int = DEFAULT_RETRY_DELAY, + request_timeout: int = DEFAULT_REQUEST_TIMEOUT, + request_max_retries: int = DEFAULT_REQUEST_MAX_RETRIES, + poll_timeout: float = None, + poll_interval: float = DEFAULT_POLL_INTERVAL_SECONDS, + ) -> None: + super().__init__( + connection_id, request_retry_delay, request_timeout, request_max_retries + ) + self.poll_timeout = poll_timeout + self.poll_interval = poll_interval + + def start_profile_run(self, profile_id: str): + """Triggers a profile run and returns runId if successful, else raises Failure. + + Args: + profile_id (str): Profile ID + """ + if not profile_id: + raise AirflowException("profile_id is required to start a profile run") + self.log.info(f"Triggering profile run for profile id: {profile_id}") + return self.make_request( + endpoint=f"/v2/sources/{profile_id}/start", + )["runId"] + + def poll_profile_run(self, profile_id: str, run_id: str) -> Dict[str, Any]: + """Polls for completion of a profile run. If poll_timeout is set, raises Failure after timeout. + + Args: + profile_id (str): Profile ID + run_id (str): Run ID + Returns: + Dict[str, Any]: Parsed json output from profile run endpoint. + """ + if not profile_id: + raise AirflowException("profile_id is required to start a profile run") + if not run_id: + raise AirflowException("run_id is required to poll status of profile run") + + status_endpoint = f"/v2/sources/{profile_id}/runs/{run_id}/status" + poll_start = datetime.datetime.now() while True: - resp = requests.get(f"{base_url}{status_endpoint}", headers=headers) - if resp.status_code != 200: - error_msg = resp.json().get('error', None) - raise AirflowException(f"Error while fetching sync status for retl, status: {resp.status_code}, error: {error_msg}") - job_status = resp.json()['status'] - logging.info('sync status for retl connection, sycId: %s, status: %s', sync_id, job_status) - if job_status == RETLSyncStatus.SUCCEEDED.value: - logging.info('sync finished for retl connection, syncId: %s', sync_id) - break - elif job_status == RETLSyncStatus.FAILED.value: - error_msg = resp.json().get('error', None) - raise AirflowException(f"sync for retl connection failed with error: {error_msg}") - time.sleep(STATUS_POLL_INTERVAL) \ No newline at end of file + resp = self.make_request(endpoint=status_endpoint, method="GET") + run_status = resp["status"] + self.log.info( + f"Polled status for runId: {run_id} for profile: {profile_id}, status: {run_status}" + ) + if run_status == ProfilesRunStatus.FINISHED: + self.log.info( + f"Profile run finished for profile: {profile_id}, runId: {run_id}" + ) + return resp + elif run_status == ProfilesRunStatus.FAILED: + error_msg = resp.get("error", None) + raise AirflowException( + f"Profile run for profile: {profile_id}, runId: {run_id} failed with error: {error_msg}" + ) + if ( + self.poll_timeout + and datetime.datetime.now() + > poll_start + datetime.timedelta(seconds=self.poll_timeout) + ): + raise AirflowException( + f"Polling for runId: {run_id} for profile: {profile_id} timed out" + ) + time.sleep(self.poll_interval) diff --git a/rudder_airflow_provider/operators/rudderstack.py b/rudder_airflow_provider/operators/rudderstack.py index 77ec514..1a2fdac 100644 --- a/rudder_airflow_provider/operators/rudderstack.py +++ b/rudder_airflow_provider/operators/rudderstack.py @@ -1,65 +1,119 @@ import logging from airflow.models import baseoperator -from rudder_airflow_provider.hooks.rudderstack import RudderstackHook +from typing import Optional +from rudder_airflow_provider.hooks.rudderstack import ( + RudderStackRETLHook, + RudderStackProfilesHook, + DEFAULT_REQUEST_MAX_RETRIES, + DEFAULT_POLL_INTERVAL_SECONDS, + DEFAULT_RETRY_DELAY, + DEFAULT_REQUEST_TIMEOUT, +) -RUDDERTACK_DEFAULT_CONNECTION_ID = 'rudderstack_default' -RETL_SYNC_TYPE_FULL = 'full' -RETL_SYNC_TYPE_INCREMENTAL = 'incremental' - - -class RudderstackOperator(baseoperator.BaseOperator): - ''' - Rudderstack operator for airflow DAGs - ''' - def __init__(self, source_id: str, connection_id: str = RUDDERTACK_DEFAULT_CONNECTION_ID, - wait_for_completion: bool = False, **kwargs): - ''' - Initialize rudderstack operator - ''' - super().__init__(**kwargs) - self.connection_id = connection_id - self.source_id = source_id - self.wait_for_completion = wait_for_completion - - def execute(self, context): - ''' - Executes rudderstack operator - ''' - rs_hook = RudderstackHook(connection_id=self.connection_id) - run_id = rs_hook.trigger_sync(self.source_id) - if self.wait_for_completion and run_id is not None: - logging.info('waiting for sync to complete for sourceId: %s, runId: %s', self.source_id, run_id) - rs_hook.poll_for_status(self.source_id, run_id) +RUDDERTACK_DEFAULT_CONNECTION_ID = "rudderstack_default" class RudderstackRETLOperator(baseoperator.BaseOperator): - template_fields = ('retl_connection_id', 'sync_type') - - ''' + template_fields = "retl_connection_id" + + """ Rudderstack operator for RETL connnections :param retl_connection_id: unique id of the retl connection - :param sync_type: type of sync to trigger. Possible values are 'full' or 'incremental' + :param sync_type: type of sync to trigger. Default is None and is recommended. Possible values are incremental or full. :param connection_id: airflow connection id for rudderstack API - :param wait_for_completion: wait for sync to complete. Default is False - ''' - def __init__(self, - retl_connection_id: str, - sync_type: str = 'incremental', - connection_id: str = RUDDERTACK_DEFAULT_CONNECTION_ID, - wait_for_completion: bool = False, - **kwargs): - ''' - Initialize rudderstack operator - ''' + :param wait_for_completion: wait for sync to complete. Default is True. + """ + + def __init__( + self, + retl_connection_id: str, + sync_type: Optional[str] = None, + connection_id: str = RUDDERTACK_DEFAULT_CONNECTION_ID, + wait_for_completion: bool = True, + request_retry_delay: int = DEFAULT_RETRY_DELAY, + request_timeout: int = DEFAULT_REQUEST_TIMEOUT, + request_max_retries: int = DEFAULT_REQUEST_MAX_RETRIES, + poll_timeout: float = None, + poll_interval: float = DEFAULT_POLL_INTERVAL_SECONDS, + **kwargs, + ): + """ + Initialize rudderstack operator + """ super().__init__(**kwargs) self.wait_for_completion = wait_for_completion self.connection_id = connection_id self.retl_connection_id = retl_connection_id self.sync_type = sync_type + self.request_retry_delay = request_retry_delay + self.request_timeout = request_timeout + self.request_max_retries = request_max_retries + self.poll_timeout = poll_timeout + self.poll_interval = poll_interval + + def execute(self, context): + rs_hook = RudderStackRETLHook( + connection_id=self.connection_id, + request_retry_delay=self.request_retry_delay, + request_timeout=self.request_timeout, + request_max_retries=self.request_max_retries, + poll_timeout=self.poll_timeout, + poll_interval=self.poll_interval, + ) + sync_id = rs_hook.start_sync(self.retl_connection_id, self.sync_type) + if self.wait_for_completion: + self.log.info( + f"poll and wait for sync to finish for retl-connectionId: {self.retl_connection_id}, syncId: {sync_id}" + ) + rs_hook.poll_sync(self.retl_connection_id, sync_id) + + +class RudderstackProfilesOperator(baseoperator.BaseOperator): + template_fields = "profile_id" + + """ + Rudderstack operator for Profiles + :param profile_id: profile id to trigger + :param wait_for_completion: wait for sync to complete. Default is True + """ + + def __init__( + self, + profile_id: str, + connection_id: str = RUDDERTACK_DEFAULT_CONNECTION_ID, + wait_for_completion: bool = True, + request_retry_delay: int = DEFAULT_RETRY_DELAY, + request_timeout: int = DEFAULT_REQUEST_TIMEOUT, + request_max_retries: int = DEFAULT_REQUEST_MAX_RETRIES, + poll_timeout: float = None, + poll_interval: float = DEFAULT_POLL_INTERVAL_SECONDS, + **kwargs, + ): + """ + Initialize rudderstack operator + """ + super().__init__(**kwargs) + self.wait_for_completion = wait_for_completion + self.connection_id = connection_id + self.profile_id = profile_id + self.request_retry_delay = request_retry_delay + self.request_timeout = request_timeout + self.request_max_retries = request_max_retries + self.poll_timeout = poll_timeout + self.poll_interval = poll_interval def execute(self, context): - rs_hook = RudderstackHook(connection_id=self.connection_id) - sync_id = rs_hook.trigger_retl_sync(self.retl_connection_id, self.sync_type) + rs_profiles_hook = RudderStackProfilesHook( + connection_id=self.connection_id, + request_retry_delay=self.request_retry_delay, + request_timeout=self.request_timeout, + request_max_retries=self.request_max_retries, + poll_timeout=self.poll_timeout, + poll_interval=self.poll_interval, + ) + profile_run_id = rs_profiles_hook.start_profile_run(self.profile_id) if self.wait_for_completion: - logging.info('waiting for sync to complete for retl-connecion: %s, syncId: %s', self.retl_connection_id, sync_id) - rs_hook.poll_retl_sync_status(self.retl_connection_id, sync_id) \ No newline at end of file + logging.info( + f"Poll and wait for profiles run to finish for profilesId: {self.profile_id}, runId: {profile_run_id}" + ) + rs_profiles_hook.poll_profile_run(self.profile_id, profile_run_id) diff --git a/rudder_airflow_provider/test/hooks/test_rudderstack_hook.py b/rudder_airflow_provider/test/hooks/test_rudderstack_hook.py index 2925b04..9c1e929 100644 --- a/rudder_airflow_provider/test/hooks/test_rudderstack_hook.py +++ b/rudder_airflow_provider/test/hooks/test_rudderstack_hook.py @@ -1,211 +1,288 @@ -import unittest -from unittest import mock +import pytest +from unittest.mock import patch, MagicMock from airflow.exceptions import AirflowException - from airflow.models.connection import Connection -from requests.models import Response -from rudder_airflow_provider.hooks.rudderstack import STATUS_POLL_INTERVAL, RudderstackHook - - -class RudderstackHookTest(unittest.TestCase): - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - def test_get_access_token(self, mocked_http: mock.Mock): - rudder_connection = Connection(password='some-password') - mocked_http.return_value = rudder_connection - hook = RudderstackHook('rudderstack_connection') - self.assertEqual(hook.get_access_token(), 'some-password') - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.run') - def test_trigger_sync(self, mock_run: mock.Mock, mock_connection: mock.Mock): - source_id = 'some-source-id' - access_token = 'some-password' - hook = RudderstackHook('rudderstack_connection') - mock_connection.return_value = Connection(password=access_token) - sync_endpoint = f"/v2/sources/{source_id}/start" - start_resp = Response() - start_resp.json = mock.MagicMock(return_value={'runId': 'some-run-id'}) - start_resp.status_code = 204 - mock_run.return_value = start_resp - run_id = hook.trigger_sync(source_id) - expected_headers = { - 'authorization': f"Bearer {access_token}", - 'Content-Type': 'application/json' - } - mock_run.assert_called_once_with(endpoint=sync_endpoint, headers=expected_headers, extra_options={'check_response': False}) - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.run') - def test_trigger_sync_conflict_status(self, mock_run: mock.Mock, mock_connection: mock.Mock): - source_id = 'some-source-id' - access_token = 'some-password' - hook = RudderstackHook('rudderstack_connection') - mock_connection.return_value = Connection(password=access_token) - sync_endpoint = f"/v2/sources/{source_id}/start" - start_resp = Response() - start_resp.status_code = 409 - mock_run.return_value = start_resp - run_id = hook.trigger_sync(source_id) - self.assertIsNone(run_id) - expected_headers = { - 'authorization': f"Bearer {access_token}", - 'Content-Type': 'application/json' - } - mock_run.assert_called_once_with(endpoint=sync_endpoint, headers=expected_headers, extra_options={'check_response': False}) - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.run') - def test_trigger_sync_error_status(self, mock_run: mock.Mock, mock_connection: mock.Mock): - source_id = 'some-source-id' - access_token = 'some-password' - hook = RudderstackHook('rudderstack_connection') - mock_connection.return_value = Connection(password=access_token) - sync_endpoint = f"/v2/sources/{source_id}/start" - start_resp = Response() - start_resp.status_code = 500 - mock_run.return_value = start_resp - self.assertRaises(AirflowException, hook.trigger_sync, source_id) - expected_headers = { - 'authorization': f"Bearer {access_token}", - 'Content-Type': 'application/json' - } - mock_run.assert_called_once_with(endpoint=sync_endpoint, headers=expected_headers, extra_options={'check_response': False}) - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.run') - def test_triger_sync_exception(self, mock_run: mock.Mock, mock_connection: mock.Mock): - source_id = 'some-source-id' - access_token = 'some-password' - mock_connection.return_value = Connection(password=access_token) - mock_run.side_effect = AirflowException() - hook = RudderstackHook('rudderstack_connection') - self.assertRaises(AirflowException, hook.trigger_sync, source_id) - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.run') - def test_poll_status(self, mock_run: mock.Mock, mock_connection: mock.Mock): - source_id = 'some-source-id' - run_id = 'some-run-id' - access_token = 'some-password' - status_endpoint = f"/v2/sources/{source_id}/runs/{run_id}/status" - finished_status_response = Response() - finished_status_response.status_code = 200 - finished_status_response.json = mock.MagicMock(return_value={'status': 'finished'}) - mock_run.return_value = finished_status_response - mock_connection.return_value = Connection(password=access_token) - hook = RudderstackHook('rudderstack_connection') - hook.poll_for_status(source_id, run_id) - expected_headers = { - 'authorization': f"Bearer {access_token}", - 'Content-Type': 'application/json' - } - mock_run.assert_called_once_with(endpoint=status_endpoint, headers=expected_headers) - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.run') - def test_poll_status_failure(self, mock_run: mock.Mock, mock_connection: mock.Mock): - source_id = 'some-source-id' - run_id = 'some-run-id' - access_token = 'some-password' - status_endpoint = f"/v2/sources/{source_id}/runs/{run_id}/status" - finished_status_response = Response() - finished_status_response.status_code = 200 - finished_status_response.json = mock.MagicMock( - return_value={'status': 'finished', 'error': 'some-eror'}) - mock_run.return_value = finished_status_response - mock_connection.return_value = Connection(password=access_token) - hook = RudderstackHook('rudderstack_connection') - self.assertRaises(AirflowException, hook.poll_for_status, source_id, run_id) - expected_headers = { - 'authorization': f"Bearer {access_token}", - 'Content-Type': 'application/json' - } - mock_run.assert_called_once_with(endpoint=status_endpoint, headers=expected_headers) - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.requests.post') - def test_retl_trigger_sync(self, mock_post: mock.Mock, mock_connection: mock.Mock): - mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {'syncId': 'some-sync-id'} - mock_connection.side_effect = [Connection(password='some-password', host='https://some-url.com'), Connection(password='some-password')] - retl_connection_id = 'some-connection-id' - sync_type = 'full' - base_url = 'https://some-url.com' - retl_sync_endpoint = f"/v2/retl-connections/{retl_connection_id}/start" - hook = RudderstackHook('rudderstack_connection') - sync_id = hook.trigger_retl_sync(retl_connection_id, sync_type) - self.assertEqual(sync_id, 'some-sync-id') - mock_post.assert_called_once_with(f"{base_url}{retl_sync_endpoint}", json={'syncType': sync_type}, - headers={'authorization' : f"Bearer some-password", 'Content-Type': 'application/json'}) - mock_connection.assert_called_with('rudderstack_connection') - - # test 409 retl trigger sync and expect AirflowException - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.requests.post') - def test_retl_trigger_sync_conflict(self, mock_post: mock.Mock, mock_connection: mock.Mock): - mock_post.return_value.status_code = 409 - mock_connection.side_effect = [Connection(password='some-password', host='https://some-url.com'), Connection(password='some-password')] - retl_connection_id = 'some-connection-id' - sync_type = 'full' - base_url = 'https://some-url.com' - retl_sync_endpoint = f"/v2/retl-connections/{retl_connection_id}/start" - hook = RudderstackHook('rudderstack_connection') - self.assertRaises(AirflowException, hook.trigger_retl_sync, retl_connection_id, sync_type) - mock_post.assert_called_once_with(f"{base_url}{retl_sync_endpoint}", json={'syncType': sync_type}, - headers={'authorization' : f"Bearer some-password", 'Content-Type': 'application/json'}) - mock_connection.assert_called_with('rudderstack_connection') - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.requests.get') - def test_poll_for_retl_sync_status(self, mock_get: mock.Mock, mock_connection: mock.Mock): - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = {'status': 'succeeded'} - mock_connection.side_effect = [Connection(password='some-password', host='https://some-url.com'), Connection(password='some-password')] - retl_connection_id = 'some-connection-id' - sync_id = 'some-sync-id' - base_url = 'https://some-url.com' - retl_sync_status_endpoint = f"/v2/retl-connections/{retl_connection_id}/syncs/{sync_id}" - hook = RudderstackHook('rudderstack_connection') - hook.poll_retl_sync_status(retl_connection_id, sync_id) - mock_get.assert_called_once_with(f"{base_url}{retl_sync_status_endpoint}", headers={'authorization' : f"Bearer some-password", 'Content-Type': 'application/json'}) - mock_connection.assert_called_with('rudderstack_connection') - - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.requests.get') - def test_poll_for_retl_sync_status_failed(self, mock_get: mock.Mock, mock_connection: mock.Mock): - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = {'status': 'failed'} - mock_connection.side_effect = [Connection(password='some-password', host='https://some-url.com'), Connection(password='some-password')] - retl_connection_id = 'some-connection-id' - sync_id = 'some-sync-id' - base_url = 'https://some-url.com' - retl_sync_status_endpoint = f"/v2/retl-connections/{retl_connection_id}/syncs/{sync_id}" - hook = RudderstackHook('rudderstack_connection') - self.assertRaises(AirflowException, hook.poll_retl_sync_status, retl_connection_id, sync_id) - mock_get.assert_called_once_with(f"{base_url}{retl_sync_status_endpoint}", headers={'authorization' : f"Bearer some-password", 'Content-Type': 'application/json'}) - mock_connection.assert_called_with('rudderstack_connection') - - @mock.patch('rudder_airflow_provider.hooks.rudderstack.HttpHook.get_connection') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.requests.get') - @mock.patch('rudder_airflow_provider.hooks.rudderstack.time.sleep') - def test_poll_for_retl_sync_status_running(self, mock_sleep: mock.Mock, mock_get: mock.Mock, mock_connection: mock.Mock): - mock_get.return_value.status_code = 200 - mock_get.return_value.json.side_effect = [{'status': 'running'}, {'status': 'succeeded'}] - mock_connection.side_effect = [Connection(password='some-password', host='https://some-url.com'), Connection(password='some-password')] - retl_connection_id = 'some-connection-id' - sync_id = 'some-sync-id' - base_url = 'https://some-url.com' - retl_sync_status_endpoint = f"/v2/retl-connections/{retl_connection_id}/syncs/{sync_id}" - hook = RudderstackHook('rudderstack_connection') - hook.poll_retl_sync_status(retl_connection_id, sync_id) - mock_get.assert_called_with(f"{base_url}{retl_sync_status_endpoint}", headers={'authorization' : f"Bearer some-password", 'Content-Type': 'application/json'}) - mock_connection.assert_called_with('rudderstack_connection') - self.assertEqual(mock_get.call_count, 2) - self.assertEqual(mock_sleep.call_count, 1) - mock_sleep.assert_called_with(STATUS_POLL_INTERVAL) - - -if __name__ == '__main__': - unittest.main() +from requests.exceptions import RequestException +from requests.exceptions import Timeout +from rudder_airflow_provider.hooks.rudderstack import ( + BaseRudderStackHook, + RudderStackRETLHook, + RETLSyncStatus, + RudderStackProfilesHook, + ProfilesRunStatus, +) + +# Mocking constants for testing +TEST_AIRFLOW_CONN_ID = "airflow_conn_id" +TEST_RETL_CONN_ID = "test_retl_conn_id" +TEST_PROFILE_ID = "test_profile_id" +TEST_RETL_SYNC_RUN_ID = "test_retl_sync_id" +TEST_PROFILE_RUN_ID = "test_profile_run_id" +TEST_ACCESS_TOKEN = "test_access_token" +TEST_BASE_URL = "http://test.rudderstack.api" + + +# Mocking connection and responses +@pytest.fixture +def airflow_connection(): + return Connection( + conn_id=TEST_AIRFLOW_CONN_ID, host=TEST_BASE_URL, password=TEST_ACCESS_TOKEN + ) + + +# BaseRudderStackHook tests +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +def test_get_access_token(mock_connection, airflow_connection): + basehook = BaseRudderStackHook(TEST_AIRFLOW_CONN_ID) + mock_connection.return_value = airflow_connection + assert basehook._get_access_token() == TEST_ACCESS_TOKEN + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +def test_get_api_base_url(mock_connection, airflow_connection): + mock_connection.return_value = airflow_connection + basehook = BaseRudderStackHook(TEST_AIRFLOW_CONN_ID) + assert basehook._get_api_base_url() == TEST_BASE_URL + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +def test_get_request_headers(mock_connection, airflow_connection): + mock_connection.return_value = airflow_connection + basehook = BaseRudderStackHook(TEST_AIRFLOW_CONN_ID) + assert basehook._get_request_headers() == { + "authorization": f"Bearer {TEST_ACCESS_TOKEN}", + "Content-Type": "application/json", + } + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_make_request_success(mock_request, mock_connection, airflow_connection): + mock_request.return_value = MagicMock( + status_code=200, json=lambda: {"result": "success"} + ) + mock_connection.return_value = airflow_connection + basehook = BaseRudderStackHook(TEST_AIRFLOW_CONN_ID) + result = basehook.make_request("/endpoint", method="GET") + assert result == {"result": "success"} + mock_request.assert_called_once_with( + method="GET", + url=TEST_BASE_URL + "/endpoint", + headers={ + "authorization": f"Bearer {TEST_ACCESS_TOKEN}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_make_request_failure(mock_request, mock_connection, airflow_connection): + mock_request.side_effect = RequestException("Request failed") + mock_connection.return_value = airflow_connection + basehook = BaseRudderStackHook(TEST_AIRFLOW_CONN_ID) + + with pytest.raises(AirflowException, match="Exceeded max number of retries"): + basehook.make_request("/endpoint") + assert mock_request.call_count == 4 + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_make_request_success_after_retry( + mock_request, mock_connection, airflow_connection +): + mock_request.side_effect = [ + Timeout(), + Timeout(), + MagicMock(status_code=200, json=lambda: {"result": "success"}), + ] + mock_connection.return_value = airflow_connection + basehook = BaseRudderStackHook(TEST_AIRFLOW_CONN_ID) + response = basehook.make_request(endpoint="/test-endpoint", method="GET") + assert response == {"result": "success"} + assert mock_request.call_count == 3 + + +# RudderStackRETLHook tests +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_start_sync(mock_request, mock_connection, airflow_connection): + mock_request.return_value = MagicMock( + status_code=200, json=lambda: {"syncId": TEST_RETL_SYNC_RUN_ID} + ) + mock_connection.return_value = airflow_connection + retl_hook = RudderStackRETLHook(TEST_AIRFLOW_CONN_ID) + sync_id = retl_hook.start_sync(TEST_RETL_CONN_ID) + assert sync_id == TEST_RETL_SYNC_RUN_ID + + mock_request.assert_called_once_with( + method="POST", + url=f"{TEST_BASE_URL}/v2/retl-connections/{TEST_RETL_CONN_ID}/start", + headers={ + "authorization": f"Bearer {TEST_ACCESS_TOKEN}", + "Content-Type": "application/json", + }, + timeout=30, + json={}, + ) + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +def test_start_sync_invalid_parameters(mock_connection, airflow_connection): + mock_connection.return_value = airflow_connection + retl_hook = RudderStackRETLHook(TEST_AIRFLOW_CONN_ID) + with pytest.raises(AirflowException, match="Invalid sync type: invalid_sync_type"): + retl_hook.start_sync(TEST_RETL_CONN_ID, "invalid_sync_type") + + with pytest.raises(AirflowException, match="retl_connection_id is required"): + retl_hook.start_sync("") + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_poll_sync_success(mock_request, mock_connection, airflow_connection): + mock_request.side_effect = [ + MagicMock( + status_code=200, + json=lambda: { + "id": TEST_RETL_SYNC_RUN_ID, + "status": RETLSyncStatus.RUNNING, + }, + ), + MagicMock( + status_code=200, + json=lambda: { + "id": TEST_RETL_SYNC_RUN_ID, + "status": RETLSyncStatus.SUCCEEDED, + }, + ), + ] + mock_connection.return_value = airflow_connection + retl_hook = RudderStackRETLHook( + connection_id=TEST_AIRFLOW_CONN_ID, poll_interval=0.1 + ) + result = retl_hook.poll_sync(TEST_RETL_CONN_ID, TEST_RETL_SYNC_RUN_ID) + assert mock_request.call_count == 2 + mock_request.assert_called_with( + method="GET", + url=f"{TEST_BASE_URL}/v2/retl-connections/{TEST_RETL_CONN_ID}/syncs/{TEST_RETL_SYNC_RUN_ID}", + headers={ + "authorization": f"Bearer {TEST_ACCESS_TOKEN}", + "Content-Type": "application/json", + }, + timeout=30, + ) + assert result == {"id": TEST_RETL_SYNC_RUN_ID, "status": RETLSyncStatus.SUCCEEDED} + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_poll_sync_timeout(mock_request, mock_connection, airflow_connection): + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"id": TEST_RETL_SYNC_RUN_ID, "status": RETLSyncStatus.RUNNING}, + ) + mock_connection.return_value = airflow_connection + retl_hook = RudderStackRETLHook( + connection_id=TEST_AIRFLOW_CONN_ID, poll_interval=0.1, poll_timeout=0.3 + ) + with pytest.raises( + AirflowException, + match="Polling for syncId: test_retl_sync_id for retl connection: test_retl_conn_id timed out", + ): + retl_hook.poll_sync(TEST_RETL_CONN_ID, TEST_RETL_SYNC_RUN_ID) + assert mock_request.call_count <= 4 + + +# RudderStackProfilesHook tests +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_start_profile_run(mock_request, mock_connection, airflow_connection): + mock_request.return_value = MagicMock( + status_code=200, json=lambda: {"runId": TEST_PROFILE_RUN_ID} + ) + mock_connection.return_value = airflow_connection + profiles_hook = RudderStackProfilesHook(TEST_AIRFLOW_CONN_ID) + run_id = profiles_hook.start_profile_run(TEST_PROFILE_ID) + assert run_id == TEST_PROFILE_RUN_ID + mock_request.assert_called_once_with( + method="POST", + url=f"{TEST_BASE_URL}/v2/sources/{TEST_PROFILE_ID}/start", + headers={ + "authorization": f"Bearer {TEST_ACCESS_TOKEN}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +def test_start_sprofile_run_invalid_parameters(mock_connection, airflow_connection): + mock_connection.return_value = airflow_connection + profiles_hook = RudderStackProfilesHook(TEST_PROFILE_ID) + with pytest.raises( + AirflowException, match="profile_id is required to start a profile run" + ): + profiles_hook.start_profile_run("") + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_poll_profile_run_success(mock_request, mock_connection, airflow_connection): + mock_request.side_effect = [ + MagicMock( + status_code=200, + json=lambda: { + "id": TEST_PROFILE_RUN_ID, + "status": ProfilesRunStatus.RUNNING, + }, + ), + MagicMock( + status_code=200, + json=lambda: { + "id": TEST_PROFILE_RUN_ID, + "status": ProfilesRunStatus.FINISHED, + }, + ), + ] + mock_connection.return_value = airflow_connection + profiles_hook = RudderStackProfilesHook( + connection_id=TEST_AIRFLOW_CONN_ID, poll_interval=0.1 + ) + result = profiles_hook.poll_profile_run(TEST_PROFILE_ID, TEST_PROFILE_RUN_ID) + assert mock_request.call_count == 2 + mock_request.assert_called_with( + method="GET", + url=f"{TEST_BASE_URL}/v2/sources/{TEST_PROFILE_ID}/runs/{TEST_PROFILE_RUN_ID}/status", + headers={ + "authorization": f"Bearer {TEST_ACCESS_TOKEN}", + "Content-Type": "application/json", + }, + timeout=30, + ) + assert result == {"id": TEST_PROFILE_RUN_ID, "status": ProfilesRunStatus.FINISHED} + + +@patch("airflow.providers.http.hooks.http.HttpHook.get_connection") +@patch("requests.request") +def test_poll_profile_run_timeout(mock_request, mock_connection, airflow_connection): + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"id": TEST_PROFILE_RUN_ID, "status": ProfilesRunStatus.RUNNING}, + ) + mock_connection.return_value = airflow_connection + profiles_hook = RudderStackProfilesHook( + connection_id=TEST_AIRFLOW_CONN_ID, poll_interval=0.1, poll_timeout=0.3 + ) + with pytest.raises( + AirflowException, + match="Polling for runId: test_profile_run_id for profile: test_profile_id timed out", + ): + profiles_hook.poll_profile_run(TEST_PROFILE_ID, TEST_PROFILE_RUN_ID) + assert mock_request.call_count <= 4 + + +if __name__ == "__main__": + pytest.main() diff --git a/rudder_airflow_provider/test/operators/test_rudderstack_operator.py b/rudder_airflow_provider/test/operators/test_rudderstack_operator.py index 94f5be6..add3fcd 100644 --- a/rudder_airflow_provider/test/operators/test_rudderstack_operator.py +++ b/rudder_airflow_provider/test/operators/test_rudderstack_operator.py @@ -1,44 +1,84 @@ -import unittest -from unittest import mock - -from rudder_airflow_provider.operators.rudderstack import RudderstackOperator - - -class TestRudderstackOperator(unittest.TestCase): - - @mock.patch('rudder_airflow_provider.operators.rudderstack.RudderstackHook.poll_for_status') - @mock.patch('rudder_airflow_provider.operators.rudderstack.RudderstackHook.trigger_sync') - def test_operator_trigger_sync_without_wait(self, mock_hook_sync: mock.Mock, - mock_poll_status: mock.Mock): - mock_hook_sync.return_value = 'some-run-id' - operator = RudderstackOperator(source_id='some-source-id', - wait_for_completion=False, task_id='some-task-id') - operator.execute(context=None) - mock_hook_sync.assert_called_once() - mock_poll_status.assert_not_called() - - @mock.patch('rudder_airflow_provider.operators.rudderstack.RudderstackHook.poll_for_status') - @mock.patch('rudder_airflow_provider.operators.rudderstack.RudderstackHook.trigger_sync') - def test_operator_trigger_sync_with_wait(self, mock_hook_sync: mock.Mock, - mock_poll_status: mock.Mock): - mock_hook_sync.return_value = 'some-run-id' - mock_poll_status.return_value = None - operator = RudderstackOperator(source_id='some-source-id', - wait_for_completion=True, task_id='some-task-id') - operator.execute(context=None) - mock_hook_sync.assert_called_once() - mock_poll_status.assert_called_once() - - # checks if poll_for_status is not called if run_id is None (possible if sync is already running) - @mock.patch('rudder_airflow_provider.operators.rudderstack.RudderstackHook.poll_for_status') - @mock.patch('rudder_airflow_provider.operators.rudderstack.RudderstackHook.trigger_sync') - def test_operator_no_polling_if_run_not_started(self, mock_hook_sync: mock.Mock, mock_poll_status: mock.Mock): - operator = RudderstackOperator(source_id='some-source-id', - wait_for_completion=True, task_id='some-task-id') - mock_hook_sync.return_value = None - operator.execute(context=None) - mock_hook_sync.assert_called_once() - mock_poll_status.assert_not_called() - -if __name__ == '__main__': - unittest.main() \ No newline at end of file +import pytest +from unittest.mock import patch, MagicMock +from airflow.exceptions import AirflowException +from rudder_airflow_provider.operators.rudderstack import ( + RudderstackRETLOperator, + RudderstackProfilesOperator +) +from rudder_airflow_provider.hooks.rudderstack import ( + RETLSyncStatus, + ProfilesRunStatus +) + +# Constants for test cases +TEST_RETL_CONNECTION_ID = "test_retl_connection" +TEST_PROFILE_ID = "test_profile_id" +TEST_SYNC_ID = "test_sync_id" +TEST_PROFILES_RUN_ID = "test_run_id" + + +# Test RudderstackRETLOperator +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackRETLHook.poll_sync') +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackRETLHook.start_sync') +def test_retl_operator_execute_without_wait(mock_start_sync, mock_poll_sync): + mock_start_sync.return_value = TEST_SYNC_ID + retl_operator = RudderstackRETLOperator(retl_connection_id=TEST_RETL_CONNECTION_ID, + wait_for_completion=False, + task_id='some-task-id') + retl_operator.execute(context=None) + mock_start_sync.assert_called_once_with(TEST_RETL_CONNECTION_ID, None) + mock_poll_sync.assert_not_called() + + +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackRETLHook.poll_sync') +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackRETLHook.start_sync') +def test_retl_operator_execute_with_wait(mock_start_sync, mock_poll_sync): + mock_start_sync.return_value = TEST_SYNC_ID + mock_poll_sync.return_value = [ + { + "id": TEST_SYNC_ID, + "job_id": TEST_RETL_CONNECTION_ID, + "status": RETLSyncStatus.SUCCEEDED, + } + ] + retl_operator = RudderstackRETLOperator(retl_connection_id=TEST_RETL_CONNECTION_ID, + task_id='some-task-id') + retl_operator.execute(context=None) + + mock_start_sync.assert_called_once_with(TEST_RETL_CONNECTION_ID, None) + mock_poll_sync.assert_called_once_with(TEST_RETL_CONNECTION_ID, TEST_SYNC_ID) + + +# Test RudderstackProfilesOperator +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackProfilesHook.poll_profile_run') +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackProfilesHook.start_profile_run') +def test_profiles_operator_execute_without_wait(mock_profile_run, mock_poll_profile_run): + mock_profile_run.return_value = TEST_PROFILES_RUN_ID + profiles_operator = RudderstackProfilesOperator(profile_id=TEST_PROFILE_ID, + wait_for_completion=False, + task_id='some-task-id') + profiles_operator.execute(context=None) + mock_profile_run.assert_called_once_with(TEST_PROFILE_ID) + mock_poll_profile_run.assert_not_called() + + +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackProfilesHook.poll_profile_run') +@patch('rudder_airflow_provider.hooks.rudderstack.RudderStackProfilesHook.start_profile_run') +def test_profiles_operator_execute_with_wait(mock_profile_run, mock_poll_profile_run): + mock_profile_run.return_value = TEST_PROFILES_RUN_ID + mock_poll_profile_run.return_value = [ + { + "id": TEST_PROFILES_RUN_ID, + "job_id": TEST_PROFILE_ID, + "status": ProfilesRunStatus.FINISHED, + } + ] + profiles_operator = RudderstackProfilesOperator(profile_id=TEST_PROFILE_ID, + task_id='some-task-id') + profiles_operator.execute(context=None) + + mock_profile_run.assert_called_once_with(TEST_PROFILE_ID) + mock_poll_profile_run.assert_called_once_with(TEST_PROFILE_ID, TEST_PROFILES_RUN_ID) + +if __name__ == "__main__": + pytest.main() \ No newline at end of file diff --git a/rudder_airflow_provider/version.py b/rudder_airflow_provider/version.py index 1a72d32..6849410 100644 --- a/rudder_airflow_provider/version.py +++ b/rudder_airflow_provider/version.py @@ -1 +1 @@ -__version__ = '1.1.0' +__version__ = "1.1.0" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a6f258f..0000000 --- a/setup.cfg +++ /dev/null @@ -1,38 +0,0 @@ -[metadata] -name = rudderstack-airflow-provider -version = attr: rudder_airflow_provider.version.__version__ -description = airflow provider for rudderstack -license = MIT -license_file = LICENSE -long_description = file: README.md -long_description_content_type = text/markdown -repository = https://github.com/rudderlabs/rudder-airflow-provider -classifiers = - Development Status :: 5 - Production/Stable - Intended Audience :: Developers - License :: OSI Approved :: MIT License - Operating System :: OS Independent - Programming Language :: Python - Programming Language :: Python :: 3.6 - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 -keywords = - rudder - rudderstack - airflow - apache-airflow - -[options] -python_requires >= 3.6 -packages = find: -install_requires = - apache-airflow >= 2.5.3 - apache-airflow-providers-http >= 4.2.0 - requests >= 2.28.2 - -[options.packages.find] -exclude= *test* - -[options.entry_points] -apache_airflow_provider= - provider_info=rudder_airflow_provider:get_provider_info