Skip to content

Commit

Permalink
hotfix: snowflake python connector default config dir (#125)
Browse files Browse the repository at this point in the history
<!--- Provide a general summary of your changes in the Title above -->

## Description
Snowflake python connector have a logic to get default[ root directory
for
configurations](https://github.com/snowflakedb/snowflake-connector-python/blob/9ddb2050cde13819a289fb6982b48431f253a53e/src/snowflake/connector/sf_dirs.py#L47):
```python
def _resolve_platform_dirs() -> PlatformDirsProto:
    """Decide on what PlatformDirs class to use.

    In case a folder exists (which can be customized with the environmental
    variable `SNOWFLAKE_HOME`) we use that directory as all platform
    directories. If this folder does not exist we'll fall back to platformdirs
    defaults.

    This helper function was introduced to make this code testable.
    """
    platformdir_kwargs = {
        "appname": "snowflake",
        "appauthor": False,
    }
    snowflake_home = pathlib.Path(
        os.environ.get("SNOWFLAKE_HOME", "~/.snowflake/"),
    ).expanduser()
    if snowflake_home.exists():
        return SFPlatformDirs(
            str(snowflake_home),
            **platformdir_kwargs,
        )
    else:
        # In case SNOWFLAKE_HOME does not exist we fall back to using
        # platformdirs to determine where system files should be placed. Please
        # see docs for all the directories defined in the module at
        # https://platformdirs.readthedocs.io/
        return PlatformDirs(**platformdir_kwargs)
```

Currently in databricks jobs execution this one is being set to
`'/root/.snowflake'` which is not allowed to be accessed.

The fix is to catch the error and provide `tmp` folder instead of
`root`.

## Related Issue
#124 

## Motivation and Context
Be able to use snowflake python in databricks

## How Has This Been Tested?
Added new tests

## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all
the boxes that apply: -->
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)

## Checklist:
<!--- Go over all the following points, and put an `x` in all the boxes
that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're
here to help! -->
- [x] My code follows the code style of this project.
- [ ] My change requires a change to the documentation.
- [ ] I have updated the documentation accordingly.
- [ ] I have read the **CONTRIBUTING** document.
- [x] I have added tests to cover my changes.
- [x] All new and existing tests passed.
  • Loading branch information
mikita-sakalouski authored Nov 25, 2024
1 parent 6faa20d commit 9496eb5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
59 changes: 58 additions & 1 deletion src/koheesio/integrations/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from typing import Any, Dict, Generator, List, Optional, Set, Union
from abc import ABC
from contextlib import contextmanager
import os
import tempfile
from types import ModuleType
from urllib.parse import urlparse

Expand All @@ -61,6 +63,7 @@
field_validator,
model_validator,
)
from koheesio.spark.utils.common import on_databricks

__all__ = [
"GrantPrivilegesOnFullyQualifiedObject",
Expand All @@ -79,6 +82,38 @@
# Turning off too-many-lines because we are defining a lot of classes in this file


def __check_access_snowflake_config_dir() -> bool:
"""Check if the Snowflake configuration directory is accessible
Returns
-------
bool
True if the Snowflake configuration directory is accessible, otherwise False
Raises
------
RuntimeError
If `snowflake-connector-python` is not installed
"""
check_result = False

try:
from snowflake.connector.sf_dirs import _resolve_platform_dirs # noqa: F401

_resolve_platform_dirs().user_config_path
check_result = True
except PermissionError as e:
warn(f"Snowflake configuration directory is not accessible. Please check the permissions.Catched error: {e}")
except (ImportError, ModuleNotFoundError) as e:
raise RuntimeError(
"You need to have the `snowflake-connector-python` package installed to use the Snowflake steps that are"
"based around SnowflakeRunQueryPython. You can install this in Koheesio by adding `koheesio[snowflake]` to "
"your package dependencies.",
) from e

return check_result


def safe_import_snowflake_connector() -> Optional[ModuleType]:
"""Validate that the Snowflake connector is installed
Expand All @@ -87,7 +122,17 @@ def safe_import_snowflake_connector() -> Optional[ModuleType]:
Optional[ModuleType]
The Snowflake connector module if it is installed, otherwise None
"""
is_accessable_sf_conf_dir = __check_access_snowflake_config_dir()

if not is_accessable_sf_conf_dir and on_databricks():
snowflake_home: str = tempfile.mkdtemp(prefix="snowflake_tmp_", dir="/tmp") # nosec B108:ignore bandit check for CWE-377
os.environ["SNOWFLAKE_HOME"] = snowflake_home
warn(f"Getting error for snowflake config directory. Going to use temp directory `{snowflake_home}` instead.")
elif not is_accessable_sf_conf_dir:
raise PermissionError("Snowflake configuration directory is not accessible. Please check the permissions.")

try:
# Keep the import here as it is perfroming resolution of snowflake configuration directory
from snowflake import connector as snowflake_connector

return snowflake_connector
Expand Down Expand Up @@ -336,8 +381,18 @@ def conn(self) -> Generator:
self.log.info(f"Connected to Snowflake account: {sf_options['account']}")

try:
from snowflake.connector.connection import logger as snowflake_logger

_preserve_snowflake_logger = snowflake_logger
snowflake_logger = self.log
snowflake_logger.debug("Replace snowflake logger with Koheesio logger")
yield _conn
finally:
if _preserve_snowflake_logger:
if snowflake_logger:
snowflake_logger.debug("Restore snowflake logger")
snowflake_logger = _preserve_snowflake_logger

if _conn:
_conn.close()

Expand All @@ -348,7 +403,9 @@ def get_query(self) -> str:
def execute(self) -> None:
"""Execute the query"""
with self.conn as conn:
cursors = conn.execute_string(self.get_query())
cursors = conn.execute_string(
self.get_query(),
)
for cursor in cursors:
self.log.debug(f"Cursor executed: {cursor}")
self.output.results.extend(cursor.fetchall())
Expand Down
16 changes: 16 additions & 0 deletions tests/snowflake/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa: F811
from copy import deepcopy
import os
from unittest import mock

import pytest
Expand All @@ -14,6 +15,7 @@
SnowflakeRunQueryPython,
SnowflakeStep,
SnowflakeTableStep,
safe_import_snowflake_connector,
)
from koheesio.integrations.snowflake.test_utils import mock_query

Expand Down Expand Up @@ -272,3 +274,17 @@ def test_initialization(self):
"""Test that the table is correctly set"""
kls = SnowflakeTableStep(**COMMON_OPTIONS, table="table")
assert kls.table == "table"


class TestSnowflakeConfigDir:
@mock.patch("koheesio.integrations.snowflake.__check_access_snowflake_config_dir", return_value=False)
@mock.patch("koheesio.integrations.snowflake.on_databricks", return_value=True)
def test_initialization_on_databricks(self, mock_on_databricks, mock_check_access):
"""Test that the config dir is correctly set"""
safe_import_snowflake_connector()
assert os.environ["SNOWFLAKE_HOME"].startswith("/tmp/snowflake_tmp_")

def test_initialization(self):
origin_snowflake_home = os.environ.get("SNOWFLAKE_HOME")
safe_import_snowflake_connector()
assert os.environ.get("SNOWFLAKE_HOME") == origin_snowflake_home

0 comments on commit 9496eb5

Please sign in to comment.