diff --git a/fakesnow/conn.py b/fakesnow/conn.py index 46faea7..043be74 100644 --- a/fakesnow/conn.py +++ b/fakesnow/conn.py @@ -42,7 +42,9 @@ def __init__( # information_schema.schemata below we use upper-case to match any existing duckdb # catalog or schemas like "information_schema" self.database = database and database.upper() - self.schema = schema and schema.upper() + self._schema = schema and ( + "_FS_INFORMATION_SCHEMA" if schema.upper() == "INFORMATION_SCHEMA" else schema.upper() + ) self.database_set = False self.schema_set = False @@ -69,24 +71,24 @@ def __init__( if ( create_schema and self.database - and self.schema + and self._schema and not duck_conn.execute( f"""select * from information_schema.schemata - where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self.schema}'""" + where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self._schema}'""" ).fetchone() ): - duck_conn.execute(f"CREATE SCHEMA {self.database}.{self.schema}") + duck_conn.execute(f"CREATE SCHEMA {self.database}.{self._schema}") # set database and schema if both exist if ( self.database - and self.schema + and self._schema and duck_conn.execute( f"""select * from information_schema.schemata - where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self.schema}'""" + where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self._schema}'""" ).fetchone() ): - duck_conn.execute(f"SET schema='{self.database}.{self.schema}'") + duck_conn.execute(f"SET schema='{self.database}.{self._schema}'") self.database_set = True self.schema_set = True # set database if only that exists @@ -149,3 +151,7 @@ def is_closed(self) -> bool: def rollback(self) -> None: self.cursor().execute("ROLLBACK") + + @property + def schema(self) -> str | None: + return "INFORMATION_SCHEMA" if self._schema == "_FS_INFORMATION_SCHEMA" else self._schema diff --git a/fakesnow/cursor.py b/fakesnow/cursor.py index 0e122c2..02dcd44 100644 --- a/fakesnow/cursor.py +++ b/fakesnow/cursor.py @@ -284,7 +284,7 @@ def _execute(self, transformed: exp.Expression, params: Sequence[Any] | dict[Any self._conn.database_set = True elif set_schema := transformed.args.get("set_schema"): - self._conn.schema = set_schema + self._conn._schema = set_schema # noqa: SLF001 self._conn.schema_set = True elif create_db_name := transformed.args.get("create_db_name"): @@ -334,10 +334,10 @@ def _execute(self, transformed: exp.Expression, params: Sequence[Any] | dict[Any # if dropping the current database/schema then reset conn metadata if cmd == "DROP DATABASE" and ident == self._conn.database: self._conn.database = None - self._conn.schema = None + self._conn._schema = None # noqa: SLF001 elif cmd == "DROP SCHEMA" and ident == self._conn.schema: - self._conn.schema = None + self._conn._schema = None # noqa: SLF001 if table_comment := cast(tuple[exp.Table, str], transformed.args.get("table_comment")): # record table comment diff --git a/fakesnow/info_schema.py b/fakesnow/info_schema.py index 843e3c1..274b760 100644 --- a/fakesnow/info_schema.py +++ b/fakesnow/info_schema.py @@ -102,7 +102,7 @@ 'STANDARD' as type from system.information_schema.schemata where catalog_name not in ('memory', 'system', 'temp', '_fs_global') - and schema_name = 'information_schema' + and schema_name = 'main' """ ) diff --git a/fakesnow/transforms.py b/fakesnow/transforms.py index 9c57b3d..4400f51 100644 --- a/fakesnow/transforms.py +++ b/fakesnow/transforms.py @@ -1015,13 +1015,16 @@ def show_objects_tables(expression: exp.Expression, current_database: str | None SQL_SHOW_SCHEMAS = """ select to_timestamp(0)::timestamptz as 'created_on', - schema_name as 'name', + case + when schema_name = '_fs_information_schema' then 'information_schema' + else schema_name + end as 'name', NULL as 'kind', catalog_name as 'database_name', NULL as 'schema_name' from information_schema.schemata where not catalog_name in ('memory', 'system', 'temp') - and not schema_name in ('main', 'pg_catalog', '_fs_information_schema') + and not schema_name in ('main', 'pg_catalog') """ diff --git a/pyproject.toml b/pyproject.toml index dae6cb5..4dad5d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ classifiers = ["License :: OSI Approved :: MIT License"] keywords = ["snowflake", "snowflakedb", "fake", "local", "mock", "testing"] requires-python = ">=3.9" dependencies = [ - "duckdb~=1.1.3", + "duckdb~=1.2.0", "pyarrow", "snowflake-connector-python", "sqlglot~=26.6.0", diff --git a/tests/test_fakes.py b/tests/test_fakes.py index c51d8b3..268cfcd 100644 --- a/tests/test_fakes.py +++ b/tests/test_fakes.py @@ -663,11 +663,11 @@ def test_regex_substr(cur: snowflake.connector.cursor.SnowflakeCursor): def test_random(cur: snowflake.connector.cursor.SnowflakeCursor): cur.execute("select random(420)") - assert cur.fetchall() == [(-2595895151578578944,)] + assert cur.fetchall() == [(-4068260216279105536,)] cur.execute("select random(420)") - assert cur.fetchall() == [(-2595895151578578944,)] + assert cur.fetchall() == [(-4068260216279105536,)] cur.execute("select random(419)") - assert cur.fetchall() == [(4590143504000221184,)] + assert cur.fetchall() == [(1460638274662493184,)] assert cur.execute("select random()").fetchall() != cur.execute("select random()").fetchall() @@ -685,7 +685,7 @@ def test_rowcount(cur: snowflake.connector.cursor.SnowflakeCursor): def test_sample(cur: snowflake.connector.cursor.SnowflakeCursor): cur.execute("create table example(id int)") cur.execute("insert into example select * from (VALUES (1), (2), (3), (4));") - cur.execute("select * from example SAMPLE (50) SEED (420)") + cur.execute("select * from example SAMPLE (50) SEED (999)") # sampling small sizes isn't exact assert cur.fetchall() == [(1,), (2,), (3,)] diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c059cc0..dd0373a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -682,7 +682,7 @@ def test_show_objects_tables() -> None: def test_show_schemas() -> None: assert ( sqlglot.parse_one("show terse schemas in database db1", read="snowflake").transform(show_schemas).sql() - == """SELECT CAST(UNIX_TO_TIME(0) AS TIMESTAMPTZ) AS "created_on", schema_name AS "name", NULL AS "kind", catalog_name AS "database_name", NULL AS "schema_name" FROM information_schema.schemata WHERE NOT catalog_name IN ('memory', 'system', 'temp') AND NOT schema_name IN ('main', 'pg_catalog', '_fs_information_schema') AND catalog_name = 'db1'""" # noqa: E501 + == """SELECT CAST(UNIX_TO_TIME(0) AS TIMESTAMPTZ) AS "created_on", CASE WHEN schema_name = '_fs_information_schema' THEN 'information_schema' ELSE schema_name END AS "name", NULL AS "kind", catalog_name AS "database_name", NULL AS "schema_name" FROM information_schema.schemata WHERE NOT catalog_name IN ('memory', 'system', 'temp') AND NOT schema_name IN ('main', 'pg_catalog') AND catalog_name = 'db1'""" # noqa: E501 )