forked from siklosid/dagger
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #36 from chocoapp/bugfix/DATA-1804_missing_batch_logs
Switching to official batch operator
- Loading branch information
Showing
3 changed files
with
105 additions
and
200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
275 changes: 81 additions & 194 deletions
275
dagger/dag_creator/airflow/operators/awsbatch_operator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,203 +1,90 @@ | ||
from pathlib import Path | ||
from time import sleep | ||
|
||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook | ||
from airflow.providers.amazon.aws.operators.batch import BatchOperator | ||
from airflow.utils.context import Context | ||
from airflow.exceptions import AirflowException | ||
from airflow.utils.decorators import apply_defaults | ||
|
||
from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator | ||
from dagger.dag_creator.airflow.utils.decorators import lazy_property | ||
from dagger import conf | ||
|
||
|
||
class AWSBatchOperator(DaggerBaseOperator): | ||
""" | ||
Execute a job on AWS Batch Service | ||
.. warning: the queue parameter was renamed to job_queue to segregate the | ||
internal CeleryExecutor queue from the AWS Batch internal queue. | ||
:param job_name: the name for the job that will run on AWS Batch | ||
:type job_name: str | ||
:param job_definition: the job definition name on AWS Batch | ||
:type job_definition: str | ||
:param job_queue: the queue name on AWS Batch | ||
:type job_queue: str | ||
:param overrides: the same parameter that boto3 will receive on | ||
containerOverrides (templated): | ||
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job | ||
:type overrides: dict | ||
:param max_retries: exponential backoff retries while waiter is not | ||
merged, 4200 = 48 hours | ||
:type max_retries: int | ||
:param aws_conn_id: connection id of AWS credentials / region name. If None, | ||
credential boto3 strategy will be used | ||
(http://boto3.readthedocs.io/en/latest/guide/configuration.html). | ||
:type aws_conn_id: str | ||
:param region_name: region name to use in AWS Hook. | ||
Override the region_name in connection (if provided) | ||
:type region_name: str | ||
:param cluster_name: Batch cluster short name or arn | ||
:type region_name: str | ||
""" | ||
|
||
ui_color = "#c3dae0" | ||
client = None | ||
arn = None | ||
template_fields = ("overrides",) | ||
from airflow.providers.amazon.aws.links.batch import ( | ||
BatchJobDefinitionLink, | ||
BatchJobQueueLink, | ||
) | ||
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink | ||
|
||
@apply_defaults | ||
def __init__( | ||
self, | ||
job_queue, | ||
job_name=None, | ||
absolute_job_name=None, | ||
overrides=None, | ||
job_definition=None, | ||
aws_conn_id=None, | ||
region_name=None, | ||
cluster_name=None, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
self.job_name = self._validate_job_name(job_name, absolute_job_name) | ||
self.aws_conn_id = aws_conn_id | ||
self.region_name = region_name | ||
self.cluster_name = cluster_name | ||
self.job_definition = job_definition or self.job_name | ||
self.job_queue = job_queue | ||
self.overrides = overrides or {} | ||
self.job_id = None | ||
|
||
@lazy_property | ||
def batch_client(self): | ||
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="batch").get_client_type( | ||
region_name=self.region_name) | ||
|
||
@lazy_property | ||
def logs_client(self): | ||
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="logs").get_client_type( | ||
region_name=self.region_name) | ||
|
||
@lazy_property | ||
def ecs_client(self): | ||
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="ecs").get_client_type( | ||
region_name=self.region_name) | ||
|
||
class AWSBatchOperator(BatchOperator): | ||
@staticmethod | ||
def _validate_job_name(job_name, absolute_job_name): | ||
if absolute_job_name is None and job_name is None: | ||
raise Exception("Both job_name and absolute_job_name cannot be null") | ||
|
||
if absolute_job_name is not None: | ||
return absolute_job_name | ||
|
||
job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/") | ||
assert ( | ||
job_path.is_dir() | ||
), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`" | ||
return job_name | ||
|
||
def execute(self, context): | ||
self.task_instance = context["ti"] | ||
self.log.info( | ||
"\n" | ||
f"\n\tJob name: {self.job_name}" | ||
f"\n\tJob queue: {self.job_queue}" | ||
f"\n\tJob definition: {self.job_definition}" | ||
"\n" | ||
) | ||
|
||
res = self.batch_client.submit_job( | ||
jobName=self.job_name, | ||
jobQueue=self.job_queue, | ||
jobDefinition=self.job_definition, | ||
containerOverrides=self.overrides, | ||
) | ||
self.job_id = res["jobId"] | ||
self.log.info( | ||
"\n" | ||
f"\n\tJob ID: {self.job_id}" | ||
"\n" | ||
) | ||
self.poll_task() | ||
|
||
def poll_task(self): | ||
log_offset = 0 | ||
print_logs_url = True | ||
|
||
while True: | ||
res = self.batch_client.describe_jobs(jobs=[self.job_id]) | ||
|
||
if len(res["jobs"]) == 0: | ||
sleep(3) | ||
continue | ||
|
||
job = res["jobs"][0] | ||
job_status = job["status"] | ||
log_stream_name = job["container"].get("logStreamName") | ||
|
||
if print_logs_url and log_stream_name: | ||
print_logs_url = False | ||
self.log.info( | ||
"\n" | ||
f"\n\tLogs at: https://{self.region_name}.console.aws.amazon.com/cloudwatch/home?" | ||
f"region={self.region_name}#logEventViewer:group=/aws/batch/job;stream={log_stream_name}" | ||
"\n" | ||
) | ||
def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): | ||
return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}" | ||
|
||
def monitor_job(self, context: Context): | ||
"""Monitor an AWS Batch job. | ||
This can raise an exception or an AirflowTaskTimeout if the task was | ||
created with ``execution_timeout``. | ||
""" | ||
if not self.job_id: | ||
raise AirflowException("AWS Batch job - job_id was not found") | ||
|
||
try: | ||
job_desc = self.hook.get_job_description(self.job_id) | ||
job_definition_arn = job_desc["jobDefinition"] | ||
job_queue_arn = job_desc["jobQueue"] | ||
self.log.info( | ||
"AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r", | ||
self.job_id, | ||
job_definition_arn, | ||
job_queue_arn, | ||
) | ||
except KeyError: | ||
self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id) | ||
else: | ||
BatchJobDefinitionLink.persist( | ||
context=context, | ||
operator=self, | ||
region_name=self.hook.conn_region_name, | ||
aws_partition=self.hook.conn_partition, | ||
job_definition_arn=job_definition_arn, | ||
) | ||
BatchJobQueueLink.persist( | ||
context=context, | ||
operator=self, | ||
region_name=self.hook.conn_region_name, | ||
aws_partition=self.hook.conn_partition, | ||
job_queue_arn=job_queue_arn, | ||
) | ||
|
||
if job_status in ("RUNNING", "FAILED", "SUCCEEDED") and log_stream_name: | ||
try: | ||
log_offset = self.print_logs(log_stream_name, log_offset) | ||
except self.logs_client.exceptions.ResourceNotFoundException: | ||
pass | ||
if self.awslogs_enabled: | ||
if self.waiters: | ||
self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) | ||
else: | ||
self.log.info(f"Job status: {job_status}") | ||
|
||
if job_status == "FAILED": | ||
status_reason = res["jobs"][0]["statusReason"] | ||
exit_code = res["jobs"][0]["container"].get("exitCode") | ||
reason = res["jobs"][0]["container"].get("reason", "") | ||
failure_msg = f"Status: {status_reason} | Exit code: {exit_code} | Reason: {reason}" | ||
container_instance_arn = job["container"]["containerInstanceArn"] | ||
self.retry_check(container_instance_arn) | ||
raise AirflowException(failure_msg) | ||
|
||
if job_status == "SUCCEEDED": | ||
self.log.info("AWS Batch Job has been successfully executed") | ||
return | ||
|
||
sleep(7.5) | ||
self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) | ||
else: | ||
if self.waiters: | ||
self.waiters.wait_for_job(self.job_id) | ||
else: | ||
self.hook.wait_for_job(self.job_id) | ||
|
||
awslogs = [] | ||
try: | ||
awslogs = self.hook.get_job_all_awslogs_info(self.job_id) | ||
except AirflowException as ae: | ||
self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) | ||
|
||
if awslogs: | ||
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) | ||
link_builder = CloudWatchEventsLink() | ||
for log in awslogs: | ||
self.log.info(self._format_cloudwatch_link(**log)) | ||
if len(awslogs) > 1: | ||
# there can be several log streams on multi-node jobs | ||
self.log.warning( | ||
"out of all those logs, we can only link to one in the UI. Using the first one." | ||
) | ||
|
||
def retry_check(self, container_instance_arn): | ||
res = self.ecs_client.describe_container_instances( | ||
cluster=self.cluster_name, containerInstances=[container_instance_arn] | ||
) | ||
instance_status = res["containerInstances"][0]["status"] | ||
if instance_status != "ACTIVE": | ||
self.log.warning( | ||
f"Instance in {instance_status} state: setting the task up for retry..." | ||
CloudWatchEventsLink.persist( | ||
context=context, | ||
operator=self, | ||
region_name=self.hook.conn_region_name, | ||
aws_partition=self.hook.conn_partition, | ||
**awslogs[0], | ||
) | ||
self.retries += self.task_instance.try_number + 1 | ||
self.task_instance.max_tries = self.retries | ||
|
||
def print_logs(self, log_stream_name, log_offset): | ||
logs = self.logs_client.get_log_events( | ||
logGroupName="/aws/batch/job", | ||
logStreamName=log_stream_name, | ||
startFromHead=True, | ||
) | ||
|
||
for event in logs["events"][log_offset:]: | ||
self.log.info(event["message"]) | ||
|
||
log_offset = len(logs["events"]) | ||
return log_offset | ||
|
||
def on_kill(self): | ||
res = self.batch_client.terminate_job( | ||
jobId=self.job_id, reason="Task killed by the user" | ||
) | ||
self.log.info(res) | ||
self.hook.check_job_success(self.job_id) | ||
self.log.info("AWS Batch job (%s) succeeded", self.job_id) |