diff --git a/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py b/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py index 4f7eea40b6d1..136ac55210d8 100644 --- a/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +++ b/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py @@ -57,6 +57,22 @@ def connection(self) -> taosws.Connection: self._connection = self._create_connection() return self._connection + def with_retry_on_closed_connection(self, fn, **kwargs): + try: + return fn(self.connection, **kwargs) + except taosws.QueryError as err: + logger.warn(f"TDEngine query error: {err}") + if "Internal error:" in str(err): + logger.info("Retrying TDEngine query with a new connection") + try: + self._connection.close() + except Exception: + pass + self._connection = None + return fn(self.connection, **kwargs) + else: + raise err + def _create_connection(self) -> taosws.Connection: """Establish a connection to the TSDB server.""" logger.debug("Creating a new connection to TDEngine", project=self.project) @@ -93,7 +109,9 @@ def create_tables(self): """Create TDEngine supertables.""" for table in self.tables: create_table_query = self.tables[table]._create_super_table_query() - self.connection.execute(create_table_query) + self.with_retry_on_closed_connection( + lambda conn: conn.execute(create_table_query) + ) def write_application_event( self, @@ -139,10 +157,14 @@ def write_application_event( ) create_table_sql = table._create_subtable_sql(subtable=table_name, values=event) - self.connection.execute(create_table_sql) + self.with_retry_on_closed_connection( + lambda conn: conn.execute(create_table_sql) + ) - insert_statement = table._insert_subtable_stmt( - self.connection, subtable=table_name, values=event + insert_statement = self.with_retry_on_closed_connection( + lambda conn: table._insert_subtable_stmt( + conn, subtable=table_name, values=event + ) ) insert_statement.add_batch() insert_statement.execute() @@ -210,7 +232,9 @@ def delete_tsdb_resources(self): get_subtable_names_query = self.tables[table]._get_subtables_query( values={mm_schemas.EventFieldType.PROJECT: self.project} ) - subtables = self.connection.query(get_subtable_names_query) + subtables = self.with_retry_on_closed_connection( + lambda conn: conn.query(get_subtable_names_query) + ) for subtable in subtables: drop_query = self.tables[table]._drop_subtable_query( subtable=subtable[0] @@ -289,7 +313,9 @@ def _get_records( ) logger.debug("Querying TDEngine", query=full_query) try: - query_result = self.connection.query(full_query) + query_result = self.with_retry_on_closed_connection( + lambda conn: conn.query(full_query) + ) except taosws.QueryError as e: raise mlrun.errors.MLRunInvalidArgumentError( f"Failed to query table {table} in database {self.database}, {str(e)}" diff --git a/tests/model_monitoring/db/tsdb/tdengine/test_tdengine_connector.py b/tests/model_monitoring/db/tsdb/tdengine/test_tdengine_connector.py index 2f3b78b91187..bd2ea6ef212e 100644 --- a/tests/model_monitoring/db/tsdb/tdengine/test_tdengine_connector.py +++ b/tests/model_monitoring/db/tsdb/tdengine/test_tdengine_connector.py @@ -104,3 +104,78 @@ def test_write_application_event(connector: TDEngineConnector) -> None: assert read_back_values.timestamp == end_infer_time assert read_back_values.value == result_value assert read_back_values.status == result_status + + +class Raiser: + def __init__(self, monkeypatch): + self.times_raised = 0 + self.monkeypatch = monkeypatch + + def raise_it(self, *args, **kwargs): + self.times_raised += 1 + self.monkeypatch.undo() + raise taosws.QueryError("Internal error: `sending on a closed channel`") + + +# ML-7991 +@pytest.mark.skipif(not is_tdengine_defined(), reason="TDEngine is not defined") +def test_write_application_event_with_channel_close( + connector: TDEngineConnector, monkeypatch +) -> None: + endpoint_id = "1" + app_name = "my_app" + result_name = "my_Result" + result_kind = 0 + start_infer_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_infer_time = datetime(2024, 1, 1, second=1, tzinfo=timezone.utc) + result_status = 0 + result_value = 123 + data = { + "endpoint_id": endpoint_id, + "application_name": app_name, + "result_name": result_name, + "result_kind": result_kind, + "start_infer_time": start_infer_time, + "end_infer_time": end_infer_time, + "result_status": result_status, + # make sure we can write apostrophes (ML-7535) + "current_stats": """{"question": "Who wrote 'To Kill a Mockingbird'?"}""", + # TODO: add this back when extra data is supported (ML-7460) + # "result_extra_data": """{"question": "Who wrote 'To Kill a Mockingbird'?"}""", + "result_value": result_value, + } + connector.create_tables() + + raiser = Raiser(monkeypatch) + monkeypatch.setattr(taosws.Connection, "execute", raiser.raise_it) + + connector.write_application_event(data) + + assert raiser.times_raised == 1 + + read_back_results = connector.read_metrics_data( + endpoint_id=endpoint_id, + start=datetime(2023, 1, 1, 1, 0, 0), + end=datetime(2025, 1, 1, 1, 0, 0), + metrics=[ + ModelEndpointMonitoringMetric( + project=project, + app=app_name, + name=result_name, + full_name=f"{project}.{app_name}.result.{result_name}", + type=ModelEndpointMonitoringMetricType.RESULT, + ), + ], + type="results", + ) + assert len(read_back_results) == 1 + read_back_result = read_back_results[0] + assert read_back_result.full_name == f"{project}.{app_name}.result.{result_name}" + assert read_back_result.data + assert read_back_result.result_kind.value == result_kind + assert read_back_result.type == "result" + assert len(read_back_result.values) == 1 + read_back_values = read_back_result.values[0] + assert read_back_values.timestamp == end_infer_time + assert read_back_values.value == result_value + assert read_back_values.status == result_status