Skip to content

Commit

Permalink
Added allways_reuse_session in SASStudioOperator to configure a task …
Browse files Browse the repository at this point in the history
…to reuse a compute session or run in parallel (#24)

* Added allways_reuse_session

* Better management of Compute Sessions. Compute Sessions is now killed if need after a task have completed

* New version 0.0.11

---------

Co-authored-by: sdktjj <[email protected]>
  • Loading branch information
torbenjuul and sdktjj authored Nov 8, 2023
1 parent 8c0ad5a commit 4ed8b18
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 35 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = sas-airflow-provider
version = 0.0.10
version = 0.0.11
author = SAS
author_email = [email protected]
description = Enables execution of Studio Flows and Jobs from Airflow
Expand Down
109 changes: 83 additions & 26 deletions src/sas_airflow_provider/operators/sas_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from sas_airflow_provider.hooks.sas import SasHook
from sas_airflow_provider.util.util import dump_logs, create_or_connect_to_session
from sas_airflow_provider.util.util import dump_logs, create_or_connect_to_session, end_compute_session

# main API URI for Code Gen
URI_BASE = "/studioDevelopment/code"
Expand All @@ -36,7 +36,20 @@
JES_URI = "/jobExecution"
JOB_URI = f"{JES_URI}/jobs"

def on_success(context):
# Only kill session when not reused or external managed
context['task']._clean_up(also_kill_reused_session=False)

def on_failure(context):
# Kill all sessions except external managed
context['task']._clean_up(also_kill_reused_session=True)

def on_retry(context):
# Kill all sessions except external managed
context['task']._clean_up(also_kill_reused_session=True)



class SASStudioOperator(BaseOperator):
"""
Executes a SAS Studio flow or a SAS program
Expand All @@ -60,22 +73,28 @@ class SASStudioOperator(BaseOperator):
suitable default is used (see DEFAULT_COMPUTE_CONTEXT NAME).
:param env_vars: (optional) Dictionary of environment variables to set before running the flow.
:param macro_vars: (optional) Dictionary of macro variables to set before running the flow.
:param compute_session_id: (optional) Compute session id to use. If not specified, one will be created using the
default session name (see AIRFLOW_SESSION_NAME). Note that the name and the id are not the same. The name
will always be the value of AIRFLOW_SESSION_NAME, which means that if you don't supply a session id, then
this named session will be created or re-used. The advantage is that the same session can be re-used between
tasks. The disadvantage is that it offers less flexibility in terms of having multiple sessions.
:param allways_reuse_session: (optional) Specify true to always reuse the same Compute Session across all tasks. The name
of the session will be the default session name (see AIRFLOW_SESSION_NAME), which means that if you don't supply a session id in compute_session_id,
then this named session will be created and later re-used between tasks. The disadvantage is that it offers less flexibility in terms of
having multiple sessions (parallelisme). Default value is False meaning a new unnamed compute sessions will always be created
UNLESS a session id is specified in compute_session_id.
:param compute_session_id: (optional) Compute Session id to use for the task. If a Session Id is specified, this will overide allways_reuse_session.
Use SASComputeCreateSession Operator to define a task that will create the session. This gives full flexibility in how compue session are used.
The id of the session created by SASComputeCreateSession will be made avaliable as XCom variable 'compute_session_id'
for subsequent use by SASStudio Operator tasks. Tip: set the value to "{{ ti.xcom_pull(key='compute_session_id', task_ids=['<task_id>'])|first}}" to get the X-Com value.
:param output_macro_var_prefix: (optional) string. If this has a value, then any macro variables which start
with this prefix will be retrieved from the session after the code has executed and will be returned as XComs
:param unknown_state_timeout: (optional) number of seconds to continue polling for the state of a running job if the state is
temporary unobtainable. When unknown_state_timeout is reached without the state being retrievable, the operator
will throw an AirflowFailException and the task will be marked as failed.
Default value is 0, meaning the task will fail immediately if the state could not be retrieved.
Default value is 0, meaning the task will fail immediately if the state could not be retrieved.
"""

ui_color = "#CCE5FF"
ui_fgcolor = "black"



template_fields: Sequence[str] = ("env_vars", "macro_vars", "compute_session_id", "path")

def __init__(
Expand All @@ -90,6 +109,7 @@ def __init__(
compute_context=DEFAULT_COMPUTE_CONTEXT_NAME,
env_vars=None,
macro_vars=None,
allways_reuse_session=False,
compute_session_id="",
output_macro_var_prefix="",
unknown_state_timeout=0,
Expand All @@ -110,10 +130,22 @@ def __init__(
self.env_vars = env_vars
self.macro_vars = macro_vars
self.connection = None
self.compute_session_id = compute_session_id
self.allways_reuse_session = allways_reuse_session

self.external_managed_session = False
self.compute_session_id = None
if compute_session_id:
self.compute_session_id = compute_session_id
self.external_managed_session=True

self.output_macro_var_prefix = output_macro_var_prefix.upper()
self.unknown_state_timeout=max(unknown_state_timeout,0)

# Use hooks to clean up
self.on_success_callback=[on_success]
self.on_failure_callback=[on_failure]
self.on_retry_callback=[on_retry]

def execute(self, context):
if self.path_type not in ['compute', 'content', 'raw']:
raise AirflowFailException("Path type is invalid. Valid values are 'compute', 'content' or 'raw'")
Expand All @@ -127,6 +159,16 @@ def execute(self, context):
h = SasHook(self.connection_name)
self.connection = h.get_conn()

# Create compute session
if not self.compute_session_id:
compute_session = create_or_connect_to_session(self.connection,
self.compute_context_name,
AIRFLOW_SESSION_NAME if self.allways_reuse_session else None)
self.compute_session_id = compute_session["id"]
else:
self.log.info(f"Compute Session {self.compute_session_id} was provided")

# Generate SAS code
if self.path_type == "raw":
code = self.path
else:
Expand Down Expand Up @@ -157,7 +199,7 @@ def execute(self, context):
except Exception as e:
raise AirflowException(f"SASStudioOperator error: {str(e)}")


# Kick off the JES job.
job, success = self._run_job_and_wait(jr, 10)
job_state = job["state"]
Expand All @@ -172,7 +214,11 @@ def execute(self, context):

# set output variables
if success and self.output_macro_var_prefix and self.compute_session_id:
self._set_output_variables(context)
try:
self._set_output_variables(context)
except Exception as e:
raise AirflowException(f"SASStudioOperator error: {str(e)}")


# raise exception in Airflow if SAS Studio Flow ended execution with "failed" "canceled" or "timed out" state
# support retry for 'failed' (typically there is an ERROR in the log) and 'timed out'
Expand All @@ -185,9 +231,31 @@ def execute(self, context):

elif job_state == "timed out":
raise AirflowException("SAS Studio Execution has timed out. See log for details ")

return 1

def on_kill(self) -> None:
self._clean_up(also_kill_reused_session=True)

def _clean_up(self, also_kill_reused_session=False):
# Always kill unnamed sessions (allways_reuse_session is false)
# however is also_kill_reused_session is specified also kill the reuse session
# newer kill external managed sessions, as this may prevent restart
if self.compute_session_id and self.external_managed_session==False:
if (also_kill_reused_session and self.allways_reuse_session) or self.allways_reuse_session==False:
try:
self.log.info(f"Deleting session with id {self.compute_session_id}")
success_end = end_compute_session(self.connection, self.compute_session_id)
if success_end:
self.log.info(f"Compute session succesfully deleted")
else:
self.log.info(f"Unable to delete compute session. You may need to kill the session manually")
self.compute_session_id=None

except Exception as e:
self.log.info(f"Unable to delete compute session. You may need to kill the session manually")
self.compute_session_id=None

def _add_airflow_env_vars(self):
for x in ['AIRFLOW_CTX_DAG_OWNER',
'AIRFLOW_CTX_DAG_ID',
Expand Down Expand Up @@ -217,25 +285,14 @@ def _get_pre_code(self):
return pre_code

def _generate_object_code(self):

uri = URI_BASE
uri=URI_BASE

if self.path_type == "compute":
self.log.info("Code Generation for Studio object stored in Compute file system")

# if session id is provided, use it, otherwise create a session
if not self.compute_session_id:
self.log.info("Create or connect to session")
compute_session = create_or_connect_to_session(self.connection,
self.compute_context_name, AIRFLOW_SESSION_NAME)
self.compute_session_id = compute_session["id"]
else:
self.log.info("Session ID was provided")

uri = f"{URI_BASE}?sessionId={self.compute_session_id}"
self.log.info("Code Generation for Studio object stored in Compute file system")
else:
self.log.info("Code generation for Studio object stored in Content")

media_type = "application/vnd.sas.dataflow"
if self.exec_type == "program":
media_type = "application/vnd.sas.program"
Expand Down Expand Up @@ -296,7 +353,7 @@ def _run_job_and_wait(self, job_request: dict, poll_interval: int) -> (dict, boo
# Print the log location to the DAG-log, in case the user needs access to the SAS-log while it is running.
if "logLocation" in job:
log_location=job["logLocation"];
self.log.info(f"While the job is running the SAS-log formated at JSON can be found at URI: {log_location}?limit=9999999")
self.log.info(f"While the job is running, the SAS-log formated as JSON can be found at URI: {log_location}?limit=9999999")
except Exception as e:
countUnknownState = countUnknownState + 1
self.log.info(f'HTTP Call failed with error "{e}". Will set state=unknown and continue checking...')
Expand Down
29 changes: 21 additions & 8 deletions src/sas_airflow_provider/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,24 +136,30 @@ def find_named_compute_session(session: requests.Session, name: str) -> dict:
raise RuntimeError(f"Find sessions failed: {response.status_code}")
sessions = response.json()
if sessions["count"] > 0:
print(f"Existing session named '{name}' was found")
print(f"Existing compute session named '{name}' with id {sessions['items'][0]['id']} was found")
return sessions["items"][0]
return {}

def create_or_connect_to_session(session: requests.Session, context_name: str, name: str) -> dict:
def create_or_connect_to_session(session: requests.Session, context_name: str, name = None) -> dict:
"""
Connect to an existing compute session by name. If that named session does not exist,
one is created using the context name supplied
:param session: rest session that includes oauth token
:param context_name: the context name to use to create the session if the session was not found
:param name: name of session to find
:return: session object
"""
compute_session = find_named_compute_session(session, name)
if compute_session:
return compute_session
if name != None:
compute_session = find_named_compute_session(session, name)
if compute_session:
return compute_session

print(f"Compute session named '{name}' does not exist, a new one will be created")
else:
print(f"A new unnamed compute session will be created")


print(f"Compute session named '{name}' does not exist, a new one will be created")
# find compute context
response = session.get("/compute/contexts", params={"filter": f'eq("name","{context_name}")'})
if not response.ok:
Expand All @@ -165,7 +171,11 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n

# create session with given context
uri = f'/compute/contexts/{sas_context["id"]}/sessions'
session_request = {"version": 1, "name": name}
if name != None:
session_request = {"version": 1, "name": name}
else:
# Create a unnamed session
session_request = {"version": 1}

headers = {"Content-Type": "application/vnd.sas.compute.session.request+json"}

Expand All @@ -175,7 +185,10 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n
if response.status_code != 201:
raise RuntimeError(f"Failed to create session: {response.text}")

return response.json()
json_response=response.json()
print(f"Compute session {json_response['id']} created")

return json_response

def end_compute_session(session: requests.Session, id):
uri = f'/compute/sessions/{id}'
Expand Down

0 comments on commit 4ed8b18

Please sign in to comment.