diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index 90dc8c36..fd5ccd96 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -364,6 +364,7 @@ RUN pip install mpld3 \ eli5 \ kaggle \ kagglehub \ + google-generativeai \ mock \ pytest && \ /tmp/clean-layer.sh diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 2f621f7d..6ac7400e 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -7,6 +7,8 @@ import importlib import importlib.machinery +import wrapt + class GcpModuleFinder(importlib.abc.MetaPathFinder): _MODULES = [ 'google.cloud.bigquery', @@ -73,3 +75,41 @@ def exec_module(self, module): if not hasattr(sys, 'frozen'): sys.meta_path.insert(0, GcpModuleFinder()) + +@wrapt.when_imported('google.generativeai') +def post_import_logic(module): + if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: + return + if (os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or + os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or + os.getenv('KAGGLE_DATA_PROXY_URL') == None): + return + + old_configure = module.configure + + def new_configure(*args, **kwargs): + if ('default_metadata' in kwargs): + default_metadata = kwargs['default_metadata'] + else: + default_metadata = [] + default_metadata.append(("x-kaggle-proxy-data", os.environ['KAGGLE_DATA_PROXY_TOKEN'])) + user_secrets_token = os.environ['KAGGLE_USER_SECRETS_TOKEN'] + default_metadata.append(('x-kaggle-authorization', f'Bearer {user_secrets_token}')) + kwargs['default_metadata'] = default_metadata + + if ('client_options' in kwargs): + client_options = kwargs['client_options'] + else: + client_options = {} + client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] + if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None: + client_options['api_endpoint'] += '/palmapi' + kwargs['transport'] = 'rest' + elif 'transport' in kwargs and kwargs['transport'] == 'rest': + client_options['api_endpoint'] += '/palmapi' + kwargs['client_options'] = client_options + + old_configure(*args, **kwargs) + + module.configure = new_configure + module.configure() # generativeai can use GOOGLE_API_KEY env variable, so make sure we have the other configs set diff --git a/tests/test_google_generativeai_patch.py b/tests/test_google_generativeai_patch.py new file mode 100644 index 00000000..68e766c2 --- /dev/null +++ b/tests/test_google_generativeai_patch.py @@ -0,0 +1,54 @@ +import json +import unittest +import threading + +from test.support.os_helper import EnvironmentVarGuard +from urllib.parse import urlparse + +from http.server import BaseHTTPRequestHandler, HTTPServer + +class HTTPHandler(BaseHTTPRequestHandler): + called = False + path = None + headers = {} + + def do_HEAD(self): + self.send_response(200) + + def do_GET(self): + HTTPHandler.path = self.path + HTTPHandler.headers = self.headers + HTTPHandler.called = True + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + +class TestGoogleGenerativeAiPatch(unittest.TestCase): + endpoint = "http://127.0.0.1:80" + + def test_proxy_enabled(self): + env = EnvironmentVarGuard() + secrets_token = "secrets_token" + proxy_token = "proxy_token" + env.set("KAGGLE_USER_SECRETS_TOKEN", secrets_token) + env.set("KAGGLE_DATA_PROXY_TOKEN", proxy_token) + env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) + env.set("KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY", "True") + server_address = urlparse(self.endpoint) + with env: + with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: + threading.Thread(target=httpd.serve_forever).start() + import google.generativeai as palm + api_key = "NotARealAPIKey" + palm.configure(api_key = api_key) + try: + for _ in palm.list_models(): + pass + except: + pass + httpd.shutdown() + self.assertTrue(HTTPHandler.called) + self.assertIn("/palmapi", HTTPHandler.path) + self.assertEqual(proxy_token, HTTPHandler.headers["x-kaggle-proxy-data"]) + self.assertEqual("Bearer {}".format(secrets_token), HTTPHandler.headers["x-kaggle-authorization"]) + self.assertEqual(api_key, HTTPHandler.headers["x-goog-api-key"]) diff --git a/tests/test_google_generativeai_patch_disabled.py b/tests/test_google_generativeai_patch_disabled.py new file mode 100644 index 00000000..65e02845 --- /dev/null +++ b/tests/test_google_generativeai_patch_disabled.py @@ -0,0 +1,43 @@ +import json +import unittest +import threading + +from test.support.os_helper import EnvironmentVarGuard +from urllib.parse import urlparse + +from http.server import BaseHTTPRequestHandler, HTTPServer + +class HTTPHandler(BaseHTTPRequestHandler): + called = False + + def do_HEAD(self): + self.send_response(200) + + def do_GET(self): + HTTPHandler.called = True + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + +class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase): + endpoint = "http://127.0.0.1:80" + + def test_disabled(self): + env = EnvironmentVarGuard() + env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar") + env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar") + env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) + env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True") + server_address = urlparse(self.endpoint) + with env: + with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: + threading.Thread(target=httpd.serve_forever).start() + import google.generativeai as palm + palm.configure(api_key = "NotARealAPIKey") + try: + for _ in palm.list_models(): + pass + except: + pass + httpd.shutdown() + self.assertFalse(HTTPHandler.called) \ No newline at end of file