Skip to content
This repository has been archived by the owner on Sep 3, 2022. It is now read-only.

Commit

Permalink
Refactor all gcs_dag_location checks into the property.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rajiv Bharadwaja committed Nov 10, 2017
1 parent bb69916 commit 58b9167
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 25 deletions.
34 changes: 18 additions & 16 deletions google/datalab/contrib/pipeline/composer/_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# the License.

import google.cloud.storage as gcs
import re
from google.datalab.contrib.pipeline.composer._api import Api


Expand All @@ -21,6 +22,8 @@ class Composer(object):
This object can be used to generate the python airflow spec.
"""

gcs_file_regexp = re.compile('gs://.*')

def __init__(self, zone, environment):
""" Initializes an instance of a Composer object.
Expand All @@ -34,30 +37,29 @@ def __init__(self, zone, environment):

def deploy(self, name, dag_string):
client = gcs.Client()
try:
gcs_dag_location_splits = self.gcs_dag_location.split('/')
bucket_name = gcs_dag_location_splits[2]
# Usually the splits are like ['gs:', '', 'foo_bucket', 'dags']. But we could have additional
# parts after the bucket. In those cases, the final file path needs to include those as well
additional_parts = ''
if len(gcs_dag_location_splits) > 4:
additional_parts = '/' + '/'.join(gcs_dag_location_splits[4:])
filename = self.gcs_dag_location.split('/')[3] + additional_parts + '/{0}.py'.format(name)
except (AttributeError, IndexError):
raise ValueError('Error in dag location from Composer environment {0}'.format(
self._environment))
_, _, bucket_name, file_path = self.gcs_dag_location.split('/', 3) # setting maxsplit to 3
file_name = '{0}{1}.py'.format(file_path, name)

bucket = client.get_bucket(bucket_name)
blob = gcs.Blob(filename, bucket)
blob = gcs.Blob(file_name, bucket)
blob.upload_from_string(dag_string)

@property
def gcs_dag_location(self):
if not self._gcs_dag_location:
environment_details = Api.environment_details_get(self._zone, self._environment)
if 'config' not in environment_details \
or 'gcsDagLocation' not in environment_details.get('config'):

if ('config' not in environment_details or
'gcsDagLocation' not in environment_details.get('config')):
raise ValueError('Dag location unavailable from Composer environment {0}'.format(
self._environment))
self._gcs_dag_location = environment_details['config']['gcsDagLocation']
gcs_dag_location = environment_details['config']['gcsDagLocation']

if gcs_dag_location is None or not self.gcs_file_regexp.match(gcs_dag_location):
raise ValueError(
'Dag location {0} from Composer environment {1} is in incorrect format'.format(
gcs_dag_location, self._environment))

self._gcs_dag_location = gcs_dag_location + '/'

return self._gcs_dag_location
75 changes: 66 additions & 9 deletions tests/pipeline/composer_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,31 @@ def test_deploy(self, mock_environment_details, mock_client, mock_blob_class,
mock_blob = mock_blob_class.return_value
mock_blob.upload_from_string.assert_called_with('foo_dag_string')

# GCS dag location has additional parts
# Only bucket with no path
mock_environment_details.return_value = {
'config': {
'gcsDagLocation': 'gs://foo_bucket/foo_random/dags'
'gcsDagLocation': 'gs://foo_bucket'
}
}
test_composer = Composer('foo_zone', 'foo_environment')
test_composer.deploy('foo_name', 'foo_dag_string')
mock_client.return_value.get_bucket.assert_called_with('foo_bucket')
mock_blob_class.assert_called_with('foo_random/dags/foo_name.py', mock.ANY)
mock_blob_class.assert_called_with('foo_name.py', mock.ANY)
mock_blob = mock_blob_class.return_value
mock_blob.upload_from_string.assert_called_with('foo_dag_string')

# GCS file-path is None
# GCS dag location has additional parts
mock_environment_details.return_value = {
'config': {
'gcsDagLocation': None
'gcsDagLocation': 'gs://foo_bucket/foo_random/dags'
}
}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError, 'Error in dag location from Composer environment foo_environment'):
test_composer.deploy('foo_name', 'foo_dag_string')
test_composer.deploy('foo_name', 'foo_dag_string')
mock_client.return_value.get_bucket.assert_called_with('foo_bucket')
mock_blob_class.assert_called_with('foo_random/dags/foo_name.py', mock.ANY)
mock_blob = mock_blob_class.return_value
mock_blob.upload_from_string.assert_called_with('foo_dag_string')

@mock.patch('google.datalab.contrib.pipeline.composer._api.Api.environment_details_get')
def test_gcs_dag_location(self, mock_environment_details):
Expand All @@ -70,11 +72,66 @@ def test_gcs_dag_location(self, mock_environment_details):
}
}
test_composer = Composer('foo_zone', 'foo_environment')
self.assertEqual('gs://foo_bucket/dags', test_composer.gcs_dag_location)
self.assertEqual('gs://foo_bucket/dags/', test_composer.gcs_dag_location)

# Composer returns good result
mock_environment_details.return_value = {
'config': {
'gcsDagLocation': 'gs://foo_bucket'
}
}
test_composer = Composer('foo_zone', 'foo_environment')
self.assertEqual('gs://foo_bucket/', test_composer.gcs_dag_location)

# Composer returns empty result
mock_environment_details.return_value = {}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError, 'Dag location unavailable from Composer environment foo_environment'):
test_composer.gcs_dag_location

# Composer returns empty result
mock_environment_details.return_value = {
'config': {}
}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError, 'Dag location unavailable from Composer environment foo_environment'):
test_composer.gcs_dag_location

# Composer returns None result
mock_environment_details.return_value = {
'config': {
'gcsDagLocation': None
}
}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError,
'Dag location None from Composer environment foo_environment is in incorrect format'):
test_composer.gcs_dag_location

# Composer returns incorrect formats
mock_environment_details.return_value = {
'config': {
'gcsDagLocation': 'gs:/foo_bucket'
}
}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError,
('Dag location gs:/foo_bucket from Composer environment foo_environment is in'
' incorrect format')):
test_composer.gcs_dag_location

mock_environment_details.return_value = {
'config': {
'gcsDagLocation': 'as://foo_bucket'
}
}
test_composer = Composer('foo_zone', 'foo_environment')
with self.assertRaisesRegexp(
ValueError,
('Dag location as://foo_bucket from Composer environment foo_environment is in'
' incorrect format')):
test_composer.gcs_dag_location

0 comments on commit 58b9167

Please sign in to comment.