diff --git a/providers/tests/grpc/hooks/test_grpc.py b/providers/tests/grpc/hooks/test_grpc.py index 8536188bb135..ed185b4bf282 100644 --- a/providers/tests/grpc/hooks/test_grpc.py +++ b/providers/tests/grpc/hooks/test_grpc.py @@ -59,21 +59,21 @@ def stream_call(self, data): return ["streaming", "call"] -class TestGrpcHook: - def setup_method(self): - self.channel_mock = mock.patch("grpc.Channel").start() +@pytest.fixture +def channel_mock(): + """We mock run_command to capture its call args; it returns nothing so mock training is unnecessary.""" + with patch("grpc.Channel") as grpc_channel: + yield grpc_channel - def custom_conn_func(self, _): - mocked_channel = self.channel_mock.return_value - return mocked_channel +class TestGrpcHook: @mock.patch("grpc.insecure_channel") @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel): + def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel, channel_mock): conn = get_airflow_connection() mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_insecure_channel.return_value = mocked_channel channel = hook.get_conn() @@ -84,11 +84,11 @@ def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel): @mock.patch("grpc.insecure_channel") @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_connection_with_port(self, mock_get_connection, mock_insecure_channel): + def test_connection_with_port(self, mock_get_connection, mock_insecure_channel, channel_mock): conn = get_airflow_connection_with_port() mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_insecure_channel.return_value = mocked_channel channel = hook.get_conn() @@ -102,13 +102,13 @@ def test_connection_with_port(self, mock_get_connection, mock_insecure_channel): @mock.patch("grpc.ssl_channel_credentials") @mock.patch("grpc.secure_channel") def test_connection_with_ssl( - self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open + self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock ): conn = get_airflow_connection(auth_type="SSL", credential_pem_file="pem") mock_get_connection.return_value = conn mock_open.return_value = StringIO("credential") hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_channel_credentials.return_value = mock_credential_object @@ -126,13 +126,13 @@ def test_connection_with_ssl( @mock.patch("grpc.ssl_channel_credentials") @mock.patch("grpc.secure_channel") def test_connection_with_tls( - self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open + self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock ): conn = get_airflow_connection(auth_type="TLS", credential_pem_file="pem") mock_get_connection.return_value = conn mock_open.return_value = StringIO("credential") hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_channel_credentials.return_value = mock_credential_object @@ -150,12 +150,17 @@ def test_connection_with_tls( @mock.patch("google.auth.default") @mock.patch("google.auth.transport.grpc.secure_authorized_channel") def test_connection_with_jwt( - self, mock_secure_channel, mock_google_default_auth, mock_google_cred, mock_get_connection + self, + mock_secure_channel, + mock_google_default_auth, + mock_google_cred, + mock_get_connection, + channel_mock, ): conn = get_airflow_connection(auth_type="JWT_GOOGLE") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_google_default_auth.return_value = (mock_credential_object, "") @@ -173,12 +178,17 @@ def test_connection_with_jwt( @mock.patch("google.auth.default") @mock.patch("google.auth.transport.grpc.secure_authorized_channel") def test_connection_with_google_oauth( - self, mock_secure_channel, mock_google_default_auth, mock_google_auth_request, mock_get_connection + self, + mock_secure_channel, + mock_google_default_auth, + mock_google_auth_request, + mock_get_connection, + channel_mock, ): conn = get_airflow_connection(auth_type="OATH_GOOGLE", scopes="grpc,gcs") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_google_default_auth.return_value = (mock_credential_object, "") @@ -192,18 +202,22 @@ def test_connection_with_google_oauth( assert channel == mocked_channel @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_custom_connection(self, mock_get_connection): + def test_custom_connection(self, mock_get_connection, channel_mock): + def custom_conn_func(_): + mocked_channel = channel_mock.return_value + return mocked_channel + conn = get_airflow_connection("CUSTOM") mock_get_connection.return_value = conn - mocked_channel = self.channel_mock.return_value - hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func) + mocked_channel = channel_mock.return_value + hook = GrpcHook("grpc_default", custom_connection_func=custom_conn_func) channel = hook.get_conn() assert channel == mocked_channel @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_custom_connection_with_no_connection_func(self, mock_get_connection): + def test_custom_connection_with_no_connection_func(self, mock_get_connection, channel_mock): conn = get_airflow_connection("CUSTOM") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") @@ -212,7 +226,7 @@ def test_custom_connection_with_no_connection_func(self, mock_get_connection): hook.get_conn() @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_connection_type_not_supported(self, mock_get_connection): + def test_connection_type_not_supported(self, mock_get_connection, channel_mock): conn = get_airflow_connection("NOT_SUPPORT") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") @@ -224,11 +238,11 @@ def test_connection_type_not_supported(self, mock_get_connection): @mock.patch("airflow.hooks.base.BaseHook.get_connection") @mock.patch("grpc.insecure_channel") def test_connection_with_interceptors( - self, mock_insecure_channel, mock_get_connection, mock_intercept_channel + self, mock_insecure_channel, mock_get_connection, mock_intercept_channel, channel_mock ): conn = get_airflow_connection() mock_get_connection.return_value = conn - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value hook = GrpcHook("grpc_default", interceptors=["test1"]) mock_insecure_channel.return_value = mocked_channel mock_intercept_channel.return_value = mocked_channel @@ -240,7 +254,7 @@ def test_connection_with_interceptors( @mock.patch("airflow.hooks.base.BaseHook.get_connection") @mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn") - def test_simple_run(self, mock_get_conn, mock_get_connection): + def test_simple_run(self, mock_get_conn, mock_get_connection, channel_mock): conn = get_airflow_connection() mock_get_connection.return_value = conn mocked_channel = mock.Mock() @@ -255,7 +269,7 @@ def test_simple_run(self, mock_get_conn, mock_get_connection): @mock.patch("airflow.hooks.base.BaseHook.get_connection") @mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn") - def test_stream_run(self, mock_get_conn, mock_get_connection): + def test_stream_run(self, mock_get_conn, mock_get_connection, channel_mock): conn = get_airflow_connection() mock_get_connection.return_value = conn mocked_channel = mock.Mock() @@ -279,13 +293,13 @@ def test_stream_run(self, mock_get_conn, mock_get_connection): ], ) @patch("airflow.providers.grpc.hooks.grpc.grpc.insecure_channel") - def test_backcompat_prefix_works(self, channel_mock, uri): + def test_backcompat_prefix_works(self, insecure_channel_mock, uri): with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}): hook = GrpcHook("my_conn") hook.get_conn() - channel_mock.assert_called_with("abc:50") + insecure_channel_mock.assert_called_with("abc:50") - def test_backcompat_prefix_both_prefers_short(self): + def test_backcompat_prefix_both_prefers_short(self, channel_mock): with patch.dict( os.environ, {"AIRFLOW_CONN_MY_CONN": "a://abc:50?extra__grpc__auth_type=non-pref&auth_type=pref"},