Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switching to official batch operator #36

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions dagger/dag_creator/airflow/operator_creators/batch_creator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from pathlib import Path
from datetime import timedelta

from dagger.dag_creator.airflow.operator_creator import OperatorCreator
from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator
from dagger import conf


class BatchCreator(OperatorCreator):
Expand All @@ -8,6 +12,20 @@ class BatchCreator(OperatorCreator):
def __init__(self, task, dag):
super().__init__(task, dag)

@staticmethod
def _validate_job_name(job_name, absolute_job_name):
if not absolute_job_name and not job_name:
raise Exception("Both job_name and absolute_job_name cannot be null")

if absolute_job_name is not None:
return absolute_job_name

job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/")
assert (
job_path.is_dir()
), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`"
return job_name

def _generate_command(self):
command = [self._task.executable_prefix, self._task.executable]
for param_name, param_value in self._template_parameters.items():
Expand All @@ -21,16 +39,16 @@ def _create_operator(self, **kwargs):
overrides = self._task.overrides
overrides.update({"command": self._generate_command()})

job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name)
batch_op = AWSBatchOperator(
dag=self._dag,
task_id=self._task.name,
job_name=self._task.job_name,
absolute_job_name=self._task.absolute_job_name,
job_name=self._task.name,
job_definition=job_name,
region_name=self._task.region_name,
cluster_name=self._task.cluster_name,
job_queue=self._task.job_queue,
overrides=overrides,
container_overrides=overrides,
awslogs_enabled=True,
**kwargs,
)

return batch_op
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _create_operator(self, **kwargs):
job_name=job_name,
region_name=self._task.region_name,
job_queue=self._task.job_queue,
overrides=overrides,
container_overrides=overrides,
**kwargs,
)
elif self._task.spark_engine == "glue":
Expand Down
275 changes: 81 additions & 194 deletions dagger/dag_creator/airflow/operators/awsbatch_operator.py
Original file line number Diff line number Diff line change
@@ -1,203 +1,90 @@
from pathlib import Path
from time import sleep

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.operators.batch import BatchOperator
from airflow.utils.context import Context
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults

from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator
from dagger.dag_creator.airflow.utils.decorators import lazy_property
from dagger import conf


class AWSBatchOperator(DaggerBaseOperator):
"""
Execute a job on AWS Batch Service
.. warning: the queue parameter was renamed to job_queue to segregate the
internal CeleryExecutor queue from the AWS Batch internal queue.
:param job_name: the name for the job that will run on AWS Batch
:type job_name: str
:param job_definition: the job definition name on AWS Batch
:type job_definition: str
:param job_queue: the queue name on AWS Batch
:type job_queue: str
:param overrides: the same parameter that boto3 will receive on
containerOverrides (templated):
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
:type overrides: dict
:param max_retries: exponential backoff retries while waiter is not
merged, 4200 = 48 hours
:type max_retries: int
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
:type aws_conn_id: str
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:type region_name: str
:param cluster_name: Batch cluster short name or arn
:type region_name: str
"""

ui_color = "#c3dae0"
client = None
arn = None
template_fields = ("overrides",)
from airflow.providers.amazon.aws.links.batch import (
BatchJobDefinitionLink,
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink

@apply_defaults
def __init__(
self,
job_queue,
job_name=None,
absolute_job_name=None,
overrides=None,
job_definition=None,
aws_conn_id=None,
region_name=None,
cluster_name=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.job_name = self._validate_job_name(job_name, absolute_job_name)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.cluster_name = cluster_name
self.job_definition = job_definition or self.job_name
self.job_queue = job_queue
self.overrides = overrides or {}
self.job_id = None

@lazy_property
def batch_client(self):
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="batch").get_client_type(
region_name=self.region_name)

@lazy_property
def logs_client(self):
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="logs").get_client_type(
region_name=self.region_name)

@lazy_property
def ecs_client(self):
return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="ecs").get_client_type(
region_name=self.region_name)

class AWSBatchOperator(BatchOperator):
@staticmethod
def _validate_job_name(job_name, absolute_job_name):
if absolute_job_name is None and job_name is None:
raise Exception("Both job_name and absolute_job_name cannot be null")

if absolute_job_name is not None:
return absolute_job_name

job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/")
assert (
job_path.is_dir()
), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`"
return job_name

