Skip to content
This repository has been archived by the owner on Sep 3, 2022. It is now read-only.

Commit

Permalink
Composer integration for %%bq pipeline (#682)
Browse files Browse the repository at this point in the history
* Composer integration for %%bq pipeline

* Addressing code-review feedback
  • Loading branch information
rajivpb authored May 9, 2018
1 parent 60d015b commit ec3002a
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 3 deletions.
16 changes: 16 additions & 0 deletions google/datalab/bigquery/commands/_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,12 @@ def _create_pipeline_subparser(parser):
help='The Google Cloud Storage bucket for the Airflow dags.')
pipeline_parser.add_argument('-f', '--gcs_dag_file_path', type=str,
help='The file path suffix for the Airflow dags.')
pipeline_parser.add_argument('-e', '--environment', type=str,
help='The name of the Google Cloud Composer environment.')
pipeline_parser.add_argument('-l', '--location', type=str,
help='The location of the Google Cloud Composer environment. '
'Refer https://cloud.google.com/about/locations/ for further '
'details.')
pipeline_parser.add_argument('-g', '--debug', type=str,
help='Debug output with the airflow spec.')
return pipeline_parser
Expand Down Expand Up @@ -937,6 +943,16 @@ def _pipeline_cell(args, cell_body):
except AttributeError:
return "Perhaps you're missing: import google.datalab.contrib.pipeline.airflow"

location = args.get('location')
environment = args.get('environment')

if location and environment:
try:
composer = google.datalab.contrib.pipeline.composer.Composer(location, environment)
composer.deploy(name, airflow_spec)
except AttributeError:
return "Perhaps you're missing: import google.datalab.contrib.pipeline.composer"

if args.get('debug'):
error_message += '\n\n' + airflow_spec

Expand Down
12 changes: 12 additions & 0 deletions google/datalab/contrib/pipeline/composer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2018 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
from ._composer import Composer # noqa
39 changes: 39 additions & 0 deletions google/datalab/contrib/pipeline/composer/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2018 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

"""Implements Composer HTTP API wrapper."""
import google.datalab.utils


class Api(object):
"""A helper class to issue Composer HTTP requests."""

_ENDPOINT = 'https://composer.googleapis.com/v1alpha1'
_ENVIRONMENTS_PATH_FORMAT = '/projects/%s/locations/%s/environments/%s'

@staticmethod
def get_environment_details(zone, environment):
""" Issues a request to Composer to get the environment details.
Args:
zone: GCP zone of the composer environment
environment: name of the Composer environment
Returns:
A parsed result object.
Raises:
Exception if there is an error performing the operation.
"""
default_context = google.datalab.Context.default()
url = (Api._ENDPOINT + (Api._ENVIRONMENTS_PATH_FORMAT % (default_context.project_id, zone,
environment)))

return google.datalab.utils.Http.request(url, credentials=default_context.credentials)
66 changes: 66 additions & 0 deletions google/datalab/contrib/pipeline/composer/_composer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2018 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import google.datalab.storage as storage
from google.datalab.contrib.pipeline.composer._api import Api
import re


class Composer(object):
""" Represents a Composer object that encapsulates a set of functionality relating to the
Cloud Composer service.
This object can be used to generate the python airflow spec.
"""

gcs_file_regexp = re.compile('gs://.*')

def __init__(self, zone, environment):
""" Initializes an instance of a Composer object.
Args:
zone: Zone in which Composer environment has been created.
environment: Name of the Composer environment.
"""
self._zone = zone
self._environment = environment
self._gcs_dag_location = None

def deploy(self, name, dag_string):
bucket_name, file_path = self.gcs_dag_location.split('/', 3)[2:] # setting maxsplit to 3
file_name = '{0}{1}.py'.format(file_path, name)

bucket = storage.Bucket(bucket_name)
file_object = bucket.object(file_name)
file_object.write_stream(dag_string, 'text/plain')

@property
def gcs_dag_location(self):
if not self._gcs_dag_location:
environment_details = Api.get_environment_details(self._zone, self._environment)

if ('config' not in environment_details or
'gcsDagLocation' not in environment_details.get('config')):
raise ValueError('Dag location unavailable from Composer environment {0}'.format(
self._environment))
gcs_dag_location = environment_details['config']['gcsDagLocation']

if gcs_dag_location is None or not self.gcs_file_regexp.match(gcs_dag_location):
raise ValueError(
'Dag location {0} from Composer environment {1} is in incorrect format'.format(
gcs_dag_location, self._environment))

self._gcs_dag_location = gcs_dag_location
if gcs_dag_location.endswith('/') is False:
self._gcs_dag_location = self._gcs_dag_location + '/'

