Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix side effect of bad grpc.Chanel mocking #44029

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions providers/tests/grpc/hooks/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,21 @@ def stream_call(self, data):
return ["streaming", "call"]


class TestGrpcHook:
def setup_method(self):
self.channel_mock = mock.patch("grpc.Channel").start()
@pytest.fixture
def channel_mock():
"""We mock run_command to capture its call args; it returns nothing so mock training is unnecessary."""
with patch("grpc.Channel") as grpc_channel:
yield grpc_channel

def custom_conn_func(self, _):
mocked_channel = self.channel_mock.return_value
return mocked_channel

class TestGrpcHook:
@mock.patch("grpc.insecure_channel")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):
def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel, channel_mock):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_insecure_channel.return_value = mocked_channel

channel = hook.get_conn()
Expand All @@ -84,11 +84,11 @@ def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):

@mock.patch("grpc.insecure_channel")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
def test_connection_with_port(self, mock_get_connection, mock_insecure_channel, channel_mock):
conn = get_airflow_connection_with_port()
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_insecure_channel.return_value = mocked_channel

channel = hook.get_conn()
Expand All @@ -102,13 +102,13 @@ def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
@mock.patch("grpc.ssl_channel_credentials")
@mock.patch("grpc.secure_channel")
def test_connection_with_ssl(
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock
):
conn = get_airflow_connection(auth_type="SSL", credential_pem_file="pem")
mock_get_connection.return_value = conn
mock_open.return_value = StringIO("credential")
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_channel_credentials.return_value = mock_credential_object
Expand All @@ -126,13 +126,13 @@ def test_connection_with_ssl(
@mock.patch("grpc.ssl_channel_credentials")
@mock.patch("grpc.secure_channel")
def test_connection_with_tls(
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open
self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock
):
conn = get_airflow_connection(auth_type="TLS", credential_pem_file="pem")
mock_get_connection.return_value = conn
mock_open.return_value = StringIO("credential")
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_channel_credentials.return_value = mock_credential_object
Expand All @@ -150,12 +150,17 @@ def test_connection_with_tls(
@mock.patch("google.auth.default")
@mock.patch("google.auth.transport.grpc.secure_authorized_channel")
def test_connection_with_jwt(
self, mock_secure_channel, mock_google_default_auth, mock_google_cred, mock_get_connection
self,
mock_secure_channel,
mock_google_default_auth,
mock_google_cred,
mock_get_connection,
channel_mock,
):
conn = get_airflow_connection(auth_type="JWT_GOOGLE")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_google_default_auth.return_value = (mock_credential_object, "")
Expand All @@ -173,12 +178,17 @@ def test_connection_with_jwt(
@mock.patch("google.auth.default")
@mock.patch("google.auth.transport.grpc.secure_authorized_channel")
def test_connection_with_google_oauth(
self, mock_secure_channel, mock_google_default_auth, mock_google_auth_request, mock_get_connection
self,
mock_secure_channel,
mock_google_default_auth,
mock_google_auth_request,
mock_get_connection,
channel_mock,
):
conn = get_airflow_connection(auth_type="OATH_GOOGLE", scopes="grpc,gcs")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
mock_secure_channel.return_value = mocked_channel
mock_credential_object = "test_credential_object"
mock_google_default_auth.return_value = (mock_credential_object, "")
Expand All @@ -192,18 +202,22 @@ def test_connection_with_google_oauth(
assert channel == mocked_channel

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_custom_connection(self, mock_get_connection):
def test_custom_connection(self, mock_get_connection, channel_mock):
def custom_conn_func(_):
mocked_channel = channel_mock.return_value
return mocked_channel

conn = get_airflow_connection("CUSTOM")
mock_get_connection.return_value = conn
mocked_channel = self.channel_mock.return_value
hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func)
mocked_channel = channel_mock.return_value
hook = GrpcHook("grpc_default", custom_connection_func=custom_conn_func)

channel = hook.get_conn()

assert channel == mocked_channel

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_custom_connection_with_no_connection_func(self, mock_get_connection):
def test_custom_connection_with_no_connection_func(self, mock_get_connection, channel_mock):
conn = get_airflow_connection("CUSTOM")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
Expand All @@ -212,7 +226,7 @@ def test_custom_connection_with_no_connection_func(self, mock_get_connection):
hook.get_conn()

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_connection_type_not_supported(self, mock_get_connection):
def test_connection_type_not_supported(self, mock_get_connection, channel_mock):
conn = get_airflow_connection("NOT_SUPPORT")
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
Expand All @@ -224,11 +238,11 @@ def test_connection_type_not_supported(self, mock_get_connection):
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("grpc.insecure_channel")
def test_connection_with_interceptors(
self, mock_insecure_channel, mock_get_connection, mock_intercept_channel
self, mock_insecure_channel, mock_get_connection, mock_intercept_channel, channel_mock
):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
mocked_channel = self.channel_mock.return_value
mocked_channel = channel_mock.return_value
hook = GrpcHook("grpc_default", interceptors=["test1"])
mock_insecure_channel.return_value = mocked_channel
mock_intercept_channel.return_value = mocked_channel
Expand All @@ -240,7 +254,7 @@ def test_connection_with_interceptors(

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn")
def test_simple_run(self, mock_get_conn, mock_get_connection):
def test_simple_run(self, mock_get_conn, mock_get_connection, channel_mock):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
mocked_channel = mock.Mock()
Expand All @@ -255,7 +269,7 @@ def test_simple_run(self, mock_get_conn, mock_get_connection):

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn")
def test_stream_run(self, mock_get_conn, mock_get_connection):
def test_stream_run(self, mock_get_conn, mock_get_connection, channel_mock):
conn = get_airflow_connection()
mock_get_connection.return_value = conn
mocked_channel = mock.Mock()
Expand All @@ -279,13 +293,13 @@ def test_stream_run(self, mock_get_conn, mock_get_connection):
],
)
@patch("airflow.providers.grpc.hooks.grpc.grpc.insecure_channel")
def test_backcompat_prefix_works(self, channel_mock, uri):
def test_backcompat_prefix_works(self, insecure_channel_mock, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = GrpcHook("my_conn")
hook.get_conn()
channel_mock.assert_called_with("abc:50")
insecure_channel_mock.assert_called_with("abc:50")

def test_backcompat_prefix_both_prefers_short(self):
def test_backcompat_prefix_both_prefers_short(self, channel_mock):
with patch.dict(
os.environ,
{"AIRFLOW_CONN_MY_CONN": "a://abc:50?extra__grpc__auth_type=non-pref&auth_type=pref"},
Expand Down