From ec3002a1cb1bd1cf94043e2d35f5267b71a52ad0 Mon Sep 17 00:00:00 2001 From: Rajiv Bharadwaja <4618540+rajivpb@users.noreply.github.com> Date: Tue, 8 May 2018 21:19:40 -0700 Subject: [PATCH] Composer integration for %%bq pipeline (#682) * Composer integration for %%bq pipeline * Addressing code-review feedback --- google/datalab/bigquery/commands/_bigquery.py | 16 ++ .../contrib/pipeline/composer/__init__.py | 12 ++ .../datalab/contrib/pipeline/composer/_api.py | 39 +++++ .../contrib/pipeline/composer/_composer.py | 66 ++++++++ setup.py | 1 + tests/bigquery/pipeline_tests.py | 10 +- tests/main.py | 4 + tests/pipeline/composer_api_tests.py | 60 ++++++++ tests/pipeline/composer_tests.py | 142 ++++++++++++++++++ 9 files changed, 347 insertions(+), 3 deletions(-) create mode 100644 google/datalab/contrib/pipeline/composer/__init__.py create mode 100644 google/datalab/contrib/pipeline/composer/_api.py create mode 100644 google/datalab/contrib/pipeline/composer/_composer.py create mode 100644 tests/pipeline/composer_api_tests.py create mode 100644 tests/pipeline/composer_tests.py diff --git a/google/datalab/bigquery/commands/_bigquery.py b/google/datalab/bigquery/commands/_bigquery.py index 5dbfa8d8f..b2bdbd67b 100644 --- a/google/datalab/bigquery/commands/_bigquery.py +++ b/google/datalab/bigquery/commands/_bigquery.py @@ -899,6 +899,12 @@ def _create_pipeline_subparser(parser): help='The Google Cloud Storage bucket for the Airflow dags.') pipeline_parser.add_argument('-f', '--gcs_dag_file_path', type=str, help='The file path suffix for the Airflow dags.') + pipeline_parser.add_argument('-e', '--environment', type=str, + help='The name of the Google Cloud Composer environment.') + pipeline_parser.add_argument('-l', '--location', type=str, + help='The location of the Google Cloud Composer environment. ' + 'Refer https://cloud.google.com/about/locations/ for further ' + 'details.') pipeline_parser.add_argument('-g', '--debug', type=str, help='Debug output with the airflow spec.') return pipeline_parser @@ -937,6 +943,16 @@ def _pipeline_cell(args, cell_body): except AttributeError: return "Perhaps you're missing: import google.datalab.contrib.pipeline.airflow" + location = args.get('location') + environment = args.get('environment') + + if location and environment: + try: + composer = google.datalab.contrib.pipeline.composer.Composer(location, environment) + composer.deploy(name, airflow_spec) + except AttributeError: + return "Perhaps you're missing: import google.datalab.contrib.pipeline.composer" + if args.get('debug'): error_message += '\n\n' + airflow_spec diff --git a/google/datalab/contrib/pipeline/composer/__init__.py b/google/datalab/contrib/pipeline/composer/__init__.py new file mode 100644 index 000000000..21989d198 --- /dev/null +++ b/google/datalab/contrib/pipeline/composer/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +from ._composer import Composer # noqa diff --git a/google/datalab/contrib/pipeline/composer/_api.py b/google/datalab/contrib/pipeline/composer/_api.py new file mode 100644 index 000000000..6e30b6aad --- /dev/null +++ b/google/datalab/contrib/pipeline/composer/_api.py @@ -0,0 +1,39 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +"""Implements Composer HTTP API wrapper.""" +import google.datalab.utils + + +class Api(object): + """A helper class to issue Composer HTTP requests.""" + + _ENDPOINT = 'https://composer.googleapis.com/v1alpha1' + _ENVIRONMENTS_PATH_FORMAT = '/projects/%s/locations/%s/environments/%s' + + @staticmethod + def get_environment_details(zone, environment): + """ Issues a request to Composer to get the environment details. + + Args: + zone: GCP zone of the composer environment + environment: name of the Composer environment + Returns: + A parsed result object. + Raises: + Exception if there is an error performing the operation. + """ + default_context = google.datalab.Context.default() + url = (Api._ENDPOINT + (Api._ENVIRONMENTS_PATH_FORMAT % (default_context.project_id, zone, + environment))) + + return google.datalab.utils.Http.request(url, credentials=default_context.credentials) diff --git a/google/datalab/contrib/pipeline/composer/_composer.py b/google/datalab/contrib/pipeline/composer/_composer.py new file mode 100644 index 000000000..e6b44bf18 --- /dev/null +++ b/google/datalab/contrib/pipeline/composer/_composer.py @@ -0,0 +1,66 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import google.datalab.storage as storage +from google.datalab.contrib.pipeline.composer._api import Api +import re + + +class Composer(object): + """ Represents a Composer object that encapsulates a set of functionality relating to the + Cloud Composer service. + + This object can be used to generate the python airflow spec. + """ + + gcs_file_regexp = re.compile('gs://.*') + + def __init__(self, zone, environment): + """ Initializes an instance of a Composer object. + + Args: + zone: Zone in which Composer environment has been created. + environment: Name of the Composer environment. + """ + self._zone = zone + self._environment = environment + self._gcs_dag_location = None + + def deploy(self, name, dag_string): + bucket_name, file_path = self.gcs_dag_location.split('/', 3)[2:] # setting maxsplit to 3 + file_name = '{0}{1}.py'.format(file_path, name) + + bucket = storage.Bucket(bucket_name) + file_object = bucket.object(file_name) + file_object.write_stream(dag_string, 'text/plain') + + @property + def gcs_dag_location(self): + if not self._gcs_dag_location: + environment_details = Api.get_environment_details(self._zone, self._environment) + + if ('config' not in environment_details or + 'gcsDagLocation' not in environment_details.get('config')): + raise ValueError('Dag location unavailable from Composer environment {0}'.format( + self._environment)) + gcs_dag_location = environment_details['config']['gcsDagLocation'] + + if gcs_dag_location is None or not self.gcs_file_regexp.match(gcs_dag_location): + raise ValueError( + 'Dag location {0} from Composer environment {1} is in incorrect format'.format( + gcs_dag_location, self._environment)) + + self._gcs_dag_location = gcs_dag_location + if gcs_dag_location.endswith('/') is False: + self._gcs_dag_location = self._gcs_dag_location + '/' + + return self._gcs_dag_location diff --git a/setup.py b/setup.py index c696fa11f..811b46965 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ 'google.datalab.contrib.mlworkbench.commands', 'google.datalab.contrib.pipeline', 'google.datalab.contrib.pipeline.airflow', + 'google.datalab.contrib.pipeline.composer', 'google.datalab.contrib.pipeline.commands', 'google.datalab.data', 'google.datalab.kernel', diff --git a/tests/bigquery/pipeline_tests.py b/tests/bigquery/pipeline_tests.py index f50c6d02d..3b24f12c7 100644 --- a/tests/bigquery/pipeline_tests.py +++ b/tests/bigquery/pipeline_tests.py @@ -577,6 +577,7 @@ def compare_parameters(self, actual_parameters, user_parameters): for item in user_parameters} self.assertDictEqual(actual_paramaters_dict, user_parameters_dict) + @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.get_environment_details') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.utils.commands.get_notebook_item') @@ -584,14 +585,17 @@ def compare_parameters(self, actual_parameters, user_parameters): @mock.patch('google.datalab.bigquery.commands._bigquery._get_table') @mock.patch('google.datalab.storage.Bucket') def test_pipeline_cell_golden(self, mock_bucket_class, mock_get_table, mock_table_exists, - mock_notebook_item, mock_environment, mock_default_context): + mock_notebook_item, mock_environment, mock_default_context, + mock_composer_env): import google.datalab.contrib.pipeline.airflow table = google.datalab.bigquery.Table('project.test.table') mock_get_table.return_value = table mock_table_exists.return_value = True context = TestCases._create_context() mock_default_context.return_value = context - + mock_composer_env.return_value = { + 'config': {'gcsDagLocation': 'gs://foo_bucket/dags'} + } env = { 'endpoint': 'Interact2', 'job_id': '1234', @@ -720,6 +724,6 @@ def test_pipeline_cell_golden(self, mock_bucket_class, mock_get_table, mock_tabl name, cell_body_dict) mock_bucket_class.assert_called_with('foo_bucket') - mock_bucket_class.return_value.object.assert_called_with('foo_file_path/bq_pipeline_test.py') + mock_bucket_class.return_value.object.assert_called_with('dags/bq_pipeline_test.py') mock_bucket_class.return_value.object.return_value.write_stream.assert_called_with( expected_airflow_spec, 'text/plain') diff --git a/tests/main.py b/tests/main.py index addcc23d7..3d4843bb9 100644 --- a/tests/main.py +++ b/tests/main.py @@ -56,6 +56,8 @@ import mlworkbench_magic.ml_tests import mlworkbench_magic.shell_process_tests import pipeline.airflow_tests +import pipeline.composer_tests +import pipeline.composer_api_tests import pipeline.pipeline_tests import stackdriver.commands.monitoring_tests import stackdriver.monitoring.group_tests @@ -104,6 +106,8 @@ ml.metrics_tests, ml.summary_tests, mlworkbench_magic.ml_tests, + pipeline.composer_api_tests, + pipeline.composer_tests, pipeline.airflow_tests, pipeline.pipeline_tests, stackdriver.commands.monitoring_tests, diff --git a/tests/pipeline/composer_api_tests.py b/tests/pipeline/composer_api_tests.py new file mode 100644 index 000000000..eb0f00735 --- /dev/null +++ b/tests/pipeline/composer_api_tests.py @@ -0,0 +1,60 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import unittest +import mock + +import google.auth +import google.datalab.utils +from google.datalab.contrib.pipeline.composer._api import Api + + +class TestCases(unittest.TestCase): + + TEST_PROJECT_ID = 'test_project' + + def validate(self, mock_http_request, expected_url, expected_args=None, expected_data=None, + expected_headers=None, expected_method=None): + url = mock_http_request.call_args[0][0] + kwargs = mock_http_request.call_args[1] + self.assertEquals(expected_url, url) + if expected_args is not None: + self.assertEquals(expected_args, kwargs['args']) + else: + self.assertNotIn('args', kwargs) + if expected_data is not None: + self.assertEquals(expected_data, kwargs['data']) + else: + self.assertNotIn('data', kwargs) + if expected_headers is not None: + self.assertEquals(expected_headers, kwargs['headers']) + else: + self.assertNotIn('headers', kwargs) + if expected_method is not None: + self.assertEquals(expected_method, kwargs['method']) + else: + self.assertNotIn('method', kwargs) + + @mock.patch('google.datalab.Context.default') + @mock.patch('google.datalab.utils.Http.request') + def test_environment_details_get(self, mock_http_request, mock_context_default): + mock_context_default.return_value = TestCases._create_context() + Api.get_environment_details('ZONE', 'ENVIRONMENT') + self.validate(mock_http_request, + 'https://composer.googleapis.com/v1alpha1/projects/test_project/locations/ZONE/' + 'environments/ENVIRONMENT') + + @staticmethod + def _create_context(): + project_id = TestCases.TEST_PROJECT_ID + creds = mock.Mock(spec=google.auth.credentials.Credentials) + return google.datalab.Context(project_id, creds) diff --git a/tests/pipeline/composer_tests.py b/tests/pipeline/composer_tests.py new file mode 100644 index 000000000..b07e3b643 --- /dev/null +++ b/tests/pipeline/composer_tests.py @@ -0,0 +1,142 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import unittest +import mock + +from google.datalab.contrib.pipeline.composer._composer import Composer + + +class TestCases(unittest.TestCase): + + @mock.patch('google.datalab.Context.default') + @mock.patch('google.datalab.storage.Bucket') + @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.get_environment_details') + def test_deploy(self, mock_environment_details, mock_bucket_class, mock_default_context): + # Happy path + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/dags' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + test_composer.deploy('foo_name', 'foo_dag_string') + mock_bucket_class.assert_called_with('foo_bucket') + mock_bucket_class.return_value.object.assert_called_with('dags/foo_name.py') + mock_bucket_class.return_value.object.return_value.write_stream.assert_called_with( + 'foo_dag_string', 'text/plain') + + # Only bucket with no path + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + test_composer.deploy('foo_name', 'foo_dag_string') + mock_bucket_class.assert_called_with('foo_bucket') + mock_bucket_class.return_value.object.assert_called_with('foo_name.py') + mock_bucket_class.return_value.object.return_value.write_stream.assert_called_with( + 'foo_dag_string', 'text/plain') + + # GCS dag location has additional parts + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/foo_random/dags' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + test_composer.deploy('foo_name', 'foo_dag_string') + mock_bucket_class.assert_called_with('foo_bucket') + mock_bucket_class.return_value.object.assert_called_with('foo_random/dags/foo_name.py') + mock_bucket_class.return_value.object.return_value.write_stream.assert_called_with( + 'foo_dag_string', 'text/plain') + + @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.get_environment_details') + def test_gcs_dag_location(self, mock_environment_details): + # Composer returns good result + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/dags' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + self.assertEqual('gs://foo_bucket/dags/', test_composer.gcs_dag_location) + + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket' # only bucket + } + } + test_composer = Composer('foo_zone', 'foo_environment') + self.assertEqual('gs://foo_bucket/', test_composer.gcs_dag_location) + + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs://foo_bucket/' # with trailing slash + } + } + test_composer = Composer('foo_zone', 'foo_environment') + self.assertEqual('gs://foo_bucket/', test_composer.gcs_dag_location) + + # Composer returns empty result + mock_environment_details.return_value = {} + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, 'Dag location unavailable from Composer environment foo_environment'): + test_composer.gcs_dag_location + + # Composer returns empty result + mock_environment_details.return_value = { + 'config': {} + } + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, 'Dag location unavailable from Composer environment foo_environment'): + test_composer.gcs_dag_location + + # Composer returns None result + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': None + } + } + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, + 'Dag location None from Composer environment foo_environment is in incorrect format'): + test_composer.gcs_dag_location + + # Composer returns incorrect formats + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'gs:/foo_bucket' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, + ('Dag location gs:/foo_bucket from Composer environment foo_environment is in' + ' incorrect format')): + test_composer.gcs_dag_location + + mock_environment_details.return_value = { + 'config': { + 'gcsDagLocation': 'as://foo_bucket' + } + } + test_composer = Composer('foo_zone', 'foo_environment') + with self.assertRaisesRegexp( + ValueError, + ('Dag location as://foo_bucket from Composer environment foo_environment is in' + ' incorrect format')): + test_composer.gcs_dag_location