def execute(self, context):
self.task_instance = context["ti"]
self.log.info(
"\n"
f"\n\tJob name: {self.job_name}"
f"\n\tJob queue: {self.job_queue}"
f"\n\tJob definition: {self.job_definition}"
"\n"
)

res = self.batch_client.submit_job(
jobName=self.job_name,
jobQueue=self.job_queue,
jobDefinition=self.job_definition,
containerOverrides=self.overrides,
)
self.job_id = res["jobId"]
self.log.info(
"\n"
f"\n\tJob ID: {self.job_id}"
"\n"
)
self.poll_task()

def poll_task(self):
log_offset = 0
print_logs_url = True

while True:
res = self.batch_client.describe_jobs(jobs=[self.job_id])

if len(res["jobs"]) == 0:
sleep(3)
continue

job = res["jobs"][0]
job_status = job["status"]
log_stream_name = job["container"].get("logStreamName")

if print_logs_url and log_stream_name:
print_logs_url = False
self.log.info(
"\n"
f"\n\tLogs at: https://{self.region_name}.console.aws.amazon.com/cloudwatch/home?"
f"region={self.region_name}#logEventViewer:group=/aws/batch/job;stream={log_stream_name}"
"\n"
)
def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str):
return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}"

def monitor_job(self, context: Context):
"""Monitor an AWS Batch job.
This can raise an exception or an AirflowTaskTimeout if the task was
created with ``execution_timeout``.
"""
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")

try:
job_desc = self.hook.get_job_description(self.job_id)
job_definition_arn = job_desc["jobDefinition"]
job_queue_arn = job_desc["jobQueue"]
self.log.info(
"AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r",
self.job_id,
job_definition_arn,
job_queue_arn,
)
except KeyError:
self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id)
else:
BatchJobDefinitionLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_definition_arn=job_definition_arn,
)
BatchJobQueueLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_queue_arn=job_queue_arn,
)

if job_status in ("RUNNING", "FAILED", "SUCCEEDED") and log_stream_name:
try:
log_offset = self.print_logs(log_stream_name, log_offset)
except self.logs_client.exceptions.ResourceNotFoundException:
pass
if self.awslogs_enabled:
if self.waiters:
self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
self.log.info(f"Job status: {job_status}")

if job_status == "FAILED":
status_reason = res["jobs"][0]["statusReason"]
exit_code = res["jobs"][0]["container"].get("exitCode")
reason = res["jobs"][0]["container"].get("reason", "")
failure_msg = f"Status: {status_reason} | Exit code: {exit_code} | Reason: {reason}"
container_instance_arn = job["container"]["containerInstanceArn"]
self.retry_check(container_instance_arn)
raise AirflowException(failure_msg)

if job_status == "SUCCEEDED":
self.log.info("AWS Batch Job has been successfully executed")
return

sleep(7.5)
self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
if self.waiters:
self.waiters.wait_for_job(self.job_id)
else:
self.hook.wait_for_job(self.job_id)

awslogs = []
try:
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
except AirflowException as ae:
self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae)

if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
link_builder = CloudWatchEventsLink()
for log in awslogs:
self.log.info(self._format_cloudwatch_link(**log))
if len(awslogs) > 1:
# there can be several log streams on multi-node jobs
self.log.warning(
"out of all those logs, we can only link to one in the UI. Using the first one."
)

def retry_check(self, container_instance_arn):
res = self.ecs_client.describe_container_instances(
cluster=self.cluster_name, containerInstances=[container_instance_arn]
)
instance_status = res["containerInstances"][0]["status"]
if instance_status != "ACTIVE":
self.log.warning(
f"Instance in {instance_status} state: setting the task up for retry..."
CloudWatchEventsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
**awslogs[0],
)
self.retries += self.task_instance.try_number + 1
self.task_instance.max_tries = self.retries

def print_logs(self, log_stream_name, log_offset):
logs = self.logs_client.get_log_events(
logGroupName="/aws/batch/job",
logStreamName=log_stream_name,
startFromHead=True,
)

for event in logs["events"][log_offset:]:
self.log.info(event["message"])

log_offset = len(logs["events"])
return log_offset

def on_kill(self):
res = self.batch_client.terminate_job(
jobId=self.job_id, reason="Task killed by the user"
)
self.log.info(res)
self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)
Loading