From 4567b3e02266ab897c1a26c2edb1a85342065b5a Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 26 Apr 2024 12:48:23 +0200 Subject: [PATCH 1/2] Switching to official batch operator --- .../operator_creators/batch_creator.py | 28 +- .../operator_creators/spark_creator.py | 2 +- .../airflow/operators/awsbatch_operator.py | 275 ++++++------------ 3 files changed, 105 insertions(+), 200 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/batch_creator.py b/dagger/dag_creator/airflow/operator_creators/batch_creator.py index a3d2534..0cfe9fb 100644 --- a/dagger/dag_creator/airflow/operator_creators/batch_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/batch_creator.py @@ -1,5 +1,9 @@ +from pathlib import Path +from datetime import timedelta + from dagger.dag_creator.airflow.operator_creator import OperatorCreator from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator +from dagger import conf class BatchCreator(OperatorCreator): @@ -8,6 +12,20 @@ class BatchCreator(OperatorCreator): def __init__(self, task, dag): super().__init__(task, dag) + @staticmethod + def _validate_job_name(job_name, absolute_job_name): + if not absolute_job_name and not job_name: + raise Exception("Both job_name and absolute_job_name cannot be null") + + if absolute_job_name is not None: + return absolute_job_name + + job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/") + assert ( + job_path.is_dir() + ), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`" + return job_name + def _generate_command(self): command = [self._task.executable_prefix, self._task.executable] for param_name, param_value in self._template_parameters.items(): @@ -21,16 +39,16 @@ def _create_operator(self, **kwargs): overrides = self._task.overrides overrides.update({"command": self._generate_command()}) + job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name) batch_op = AWSBatchOperator( dag=self._dag, task_id=self._task.name, - job_name=self._task.job_name, - absolute_job_name=self._task.absolute_job_name, + job_name=self._task.name, + job_definition=job_name, region_name=self._task.region_name, - cluster_name=self._task.cluster_name, job_queue=self._task.job_queue, - overrides=overrides, + container_overrides=overrides, + awslogs_enabled=True, **kwargs, ) - return batch_op diff --git a/dagger/dag_creator/airflow/operator_creators/spark_creator.py b/dagger/dag_creator/airflow/operator_creators/spark_creator.py index 2bb41e9..c48ebda 100644 --- a/dagger/dag_creator/airflow/operator_creators/spark_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/spark_creator.py @@ -113,7 +113,7 @@ def _create_operator(self, **kwargs): job_name=job_name, region_name=self._task.region_name, job_queue=self._task.job_queue, - overrides=overrides, + container_overrides=overrides, **kwargs, ) elif self._task.spark_engine == "glue": diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index a267ba7..b2f4bb3 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,203 +1,90 @@ -from pathlib import Path -from time import sleep - -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.batch import BatchOperator +from airflow.utils.context import Context from airflow.exceptions import AirflowException -from airflow.utils.decorators import apply_defaults - -from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator -from dagger.dag_creator.airflow.utils.decorators import lazy_property -from dagger import conf - - -class AWSBatchOperator(DaggerBaseOperator): - """ - Execute a job on AWS Batch Service - - .. warning: the queue parameter was renamed to job_queue to segregate the - internal CeleryExecutor queue from the AWS Batch internal queue. - - :param job_name: the name for the job that will run on AWS Batch - :type job_name: str - :param job_definition: the job definition name on AWS Batch - :type job_definition: str - :param job_queue: the queue name on AWS Batch - :type job_queue: str - :param overrides: the same parameter that boto3 will receive on - containerOverrides (templated): - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job - :type overrides: dict - :param max_retries: exponential backoff retries while waiter is not - merged, 4200 = 48 hours - :type max_retries: int - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). - :type aws_conn_id: str - :param region_name: region name to use in AWS Hook. - Override the region_name in connection (if provided) - :type region_name: str - :param cluster_name: Batch cluster short name or arn - :type region_name: str - - """ - - ui_color = "#c3dae0" - client = None - arn = None - template_fields = ("overrides",) +from airflow.providers.amazon.aws.links.batch import ( + BatchJobDefinitionLink, + BatchJobQueueLink, +) +from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink - @apply_defaults - def __init__( - self, - job_queue, - job_name=None, - absolute_job_name=None, - overrides=None, - job_definition=None, - aws_conn_id=None, - region_name=None, - cluster_name=None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.job_name = self._validate_job_name(job_name, absolute_job_name) - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.cluster_name = cluster_name - self.job_definition = job_definition or self.job_name - self.job_queue = job_queue - self.overrides = overrides or {} - self.job_id = None - - @lazy_property - def batch_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="batch").get_client_type( - region_name=self.region_name) - - @lazy_property - def logs_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="logs").get_client_type( - region_name=self.region_name) - - @lazy_property - def ecs_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="ecs").get_client_type( - region_name=self.region_name) +class AWSBatchOperator(AWSBatchOperator): @staticmethod - def _validate_job_name(job_name, absolute_job_name): - if absolute_job_name is None and job_name is None: - raise Exception("Both job_name and absolute_job_name cannot be null") - - if absolute_job_name is not None: - return absolute_job_name - - job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/") - assert ( - job_path.is_dir() - ), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`" - return job_name - - def execute(self, context): - self.task_instance = context["ti"] - self.log.info( - "\n" - f"\n\tJob name: {self.job_name}" - f"\n\tJob queue: {self.job_queue}" - f"\n\tJob definition: {self.job_definition}" - "\n" - ) - - res = self.batch_client.submit_job( - jobName=self.job_name, - jobQueue=self.job_queue, - jobDefinition=self.job_definition, - containerOverrides=self.overrides, - ) - self.job_id = res["jobId"] - self.log.info( - "\n" - f"\n\tJob ID: {self.job_id}" - "\n" - ) - self.poll_task() - - def poll_task(self): - log_offset = 0 - print_logs_url = True - - while True: - res = self.batch_client.describe_jobs(jobs=[self.job_id]) - - if len(res["jobs"]) == 0: - sleep(3) - continue - - job = res["jobs"][0] - job_status = job["status"] - log_stream_name = job["container"].get("logStreamName") - - if print_logs_url and log_stream_name: - print_logs_url = False - self.log.info( - "\n" - f"\n\tLogs at: https://{self.region_name}.console.aws.amazon.com/cloudwatch/home?" - f"region={self.region_name}#logEventViewer:group=/aws/batch/job;stream={log_stream_name}" - "\n" - ) + def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): + return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}" + + def monitor_job(self, context: Context): + """Monitor an AWS Batch job. + + This can raise an exception or an AirflowTaskTimeout if the task was + created with ``execution_timeout``. + """ + if not self.job_id: + raise AirflowException("AWS Batch job - job_id was not found") + + try: + job_desc = self.hook.get_job_description(self.job_id) + job_definition_arn = job_desc["jobDefinition"] + job_queue_arn = job_desc["jobQueue"] + self.log.info( + "AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r", + self.job_id, + job_definition_arn, + job_queue_arn, + ) + except KeyError: + self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id) + else: + BatchJobDefinitionLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_definition_arn=job_definition_arn, + ) + BatchJobQueueLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_queue_arn=job_queue_arn, + ) - if job_status in ("RUNNING", "FAILED", "SUCCEEDED") and log_stream_name: - try: - log_offset = self.print_logs(log_stream_name, log_offset) - except self.logs_client.exceptions.ResourceNotFoundException: - pass + if self.awslogs_enabled: + if self.waiters: + self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) else: - self.log.info(f"Job status: {job_status}") - - if job_status == "FAILED": - status_reason = res["jobs"][0]["statusReason"] - exit_code = res["jobs"][0]["container"].get("exitCode") - reason = res["jobs"][0]["container"].get("reason", "") - failure_msg = f"Status: {status_reason} | Exit code: {exit_code} | Reason: {reason}" - container_instance_arn = job["container"]["containerInstanceArn"] - self.retry_check(container_instance_arn) - raise AirflowException(failure_msg) - - if job_status == "SUCCEEDED": - self.log.info("AWS Batch Job has been successfully executed") - return - - sleep(7.5) + self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) + else: + if self.waiters: + self.waiters.wait_for_job(self.job_id) + else: + self.hook.wait_for_job(self.job_id) + + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + link_builder = CloudWatchEventsLink() + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + if len(awslogs) > 1: + # there can be several log streams on multi-node jobs + self.log.warning( + "out of all those logs, we can only link to one in the UI. Using the first one." + ) - def retry_check(self, container_instance_arn): - res = self.ecs_client.describe_container_instances( - cluster=self.cluster_name, containerInstances=[container_instance_arn] - ) - instance_status = res["containerInstances"][0]["status"] - if instance_status != "ACTIVE": - self.log.warning( - f"Instance in {instance_status} state: setting the task up for retry..." + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs[0], ) - self.retries += self.task_instance.try_number + 1 - self.task_instance.max_tries = self.retries - - def print_logs(self, log_stream_name, log_offset): - logs = self.logs_client.get_log_events( - logGroupName="/aws/batch/job", - logStreamName=log_stream_name, - startFromHead=True, - ) - - for event in logs["events"][log_offset:]: - self.log.info(event["message"]) - - log_offset = len(logs["events"]) - return log_offset - def on_kill(self): - res = self.batch_client.terminate_job( - jobId=self.job_id, reason="Task killed by the user" - ) - self.log.info(res) + self.hook.check_job_success(self.job_id) + self.log.info("AWS Batch job (%s) succeeded", self.job_id) From a8e647184a76c5dc3be96b4f6938c9a038356282 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 26 Apr 2024 12:55:35 +0200 Subject: [PATCH 2/2] Complete renaming of classes --- dagger/dag_creator/airflow/operators/awsbatch_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index b2f4bb3..23b3596 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -8,7 +8,7 @@ from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink -class AWSBatchOperator(AWSBatchOperator): +class AWSBatchOperator(BatchOperator): @staticmethod def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}"