return self._gcs_dag_location
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'google.datalab.contrib.mlworkbench.commands',
'google.datalab.contrib.pipeline',
'google.datalab.contrib.pipeline.airflow',
'google.datalab.contrib.pipeline.composer',
'google.datalab.contrib.pipeline.commands',
'google.datalab.data',
'google.datalab.kernel',
Expand Down
10 changes: 7 additions & 3 deletions tests/bigquery/pipeline_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,21 +577,25 @@ def compare_parameters(self, actual_parameters, user_parameters):
for item in user_parameters}
self.assertDictEqual(actual_paramaters_dict, user_parameters_dict)

@mock.patch('google.datalab.contrib.pipeline.composer._api.Api.get_environment_details')
@mock.patch('google.datalab.Context.default')
@mock.patch('google.datalab.utils.commands.notebook_environment')
@mock.patch('google.datalab.utils.commands.get_notebook_item')
@mock.patch('google.datalab.bigquery.Table.exists')
@mock.patch('google.datalab.bigquery.commands._bigquery._get_table')
@mock.patch('google.datalab.storage.Bucket')
def test_pipeline_cell_golden(self, mock_bucket_class, mock_get_table, mock_table_exists,
mock_notebook_item, mock_environment, mock_default_context):
mock_notebook_item, mock_environment, mock_default_context,
mock_composer_env):
import google.datalab.contrib.pipeline.airflow
table = google.datalab.bigquery.Table('project.test.table')
mock_get_table.return_value = table
mock_table_exists.return_value = True
context = TestCases._create_context()
mock_default_context.return_value = context

mock_composer_env.return_value = {
'config': {'gcsDagLocation': 'gs://foo_bucket/dags'}
}
env = {
'endpoint': 'Interact2',
'job_id': '1234',
Expand Down Expand Up @@ -720,6 +724,6 @@ def test_pipeline_cell_golden(self, mock_bucket_class, mock_get_table, mock_tabl
name, cell_body_dict)

mock_bucket_class.assert_called_with('foo_bucket')
mock_bucket_class.return_value.object.assert_called_with('foo_file_path/bq_pipeline_test.py')
mock_bucket_class.return_value.object.assert_called_with('dags/bq_pipeline_test.py')
mock_bucket_class.return_value.object.return_value.write_stream.assert_called_with(
expected_airflow_spec, 'text/plain')
4 changes: 4 additions & 0 deletions tests/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import mlworkbench_magic.ml_tests
import mlworkbench_magic.shell_process_tests
import pipeline.airflow_tests
import pipeline.composer_tests
import pipeline.composer_api_tests
import pipeline.pipeline_tests
import stackdriver.commands.monitoring_tests
import stackdriver.monitoring.group_tests
Expand Down Expand Up @@ -104,6 +106,8 @@
ml.metrics_tests,
ml.summary_tests,
mlworkbench_magic.ml_tests,
pipeline.composer_api_tests,
pipeline.composer_tests,
pipeline.airflow_tests,
pipeline.pipeline_tests,
stackdriver.commands.monitoring_tests,
Expand Down
60 changes: 60 additions & 0 deletions tests/pipeline/composer_api_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2018 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import unittest
import mock

import google.auth
import google.datalab.utils
from google.datalab.contrib.pipeline.composer._api import Api


class TestCases(unittest.TestCase):

TEST_PROJECT_ID = 'test_project'

def validate(self, mock_http_request, expected_url, expected_args=None, expected_data=None,
expected_headers=None, expected_method=None):
url = mock_http_request.call_args[0][0]
kwargs = mock_http_request.call_args[1]
self.assertEquals(expected_url, url)
if expected_args is not None:
self.assertEquals(expected_args, kwargs['args'])
else:
self.assertNotIn('args', kwargs)
if expected_data is not None:
self.assertEquals(expected_data, kwargs['data'])
else:
self.assertNotIn('data', kwargs)
if expected_headers is not None:
self.assertEquals(expected_headers, kwargs['headers'])
else:
self.assertNotIn('headers', kwargs)
if expected_method is not None:
self.assertEquals(expected_method, kwargs['method'])
else:
self.assertNotIn('method', kwargs)

@mock.patch('google.datalab.Context.default')
@mock.patch('google.datalab.utils.Http.request')
def test_environment_details_get(self, mock_http_request, mock_context_default):
mock_context_default.return_value = TestCases._create_context()
Api.get_environment_details('ZONE', 'ENVIRONMENT')
self.validate(mock_http_request,
'https://composer.googleapis.com/v1alpha1/projects/test_project/locations/ZONE/'
'environments/ENVIRONMENT')

@staticmethod
def _create_context():
project_id = TestCases.TEST_PROJECT_ID
creds = mock.Mock(spec=google.auth.credentials.Credentials)
return google.datalab.Context(project_id, creds)
Loading

0 comments on commit ec3002a

Please sign in to comment.