Skip to content

Commit

Permalink
feat: fixes with mypy, add tox, remove 3.7 support (#212)
Browse files Browse the repository at this point in the history
Added mypy to precommit workflow. Closes #98.

Added mypy to ci_workflow.yml tests using tox. Includes other linting as
well.

Removed Python 3.7 in the process so this closes #141 as well.
  • Loading branch information
sebastianswms authored Nov 1, 2023
1 parent 84c3faa commit a99fe7f
Show file tree
Hide file tree
Showing 8 changed files with 677 additions and 301 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/ci_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ on:
inputs: {}

jobs:
pytest:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- name: Fix key permissions
Expand All @@ -37,9 +37,12 @@ jobs:
- name: Install dependencies
run: |
poetry install
- name: Test with pytest
- name: Run pytest
run: |
poetry run pytest --capture=no
- name: Run lint
run: |
poetry run tox -e lint
integration:
runs-on: ubuntu-latest
Expand Down
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,14 @@ repos:
hooks:
- id: pyupgrade
args: [--py37-plus]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.6.1'
hooks:
- id: mypy
exclude: tests
additional_dependencies:
- types-paramiko
- types-simplejson
- types-sqlalchemy
- types-jsonschema
801 changes: 548 additions & 253 deletions poetry.lock

Large diffs are not rendered by default.

20 changes: 17 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ classifiers = [
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
Expand All @@ -31,7 +30,7 @@ packages = [
]

[tool.poetry.dependencies]
python = "<3.12,>=3.7.1"
python = "<3.12,>=3.8.1"
requests = "^2.25.1"
singer-sdk = ">=0.28,<0.34"
psycopg2-binary = "2.9.9"
Expand All @@ -40,8 +39,23 @@ sshtunnel = "0.4.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"
mypy = "^1.0"
mypy = "^1.6.1"
remote-pdb="2.1.0"
black = "^23.1.0"
flake8 = "^6.0.0"
isort = "^5.10.1"
tox = "^4"
types-paramiko = "^3.3.0.0"
types-simplejson = "^3.19.0.2"
types-sqlalchemy = "^1.4.53.38"
types-jsonschema = "^4.19.0.3"

[tool.mypy]
exclude = "tests"

[[tool.mypy.overrides]]
module = ["sshtunnel"]
ignore_missing_imports = true

[tool.isort]
profile = "black"
Expand Down
34 changes: 17 additions & 17 deletions target_postgres/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ class PostgresConnector(SQLConnector):
allow_merge_upsert: bool = True # Whether MERGE UPSERT is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.

def __init__(self, config: dict | None = None) -> None:
def __init__(self, config: dict) -> None:
"""Initialize a connector to a Postgres database.
Args:
config: Configuration for the connector. Defaults to None.
config: Configuration for the connector.
"""
url: URL = make_url(self.get_sqlalchemy_url(config=config))
ssh_config = config.get("ssh_tunnel", {})
self.ssh_tunnel = None
self.ssh_tunnel: SSHTunnelForwarder

if ssh_config.get("enable", False):
# Return a new URL with SSH tunnel parameters
self.ssh_tunnel: SSHTunnelForwarder = SSHTunnelForwarder(
self.ssh_tunnel = SSHTunnelForwarder(
ssh_address_or_host=(ssh_config["host"], ssh_config["port"]),
ssh_username=ssh_config["username"],
ssh_private_key=self.guess_key_type(ssh_config["private_key"]),
Expand All @@ -78,7 +78,7 @@ def __init__(self, config: dict | None = None) -> None:
sqlalchemy_url=url.render_as_string(hide_password=False),
)

def prepare_table(
def prepare_table( # type: ignore[override]
self,
full_table_name: str,
schema: dict,
Expand All @@ -102,7 +102,7 @@ def prepare_table(
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData(schema=schema_name)
table: sqlalchemy.Table = None
table: sqlalchemy.Table
if not self.table_exists(full_table_name=full_table_name):
table = self.create_empty_table(
table_name=table_name,
Expand All @@ -120,7 +120,7 @@ def prepare_table(
] # So we don't mess up the casing of the Table reference
for property_name, property_def in schema["properties"].items():
self.prepare_column(
schema_name=schema_name,
schema_name=cast(str, schema_name),
table=table,
column_name=property_name,
sql_type=self.to_sql_type(property_def),
Expand Down Expand Up @@ -149,7 +149,7 @@ def copy_table_structure(
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData(schema=schema_name)
new_table: sqlalchemy.Table = None
new_table: sqlalchemy.Table
columns = []
if self.table_exists(full_table_name=full_table_name):
raise RuntimeError("Table already exists")
Expand Down Expand Up @@ -305,7 +305,7 @@ def pick_best_sql_type(sql_type_array: list):
return obj
return TEXT()

def create_empty_table(
def create_empty_table( # type: ignore[override]
self,
table_name: str,
meta: sqlalchemy.MetaData,
Expand Down Expand Up @@ -359,7 +359,7 @@ def create_empty_table(
new_table.create(bind=connection)
return new_table

def prepare_column(
def prepare_column( # type: ignore[override]
self,
schema_name: str,
table: sqlalchemy.Table,
Expand Down Expand Up @@ -396,7 +396,7 @@ def prepare_column(
connection=connection,
)

def _create_empty_column(
def _create_empty_column( # type: ignore[override]
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -426,7 +426,7 @@ def _create_empty_column(
)
connection.execute(column_add_ddl)

def get_column_add_ddl(
def get_column_add_ddl( # type: ignore[override]
self,
table_name: str,
schema_name: str,
Expand Down Expand Up @@ -459,7 +459,7 @@ def get_column_add_ddl(
},
)

def _adapt_column_type(
def _adapt_column_type( # type: ignore[override]
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -521,7 +521,7 @@ def _adapt_column_type(
)
connection.execute(alter_column_ddl)

def get_column_alter_ddl(
def get_column_alter_ddl( # type: ignore[override]
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -687,7 +687,7 @@ def catch_signal(self, signum, frame) -> None:
"""
exit(1) # Calling this to be sure atexit is called, so clean_up gets called

def _get_column_type(
def _get_column_type( # type: ignore[override]
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -721,7 +721,7 @@ def _get_column_type(

return t.cast(sqlalchemy.types.TypeEngine, column.type)

def get_table_columns(
def get_table_columns( # type: ignore[override]
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -754,7 +754,7 @@ def get_table_columns(
or col_meta["name"].casefold() in {col.casefold() for col in column_names}
}

def column_exists(
def column_exists( # type: ignore[override]
self,
full_table_name: str,
column_name: str,
Expand Down
39 changes: 26 additions & 13 deletions target_postgres/sinks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Postgres target sink class, which handles writing streams."""

import uuid
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union, cast

import sqlalchemy
from pendulum import now
Expand Down Expand Up @@ -33,6 +33,15 @@ def append_only(self, value: bool) -> None:
"""Set the append_only attribute."""
self._append_only = value

@property
def connector(self) -> PostgresConnector:
"""Return the connector object.
Returns:
The connector object.
"""
return cast(PostgresConnector, self._connector)

def setup(self) -> None:
"""Set up Sink.
Expand Down Expand Up @@ -108,7 +117,7 @@ def generate_temp_table_name(self):
# in postgres, used a guid just in case we are using the same session
return f"{str(uuid.uuid4()).replace('-','_')}"

def bulk_insert_records(
def bulk_insert_records( # type: ignore[override]
self,
table: sqlalchemy.Table,
schema: dict,
Expand All @@ -132,9 +141,12 @@ def bulk_insert_records(
True if table exists, False if not, None if unsure or undetectable.
"""
columns = self.column_representation(schema)
insert = self.generate_insert_statement(
table.name,
columns,
insert: str = cast(
str,
self.generate_insert_statement(
table.name,
columns,
),
)
self.logger.info("Inserting with SQL: %s", insert)
# Only one record per PK, we want to take the last one
Expand Down Expand Up @@ -165,7 +177,7 @@ def upsert(
from_table: sqlalchemy.Table,
to_table: sqlalchemy.Table,
schema: dict,
join_keys: List[Column],
join_keys: List[str],
connection: sqlalchemy.engine.Connection,
) -> Optional[int]:
"""Merge upsert data from one table to another.
Expand All @@ -191,16 +203,17 @@ def upsert(
connection.execute(insert_stmt)
else:
join_predicates = []
to_table_key: sqlalchemy.Column
for key in join_keys:
from_table_key: sqlalchemy.Column = from_table.columns[key]
to_table_key: sqlalchemy.Column = to_table.columns[key]
to_table_key = to_table.columns[key]
join_predicates.append(from_table_key == to_table_key)

join_condition = sqlalchemy.and_(*join_predicates)

where_predicates = []
for key in join_keys:
to_table_key: sqlalchemy.Column = to_table.columns[key]
to_table_key = to_table.columns[key]
where_predicates.append(to_table_key.is_(None))
where_condition = sqlalchemy.and_(*where_predicates)

Expand Down Expand Up @@ -246,7 +259,7 @@ def column_representation(
def generate_insert_statement(
self,
full_table_name: str,
columns: List[Column],
columns: List[Column], # type: ignore[override]
) -> Union[str, Executable]:
"""Generate an insert statement for the given records.
Expand Down Expand Up @@ -323,9 +336,9 @@ def activate_version(self, new_version: int) -> None:
column_name=self.version_column_name,
connection=connection,
):
self.connector.prepare_column(
self.connector.prepare_column( # type: ignore[call-arg]
self.full_table_name,
self.version_column_name,
self.version_column_name, # type: ignore[arg-type]
sql_type=integer_type,
connection=connection,
)
Expand All @@ -346,9 +359,9 @@ def activate_version(self, new_version: int) -> None:
column_name=self.soft_delete_column_name,
connection=connection,
):
self.connector.prepare_column(
self.connector.prepare_column( # type: ignore[call-arg]
self.full_table_name,
self.soft_delete_column_name,
self.soft_delete_column_name, # type: ignore[arg-type]
sql_type=datetime_type,
connection=connection,
)
Expand Down
13 changes: 1 addition & 12 deletions target_postgres/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
variables.
validate_config: True to require validation of config settings.
"""
self.max_parallelism = 1
super().__init__(
config=config,
parse_env_config=parse_env_config,
Expand Down Expand Up @@ -307,15 +308,3 @@ def __init__(
),
).to_dict()
default_sink_class = PostgresSink

@property
def max_parallelism(self) -> int:
"""Get max parallel sinks.
The default is 8 if not overridden.
Returns:
Max number of sinks that can be drained in parallel.
"""
# https://github.com/MeltanoLabs/target-postgres/issues/3
return 1
Loading

0 comments on commit a99fe7f

Please sign in to comment.