diff --git a/google/datalab/contrib/pipeline/composer/_api.py b/google/datalab/contrib/pipeline/composer/_api.py index 487902fb5..ef98f4a3f 100644 --- a/google/datalab/contrib/pipeline/composer/_api.py +++ b/google/datalab/contrib/pipeline/composer/_api.py @@ -22,7 +22,8 @@ class Api(object): _DEFAULT_TIMEOUT = 60000 - def environment_details_get(self, zone, environment): + @staticmethod + def environment_details_get(zone, environment): """ Issues a request to load data from GCS to a BQ table Args: diff --git a/google/datalab/contrib/pipeline/composer/_composer.py b/google/datalab/contrib/pipeline/composer/_composer.py index ab4cc323b..13a61d91a 100644 --- a/google/datalab/contrib/pipeline/composer/_composer.py +++ b/google/datalab/contrib/pipeline/composer/_composer.py @@ -11,6 +11,7 @@ # the License. import google.cloud.storage as gcs +from google.datalab.contrib.pipeline.composer._api import Api class Composer(object): @@ -29,22 +30,34 @@ def __init__(self, zone, environment): """ self._zone = zone self._environment = environment + self._gcs_dag_location = None def deploy(self, name, dag_string): client = gcs.Client() - bucket = client.get_bucket(self.bucket_name) - filename = 'dags/{0}.py'.format(name) + try: + gcs_dag_location_splits = self.gcs_dag_location.split('/') + bucket_name = gcs_dag_location_splits[2] + # Usually the splits are like ['gs:', '', 'foo_bucket', 'dags']. But we could have additional + # parts after the bucket. In those cases, the final file path needs to include those as well + additional_parts = '' + if len(gcs_dag_location_splits) > 4: + additional_parts = '/' + '/'.join(gcs_dag_location_splits[4:]) + filename = self.gcs_dag_location.split('/')[3] + additional_parts + '/{0}.py'.format(name) + except (AttributeError, IndexError): + raise ValueError('Error in dag location from Composer environment {0}'.format( + self._environment)) + + bucket = client.get_bucket(bucket_name) blob = gcs.Blob(filename, bucket) blob.upload_from_string(dag_string) @property - def bucket_name(self): - # TODO(rajivpb): Get this programmatically from the Composer API - return 'airflow-staging-test36490808-bucket' - - @property - def get_bucket_name(self): - # environment_details = Api().environment_details_get(self._zone, self._environment) - - # TODO(rajivpb): Get this programmatically from the Composer API - return 'airflow-staging-test36490808-bucket' + def gcs_dag_location(self): + if not self._gcs_dag_location: + environment_details = Api.environment_details_get(self._zone, self._environment) + if 'config' not in environment_details \ + or 'gcsDagLocation' not in environment_details.get('config'): + raise ValueError('Dag location unavailable from Composer environment {0}'.format( + self._environment)) + self._gcs_dag_location = environment_details['config']['gcsDagLocation'] + return self._gcs_dag_location diff --git a/tests/bigquery/pipeline_tests.py b/tests/bigquery/pipeline_tests.py index 1ef88ed0c..27edf2e82 100644 --- a/tests/bigquery/pipeline_tests.py +++ b/tests/bigquery/pipeline_tests.py @@ -435,6 +435,7 @@ def test_get_execute_parameters(self, mock_notebook_item): self.assertDictEqual(actual_execute_config, expected_execute_config) + @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.environment_details_get') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.utils.commands.get_notebook_item') @@ -445,7 +446,7 @@ def test_get_execute_parameters(self, mock_notebook_item): @mock.patch('google.cloud.storage.Client.get_bucket') def test_pipeline_cell_golden(self, mock_client_get_bucket, mock_client, mock_blob_class, mock_get_table, mock_table_exists, mock_notebook_item, - mock_environment, mock_default_context): + mock_environment, mock_default_context, mock_composer_env): table = google.datalab.bigquery.Table('project.test.table') mock_get_table.return_value = table mock_table_exists.return_value = True @@ -454,6 +455,12 @@ def test_pipeline_cell_golden(self, mock_client_get_bucket, mock_client, mock_bl mock_client_get_bucket.return_value = mock.Mock(spec=google.cloud.storage.Bucket) mock_blob = mock_blob_class.return_value + mock_composer_env.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/dags' + } + } + env = { 'endpoint': 'Interact2', 'job_id': '1234', diff --git a/tests/pipeline/composer_api_tests.py b/tests/pipeline/composer_api_tests.py index 93c782fd9..0f0397ace 100644 --- a/tests/pipeline/composer_api_tests.py +++ b/tests/pipeline/composer_api_tests.py @@ -48,7 +48,7 @@ def validate(self, mock_http_request, expected_url, expected_args=None, expected @mock.patch('google.datalab.utils.Http.request') def test_environment_details_get(self, mock_http_request, mock_context_default): mock_context_default.return_value = TestCases._create_context() - Api().environment_details_get('ZONE', 'ENVIRONMENT') + Api.environment_details_get('ZONE', 'ENVIRONMENT') self.validate(mock_http_request, 'https://composer.googleapis.com/v1alpha1/projects/test_project/locations/ZONE/' 'environments/ENVIRONMENT', expected_args={'timeoutMs': 60000}) diff --git a/tests/pipeline/composer_tests.py b/tests/pipeline/composer_tests.py index 22d3dab71..57592ef24 100644 --- a/tests/pipeline/composer_tests.py +++ b/tests/pipeline/composer_tests.py @@ -13,19 +13,54 @@ import unittest import mock -import google.auth -import google.datalab.utils from google.datalab.contrib.pipeline.composer._composer import Composer class TestCases(unittest.TestCase): - @mock.patch('google.cloud.storage.Client') + @mock.patch('google.datalab.Context.default') @mock.patch('google.cloud.storage.Blob') + @mock.patch('google.cloud.storage.Client') @mock.patch('google.cloud.storage.Client.get_bucket') - def test_deploy(self, mock_client_get_bucket, mock_blob_class, mock_client): - mock_client_get_bucket.return_value = mock.Mock(spec=google.cloud.storage.Bucket) + @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.environment_details_get') + def test_deploy(self, mock_environment_details, mock_client_get_bucket, mock_client, + mock_blob_class, mock_default_context): + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/dags' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + test_composer.deploy('foo_name', 'foo_dag_string') + mock_blob_class.assert_called_with('dags/foo_name.py', mock.ANY) mock_blob = mock_blob_class.return_value + mock_blob.upload_from_string.assert_called_with('foo_dag_string') + + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/foo_random/dags' + } + } test_composer = Composer('foo_zone', 'foo_environment') test_composer.deploy('foo_name', 'foo_dag_string') + mock_blob_class.assert_called_with('foo_random/dags/foo_name.py', mock.ANY) + mock_blob = mock_blob_class.return_value mock_blob.upload_from_string.assert_called_with('foo_dag_string') + + # API returns empty result + mock_environment_details.return_value = {} + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, 'Dag location unavailable from Composer environment foo_environment'): + test_composer.deploy('foo_name', 'foo_dag_string') + + # GCS file-path is None + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': None + } + } + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, 'Error in dag location from Composer environment foo_environment'): + test_composer.deploy('foo_name', 'foo_dag_string')