Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Record/Replay of Queries #107

Closed
wants to merge 9 commits into from
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240426-142511.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Support record/replay mode.
time: 2024-04-26T14:25:11.251251-04:00
custom:
Author: peterallenwebb
Issue: "407"
6 changes: 0 additions & 6 deletions dbt/__init__.py

This file was deleted.

8 changes: 0 additions & 8 deletions dbt/adapters/__init__.py

This file was deleted.

67 changes: 67 additions & 0 deletions dbt/adapters/record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import dataclasses
from io import StringIO
import json
import re
from typing import Any, Optional, Mapping

from agate import Table

from dbt_common.events.contextvars import get_node_info
from dbt_common.record import Record, Recorder

from dbt.adapters.contracts.connection import AdapterResponse


@dataclasses.dataclass
class QueryRecordParams:
sql: str
auto_begin: bool = False
fetch: bool = False
limit: Optional[int] = None
node_unique_id: Optional[str] = None

def __post_init__(self):
if self.node_unique_id is None:
node_info = get_node_info()
self.node_unique_id = node_info["unique_id"] if node_info else ""

@staticmethod
def _clean_up_sql(sql: str) -> str:
sql = re.sub(r"--.*?\n", "", sql) # Remove single-line comments (--)
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) # Remove multi-line comments (/* */)
return sql.replace(" ", "").replace("\n", "")

def _matches(self, other: "QueryRecordParams") -> bool:
return self.node_unique_id == other.node_unique_id and self._clean_up_sql(
self.sql
) == self._clean_up_sql(other.sql)


@dataclasses.dataclass
class QueryRecordResult:
adapter_response: Optional["AdapterResponse"]
table: Optional[Table]

def _to_dict(self) -> Any:
buf = StringIO()
self.table.to_json(buf) # type: ignore

return {
"adapter_response": self.adapter_response.to_dict(), # type: ignore
"table": buf.getvalue(),
}

@classmethod
def _from_dict(cls, dct: Mapping) -> "QueryRecordResult":
return QueryRecordResult(
adapter_response=AdapterResponse.from_dict(dct["adapter_response"]),
table=Table.from_object(json.loads(dct["table"])),
)


class QueryRecord(Record):
params_cls = QueryRecordParams
result_cls = QueryRecordResult


Recorder.register_record_type(QueryRecord)
3 changes: 3 additions & 0 deletions dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtInternalError, NotImplementedError
from dbt_common.record import record_function
from dbt_common.utils import cast_to_str

from dbt.adapters.base import BaseConnectionManager
Expand All @@ -19,6 +20,7 @@
SQLQuery,
SQLQueryStatus,
)
from dbt.adapters.record import QueryRecord

if TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -143,6 +145,7 @@ def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> "agate.Tab

return table_from_data_flat(data, column_names)

@record_function(QueryRecord, method=True, tuple_result=True)
def execute(
self,
sql: str,
Expand Down
Loading