From bb9b855334a48c3a05e2e1a1319913bf830bbc36 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Wed, 29 Nov 2023 20:42:27 +0000 Subject: [PATCH 01/20] Patching for Python package --- Dockerfile.tmpl | 7 +++++++ patches/generativeaipatch.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 patches/generativeaipatch.py diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index 6b2d3a5d..6e000740 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -634,6 +634,10 @@ RUN mkdir -p ~/src && git clone https://github.com/SohierDane/BigQuery_Helper ~/ sed -i 's/)/packages=["bq_helper"])/g' ~/src/BigQuery_Helper/setup.py && \ pip install -e ~/src/BigQuery_Helper && \ /tmp/clean-layer.sh + +RUN pip install wrapt \ + google-generativeai && \ + /tmp/clean-layer.sh # Add BigQuery client proxy settings ENV PYTHONUSERBASE "/root/.local" @@ -647,6 +651,9 @@ ADD patches/sitecustomize.py /root/.local/lib/python3.10/site-packages/sitecusto # Override default imagemagick policies ADD patches/imagemagick-policy.xml /etc/ImageMagick-6/policy.xml +# Add generativeai Python patch +AND patches/generativeaipatch.py /root/.local/lib/python3.10/site-packages/generativeaipatch.py + # Add Kaggle module resolver ADD patches/kaggle_module_resolver.py /opt/conda/lib/python3.10/site-packages/tensorflow_hub/kaggle_module_resolver.py RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /opt/conda/lib/python3.10/site-packages/tensorflow_hub/config.py && \ diff --git a/patches/generativeaipatch.py b/patches/generativeaipatch.py new file mode 100644 index 00000000..40c23dd0 --- /dev/null +++ b/patches/generativeaipatch.py @@ -0,0 +1,24 @@ +import wrapt +import os + +@wrapt.when_imported('google.generativeai') +def post_import_logic(module): + old_configure = module.configure + def new_configure(*args, **kwargs): + if ('default_metadata' in kwargs): + default_metadata = kwargs['default_metadata'] + else: + default_metadata = [] + kwargs['transport'] = 'rest' # Only support REST requests for now + default_metadata.append(("x-kaggle-proxy-data", os.environ['KAGGLE_DATA_PROXY_TOKEN'])) + default_metadata.append(('x-kaggle-authorization', f'Bearer {os.environ['KAGGLE_USER_SECRETS_TOKEN']}')) + kwargs['default_metadata'] = default_metadata + if ('client_options' in kwargs): + client_options = kwargs['client_options'] + else: + client_options = {} + client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] + '/palmapi' + kwargs['client_options'] = client_options + old_configure(*args, **kwargs) + module.configure = new_configure + module.configure() # generativeai can use GOOGLE_API_KEY env variable, so make sure we have the other configs set From 6508445399a8ba11edb919ef6228a1872b4bb0f8 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Wed, 29 Nov 2023 21:27:50 +0000 Subject: [PATCH 02/20] Typo fix --- Dockerfile.tmpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index 6e000740..f4207db3 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -652,7 +652,7 @@ ADD patches/sitecustomize.py /root/.local/lib/python3.10/site-packages/sitecusto ADD patches/imagemagick-policy.xml /etc/ImageMagick-6/policy.xml # Add generativeai Python patch -AND patches/generativeaipatch.py /root/.local/lib/python3.10/site-packages/generativeaipatch.py +ADD patches/generativeaipatch.py /root/.local/lib/python3.10/site-packages/generativeaipatch.py # Add Kaggle module resolver ADD patches/kaggle_module_resolver.py /opt/conda/lib/python3.10/site-packages/tensorflow_hub/kaggle_module_resolver.py From 684dc6a636c83cb0211f1c6c40466e9777132793 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Thu, 30 Nov 2023 06:21:17 +0000 Subject: [PATCH 03/20] Run python script instead of just adding it to docker image --- Dockerfile.tmpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index f4207db3..6342ebfd 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -652,7 +652,7 @@ ADD patches/sitecustomize.py /root/.local/lib/python3.10/site-packages/sitecusto ADD patches/imagemagick-policy.xml /etc/ImageMagick-6/policy.xml # Add generativeai Python patch -ADD patches/generativeaipatch.py /root/.local/lib/python3.10/site-packages/generativeaipatch.py +CMD python patches/generativeaipatch.py # Add Kaggle module resolver ADD patches/kaggle_module_resolver.py /opt/conda/lib/python3.10/site-packages/tensorflow_hub/kaggle_module_resolver.py From a49c250613fc6bf9c3b9b1c1e1942edfb2db166a Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Thu, 30 Nov 2023 15:47:50 +0000 Subject: [PATCH 04/20] Try using ADD and CMD --- Dockerfile.tmpl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index 6342ebfd..ae5ab937 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -652,7 +652,8 @@ ADD patches/sitecustomize.py /root/.local/lib/python3.10/site-packages/sitecusto ADD patches/imagemagick-policy.xml /etc/ImageMagick-6/policy.xml # Add generativeai Python patch -CMD python patches/generativeaipatch.py +ADD patches/generativeaipatch.py /root/.local/lib/python3.10/site-packages/generativeaipatch.py +CMD ["python", "/root/.local/lib/python3.10/site-packages/generativeaipatch.py"] # Add Kaggle module resolver ADD patches/kaggle_module_resolver.py /opt/conda/lib/python3.10/site-packages/tensorflow_hub/kaggle_module_resolver.py From de9912c3cd2a4255b6a5e7676af268d2a9eafa8e Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Thu, 30 Nov 2023 19:59:26 +0000 Subject: [PATCH 05/20] Add logic to sitecustomize --- Dockerfile.tmpl | 4 ---- patches/generativeaipatch.py | 24 ------------------------ patches/sitecustomize.py | 24 ++++++++++++++++++++++++ 3 files changed, 24 insertions(+), 28 deletions(-) delete mode 100644 patches/generativeaipatch.py diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index ae5ab937..6553b8c7 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -651,10 +651,6 @@ ADD patches/sitecustomize.py /root/.local/lib/python3.10/site-packages/sitecusto # Override default imagemagick policies ADD patches/imagemagick-policy.xml /etc/ImageMagick-6/policy.xml -# Add generativeai Python patch -ADD patches/generativeaipatch.py /root/.local/lib/python3.10/site-packages/generativeaipatch.py -CMD ["python", "/root/.local/lib/python3.10/site-packages/generativeaipatch.py"] - # Add Kaggle module resolver ADD patches/kaggle_module_resolver.py /opt/conda/lib/python3.10/site-packages/tensorflow_hub/kaggle_module_resolver.py RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /opt/conda/lib/python3.10/site-packages/tensorflow_hub/config.py && \ diff --git a/patches/generativeaipatch.py b/patches/generativeaipatch.py deleted file mode 100644 index 40c23dd0..00000000 --- a/patches/generativeaipatch.py +++ /dev/null @@ -1,24 +0,0 @@ -import wrapt -import os - -@wrapt.when_imported('google.generativeai') -def post_import_logic(module): - old_configure = module.configure - def new_configure(*args, **kwargs): - if ('default_metadata' in kwargs): - default_metadata = kwargs['default_metadata'] - else: - default_metadata = [] - kwargs['transport'] = 'rest' # Only support REST requests for now - default_metadata.append(("x-kaggle-proxy-data", os.environ['KAGGLE_DATA_PROXY_TOKEN'])) - default_metadata.append(('x-kaggle-authorization', f'Bearer {os.environ['KAGGLE_USER_SECRETS_TOKEN']}')) - kwargs['default_metadata'] = default_metadata - if ('client_options' in kwargs): - client_options = kwargs['client_options'] - else: - client_options = {} - client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] + '/palmapi' - kwargs['client_options'] = client_options - old_configure(*args, **kwargs) - module.configure = new_configure - module.configure() # generativeai can use GOOGLE_API_KEY env variable, so make sure we have the other configs set diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 2f621f7d..3a72e19c 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -7,6 +7,8 @@ import importlib import importlib.machinery +import wrapt + class GcpModuleFinder(importlib.abc.MetaPathFinder): _MODULES = [ 'google.cloud.bigquery', @@ -73,3 +75,25 @@ def exec_module(self, module): if not hasattr(sys, 'frozen'): sys.meta_path.insert(0, GcpModuleFinder()) + +@wrapt.when_imported('google.generativeai') +def post_import_logic(module): + old_configure = module.configure + def new_configure(*args, **kwargs): + if ('default_metadata' in kwargs): + default_metadata = kwargs['default_metadata'] + else: + default_metadata = [] + kwargs['transport'] = 'rest' # Only support REST requests for now + default_metadata.append(("x-kaggle-proxy-data", os.environ['KAGGLE_DATA_PROXY_TOKEN'])) + default_metadata.append(('x-kaggle-authorization', f'Bearer {os.environ['KAGGLE_USER_SECRETS_TOKEN']}')) + kwargs['default_metadata'] = default_metadata + if ('client_options' in kwargs): + client_options = kwargs['client_options'] + else: + client_options = {} + client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] + '/palmapi' + kwargs['client_options'] = client_options + old_configure(*args, **kwargs) + module.configure = new_configure + module.configure() # generativeai can use GOOGLE_API_KEY env variable, so make sure we have the other configs set From 3477d4dd7c7929343fa3f7de57be0adadfb4617c Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Thu, 30 Nov 2023 20:00:39 +0000 Subject: [PATCH 06/20] Escape path --- patches/sitecustomize.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 3a72e19c..cb2d2b76 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -78,6 +78,8 @@ def exec_module(self, module): @wrapt.when_imported('google.generativeai') def post_import_logic(module): + if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') == 1: + return old_configure = module.configure def new_configure(*args, **kwargs): if ('default_metadata' in kwargs): From a3c4b160d86d9597e39e049c9f80bc6d31ce2940 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Thu, 30 Nov 2023 22:31:12 +0000 Subject: [PATCH 07/20] Address comments --- patches/sitecustomize.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index cb2d2b76..e396820f 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -80,6 +80,8 @@ def exec_module(self, module): def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') == 1: return + if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == 0: + return old_configure = module.configure def new_configure(*args, **kwargs): if ('default_metadata' in kwargs): From 4c4240a55f9da8749416fb7a00937248221c8833 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Thu, 30 Nov 2023 22:33:34 +0000 Subject: [PATCH 08/20] Check for another env variable --- patches/sitecustomize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index e396820f..0b4e8b41 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -80,7 +80,7 @@ def exec_module(self, module): def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') == 1: return - if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == 0: + if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == 0 || os.getenv('KAGGLE_USER_SECRETS_TOKEN') == 0: return old_configure = module.configure def new_configure(*args, **kwargs): From 5a1504a784566cfc6f24fd049dec0311a03fec65 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Fri, 1 Dec 2023 15:28:09 +0000 Subject: [PATCH 09/20] Address comments --- patches/sitecustomize.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 0b4e8b41..19d2f297 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -78,9 +78,9 @@ def exec_module(self, module): @wrapt.when_imported('google.generativeai') def post_import_logic(module): - if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') == 1: + if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: return - if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == 0 || os.getenv('KAGGLE_USER_SECRETS_TOKEN') == 0: + if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None || os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None || os.getenv('KAGGLE_DATA_PROXY_URL') == None: return old_configure = module.configure def new_configure(*args, **kwargs): @@ -97,6 +97,8 @@ def new_configure(*args, **kwargs): else: client_options = {} client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] + '/palmapi' + if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None: + client_options['api_endpoint'] += '/palmapi' kwargs['client_options'] = client_options old_configure(*args, **kwargs) module.configure = new_configure From 4c5f7cd0497771008aecb6205b38ced23dc00f66 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Fri, 1 Dec 2023 19:18:12 +0000 Subject: [PATCH 10/20] Build fix hopefully --- Dockerfile.tmpl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index 61785fc5..f40652d1 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -364,6 +364,8 @@ RUN pip install mpld3 \ eli5 \ kaggle \ kagglehub \ + google-generativeai \ + wrapt \ mock \ pytest && \ /tmp/clean-layer.sh @@ -635,10 +637,6 @@ RUN mkdir -p ~/src && git clone https://github.com/SohierDane/BigQuery_Helper ~/ sed -i 's/)/packages=["bq_helper"])/g' ~/src/BigQuery_Helper/setup.py && \ pip install -e ~/src/BigQuery_Helper && \ /tmp/clean-layer.sh - -RUN pip install wrapt \ - google-generativeai && \ - /tmp/clean-layer.sh # Add BigQuery client proxy settings ENV PYTHONUSERBASE "/root/.local" From 7d4f55db18b11c86eb559925380fd691542bb168 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Fri, 1 Dec 2023 21:03:46 +0000 Subject: [PATCH 11/20] Add unit test for headers --- patches/sitecustomize.py | 9 +++-- tests/test_google_generativeai_patch.py | 47 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 tests/test_google_generativeai_patch.py diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 19d2f297..31758cca 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -79,8 +79,10 @@ def exec_module(self, module): @wrapt.when_imported('google.generativeai') def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: + print('disabled google ai integration') return - if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None || os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None || os.getenv('KAGGLE_DATA_PROXY_URL') == None: + if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or os.getenv('KAGGLE_DATA_PROXY_URL') == None: + print('one of the tokens is not available') return old_configure = module.configure def new_configure(*args, **kwargs): @@ -90,13 +92,14 @@ def new_configure(*args, **kwargs): default_metadata = [] kwargs['transport'] = 'rest' # Only support REST requests for now default_metadata.append(("x-kaggle-proxy-data", os.environ['KAGGLE_DATA_PROXY_TOKEN'])) - default_metadata.append(('x-kaggle-authorization', f'Bearer {os.environ['KAGGLE_USER_SECRETS_TOKEN']}')) + user_secrets_token = os.environ['KAGGLE_USER_SECRETS_TOKEN'] + default_metadata.append(('x-kaggle-authorization', f'Bearer {user_secrets_token}')) kwargs['default_metadata'] = default_metadata if ('client_options' in kwargs): client_options = kwargs['client_options'] else: client_options = {} - client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] + '/palmapi' + client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None: client_options['api_endpoint'] += '/palmapi' kwargs['client_options'] = client_options diff --git a/tests/test_google_generativeai_patch.py b/tests/test_google_generativeai_patch.py new file mode 100644 index 00000000..c1ae8fe9 --- /dev/null +++ b/tests/test_google_generativeai_patch.py @@ -0,0 +1,47 @@ +import unittest +import threading +from test.support.os_helper import EnvironmentVarGuard +from urllib.parse import urlparse + +from http.server import BaseHTTPRequestHandler, HTTPServer + +class HTTPHandler(BaseHTTPRequestHandler): + called = False + data_proxy_token_header_found = False + user_secrets_token_header_found = False + + def do_HEAD(self): + self.send_response(200) + + def do_GET(self): + HTTPHandler.called = True + HTTPHandler.data_proxy_token_header_found = any( + k for k in self.headers if k == "x-kaggle-proxy-data") + HTTPHandler.user_secrets_token_header_found = any( + k for k in self.headers if k == "x-kaggle-authorization") + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + +class TestGoogleGenerativeAiPatch(unittest.TestCase): + def test_headers_are_set(self): + endpoint = "http://127.0.0.1:80" + env = EnvironmentVarGuard() + env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar') + env.set('KAGGLE_DATA_PROXY_TOKEN', 'foobar') + env.set('KAGGLE_DATA_PROXY_URL', endpoint) + server_address = urlparse(endpoint) + with env: + with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: + threading.Thread(target=httpd.serve_forever).start() + import google.generativeai as palm + palm.configure(api_key = "NotARealAPIKey") + try: + for _ in palm.list_models(): + pass + except: + pass + httpd.shutdown() + self.assertTrue(HTTPHandler.data_proxy_token_header_found) + self.assertTrue(HTTPHandler.user_secrets_token_header_found) + From a06ec16ec4240c55c85e1a3a205c8a2865f8a6a4 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Fri, 1 Dec 2023 21:11:37 +0000 Subject: [PATCH 12/20] bug fix --- patches/sitecustomize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 31758cca..ea5c61ac 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -90,7 +90,6 @@ def new_configure(*args, **kwargs): default_metadata = kwargs['default_metadata'] else: default_metadata = [] - kwargs['transport'] = 'rest' # Only support REST requests for now default_metadata.append(("x-kaggle-proxy-data", os.environ['KAGGLE_DATA_PROXY_TOKEN'])) user_secrets_token = os.environ['KAGGLE_USER_SECRETS_TOKEN'] default_metadata.append(('x-kaggle-authorization', f'Bearer {user_secrets_token}')) @@ -102,6 +101,7 @@ def new_configure(*args, **kwargs): client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None: client_options['api_endpoint'] += '/palmapi' + kwargs['transport'] = 'rest' # Only support REST requests for now kwargs['client_options'] = client_options old_configure(*args, **kwargs) module.configure = new_configure From 0f7eaa554ec5c036109bc2f51c0a10b8265b2ae2 Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Fri, 1 Dec 2023 22:02:23 +0000 Subject: [PATCH 13/20] Fix import issue and test --- Dockerfile.tmpl | 1 - patches/sitecustomize.py | 2 -- tests/test_google_generativeai_patch.py | 1 + 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index f40652d1..fd5ccd96 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -365,7 +365,6 @@ RUN pip install mpld3 \ kaggle \ kagglehub \ google-generativeai \ - wrapt \ mock \ pytest && \ /tmp/clean-layer.sh diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index ea5c61ac..9e6af437 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -79,10 +79,8 @@ def exec_module(self, module): @wrapt.when_imported('google.generativeai') def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: - print('disabled google ai integration') return if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or os.getenv('KAGGLE_DATA_PROXY_URL') == None: - print('one of the tokens is not available') return old_configure = module.configure def new_configure(*args, **kwargs): diff --git a/tests/test_google_generativeai_patch.py b/tests/test_google_generativeai_patch.py index c1ae8fe9..0f93ca2c 100644 --- a/tests/test_google_generativeai_patch.py +++ b/tests/test_google_generativeai_patch.py @@ -30,6 +30,7 @@ def test_headers_are_set(self): env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar') env.set('KAGGLE_DATA_PROXY_TOKEN', 'foobar') env.set('KAGGLE_DATA_PROXY_URL', endpoint) + env.set('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY', 'True') server_address = urlparse(endpoint) with env: with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: From f1559e00b8e3f80b52b13371dbb79df43921965b Mon Sep 17 00:00:00 2001 From: Prathamesh Bang Date: Fri, 1 Dec 2023 22:18:06 +0000 Subject: [PATCH 14/20] small modification --- patches/sitecustomize.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 9e6af437..583644df 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -99,7 +99,9 @@ def new_configure(*args, **kwargs): client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None: client_options['api_endpoint'] += '/palmapi' - kwargs['transport'] = 'rest' # Only support REST requests for now + kwargs['transport'] = 'rest' + elif 'transport' in kwargs and kwargs['transport'] == 'rest': + client_options['api_endpoint'] += '/palmapi' kwargs['client_options'] = client_options old_configure(*args, **kwargs) module.configure = new_configure From f89b6a2f36def9bb39f4b0ad73033441bf50b722 Mon Sep 17 00:00:00 2001 From: Philippe Modard Date: Sat, 2 Dec 2023 14:15:29 +0000 Subject: [PATCH 15/20] add a test for disabled proxy for generativeai api --- patches/sitecustomize.py | 10 ++++- tests/test_google_generativeai_patch.py | 40 +++++++++-------- ...test_google_generativeai_patch_disabled.py | 43 +++++++++++++++++++ 3 files changed, 75 insertions(+), 18 deletions(-) create mode 100644 tests/test_google_generativeai_patch_disabled.py diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 583644df..0a4cd7e5 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -79,10 +79,15 @@ def exec_module(self, module): @wrapt.when_imported('google.generativeai') def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: + print("PHILMOD EARLY EXIT") return - if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or os.getenv('KAGGLE_DATA_PROXY_URL') == None: + if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or + os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or + os.getenv('KAGGLE_DATA_PROXY_URL') == None: return + old_configure = module.configure + def new_configure(*args, **kwargs): if ('default_metadata' in kwargs): default_metadata = kwargs['default_metadata'] @@ -92,6 +97,7 @@ def new_configure(*args, **kwargs): user_secrets_token = os.environ['KAGGLE_USER_SECRETS_TOKEN'] default_metadata.append(('x-kaggle-authorization', f'Bearer {user_secrets_token}')) kwargs['default_metadata'] = default_metadata + if ('client_options' in kwargs): client_options = kwargs['client_options'] else: @@ -103,6 +109,8 @@ def new_configure(*args, **kwargs): elif 'transport' in kwargs and kwargs['transport'] == 'rest': client_options['api_endpoint'] += '/palmapi' kwargs['client_options'] = client_options + old_configure(*args, **kwargs) + module.configure = new_configure module.configure() # generativeai can use GOOGLE_API_KEY env variable, so make sure we have the other configs set diff --git a/tests/test_google_generativeai_patch.py b/tests/test_google_generativeai_patch.py index 0f93ca2c..64c51b92 100644 --- a/tests/test_google_generativeai_patch.py +++ b/tests/test_google_generativeai_patch.py @@ -1,5 +1,7 @@ +import json import unittest import threading + from test.support.os_helper import EnvironmentVarGuard from urllib.parse import urlparse @@ -7,42 +9,46 @@ class HTTPHandler(BaseHTTPRequestHandler): called = False - data_proxy_token_header_found = False - user_secrets_token_header_found = False + path = None + headers = {} def do_HEAD(self): self.send_response(200) def do_GET(self): + HTTPHandler.path = self.path + HTTPHandler.headers = self.headers HTTPHandler.called = True - HTTPHandler.data_proxy_token_header_found = any( - k for k in self.headers if k == "x-kaggle-proxy-data") - HTTPHandler.user_secrets_token_header_found = any( - k for k in self.headers if k == "x-kaggle-authorization") self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() -class TestGoogleGenerativeAiPatch(unittest.TestCase): +class TestGoogleGenerativeAiPatch(unittest.TestCase): + endpoint = "http://127.0.0.1:80" + def test_headers_are_set(self): - endpoint = "http://127.0.0.1:80" env = EnvironmentVarGuard() - env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar') - env.set('KAGGLE_DATA_PROXY_TOKEN', 'foobar') - env.set('KAGGLE_DATA_PROXY_URL', endpoint) - env.set('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY', 'True') - server_address = urlparse(endpoint) + secrets_token = "secrets_token" + proxy_token = "proxy_token" + env.set("KAGGLE_USER_SECRETS_TOKEN", secrets_token) + env.set("KAGGLE_DATA_PROXY_TOKEN", proxy_token) + env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) + env.set("KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY", "True") + server_address = urlparse(self.endpoint) with env: with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: threading.Thread(target=httpd.serve_forever).start() import google.generativeai as palm - palm.configure(api_key = "NotARealAPIKey") + api_key = "NotARealAPIKey" + palm.configure(api_key = api_key) try: for _ in palm.list_models(): pass except: pass httpd.shutdown() - self.assertTrue(HTTPHandler.data_proxy_token_header_found) - self.assertTrue(HTTPHandler.user_secrets_token_header_found) - + self.assertTrue(HTTPHandler.called) + self.assertIn("/palmapi", HTTPHandler.path) + self.assertEqual(proxy_token, HTTPHandler.headers["x-kaggle-proxy-data"]) + self.assertEqual("Bearer {}".format(secrets_token), HTTPHandler.headers["x-kaggle-authorization"]) + self.assertEqual(api_key, HTTPHandler.headers["x-goog-api-key"]) diff --git a/tests/test_google_generativeai_patch_disabled.py b/tests/test_google_generativeai_patch_disabled.py new file mode 100644 index 00000000..3090b1fb --- /dev/null +++ b/tests/test_google_generativeai_patch_disabled.py @@ -0,0 +1,43 @@ +import json +import unittest +import threading + +from test.support.os_helper import EnvironmentVarGuard +from urllib.parse import urlparse + +from http.server import BaseHTTPRequestHandler, HTTPServer + +class HTTPHandler(BaseHTTPRequestHandler): + called = False + + def do_HEAD(self): + self.send_response(200) + + def do_GET(self): + HTTPHandler.called = True + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + +class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase): + endpoint = "http://127.0.0.1:80" + + def test_disabled(self): + env = EnvironmentVarGuard() + env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar") + env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar") + env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) + env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True") + server_address = urlparse(self.endpoint) + with env: + with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: + threading.Thread(target=httpd.serve_forever).start() + import google.generativeai as palm + palm.configure(api_key = "NotARealAPIKey") + try: + for _ in palm.list_models(): + pass + except: + pass + httpd.shutdown() + self.assertFalse(HTTPHandler.called) From a9d45f9037cf9b607e6eb705decf5efc50b094c1 Mon Sep 17 00:00:00 2001 From: Philippe Modard Date: Sat, 2 Dec 2023 14:19:41 +0000 Subject: [PATCH 16/20] remove print --- patches/sitecustomize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 0a4cd7e5..2b6e9e35 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -79,7 +79,6 @@ def exec_module(self, module): @wrapt.when_imported('google.generativeai') def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: - print("PHILMOD EARLY EXIT") return if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or From f7b69ea6e926906ced3efe96db6228dec6486663 Mon Sep 17 00:00:00 2001 From: Philmod Date: Sat, 2 Dec 2023 18:43:32 +0100 Subject: [PATCH 17/20] change name of test --- tests/test_google_generativeai_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_google_generativeai_patch.py b/tests/test_google_generativeai_patch.py index 64c51b92..68e766c2 100644 --- a/tests/test_google_generativeai_patch.py +++ b/tests/test_google_generativeai_patch.py @@ -26,7 +26,7 @@ def do_GET(self): class TestGoogleGenerativeAiPatch(unittest.TestCase): endpoint = "http://127.0.0.1:80" - def test_headers_are_set(self): + def test_proxy_enabled(self): env = EnvironmentVarGuard() secrets_token = "secrets_token" proxy_token = "proxy_token" From dfb9145654435cfb31bd2611db069721d58c2c0a Mon Sep 17 00:00:00 2001 From: Philmod Date: Sun, 3 Dec 2023 11:24:57 +0100 Subject: [PATCH 18/20] remove disabled test to check if that's the reason it's broken on ci --- ...test_google_generativeai_patch_disabled.py | 43 ------------------- 1 file changed, 43 deletions(-) delete mode 100644 tests/test_google_generativeai_patch_disabled.py diff --git a/tests/test_google_generativeai_patch_disabled.py b/tests/test_google_generativeai_patch_disabled.py deleted file mode 100644 index 3090b1fb..00000000 --- a/tests/test_google_generativeai_patch_disabled.py +++ /dev/null @@ -1,43 +0,0 @@ -import json -import unittest -import threading - -from test.support.os_helper import EnvironmentVarGuard -from urllib.parse import urlparse - -from http.server import BaseHTTPRequestHandler, HTTPServer - -class HTTPHandler(BaseHTTPRequestHandler): - called = False - - def do_HEAD(self): - self.send_response(200) - - def do_GET(self): - HTTPHandler.called = True - self.send_response(200) - self.send_header("Content-type", "application/json") - self.end_headers() - -class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase): - endpoint = "http://127.0.0.1:80" - - def test_disabled(self): - env = EnvironmentVarGuard() - env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar") - env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar") - env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) - env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True") - server_address = urlparse(self.endpoint) - with env: - with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: - threading.Thread(target=httpd.serve_forever).start() - import google.generativeai as palm - palm.configure(api_key = "NotARealAPIKey") - try: - for _ in palm.list_models(): - pass - except: - pass - httpd.shutdown() - self.assertFalse(HTTPHandler.called) From ff335d68b202a40b600be64dc159e55b7e7403ce Mon Sep 17 00:00:00 2001 From: Philippe Modard Date: Sun, 3 Dec 2023 14:42:54 +0000 Subject: [PATCH 19/20] readd disabled test --- patches/sitecustomize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 2b6e9e35..6ac7400e 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -80,9 +80,9 @@ def exec_module(self, module): def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: return - if os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or + if (os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or - os.getenv('KAGGLE_DATA_PROXY_URL') == None: + os.getenv('KAGGLE_DATA_PROXY_URL') == None): return old_configure = module.configure From 10b9a815f9efc79b4e91179dd970a02a515b92ea Mon Sep 17 00:00:00 2001 From: Philippe Modard Date: Sun, 3 Dec 2023 14:43:02 +0000 Subject: [PATCH 20/20] readd disabled test --- ...test_google_generativeai_patch_disabled.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/test_google_generativeai_patch_disabled.py diff --git a/tests/test_google_generativeai_patch_disabled.py b/tests/test_google_generativeai_patch_disabled.py new file mode 100644 index 00000000..65e02845 --- /dev/null +++ b/tests/test_google_generativeai_patch_disabled.py @@ -0,0 +1,43 @@ +import json +import unittest +import threading + +from test.support.os_helper import EnvironmentVarGuard +from urllib.parse import urlparse + +from http.server import BaseHTTPRequestHandler, HTTPServer + +class HTTPHandler(BaseHTTPRequestHandler): + called = False + + def do_HEAD(self): + self.send_response(200) + + def do_GET(self): + HTTPHandler.called = True + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + +class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase): + endpoint = "http://127.0.0.1:80" + + def test_disabled(self): + env = EnvironmentVarGuard() + env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar") + env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar") + env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) + env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True") + server_address = urlparse(self.endpoint) + with env: + with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: + threading.Thread(target=httpd.serve_forever).start() + import google.generativeai as palm + palm.configure(api_key = "NotARealAPIKey") + try: + for _ in palm.list_models(): + pass + except: + pass + httpd.shutdown() + self.assertFalse(HTTPHandler.called) \ No newline at end of file