Skip to content
This repository has been archived by the owner on Sep 3, 2022. It is now read-only.

Commit

Permalink
Support for fetching gcs dag location from Composer environment
Browse files Browse the repository at this point in the history
  • Loading branch information
Rajiv Bharadwaja committed Nov 10, 2017
1 parent a14ae52 commit 94a3a75
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 20 deletions.
3 changes: 2 additions & 1 deletion google/datalab/contrib/pipeline/composer/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class Api(object):

_DEFAULT_TIMEOUT = 60000

def environment_details_get(self, zone, environment):
@staticmethod
def environment_details_get(zone, environment):
""" Issues a request to load data from GCS to a BQ table
Args:
Expand Down
37 changes: 25 additions & 12 deletions google/datalab/contrib/pipeline/composer/_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# the License.

import google.cloud.storage as gcs
from google.datalab.contrib.pipeline.composer._api import Api


class Composer(object):
Expand All @@ -29,22 +30,34 @@ def __init__(self, zone, environment):
"""
self._zone = zone
self._environment = environment
self._gcs_dag_location = None

def deploy(self, name, dag_string):
client = gcs.Client()
bucket = client.get_bucket(self.bucket_name)
filename = 'dags/{0}.py'.format(name)
try:
gcs_dag_location_splits = self.gcs_dag_location.split('/')
bucket_name = gcs_dag_location_splits[2]
# Usually the splits are like ['gs:', '', 'foo_bucket', 'dags']. But we could have additional
# parts after the bucket. In those cases, the final file path needs to include those as well
additional_parts = ''
if len(gcs_dag_location_splits) > 4:
additional_parts = '/' + '/'.join(gcs_dag_location_splits[4:])
filename = self.gcs_dag_location.split('/')[3] + additional_parts + '/{0}.py'.format(name)
except (AttributeError, IndexError):
raise ValueError('Error in dag location from Composer environment {0}'.format(
self._environment))

bucket = client.get_bucket(bucket_name)
blob = gcs.Blob(filename, bucket)
blob.upload_from_string(dag_string)

@property
def bucket_name(self):
# TODO(rajivpb): Get this programmatically from the Composer API
return 'airflow-staging-test36490808-bucket'

@property
def get_bucket_name(self):
# environment_details = Api().environment_details_get(self._zone, self._environment)

# TODO(rajivpb): Get this programmatically from the Composer API
return 'airflow-staging-test36490808-bucket'
def gcs_dag_location(self):
if not self._gcs_dag_location:
environment_details = Api.environment_details_get(self._zone, self._environment)
if 'config' not in environment_details \
or 'gcsDagLocation' not in environment_details.get('config'):
raise ValueError('Dag location unavailable from Composer environment {0}'.format(
self._environment))
self._gcs_dag_location = environment_details['config']['gcsDagLocation']
return self._gcs_dag_location
9 changes: 8 additions & 1 deletion tests/bigquery/pipeline_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def test_get_execute_parameters(self, mock_notebook_item):

self.assertDictEqual(actual_execute_config, expected_execute_config)

@mock.patch('google.datalab.contrib.pipeline.composer._api.Api.environment_details_get')
@mock.patch('google.datalab.Context.default')
@mock.patch('google.datalab.utils.commands.notebook_environment')
@mock.patch('google.datalab.utils.commands.get_notebook_item')
Expand All @@ -445,7 +446,7 @@ def test_get_execute_parameters(self, mock_notebook_item):
@mock.patch('google.cloud.storage.Client.get_bucket')
def test_pipeline_cell_golden(self, mock_client_get_bucket, mock_client, mock_blob_class,
mock_get_table, mock_table_exists, mock_notebook_item,
mock_environment, mock_default_context):
mock_environment, mock_default_context, mock_composer_env):
table = google.datalab.bigquery.Table('project.test.table')
mock_get_table.return_value = table
mock_table_exists.return_value = True
Expand All @@ -454,6 +455,12 @@ def test_pipeline_cell_golden(self, mock_client_get_bucket, mock_client, mock_bl
mock_client_get_bucket.return_value = mock.Mock(spec=google.cloud.storage.Bucket)
mock_blob = mock_blob_class.return_value

mock_composer_env.return_value = {
'config': {
'gcsDagLocation': 'gs://foo_bucket/dags'
}
}

env = {
'endpoint': 'Interact2',
'job_id': '1234',
Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline/composer_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def validate(self, mock_http_request, expected_url, expected_args=None, expected
@mock.patch('google.datalab.utils.Http.request')
def test_environment_details_get(self, mock_http_request, mock_context_default):
mock_context_default.return_value = TestCases._create_context()
Api().environment_details_get('ZONE', 'ENVIRONMENT')
Api.environment_details_get('ZONE', 'ENVIRONMENT')
self.validate(mock_http_request,
'https://composer.googleapis.com/v1alpha1/projects/test_project/locations/ZONE/'
'environments/ENVIRONMENT', expected_args={'timeoutMs': 60000})
Expand Down
45 changes: 40 additions & 5 deletions tests/pipeline/composer_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,54 @@
import unittest
import mock

import google.auth
import google.datalab.utils
from google.datalab.contrib.pipeline.composer._composer import Composer


class TestCases(unittest.TestCase):

@mock.patch('google.cloud.storage.Client')
@mock.patch('google.datalab.Context.default')
@mock.patch('google.cloud.storage.Blob')
@mock.patch('google.cloud.storage.Client')
@mock.patch('google.cloud.storage.Client.get_bucket')
def test_deploy(self, mock_client_get_bucket, mock_blob_class, mock_client):
mock_client_get_bucket.return_value = mock.Mock(spec=google.cloud.storage.Bucket)
@mock.patch('google.datalab.contrib.pipeline.composer._api.Api.environment_details_get')
def test_deploy(self, mock_environment_details, mock_client_get_bucket, mock_client,
mock_blob_class, mock_default_context):
mock_environment_details.return_value = {
'config': {
'gcsDagLocation': 'gs://foo_bucket/dags'
}
}
test_composer = Composer('foo_zone', 'foo_environment')
test_composer.deploy('foo_name', 'foo_dag_string')
mock_blob_class.assert_called_with('dags/foo_name.py', mock.ANY)
mock_blob = mock_blob_class.return_value
mock_blob.upload_from_string.assert_called_with('foo_dag_string')

mock_environment_details.return_value = {
'config': {
'gcsDagLocation': 'gs://foo_bucket/foo_random/dags'
}
}
test_composer = Composer('foo_zone', 'foo_environment')
test_composer.deploy('foo_name', 'foo_dag_string')
mock_blob_class.assert_called_with('foo_random/dags/foo_name.py', mock.ANY)
mock_blob = mock_blob_class.return_value
mock_blob.upload_from_string.assert_called_with('foo_dag_string')

# API returns empty result
mock_environment_details.return_value = {}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError, 'Dag location unavailable from Composer environment foo_environment'):
test_composer.deploy('foo_name', 'foo_dag_string')

# GCS file-path is None
mock_environment_details.return_value = {
'config': {
'gcsDagLocation': None
}
}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError, 'Error in dag location from Composer environment foo_environment'):
test_composer.deploy('foo_name', 'foo_dag_string')

0 comments on commit 94a3a75

Please sign in to comment.