From aadc6c6c493c32c0f9686a6b7dddfb14c04614ca Mon Sep 17 00:00:00 2001 From: Rajiv Bharadwaja Date: Thu, 9 Nov 2017 12:31:14 -0800 Subject: [PATCH] (Incomplete) Support for fetching gcs dag location from Composer environment --- .../datalab/contrib/pipeline/composer/_api.py | 3 +- .../contrib/pipeline/composer/_composer.py | 33 +++++++---- tests/bigquery/pipeline_tests.py | 9 ++- tests/pipeline/composer_api_tests.py | 2 +- tests/pipeline/composer_tests.py | 57 ++++++++++++++++++- 5 files changed, 88 insertions(+), 16 deletions(-) diff --git a/google/datalab/contrib/pipeline/composer/_api.py b/google/datalab/contrib/pipeline/composer/_api.py index 487902fb5..ef98f4a3f 100644 --- a/google/datalab/contrib/pipeline/composer/_api.py +++ b/google/datalab/contrib/pipeline/composer/_api.py @@ -22,7 +22,8 @@ class Api(object): _DEFAULT_TIMEOUT = 60000 - def environment_details_get(self, zone, environment): + @staticmethod + def environment_details_get(zone, environment): """ Issues a request to load data from GCS to a BQ table Args: diff --git a/google/datalab/contrib/pipeline/composer/_composer.py b/google/datalab/contrib/pipeline/composer/_composer.py index ab4cc323b..cf0f3f618 100644 --- a/google/datalab/contrib/pipeline/composer/_composer.py +++ b/google/datalab/contrib/pipeline/composer/_composer.py @@ -11,6 +11,7 @@ # the License. import google.cloud.storage as gcs +from google.datalab.contrib.pipeline.composer._api import Api class Composer(object): @@ -29,22 +30,30 @@ def __init__(self, zone, environment): """ self._zone = zone self._environment = environment + self._gcs_dag_location = None def deploy(self, name, dag_string): client = gcs.Client() - bucket = client.get_bucket(self.bucket_name) - filename = 'dags/{0}.py'.format(name) + try: + bucket_name = self.gcs_dag_location.split('/')[2] + # [3] should have 'dags', and the rest should be concatenated with it + filename = self.gcs_dag_location.split('/')[3] \ + + '/'.join(self.gcs_dag_location.split('/')[4:]) + '/{0}.py'.format(name) + except (AttributeError, IndexError): + raise ValueError('Error in dag location from Composer environment {0}'.format( + self._environment)) + + bucket = client.get_bucket(bucket_name) blob = gcs.Blob(filename, bucket) blob.upload_from_string(dag_string) @property - def bucket_name(self): - # TODO(rajivpb): Get this programmatically from the Composer API - return 'airflow-staging-test36490808-bucket' - - @property - def get_bucket_name(self): - # environment_details = Api().environment_details_get(self._zone, self._environment) - - # TODO(rajivpb): Get this programmatically from the Composer API - return 'airflow-staging-test36490808-bucket' + def gcs_dag_location(self): + if not self._gcs_dag_location: + environment_details = Api.environment_details_get(self._zone, self._environment) + if 'config' not in environment_details \ + or 'gcsDagLocation' not in environment_details.get('config'): + raise ValueError('Dag location unavailable from Composer environment {0}'.format( + self._environment)) + self._gcs_dag_location = environment_details['config']['gcsDagLocation'] + return self._gcs_dag_location diff --git a/tests/bigquery/pipeline_tests.py b/tests/bigquery/pipeline_tests.py index 1ef88ed0c..27edf2e82 100644 --- a/tests/bigquery/pipeline_tests.py +++ b/tests/bigquery/pipeline_tests.py @@ -435,6 +435,7 @@ def test_get_execute_parameters(self, mock_notebook_item): self.assertDictEqual(actual_execute_config, expected_execute_config) + @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.environment_details_get') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.utils.commands.get_notebook_item') @@ -445,7 +446,7 @@ def test_get_execute_parameters(self, mock_notebook_item): @mock.patch('google.cloud.storage.Client.get_bucket') def test_pipeline_cell_golden(self, mock_client_get_bucket, mock_client, mock_blob_class, mock_get_table, mock_table_exists, mock_notebook_item, - mock_environment, mock_default_context): + mock_environment, mock_default_context, mock_composer_env): table = google.datalab.bigquery.Table('project.test.table') mock_get_table.return_value = table mock_table_exists.return_value = True @@ -454,6 +455,12 @@ def test_pipeline_cell_golden(self, mock_client_get_bucket, mock_client, mock_bl mock_client_get_bucket.return_value = mock.Mock(spec=google.cloud.storage.Bucket) mock_blob = mock_blob_class.return_value + mock_composer_env.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/dags' + } + } + env = { 'endpoint': 'Interact2', 'job_id': '1234', diff --git a/tests/pipeline/composer_api_tests.py b/tests/pipeline/composer_api_tests.py index 93c782fd9..0f0397ace 100644 --- a/tests/pipeline/composer_api_tests.py +++ b/tests/pipeline/composer_api_tests.py @@ -48,7 +48,7 @@ def validate(self, mock_http_request, expected_url, expected_args=None, expected @mock.patch('google.datalab.utils.Http.request') def test_environment_details_get(self, mock_http_request, mock_context_default): mock_context_default.return_value = TestCases._create_context() - Api().environment_details_get('ZONE', 'ENVIRONMENT') + Api.environment_details_get('ZONE', 'ENVIRONMENT') self.validate(mock_http_request, 'https://composer.googleapis.com/v1alpha1/projects/test_project/locations/ZONE/' 'environments/ENVIRONMENT', expected_args={'timeoutMs': 60000}) diff --git a/tests/pipeline/composer_tests.py b/tests/pipeline/composer_tests.py index 22d3dab71..10f326860 100644 --- a/tests/pipeline/composer_tests.py +++ b/tests/pipeline/composer_tests.py @@ -20,12 +20,67 @@ class TestCases(unittest.TestCase): + @mock.patch('google.datalab.Context.default') @mock.patch('google.cloud.storage.Client') @mock.patch('google.cloud.storage.Blob') @mock.patch('google.cloud.storage.Client.get_bucket') - def test_deploy(self, mock_client_get_bucket, mock_blob_class, mock_client): + @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.environment_details_get') + def test_deploy(self, mock_environment_details, mock_client_get_bucket, mock_blob_class, + mock_client, mock_default_context): mock_client_get_bucket.return_value = mock.Mock(spec=google.cloud.storage.Bucket) mock_blob = mock_blob_class.return_value + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/dags' + } + } test_composer = Composer('foo_zone', 'foo_environment') test_composer.deploy('foo_name', 'foo_dag_string') mock_blob.upload_from_string.assert_called_with('foo_dag_string') + + # Happy path + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/dags' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + test_composer.deploy('foo_name', 'foo_dag_string') + mock_blob_class.assert_called_with('/dags/foo_name.py', mock.) + mock_blob.upload_from_string.assert_called_with('foo_dag_string') + + # Happy path + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/foo_random/dags' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + test_composer.deploy('foo_name', 'foo_dag_string') + mock_blob.upload_from_string.assert_called_with('foo_dag_string') + + mock_environment_details.return_value = {} + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, 'Dag location unavailable from Composer environment foo_environment'): + test_composer.deploy('foo_name', 'foo_dag_string') + + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': None + } + } + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, 'Error in dag location from Composer environment foo_environment'): + test_composer.deploy('foo_name', 'foo_dag_string') + + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs:/' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, 'Error in dag location from Composer environment foo_environment'): + test_composer.deploy('foo_name', 'foo_dag_string')