Skip to content

Commit

Permalink
fix type hints for optional dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger committed Aug 23, 2024
1 parent ce231d8 commit 5438d7b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
15 changes: 11 additions & 4 deletions awswrangler/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
Expand All @@ -23,9 +24,15 @@
from awswrangler._config import apply_configs
from awswrangler._sql_utils import identifier

__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"]
if TYPE_CHECKING:
try:
import oracledb
except ImportError:
pass
else:
oracledb = _utils.import_optional_dependency("oracledb")

oracledb = _utils.import_optional_dependency("oracledb")
__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"]

_logger: logging.Logger = logging.getLogger(__name__)
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
Expand Down Expand Up @@ -167,7 +174,7 @@ def connect(
Examples
--------
>>> import awswrangler as wr
>>> with wr.oracle.connect(connection="MY_GLUE_CONNECTION") as con"
>>> with wr.oracle.connect(connection="MY_GLUE_CONNECTION") as con:
... with con.cursor() as cursor:
... cursor.execute("SELECT 1 FROM DUAL")
... print(cursor.fetchall())
Expand All @@ -190,7 +197,7 @@ def connect(
)
# oracledb.connect does not have a call_timeout attribute, it has to be set separatly
oracle_connection.call_timeout = call_timeout
return oracle_connection
return oracle_connection # type: ignore[no-any-return]


@overload
Expand Down
16 changes: 11 additions & 5 deletions awswrangler/redshift/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@

from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

import boto3

from awswrangler import _databases as _db_utils
from awswrangler import _utils, exceptions

redshift_connector = _utils.import_optional_dependency("redshift_connector")
if TYPE_CHECKING:
try:
import redshift_connector
except ImportError:
pass
else:
redshift_connector = _utils.import_optional_dependency("redshift_connector")


def _validate_connection(con: "redshift_connector.Connection") -> None: # type: ignore[name-defined]
def _validate_connection(con: "redshift_connector.Connection") -> None:
if not isinstance(con, redshift_connector.Connection):
raise exceptions.InvalidConnection(
"Invalid 'conn' argument, please pass a "
Expand All @@ -33,7 +39,7 @@ def connect(
max_prepared_statements: int = 1000,
tcp_keepalive: bool = True,
**kwargs: Any,
) -> "redshift_connector.Connection": # type: ignore[name-defined]
) -> "redshift_connector.Connection":
"""Return a redshift_connector connection from a Glue Catalog or Secret Manager.
Note
Expand Down Expand Up @@ -144,7 +150,7 @@ def connect_temp(
max_prepared_statements: int = 1000,
tcp_keepalive: bool = True,
**kwargs: Any,
) -> "redshift_connector.Connection": # type: ignore[name-defined]
) -> "redshift_connector.Connection":
"""Return a redshift_connector temporary connection (No password required).
https://github.com/aws/amazon-redshift-python-driver
Expand Down
10 changes: 8 additions & 2 deletions awswrangler/redshift/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from typing import Any, Iterator, Literal
from typing import TYPE_CHECKING, Any, Iterator, Literal

import boto3
import pyarrow as pa
Expand All @@ -17,7 +17,13 @@
from ._connect import _validate_connection
from ._utils import _make_s3_auth_string

redshift_connector = _utils.import_optional_dependency("redshift_connector")
if TYPE_CHECKING:
try:
import redshift_connector
except ImportError:
pass
else:
redshift_connector = _utils.import_optional_dependency("redshift_connector")

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down
9 changes: 8 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,14 @@
typehints_use_signature = True
typehints_use_signature_return = True

autodoc_mock_imports = ["pyodbc"]
autodoc_mock_imports = [
"opensearchpy",
"oracledb",
"pg8000",
"pymysql",
"pyodbc",
"redshift_connector",
]


def setup(app):
Expand Down

0 comments on commit 5438d7b

Please sign in to comment.