Skip to content

Commit

Permalink
Merge pull request #36 from chocoapp/bugfix/DATA-1804_missing_batch_logs
Browse files Browse the repository at this point in the history
Switching to official batch operator
  • Loading branch information
siklosid authored Apr 26, 2024
2 parents 52e7572 + a8e6471 commit 3c5fd12
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 200 deletions.
28 changes: 23 additions & 5 deletions dagger/dag_creator/airflow/operator_creators/batch_creator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from pathlib import Path
from datetime import timedelta

from dagger.dag_creator.airflow.operator_creator import OperatorCreator
from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator
from dagger import conf


class BatchCreator(OperatorCreator):
Expand All @@ -8,6 +12,20 @@ class BatchCreator(OperatorCreator):
def __init__(self, task, dag):
super().__init__(task, dag)

@staticmethod
def _validate_job_name(job_name, absolute_job_name):
if not absolute_job_name and not job_name:
raise Exception("Both job_name and absolute_job_name cannot be null")

if absolute_job_name is not None:
return absolute_job_name

job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/")
assert (
job_path.is_dir()
), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`"
return job_name

def _generate_command(self):
command = [self._task.executable_prefix, self._task.executable]
for param_name, param_value in self._template_parameters.items():
Expand All @@ -21,16 +39,16 @@ def _create_operator(self, **kwargs):
overrides = self._task.overrides
overrides.update({"command": self._generate_command()})

job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name)
batch_op = AWSBatchOperator(
dag=self._dag,
task_id=self._task.name,
job_name=self._task.job_name,
absolute_job_name=self._task.absolute_job_name,
job_name=self._task.name,
job_definition=job_name,
region_name=self._task.region_name,
cluster_name=self._task.cluster_name,
job_queue=self._task.job_queue,
overrides=overrides,
container_overrides=overrides,
awslogs_enabled=True,
**kwargs,
)

return batch_op
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _create_operator(self, **kwargs):
job_name=job_name,
region_name=self._task.region_name,
job_queue=self._task.job_queue,
overrides=overrides,
container_overrides=overrides,
**kwargs,
)
elif self._task.spark_engine == "glue":
Expand Down
275 changes: 81 additions & 194 deletions dagger/dag_creator/airflow/operators/awsbatch_operator.py
Original file line number Diff line number Diff line change
@@ -1,203 +1,90 @@
from pathlib import Path
from time import sleep

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.operators.batch import BatchOperator
from airflow.utils.context import Context
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults

from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator
from dagger.dag_creator.airflow.utils.decorators import lazy_property
from dagger import conf


class AWSBatchOperator(DaggerBaseOperator):
"""
Execute a job on AWS Batch Service
.. warning: the queue parameter was renamed to job_queue to segregate the
internal CeleryExecutor queue from the AWS Batch internal queue.
:param job_name: the name for the job that will run on AWS Batch
:type job_name: str
:param job_definition: the job definition name on AWS Batch
:type job_definition: str
:param job_queue: the queue name on AWS Batch
:type job_queue: str
:param overrides: the same parameter that boto3 will receive on
containerOverrides (templated):
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
:type overrides: dict
:param max_retries: exponential backoff retries while waiter is not
merged, 4200 = 48 hours
:type max_retries: int
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
:type aws_conn_id: str
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:type region_name: str
:param cluster_name: Batch cluster short name or arn
:type region_name: str
"""

ui_color = "#c3dae0"
client = None
arn = None
template_fields = ("overrides",)
from airflow.providers.amazon.aws.links.batch import (
BatchJobDefinitionLink,
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink

@apply_defaults
def __init__(
self,
job_queue,
job_name=None,
absolute_job_name=None,
overrides=None,
job_definition=None,
aws_conn_id=None,
region_name=None,
cluster_name=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.job_name = self._validate_job_name(job_name, absolute_job_name)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.cluster_name = cluster_name
self.job_definition = job_definition or self.job_name
self.job_queue = job_queue
self.overrides = overrides or {}
self.job_id = None

@lazy_property
def batch_client(self):
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="batch").get_client_type(
region_name=self.region_name)

@lazy_property
def logs_client(self):
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="logs").get_client_type(
region_name=self.region_name)

@lazy_property
def ecs_client(self):
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="ecs").get_client_type(
region_name=self.region_name)

class AWSBatchOperator(BatchOperator):
@staticmethod
def _validate_job_name(job_name, absolute_job_name):
if absolute_job_name is None and job_name is None:
raise Exception("Both job_name and absolute_job_name cannot be null")

if absolute_job_name is not None:
return absolute_job_name

job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/")
assert (
job_path.is_dir()
), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`"
return job_name

