diff --git a/apps/api/src/python/query/index.ts b/apps/api/src/python/query/index.ts index 1d53b6c8..1263224c 100644 --- a/apps/api/src/python/query/index.ts +++ b/apps/api/src/python/query/index.ts @@ -196,10 +196,6 @@ export async function makeQuery( sessionId, code, (outputs) => { - if (error || aborted) { - return - } - for (const output of outputs) { if (output.type === 'stdio' && output.name === 'stdout') { const lines = output.text.trim().split('\n') @@ -208,15 +204,21 @@ export async function makeQuery( const parsed = JSON.parse(line.trim()) switch (parsed.type) { case 'success': - onProgress(parsed) - result = parsed + if (!aborted) { + onProgress(parsed) + result = parsed + } break case 'syntax-error': - result = parsed + if (!aborted) { + result = parsed + } break case 'abort-error': - result = parsed - aborted = true + if (!aborted) { + result = parsed + aborted = true + } break case 'log': logger().debug( @@ -230,7 +232,10 @@ export async function makeQuery( ) break default: - error = new Error('Unexpected output: ' + line) + if (!aborted) { + error = new Error('Unexpected output: ' + line) + } + break } } catch {} } diff --git a/apps/api/src/python/query/sqlalchemy.ts b/apps/api/src/python/query/sqlalchemy.ts index 70c90b13..5694ea8f 100644 --- a/apps/api/src/python/query/sqlalchemy.ts +++ b/apps/api/src/python/query/sqlalchemy.ts @@ -56,6 +56,7 @@ def briefer_make_sqlalchemy_query(): import time from datetime import datetime from datetime import timedelta + import multiprocessing print(json.dumps({"type": "log", "message": "Starting SQLAlchemy query"})) @@ -74,14 +75,6 @@ def briefer_make_sqlalchemy_query(): df.columns = new_cols return df - def _briefer_cancel_sqlalchemy_query(engine, job_id, datasource_type): - with engine.connect() as conn: - if datasource_type == "snowflake": - conn.execute(text(f"SELECT SYSTEM$CANCEL_QUERY('{job_id}');")) - else: - conn.execute(text(f"SELECT pg_cancel_backend(pid) FROM pg_stat_activity WHERE query LIKE '%{job_id}%';")) - - aborted = False dump_file_base = f'/home/jupyteruser/.briefer/query-${queryId}' parquet_file_path = f'{dump_file_base}.parquet.gzip' csv_file_path = f'{dump_file_base}.csv' @@ -123,165 +116,200 @@ def briefer_make_sqlalchemy_query(): df[column] = df[column].astype(str) return df + def run_query(queue, engine, job_id, datasource_type, flag_file_path): + aborted = False + try: + # if oracle, initialize the oracle client + if datasource_type == "oracle": + import oracledb + oracledb.init_oracle_client() - try: - job_id = ${JSON.stringify(jobId)} - flag_file_path = ${JSON.stringify(flagFilePath)} - os.makedirs('/home/jupyteruser/.briefer', exist_ok=True) - print(json.dumps({"type": "log", "message": "Creating flag file"})) - open(flag_file_path, "a").close() + try: + with engine.connect() as conn: + print(json.dumps({"type": "log", "message": "Running query"})) + chunks = pd.read_sql_query(text(${JSON.stringify( + renderedQuery + )}), con=conn, chunksize=100000) + rows = None + columns = None + last_emitted_at = 0 + count = 0 + df = pd.DataFrame() + print(json.dumps({"type": "log", "message": "Iterating over chunks"})) + for chunk in chunks: + if not os.path.exists(flag_file_path): + aborted = True + break - print(json.dumps({"type": "log", "message": "Connecting to database"})) - engine = create_engine(${JSON.stringify(databaseUrl)}) + count += len(chunk) + print(json.dumps({"type": "log", "message": f"Got chunk {len(chunk)} rows"})) + chunk = rename_duplicates(chunk) + df = convert_df(pd.concat([df, chunk], ignore_index=True)) + if rows is None: + rows = json.loads(df.head(250).to_json(orient='records', date_format="iso")) - # if oracle, initialize the oracle client - if ${JSON.stringify(dataSourceType)} == "oracle": - import oracledb - oracledb.init_oracle_client() + # convert all values to string to make sure we preserve the python values + # when displaying this data in the browser + for row in rows: + for key in row: + row[key] = str(row[key]) - try: - with engine.connect() as conn: - print(json.dumps({"type": "log", "message": "Running query"})) - chunks = pd.read_sql_query(text(${JSON.stringify( - renderedQuery - )}), con=conn, chunksize=100000) - rows = None - columns = None - last_emitted_at = 0 - count = 0 - df = pd.DataFrame() - print(json.dumps({"type": "log", "message": "Iterating over chunks"})) - for chunk in chunks: - if not os.path.exists(flag_file_path): - _briefer_cancel_sqlalchemy_query(engine, job_id, ${JSON.stringify( - dataSourceType - )}) - aborted = True - break - - count += len(chunk) - print(json.dumps({"type": "log", "message": f"Got chunk {len(chunk)} rows"})) - chunk = rename_duplicates(chunk) - df = convert_df(pd.concat([df, chunk], ignore_index=True)) - if rows is None: - rows = json.loads(df.head(250).to_json(orient='records', date_format="iso")) - - # convert all values to string to make sure we preserve the python values - # when displaying this data in the browser - for row in rows: - for key in row: - row[key] = str(row[key]) - - if columns is None: - columns = [{"name": col, "type": dtype.name} for col, dtype in chunk.dtypes.items()] - - for col in columns: - if col["name"] not in chunk.columns: - continue - - categories = col.get("categories", []) - if len(categories) >= 1000: - continue - - dtype = chunk[col["name"]].dtype - if pd.api.types.is_string_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype): - try: - chunk_categories = chunk[col["name"]].dropna().unique() - categories.extend(list(chunk_categories)) + if columns is None: + columns = [{"name": col, "type": dtype.name} for col, dtype in chunk.dtypes.items()] - # use dict.fromkeys instead of set to keep the order - categories = list(dict.fromkeys(categories)) + for col in columns: + if col["name"] not in chunk.columns: + continue - categories = categories[:1000] - col["categories"] = categories - except: - pass + categories = col.get("categories", []) + if len(categories) >= 1000: + continue - # only emit every 1 second - now = time.time() - if now - last_emitted_at > 1: + dtype = chunk[col["name"]].dtype + if pd.api.types.is_string_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype): + try: + chunk_categories = chunk[col["name"]].dropna().unique() + categories.extend(list(chunk_categories)) + + # use dict.fromkeys instead of set to keep the order + categories = list(dict.fromkeys(categories)) + + categories = categories[:1000] + col["categories"] = categories + except: + pass + + # only emit every 1 second + now = time.time() + if now - last_emitted_at > 1: + result = { + "type": "success", + "rows": rows, + "columns": columns, + "count": count + } + print(json.dumps(result, ensure_ascii=False, default=str)) + last_emitted_at = now + + duration_ms = None + # query trino to get query execution time + if datasource_type == "trino": + start_time = datetime.now() + + while datetime.now() - start_time < timedelta(seconds=5): + try: + execution_time_query = f""" + SELECT created, "end" + FROM system.runtime.queries + WHERE query LIKE '%{job_id}%' + """ + result = conn.execute(text(execution_time_query)).fetchone() + if not result[1]: + print(json.dumps({"type": "log", "message": f"Query execution time not available yet"})) + time.sleep(0.2) + continue + + time_span = result[1] - result[0] + duration_ms = int(time_span.total_seconds() * 1000) + break + except Exception as e: + print(json.dumps({"type": "log", "message": f"Failed to get query execution time: {str(e)}"})) + break + + + # make sure .briefer directory exists + os.makedirs('/home/jupyteruser/.briefer', exist_ok=True) + + # write to parquet + print(json.dumps({"type": "log", "message": f"Dumping {len(df)} rows as parquet."})) + df.to_parquet(parquet_file_path, compression='gzip', index=False) + + # write to csv + print(json.dumps({"type": "log", "message": f"Dumping {len(df)} rows as csv."})) + df.to_csv(csv_file_path, index=False) + + if aborted or not os.path.exists(flag_file_path): + print(json.dumps({"type": "log", "message": f"Query aborted 1 {aborted} {os.path.exists(flag_file_path)}"})) result = { - "type": "success", - "rows": rows, - "columns": columns, - "count": count + "type": "abort-error", + "message": "Query aborted", } - print(json.dumps(result, ensure_ascii=False, default=str)) - last_emitted_at = now - - duration_ms = None - # query trino to get query execution time - if ${JSON.stringify(dataSourceType)} == "trino": - start_time = datetime.now() - - while datetime.now() - start_time < timedelta(seconds=5): - try: - execution_time_query = f""" - SELECT created, "end" - FROM system.runtime.queries - WHERE query LIKE '%{job_id}%' - """ - result = conn.execute(text(execution_time_query)).fetchone() - if not result[1]: - print(json.dumps({"type": "log", "message": f"Query execution time not available yet"})) - time.sleep(0.2) - continue + print(json.dumps(result, default=str)) + return - time_span = result[1] - result[0] - duration_ms = int(time_span.total_seconds() * 1000) - break - except Exception as e: - print(json.dumps({"type": "log", "message": f"Failed to get query execution time: {str(e)}"})) - break + result = { + "type": "success", + "rows": rows, + "columns": columns, + "count": count, + "durationMs": duration_ms, + } + print(json.dumps(result, ensure_ascii=False, default=str)) + queue.put(None) + except (DatabaseError, DBAPIError) as e: + if isinstance(e.__cause__, QueryCanceled): + error = { + "type": "abort-error", + "message": "Query aborted", + } + print(json.dumps(error, default=str)) + else: + error = { + "type": "syntax-error", + "message": str(e) + } + print(json.dumps(error, default=str)) + queue.put(None) + except Exception as e: + queue.put(e) - # make sure .briefer directory exists - os.makedirs('/home/jupyteruser/.briefer', exist_ok=True) + job_id = ${JSON.stringify(jobId)} + datasource_type = ${JSON.stringify(dataSourceType)} + flag_file_path = ${JSON.stringify(flagFilePath)} + print(json.dumps({"type": "log", "message": "Connecting to database"})) + engine = create_engine(${JSON.stringify(databaseUrl)}) - # write to parquet - print(json.dumps({"type": "log", "message": f"Dumping {len(df)} rows as parquet."})) - df.to_parquet(parquet_file_path, compression='gzip', index=False) + process = None + def abort(): + print(json.dumps({"type": "log", "message": "Query aborted 2"})) + result = { + "type": "abort-error", + "message": "Query aborted", + } + print(json.dumps(result, default=str)) + if process is not None and process.is_alive(): + process.terminate() + process.join() - # write to csv - print(json.dumps({"type": "log", "message": f"Dumping {len(df)} rows as csv."})) - df.to_csv(csv_file_path, index=False) + try: + os.makedirs('/home/jupyteruser/.briefer', exist_ok=True) + print(json.dumps({"type": "log", "message": "Creating flag file"})) + open(flag_file_path, "a").close() - if aborted or not os.path.exists(flag_file_path): - print(json.dumps({"type": "log", "message": "Query aborted"})) - result = { - "type": "abort-error", - "message": "Query aborted", - } - print(json.dumps(result, default=str)) - return - - result = { - "type": "success", - "rows": rows, - "columns": columns, - "count": count, - "durationMs": duration_ms, - } - print(json.dumps(result, ensure_ascii=False, default=str)) - finally: - print(json.dumps({"type": "log", "message": "Disposing of engine"})) - engine.dispose() - if os.path.exists(flag_file_path): - print(json.dumps({"type": "log", "message": "Removing flag file"})) - os.remove(flag_file_path) - except (DatabaseError, DBAPIError) as e: - if isinstance(e.__cause__, QueryCanceled): - error = { - "type": "abort-error", - "message": "Query aborted", - } - print(json.dumps(error, default=str)) - else: - error = { - "type": "syntax-error", - "message": str(e) - } - print(json.dumps(error, default=str)) + queue = multiprocessing.Queue() + process = multiprocessing.Process(target=run_query, args=(queue, engine, job_id, datasource_type, flag_file_path)) + process.start() + + while process.is_alive(): + print(json.dumps({"type": "log", "message": "Outer process, waiting for inner process to finish"})) + if not os.path.exists(flag_file_path): + print(json.dumps({"type": "log", "message": "Outer process detected that flag file does not exist, aborting query"})) + abort() + break + time.sleep(0.5) + result = queue.get_nowait() + if result and isinstance(result, Exception): + raise result + except KeyboardInterrupt: + print(json.dumps({"type": "log", "message": "Outer process caught KeyboardInterrupt"})) + abort() + finally: + print(json.dumps({"type": "log", "message": "Disposing of engine"})) + engine.dispose() + if os.path.exists(flag_file_path): + print(json.dumps({"type": "log", "message": "Removing flag file"})) + os.remove(flag_file_path) briefer_make_sqlalchemy_query()` diff --git a/apps/api/src/yjs/v2/executor/executor.ts b/apps/api/src/yjs/v2/executor/executor.ts index 9ad597cc..fe128489 100644 --- a/apps/api/src/yjs/v2/executor/executor.ts +++ b/apps/api/src/yjs/v2/executor/executor.ts @@ -262,7 +262,9 @@ export class Executor { break case 'unknown': case 'aborting': - // TODO: when getting here we should try to abort one more time + // This means we looped the executor and found an aborting item + // TODO: + // We should make sure that the jupyter queue is empty before we advance current.setCompleted('aborted') break default: diff --git a/apps/api/src/yjs/v2/executor/sql.ts b/apps/api/src/yjs/v2/executor/sql.ts index bdd0885f..ad66dd74 100644 --- a/apps/api/src/yjs/v2/executor/sql.ts +++ b/apps/api/src/yjs/v2/executor/sql.ts @@ -14,16 +14,11 @@ import { renameDataFrame, } from '../../../python/query/index.js' import { logger } from '../../../logger.js' -import { - DataFrame, - OnboardingTutorialStep, - RunQueryResult, -} from '@briefer/types' +import { DataFrame, RunQueryResult } from '@briefer/types' import { SQLEvents } from '../../../events/index.js' import { WSSharedDocV2 } from '../index.js' import { updateDataframes } from './index.js' import { advanceTutorial } from '../../../tutorials.js' -import { IOServer } from '../../../websocket/index.js' import { broadcastTutorialStepStates } from '../../../websocket/workspace/tutorial.js' export type SQLEffects = {