diff --git a/data_lineage/__init__.py b/data_lineage/__init__.py index cc8188c..0452ab2 100644 --- a/data_lineage/__init__.py +++ b/data_lineage/__init__.py @@ -4,11 +4,12 @@ import datetime import json import logging -from typing import Any, Dict, Generator, List, Optional, Type +from typing import Any, Dict, Generator, Generic, List, Optional, Type, TypeVar import requests from dbcat.catalog.models import JobExecutionStatus from furl import furl +from requests import HTTPError from data_lineage.graph import LineageGraph @@ -73,10 +74,13 @@ def __init__(self, session, attributes, obj_id, relationships): self._relationships = relationships def __getattr__(self, item): + logging.debug("Attributes: {}".format(self._attributes)) if item == "id": return self._obj_id - elif item in self._attributes.keys(): + elif self._attributes and item in self._attributes.keys(): return self._attributes[item] + elif self._relationships and item in self._relationships.keys(): + return self._relationships[item] raise AttributeError @@ -120,6 +124,9 @@ def __init__(self, session, attributes, obj_id, relationships): super().__init__(session, attributes, obj_id, relationships) +ModelType = TypeVar("ModelType", bound=BaseModel) + + class Catalog: def __init__(self, url: str): self._base_url = furl(url) / "api/v1/catalog" @@ -134,14 +141,39 @@ def _build_url(self, *urls) -> str: logging.debug(built_url) return built_url + str_to_type = { + "sources": Source, + "schemata": Schema, + } + + def _resolve_relationships(self, relationships) -> Dict[str, BaseModel]: + resolved: Dict[str, BaseModel] = {} + for key, value in relationships.items(): + logging.debug("Resolving {}:{}".format(key, value)) + if value["data"]: + resolved[key] = self._obj_factory( + value["data"], + Catalog.str_to_type[value["data"]["type"]], + resolve_relationships=False, + ) + + return resolved + def _obj_factory( - self, payload: Dict[str, Any], clazz: Type[BaseModel] - ) -> BaseModel: + self, + payload: Dict[str, Any], + clazz: Type[ModelType], + resolve_relationships=False, + ) -> ModelType: + resolved = None + if resolve_relationships and payload.get("relationships"): + resolved = self._resolve_relationships(payload.get("relationships")) + return clazz( session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], + attributes=payload.get("attributes"), + obj_id=payload.get("id"), + relationships=resolved, ) def _iterate(self, payload: Dict[str, Any], clazz: Type[BaseModel]): @@ -161,14 +193,20 @@ def _index(self, path: str, clazz: Type[BaseModel]): logging.debug(response.json()) return self._iterate(response.json(), clazz) - def _get(self, path: str, obj_id: int) -> Dict[Any, Any]: + def _get( + self, + path: str, + obj_id: int, + clazz: Type[ModelType], + resolve_relationships=False, + ) -> ModelType: response = self._session.get(self._build_url(path, str(obj_id))) json_response = response.json() logging.debug(json_response) - - if "error" in json_response: - raise RuntimeError(json_response["error"]) - return json_response["data"] + response.raise_for_status() + return self._obj_factory( + json_response["data"], clazz, resolve_relationships=resolve_relationships + ) @staticmethod def _one(response): @@ -204,6 +242,15 @@ def _post(self, path: str, data: Dict[str, Any], type: str) -> Dict[Any, Any]: logging.debug(json_response) return json_response["data"] + def _patch(self, path: str, obj_id: int, data: Dict[str, Any], type: str): + payload = {"data": {"type": type, "attributes": data, "id": obj_id}} + response = self._session.patch( + url=self._build_url(path, str(obj_id)), + data=json.dumps(payload, default=str), + ) + response.raise_for_status() + return + def get_sources(self) -> Generator[Any, Any, None]: return self._index("sources", Source) @@ -226,63 +273,28 @@ def get_column_lineages(self): return self._index("column_lineages", ColumnLineage) def get_source_by_id(self, obj_id) -> Source: - payload = self._get("sources", obj_id) - return Source( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._get("sources", obj_id, Source) def get_schema_by_id(self, obj_id) -> Schema: - payload = self._get("schemata", obj_id) - return Schema( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._get("schemata", obj_id, Schema) def get_table_by_id(self, obj_id) -> Table: - payload = self._get("tables", obj_id) - return Table( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._get("tables", obj_id, Table) def get_column_by_id(self, obj_id) -> Column: - payload = self._get("columns", obj_id) - return Column( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._get("columns", obj_id, Column) def get_job_by_id(self, obj_id) -> Job: - payload = self._get("jobs", obj_id) - return Job( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=None, - ) + return self._get("jobs", obj_id, Job) def get_job_execution_by_id(self, obj_id) -> JobExecution: - payload = self._get("job_executions", obj_id) - return JobExecution( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=None, - ) + return self._get("job_executions", obj_id, JobExecution) def get_column_lineage(self, job_ids: List[int]) -> List[ColumnLineage]: params = {"job_ids": job_ids} response = self._session.get(self._build_url("column_lineage"), params=params) logging.debug(response.json()) + response.raise_for_status() return [ ColumnLineage( session=self._session, @@ -299,12 +311,8 @@ def get_source(self, name) -> Source: payload = self._search_one("sources", filters) except NoResultFound: raise SourceNotFound("Source not found: source_name={}".format(name)) - return Source( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + + return self._obj_factory(payload, Source) def get_schema(self, source_name: str, schema_name: str) -> Schema: name_filter = dict(name="name", op="eq", val=schema_name) @@ -321,12 +329,7 @@ def get_schema(self, source_name: str, schema_name: str) -> Schema: source_name, schema_name ) ) - return Schema( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._obj_factory(payload, Schema) def get_table(self, source_name: str, schema_name: str, table_name: str) -> Table: schema = self.get_schema(source_name, schema_name) @@ -343,17 +346,12 @@ def get_table(self, source_name: str, schema_name: str, table_name: str) -> Tabl source_name, schema_name, table_name ) ) - return Table( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._obj_factory(payload, Table) def get_columns_for_table(self, table: Table): return self._index("tables/{}/columns".format(table.id), Column) - def get_column(self, source_name, schema_name, table_name, column_name): + def get_column(self, source_name, schema_name, table_name, column_name) -> Column: table = self.get_table(source_name, schema_name, table_name) name_filter = dict(name="name", op="eq", val=column_name) table_filter = dict(name="table_id", op="eq", val=str(table.id)) @@ -367,49 +365,30 @@ def get_column(self, source_name, schema_name, table_name, column_name): source_name, schema_name, table_name, column_name ) ) - return Column( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._obj_factory(payload, Column) def add_source(self, name: str, source_type: str, **kwargs) -> Source: data = {"name": name, "source_type": source_type, **kwargs} payload = self._post(path="sources", data=data, type="sources") - return Source( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._obj_factory(payload, Source) def scan_source(self, source: Source) -> bool: payload = {"id": source.id} response = self._session.post( url=self._build_url("scanner"), data=json.dumps(payload) ) + response.raise_for_status() return response.status_code == 200 def add_schema(self, name: str, source: Source) -> Schema: data = {"name": name, "source_id": source.id} payload = self._post(path="schemata", data=data, type="schemata") - return Schema( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._obj_factory(payload, Schema) def add_table(self, name: str, schema: Schema) -> Table: data = {"name": name, "schema_id": schema.id} payload = self._post(path="tables", data=data, type="tables") - return Table( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._obj_factory(payload, Table) def add_column( self, name: str, data_type: str, sort_order: int, table: Table @@ -421,23 +400,12 @@ def add_column( "sort_order": sort_order, } payload = self._post(path="columns", data=data, type="columns") - return Column( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=payload["relationships"], - ) + return self._obj_factory(payload, Column) def add_job(self, name: str, context: Dict[Any, Any]) -> Job: data = {"name": name, "context": context} payload = self._post(path="jobs", data=data, type="jobs") - print(payload) - return Job( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=None, - ) + return self._obj_factory(payload, Job) def add_job_execution( self, @@ -453,12 +421,7 @@ def add_job_execution( "status": status.name, } payload = self._post(path="job_executions", data=data, type="job_executions") - return JobExecution( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=None, - ) + return self._obj_factory(payload, JobExecution) def add_column_lineage( self, @@ -474,21 +437,38 @@ def add_column_lineage( "context": context, } payload = self._post(path="column_lineage", data=data, type="column_lineage") - return ColumnLineage( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=None, - ) + return self._obj_factory(payload, ColumnLineage) - def update_source(self, source: Source, schema: Schema): - data = {"source_id": source.id, "schema_id": schema.id} - payload = self._post(path="default_schema", data=data, type="default_schema") - return DefaultSchema( - session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], - relationships=None, + def update_source(self, source: Source, schema: Schema) -> DefaultSchema: + try: + current_obj = self._get( + path="default_schema", + obj_id=source.id, + clazz=DefaultSchema, + resolve_relationships=True, + ) + if current_obj.schema.id == schema.id: + return current_obj + except HTTPError as error: + if error.response.status_code == 404: + data = {"source_id": source.id, "schema_id": schema.id} + payload = self._post( + path="default_schema", data=data, type="default_schema" + ) + return self._obj_factory( + payload, DefaultSchema, resolve_relationships=True + ) + + # Patch + data = {"schema_id": schema.id} + self._patch( + path="default_schema", data=data, type="default_schema", obj_id=source.id + ) + return self._get( + path="default_schema", + obj_id=source.id, + clazz=DefaultSchema, + resolve_relationships=True, ) @@ -529,8 +509,8 @@ def analyze( payload = response.json()["data"] return JobExecution( session=self._session, - attributes=payload["attributes"], - obj_id=payload["id"], + attributes=payload.get("attributes"), + obj_id=payload.get("id"), relationships=None, ) diff --git a/test/test_server.py b/test/test_server.py index 5038330..8c5f869 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -51,6 +51,7 @@ def test_get_columns(rest_catalog): def test_get_source_by_id(rest_catalog): source = rest_catalog.get_source_by_id(1) + print(source.__class__.__name__) assert source.name == "test" assert source.fqdn == "test" assert source.source_type == "redshift" @@ -197,6 +198,23 @@ def test_add_source_snowflake(rest_catalog): assert sf_conn.warehouse == "db_warehouse" +def test_update_source(rest_catalog): + glue_conn = rest_catalog.add_source(name="gl_2", source_type="glue") + schema_1 = rest_catalog.add_schema("schema_1", glue_conn) + + default_schema = rest_catalog.update_source(glue_conn, schema_1) + + assert default_schema.source.id == glue_conn.id + assert default_schema.schema.id == schema_1.id + + schema_2 = rest_catalog.add_schema("schema_2", glue_conn) + + default_schema = rest_catalog.update_source(glue_conn, schema_2) + + assert default_schema.source.id == glue_conn.id + assert default_schema.schema.id == schema_2.id + + def load_edges(catalog, expected_edges, job_execution_id): column_edge_ids = [] for edge in expected_edges: