Skip to content

Commit

Permalink
SNOW-638840: Version 1.4.0 breaks SnowflakeDialect._get_schema_column…
Browse files Browse the repository at this point in the history
…s function when column of type DATE exists in schema (#334)
  • Loading branch information
sfc-gh-aling authored Aug 5, 2022
1 parent ca9ed7f commit 50a8b4e
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 1 deletion.
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Source code is also available at:

- v1.4.1(Unreleased)
- snowflake-sqlalchemy is now SQLAlchemy 2.0 compatible.
- Fixed a bug that `DATE` should not be removed from `SnowflakeDialect.ischema_names`.

- v1.4.0(July 20, 2022)

Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BINARY,
BOOLEAN,
CHAR,
DATE,
DATETIME,
DECIMAL,
FLOAT,
Expand Down Expand Up @@ -79,6 +80,7 @@
"BOOLEAN": BOOLEAN,
"CHAR": CHAR,
"CHARACTER": CHAR,
"DATE": DATE,
"DATETIME": DATETIME,
"DEC": DECIMAL,
"DECIMAL": DECIMAL,
Expand Down
28 changes: 27 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .conftest import create_engine_with_future_flag as create_engine
from .conftest import get_engine
from .parameters import CONNECTION_PARAMETERS
from .util import random_string
from .util import ischema_names_baseline, random_string

THIS_DIR = os.path.dirname(os.path.realpath(__file__))

Expand Down Expand Up @@ -1604,3 +1604,29 @@ def test_empty_comments(engine_testaccount):
assert all([c["comment"] is None for c in columns])
finally:
conn.execute(text(f"drop table public.{table_name}"))


def test_column_type_schema(engine_testaccount):
with engine_testaccount.connect() as conn:
table_name = random_string(5)
conn.exec_driver_sql(
f"""\
CREATE TEMP TABLE {table_name} (
C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC,
C9 DECIMAL, C10 DOUBLE,
-- C11 FIXED, # SQL compilation error: Unsupported data type 'FIXED'.
C12 FLOAT, C13 INT, C14 INTEGER, C15 NUMBER, C16 REAL, C17 BYTEINT, C18 SMALLINT,
C19 STRING, C20 TEXT, C21 TIME, C22 TIMESTAMP, C23 TIMESTAMP_TZ, C24 TIMESTAMP_LTZ,
C25 TIMESTAMP_NTZ, C26 TINYINT, C27 VARBINARY, C28 VARCHAR, C29 VARIANT,
C30 OBJECT, C31 ARRAY, C32 GEOGRAPHY
)
"""
)

table_reflected = Table(
table_name, MetaData(), autoload=True, autoload_with=conn
)
columns = table_reflected.columns
assert (
len(columns) == len(ischema_names_baseline) - 1
) # -1 because FIXED is not supported
11 changes: 11 additions & 0 deletions tests/test_unit_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,21 @@
#

import snowflake.sqlalchemy
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect

from .util import ischema_names_baseline


def test_type_synonyms():
from snowflake.sqlalchemy.snowdialect import ischema_names

for k, _ in ischema_names.items():
assert getattr(snowflake.sqlalchemy, k) is not None


def test_type_baseline():
assert set(SnowflakeDialect.ischema_names.keys()) == set(
ischema_names_baseline.keys()
)
for k, v in SnowflakeDialect.ischema_names.items():
assert issubclass(v, ischema_names_baseline[k])
64 changes: 64 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,70 @@
import string
from typing import Sequence

from sqlalchemy.types import (
BIGINT,
BINARY,
BOOLEAN,
CHAR,
DATE,
DATETIME,
DECIMAL,
FLOAT,
INTEGER,
REAL,
SMALLINT,
TIME,
TIMESTAMP,
VARCHAR,
)

from snowflake.sqlalchemy.custom_types import (
ARRAY,
GEOGRAPHY,
OBJECT,
TIMESTAMP_LTZ,
TIMESTAMP_NTZ,
TIMESTAMP_TZ,
VARIANT,
)

ischema_names_baseline = {
"BIGINT": BIGINT,
"BINARY": BINARY,
# 'BIT': BIT,
"BOOLEAN": BOOLEAN,
"CHAR": CHAR,
"CHARACTER": CHAR,
"DATE": DATE,
"DATETIME": DATETIME,
"DEC": DECIMAL,
"DECIMAL": DECIMAL,
"DOUBLE": FLOAT,
"FIXED": DECIMAL,
"FLOAT": FLOAT,
"INT": INTEGER,
"INTEGER": INTEGER,
"NUMBER": DECIMAL,
# 'OBJECT': ?
"REAL": REAL,
"BYTEINT": SMALLINT,
"SMALLINT": SMALLINT,
"STRING": VARCHAR,
"TEXT": VARCHAR,
"TIME": TIME,
"TIMESTAMP": TIMESTAMP,
"TIMESTAMP_TZ": TIMESTAMP_TZ,
"TIMESTAMP_LTZ": TIMESTAMP_LTZ,
"TIMESTAMP_NTZ": TIMESTAMP_NTZ,
"TINYINT": SMALLINT,
"VARBINARY": BINARY,
"VARCHAR": VARCHAR,
"VARIANT": VARIANT,
"OBJECT": OBJECT,
"ARRAY": ARRAY,
"GEOGRAPHY": GEOGRAPHY,
}


def random_string(
length: int,
Expand Down

0 comments on commit 50a8b4e

Please sign in to comment.