def execute(self, context):
self.task_instance = context["ti"]
self.log.info(
"\n"
f"\n\tJob name: {self.job_name}"
f"\n\tJob queue: {self.job_queue}"
f"\n\tJob definition: {self.job_definition}"
"\n"
)

res = self.batch_client.submit_job(
jobName=self.job_name,
jobQueue=self.job_queue,
jobDefinition=self.job_definition,
containerOverrides=self.overrides,
)
self.job_id = res["jobId"]
self.log.info(
"\n"
f"\n\tJob ID: {self.job_id}"
"\n"
)
self.poll_task()

def poll_task(self):
log_offset = 0
print_logs_url = True

while True:
res = self.batch_client.describe_jobs(jobs=[self.job_id])

if len(res["jobs"]) == 0:
sleep(3)
continue

job = res["jobs"][0]
job_status = job["status"]
log_stream_name = job["container"].get("logStreamName")

if print_logs_url and log_stream_name:
print_logs_url = False
self.log.info(
"\n"
f"\n\tLogs at: https://{self.region_name}.console.aws.amazon.com/cloudwatch/home?"
f"region={self.region_name}#logEventViewer:group=/aws/batch/job;stream={log_stream_name}"
"\n"
)
def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str):
return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}"

def monitor_job(self, context: Context):
"""Monitor an AWS Batch job.
This can raise an exception or an AirflowTaskTimeout if the task was
created with ``execution_timeout``.
"""
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")

try:
job_desc = self.hook.get_job_description(self.job_id)
job_definition_arn = job_desc["jobDefinition"]
job_queue_arn = job_desc["jobQueue"]
self.log.info(
"AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r",
self.job_id,
job_definition_arn,
job_queue_arn,
)
except KeyError:
self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id)
else:
BatchJobDefinitionLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_definition_arn=job_definition_arn,
)
BatchJobQueueLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_queue_arn=job_queue_arn,
)

if job_status in ("RUNNING", "FAILED", "SUCCEEDED") and log_stream_name:
try:
log_offset = self.print_logs(log_stream_name, log_offset)
except self.logs_client.exceptions.ResourceNotFoundException:
pass
if self.awslogs_enabled:
if self.waiters:
self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
self.log.info(f"Job status: {job_status}")

if job_status == "FAILED":
status_reason = res["jobs"][0]["statusReason"]
exit_code = res["jobs"][0]["container"].get("exitCode")
reason = res["jobs"][0]["container"].get("reason", "")
failure_msg = f"Status: {status_reason} | Exit code: {exit_code} | Reason: {reason}"
container_instance_arn = job["container"]["containerInstanceArn"]
self.retry_check(container_instance_arn)
raise AirflowException(failure_msg)

if job_status == "SUCCEEDED":
self.log.info("AWS Batch Job has been successfully executed")
return

sleep(7.5)
self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
if self.waiters:
self.waiters.wait_for_job(self.job_id)
else:
self.hook.wait_for_job(self.job_id)

awslogs = []
try:
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
except AirflowException as ae:
self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae)

if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
link_builder = CloudWatchEventsLink()
for log in awslogs:
self.log.info(self._format_cloudwatch_link(**log))
if len(awslogs) > 1:
# there can be several log streams on multi-node jobs
self.log.warning(
"out of all those logs, we can only link to one in the UI. Using the first one."
)

def retry_check(self, container_instance_arn):
res = self.ecs_client.describe_container_instances(
cluster=self.cluster_name, containerInstances=[container_instance_arn]
)
instance_status = res["containerInstances"][0]["status"]
if instance_status != "ACTIVE":
self.log.warning(
f"Instance in {instance_status} state: setting the task up for retry..."
CloudWatchEventsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
**awslogs[0],
)
self.retries += self.task_instance.try_number + 1
self.task_instance.max_tries = self.retries

def print_logs(self, log_stream_name, log_offset):
logs = self.logs_client.get_log_events(
logGroupName="/aws/batch/job",
logStreamName=log_stream_name,
startFromHead=True,
)

for event in logs["events"][log_offset:]:
self.log.info(event["message"])

log_offset = len(logs["events"])
return log_offset

def on_kill(self):
res = self.batch_client.terminate_job(
jobId=self.job_id, reason="Task killed by the user"
)
self.log.info(res)
self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)

0 comments on commit 3c5fd12

Please sign in to comment.