Skip to content

Commit

Permalink
Fix side effect of bad grpc.Chanel mocking (#44029)
Browse files Browse the repository at this point in the history
The grpc.Channel has been patched but not relased in the test_grpc.py
and it could have caused other tests failing - when they were run
later in the same interpreter. For example it failed in in #44011 in the
#44011 (comment)

Patching is now fixed via using fixtures.
  • Loading branch information
potiuk authored Nov 14, 2024
1 parent 339bc77 commit 7728139
Showing 1 changed file with 44 additions and 30 deletions.
74 changes: 44 additions & 30 deletions providers/tests/grpc/hooks/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,21 @@ def stream_call(self, data):
return ["streaming", "call"]


class TestGrpcHook:
def setup_method(self):
self.channel_mock = mock.patch("grpc.Channel").start()
@pytest.fixture
def channel_mock():
"""We mock run_command to capture its call args; it returns nothing so mock training is unnecessary."""
with patch("grpc.Channel") as grpc_channel:
yield grpc_channel

def custom_conn_func(self, _):
mocked_channel = self.channel_mock.return_value
return mocked_channel

class TestGrpcHook:
@mock.patch("grpc.insecure_channel")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):
def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel, channel_mock):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_insecure_channel.return_value = mocked_channel

channel = hook.get_conn()
Expand All @@ -84,11 +84,11 @@ def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):

@mock.patch("grpc.insecure_channel")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
def test_connection_with_port(self, mock_get_connection, mock_insecure_channel, channel_mock):
conn = get_airflow_connection_with_port()
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_insecure_channel.return_value = mocked_channel

channel = hook.get_conn()
Expand All @@ -102,13 +102,13 @@ def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
@mock.patch("grpc.ssl_channel_credentials")
@mock.patch("grpc.secure_channel")
def test_connection_with_ssl(
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock
):
conn = get_airflow_connection(auth_type="SSL", credential_pem_file="pem")
mock_get_connection.return_value = conn
mock_open.return_value = StringIO("credential")
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_channel_credentials.return_value = mock_credential_object
Expand All @@ -126,13 +126,13 @@ def test_connection_with_ssl(
@mock.patch("grpc.ssl_channel_credentials")
@mock.patch("grpc.secure_channel")
def test_connection_with_tls(
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock
):
conn = get_airflow_connection(auth_type="TLS", credential_pem_file="pem")
mock_get_connection.return_value = conn
mock_open.return_value = StringIO("credential")
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_channel_credentials.return_value = mock_credential_object
Expand All @@ -150,12 +150,17 @@ def test_connection_with_tls(
@mock.patch("google.auth.default")
@mock.patch("google.auth.transport.grpc.secure_authorized_channel")
def test_connection_with_jwt(
self, mock_secure_channel, mock_google_default_auth, mock_google_cred, mock_get_connection
self,
mock_secure_channel,
mock_google_default_auth,
mock_google_cred,
mock_get_connection,
channel_mock,
):
conn = get_airflow_connection(auth_type="JWT_GOOGLE")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_google_default_auth.return_value = (mock_credential_object, "")
Expand All @@ -173,12 +178,17 @@ def test_connection_with_jwt(
@mock.patch("google.auth.default")
@mock.patch("google.auth.transport.grpc.secure_authorized_channel")
def test_connection_with_google_oauth(
self, mock_secure_channel, mock_google_default_auth, mock_google_auth_request, mock_get_connection
self,
mock_secure_channel,
mock_google_default_auth,
mock_google_auth_request,
mock_get_connection,
channel_mock,
):
conn = get_airflow_connection(auth_type="OATH_GOOGLE", scopes="grpc,gcs")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_google_default_auth.return_value = (mock_credential_object, "")
Expand All @@ -192,18 +202,22 @@ def test_connection_with_google_oauth(
assert channel == mocked_channel

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_custom_connection(self, mock_get_connection):
def test_custom_connection(self, mock_get_connection, channel_mock):
def custom_conn_func(_):
mocked_channel = channel_mock.return_value
return mocked_channel

conn = get_airflow_connection("CUSTOM")
mock_get_connection.return_value = conn
mocked_channel = self.channel_mock.return_value
hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func)
mocked_channel = channel_mock.return_value
hook = GrpcHook("grpc_default", custom_connection_func=custom_conn_func)

channel = hook.get_conn()

assert channel == mocked_channel

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_custom_connection_with_no_connection_func(self, mock_get_connection):
def test_custom_connection_with_no_connection_func(self, mock_get_connection, channel_mock):
conn = get_airflow_connection("CUSTOM")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
Expand All @@ -212,7 +226,7 @@ def test_custom_connection_with_no_connection_func(self, mock_get_connection):
hook.get_conn()

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_connection_type_not_supported(self, mock_get_connection):
def test_connection_type_not_supported(self, mock_get_connection, channel_mock):
conn = get_airflow_connection("NOT_SUPPORT")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
Expand All @@ -224,11 +238,11 @@ def test_connection_type_not_supported(self, mock_get_connection):
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("grpc.insecure_channel")
def test_connection_with_interceptors(
self, mock_insecure_channel, mock_get_connection, mock_intercept_channel
self, mock_insecure_channel, mock_get_connection, mock_intercept_channel, channel_mock
):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
hook = GrpcHook("grpc_default", interceptors=["test1"])
mock_insecure_channel.return_value = mocked_channel
mock_intercept_channel.return_value = mocked_channel
Expand All @@ -240,7 +254,7 @@ def test_connection_with_interceptors(

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn")
def test_simple_run(self, mock_get_conn, mock_get_connection):
def test_simple_run(self, mock_get_conn, mock_get_connection, channel_mock):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
mocked_channel = mock.Mock()
Expand All @@ -255,7 +269,7 @@ def test_simple_run(self, mock_get_conn, mock_get_connection):

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn")
def test_stream_run(self, mock_get_conn, mock_get_connection):
def test_stream_run(self, mock_get_conn, mock_get_connection, channel_mock):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
mocked_channel = mock.Mock()
Expand All @@ -279,13 +293,13 @@ def test_stream_run(self, mock_get_conn, mock_get_connection):
],
)
@patch("airflow.providers.grpc.hooks.grpc.grpc.insecure_channel")
def test_backcompat_prefix_works(self, channel_mock, uri):
def test_backcompat_prefix_works(self, insecure_channel_mock, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = GrpcHook("my_conn")
hook.get_conn()
channel_mock.assert_called_with("abc:50")
insecure_channel_mock.assert_called_with("abc:50")

def test_backcompat_prefix_both_prefers_short(self):
def test_backcompat_prefix_both_prefers_short(self, channel_mock):
with patch.dict(
os.environ,
{"AIRFLOW_CONN_MY_CONN": "a://abc:50?extra__grpc__auth_type=non-pref&auth_type=pref"},
Expand Down

0 comments on commit 7728139

Please sign in to comment.