diff --git a/setup.cfg b/setup.cfg index 44add39..e6f2845 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = sas-airflow-provider -version = 0.0.10 +version = 0.0.11 author = SAS author_email = andrew.shakinovsky@sas.com description = Enables execution of Studio Flows and Jobs from Airflow diff --git a/src/sas_airflow_provider/operators/sas_studio.py b/src/sas_airflow_provider/operators/sas_studio.py index feeab01..3823b82 100644 --- a/src/sas_airflow_provider/operators/sas_studio.py +++ b/src/sas_airflow_provider/operators/sas_studio.py @@ -24,7 +24,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from sas_airflow_provider.hooks.sas import SasHook -from sas_airflow_provider.util.util import dump_logs, create_or_connect_to_session +from sas_airflow_provider.util.util import dump_logs, create_or_connect_to_session, end_compute_session # main API URI for Code Gen URI_BASE = "/studioDevelopment/code" @@ -36,7 +36,20 @@ JES_URI = "/jobExecution" JOB_URI = f"{JES_URI}/jobs" +def on_success(context): + # Only kill session when not reused or external managed + context['task']._clean_up(also_kill_reused_session=False) +def on_failure(context): + # Kill all sessions except external managed + context['task']._clean_up(also_kill_reused_session=True) + +def on_retry(context): + # Kill all sessions except external managed + context['task']._clean_up(also_kill_reused_session=True) + + + class SASStudioOperator(BaseOperator): """ Executes a SAS Studio flow or a SAS program @@ -60,22 +73,28 @@ class SASStudioOperator(BaseOperator): suitable default is used (see DEFAULT_COMPUTE_CONTEXT NAME). :param env_vars: (optional) Dictionary of environment variables to set before running the flow. :param macro_vars: (optional) Dictionary of macro variables to set before running the flow. - :param compute_session_id: (optional) Compute session id to use. If not specified, one will be created using the - default session name (see AIRFLOW_SESSION_NAME). Note that the name and the id are not the same. The name - will always be the value of AIRFLOW_SESSION_NAME, which means that if you don't supply a session id, then - this named session will be created or re-used. The advantage is that the same session can be re-used between - tasks. The disadvantage is that it offers less flexibility in terms of having multiple sessions. + :param allways_reuse_session: (optional) Specify true to always reuse the same Compute Session across all tasks. The name + of the session will be the default session name (see AIRFLOW_SESSION_NAME), which means that if you don't supply a session id in compute_session_id, + then this named session will be created and later re-used between tasks. The disadvantage is that it offers less flexibility in terms of + having multiple sessions (parallelisme). Default value is False meaning a new unnamed compute sessions will always be created + UNLESS a session id is specified in compute_session_id. + :param compute_session_id: (optional) Compute Session id to use for the task. If a Session Id is specified, this will overide allways_reuse_session. + Use SASComputeCreateSession Operator to define a task that will create the session. This gives full flexibility in how compue session are used. + The id of the session created by SASComputeCreateSession will be made avaliable as XCom variable 'compute_session_id' + for subsequent use by SASStudio Operator tasks. Tip: set the value to "{{ ti.xcom_pull(key='compute_session_id', task_ids=[''])|first}}" to get the X-Com value. :param output_macro_var_prefix: (optional) string. If this has a value, then any macro variables which start with this prefix will be retrieved from the session after the code has executed and will be returned as XComs :param unknown_state_timeout: (optional) number of seconds to continue polling for the state of a running job if the state is temporary unobtainable. When unknown_state_timeout is reached without the state being retrievable, the operator will throw an AirflowFailException and the task will be marked as failed. - Default value is 0, meaning the task will fail immediately if the state could not be retrieved. + Default value is 0, meaning the task will fail immediately if the state could not be retrieved. """ ui_color = "#CCE5FF" ui_fgcolor = "black" + + template_fields: Sequence[str] = ("env_vars", "macro_vars", "compute_session_id", "path") def __init__( @@ -90,6 +109,7 @@ def __init__( compute_context=DEFAULT_COMPUTE_CONTEXT_NAME, env_vars=None, macro_vars=None, + allways_reuse_session=False, compute_session_id="", output_macro_var_prefix="", unknown_state_timeout=0, @@ -110,10 +130,22 @@ def __init__( self.env_vars = env_vars self.macro_vars = macro_vars self.connection = None - self.compute_session_id = compute_session_id + self.allways_reuse_session = allways_reuse_session + + self.external_managed_session = False + self.compute_session_id = None + if compute_session_id: + self.compute_session_id = compute_session_id + self.external_managed_session=True + self.output_macro_var_prefix = output_macro_var_prefix.upper() self.unknown_state_timeout=max(unknown_state_timeout,0) + # Use hooks to clean up + self.on_success_callback=[on_success] + self.on_failure_callback=[on_failure] + self.on_retry_callback=[on_retry] + def execute(self, context): if self.path_type not in ['compute', 'content', 'raw']: raise AirflowFailException("Path type is invalid. Valid values are 'compute', 'content' or 'raw'") @@ -127,6 +159,16 @@ def execute(self, context): h = SasHook(self.connection_name) self.connection = h.get_conn() + # Create compute session + if not self.compute_session_id: + compute_session = create_or_connect_to_session(self.connection, + self.compute_context_name, + AIRFLOW_SESSION_NAME if self.allways_reuse_session else None) + self.compute_session_id = compute_session["id"] + else: + self.log.info(f"Compute Session {self.compute_session_id} was provided") + + # Generate SAS code if self.path_type == "raw": code = self.path else: @@ -157,7 +199,7 @@ def execute(self, context): except Exception as e: raise AirflowException(f"SASStudioOperator error: {str(e)}") - + # Kick off the JES job. job, success = self._run_job_and_wait(jr, 10) job_state = job["state"] @@ -172,7 +214,11 @@ def execute(self, context): # set output variables if success and self.output_macro_var_prefix and self.compute_session_id: - self._set_output_variables(context) + try: + self._set_output_variables(context) + except Exception as e: + raise AirflowException(f"SASStudioOperator error: {str(e)}") + # raise exception in Airflow if SAS Studio Flow ended execution with "failed" "canceled" or "timed out" state # support retry for 'failed' (typically there is an ERROR in the log) and 'timed out' @@ -185,9 +231,31 @@ def execute(self, context): elif job_state == "timed out": raise AirflowException("SAS Studio Execution has timed out. See log for details ") - + return 1 + def on_kill(self) -> None: + self._clean_up(also_kill_reused_session=True) + + def _clean_up(self, also_kill_reused_session=False): + # Always kill unnamed sessions (allways_reuse_session is false) + # however is also_kill_reused_session is specified also kill the reuse session + # newer kill external managed sessions, as this may prevent restart + if self.compute_session_id and self.external_managed_session==False: + if (also_kill_reused_session and self.allways_reuse_session) or self.allways_reuse_session==False: + try: + self.log.info(f"Deleting session with id {self.compute_session_id}") + success_end = end_compute_session(self.connection, self.compute_session_id) + if success_end: + self.log.info(f"Compute session succesfully deleted") + else: + self.log.info(f"Unable to delete compute session. You may need to kill the session manually") + self.compute_session_id=None + + except Exception as e: + self.log.info(f"Unable to delete compute session. You may need to kill the session manually") + self.compute_session_id=None + def _add_airflow_env_vars(self): for x in ['AIRFLOW_CTX_DAG_OWNER', 'AIRFLOW_CTX_DAG_ID', @@ -217,25 +285,14 @@ def _get_pre_code(self): return pre_code def _generate_object_code(self): - - uri = URI_BASE + uri=URI_BASE if self.path_type == "compute": - self.log.info("Code Generation for Studio object stored in Compute file system") - - # if session id is provided, use it, otherwise create a session - if not self.compute_session_id: - self.log.info("Create or connect to session") - compute_session = create_or_connect_to_session(self.connection, - self.compute_context_name, AIRFLOW_SESSION_NAME) - self.compute_session_id = compute_session["id"] - else: - self.log.info("Session ID was provided") - uri = f"{URI_BASE}?sessionId={self.compute_session_id}" + self.log.info("Code Generation for Studio object stored in Compute file system") else: self.log.info("Code generation for Studio object stored in Content") - + media_type = "application/vnd.sas.dataflow" if self.exec_type == "program": media_type = "application/vnd.sas.program" @@ -296,7 +353,7 @@ def _run_job_and_wait(self, job_request: dict, poll_interval: int) -> (dict, boo # Print the log location to the DAG-log, in case the user needs access to the SAS-log while it is running. if "logLocation" in job: log_location=job["logLocation"]; - self.log.info(f"While the job is running the SAS-log formated at JSON can be found at URI: {log_location}?limit=9999999") + self.log.info(f"While the job is running, the SAS-log formated as JSON can be found at URI: {log_location}?limit=9999999") except Exception as e: countUnknownState = countUnknownState + 1 self.log.info(f'HTTP Call failed with error "{e}". Will set state=unknown and continue checking...') diff --git a/src/sas_airflow_provider/util/util.py b/src/sas_airflow_provider/util/util.py index 706f7dd..0f7165a 100644 --- a/src/sas_airflow_provider/util/util.py +++ b/src/sas_airflow_provider/util/util.py @@ -136,11 +136,11 @@ def find_named_compute_session(session: requests.Session, name: str) -> dict: raise RuntimeError(f"Find sessions failed: {response.status_code}") sessions = response.json() if sessions["count"] > 0: - print(f"Existing session named '{name}' was found") + print(f"Existing compute session named '{name}' with id {sessions['items'][0]['id']} was found") return sessions["items"][0] return {} -def create_or_connect_to_session(session: requests.Session, context_name: str, name: str) -> dict: +def create_or_connect_to_session(session: requests.Session, context_name: str, name = None) -> dict: """ Connect to an existing compute session by name. If that named session does not exist, one is created using the context name supplied @@ -148,12 +148,18 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n :param context_name: the context name to use to create the session if the session was not found :param name: name of session to find :return: session object + """ - compute_session = find_named_compute_session(session, name) - if compute_session: - return compute_session + if name != None: + compute_session = find_named_compute_session(session, name) + if compute_session: + return compute_session + + print(f"Compute session named '{name}' does not exist, a new one will be created") + else: + print(f"A new unnamed compute session will be created") + - print(f"Compute session named '{name}' does not exist, a new one will be created") # find compute context response = session.get("/compute/contexts", params={"filter": f'eq("name","{context_name}")'}) if not response.ok: @@ -165,7 +171,11 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n # create session with given context uri = f'/compute/contexts/{sas_context["id"]}/sessions' - session_request = {"version": 1, "name": name} + if name != None: + session_request = {"version": 1, "name": name} + else: + # Create a unnamed session + session_request = {"version": 1} headers = {"Content-Type": "application/vnd.sas.compute.session.request+json"} @@ -175,7 +185,10 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n if response.status_code != 201: raise RuntimeError(f"Failed to create session: {response.text}") - return response.json() + json_response=response.json() + print(f"Compute session {json_response['id']} created") + + return json_response def end_compute_session(session: requests.Session, id): uri = f'/compute/sessions/{id}'