You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/09/06 16:51:30 UTC

[3/3] incubator-airflow git commit: [AIRFLOW-1567][Airflow-1567] Renamed cloudml hook and operator to mlengine

[AIRFLOW-1567][Airflow-1567] Renamed cloudml hook and operator to mlengine

Closes #2567 from yk5/cmle


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/af91e2ac
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/af91e2ac
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/af91e2ac

Branch: refs/heads/master
Commit: af91e2ac0636685c0c1c25ddeba97f78b7009b88
Parents: 86063ba
Author: Younghee Kwon <yo...@google.com>
Authored: Wed Sep 6 09:51:17 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Wed Sep 6 09:51:17 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/gcp_cloudml_hook.py       | 269 ---------
 airflow/contrib/hooks/gcp_mlengine_hook.py      | 269 +++++++++
 airflow/contrib/operators/cloudml_operator.py   | 565 -------------------
 .../contrib/operators/cloudml_operator_utils.py | 245 --------
 .../operators/cloudml_prediction_summary.py     | 177 ------
 airflow/contrib/operators/mlengine_operator.py  | 564 ++++++++++++++++++
 .../operators/mlengine_operator_utils.py        | 245 ++++++++
 .../operators/mlengine_prediction_summary.py    | 177 ++++++
 tests/contrib/hooks/test_gcp_cloudml_hook.py    | 413 --------------
 tests/contrib/hooks/test_gcp_mlengine_hook.py   | 413 ++++++++++++++
 .../contrib/operators/test_cloudml_operator.py  | 373 ------------
 .../operators/test_cloudml_operator_utils.py    | 183 ------
 .../contrib/operators/test_mlengine_operator.py | 373 ++++++++++++
 .../operators/test_mlengine_operator_utils.py   | 183 ++++++
 14 files changed, 2224 insertions(+), 2225 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py
deleted file mode 100644
index e1ff155..0000000
--- a/airflow/contrib/hooks/gcp_cloudml_hook.py
+++ /dev/null
@@ -1,269 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-import random
-import time
-from airflow import settings
-from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
-from apiclient.discovery import build
-from apiclient import errors
-from oauth2client.client import GoogleCredentials
-
-logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
-
-
-def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
-
-    for i in range(0, max_n):
-        try:
-            response = request.execute()
-            if is_error_func(response):
-                raise ValueError(
-                    'The response contained an error: {}'.format(response))
-            elif is_done_func(response):
-                logging.info('Operation is done: {}'.format(response))
-                return response
-            else:
-                time.sleep((2**i) + (random.randint(0, 1000) / 1000))
-        except errors.HttpError as e:
-            if e.resp.status != 429:
-                logging.info(
-                    'Something went wrong. Not retrying: {}'.format(e))
-                raise
-            else:
-                time.sleep((2**i) + (random.randint(0, 1000) / 1000))
-
-
-class CloudMLHook(GoogleCloudBaseHook):
-
-    def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None):
-        super(CloudMLHook, self).__init__(gcp_conn_id, delegate_to)
-        self._cloudml = self.get_conn()
-
-    def get_conn(self):
-        """
-        Returns a Google CloudML service object.
-        """
-        credentials = GoogleCredentials.get_application_default()
-        return build('ml', 'v1', credentials=credentials)
-
-    def create_job(self, project_id, job, use_existing_job_fn=None):
-        """
-        Launches a CloudML job and wait for it to reach a terminal state.
-
-        :param project_id: The Google Cloud project id within which CloudML
-            job will be launched.
-        :type project_id: string
-
-        :param job: CloudML Job object that should be provided to the CloudML
-            API, such as:
-            {
-              'jobId': 'my_job_id',
-              'trainingInput': {
-                'scaleTier': 'STANDARD_1',
-                ...
-              }
-            }
-        :type job: dict
-
-        :param use_existing_job_fn: In case that a CloudML job with the same
-            job_id already exist, this method (if provided) will decide whether
-            we should use this existing job, continue waiting for it to finish
-            and returning the job object. It should accepts a CloudML job
-            object, and returns a boolean value indicating whether it is OK to
-            reuse the existing job. If 'use_existing_job_fn' is not provided,
-            we by default reuse the existing CloudML job.
-        :type use_existing_job_fn: function
-
-        :return: The CloudML job object if the job successfully reach a
-            terminal state (which might be FAILED or CANCELLED state).
-        :rtype: dict
-        """
-        request = self._cloudml.projects().jobs().create(
-            parent='projects/{}'.format(project_id),
-            body=job)
-        job_id = job['jobId']
-
-        try:
-            request.execute()
-        except errors.HttpError as e:
-            # 409 means there is an existing job with the same job ID.
-            if e.resp.status == 409:
-                if use_existing_job_fn is not None:
-                    existing_job = self._get_job(project_id, job_id)
-                    if not use_existing_job_fn(existing_job):
-                        logging.error(
-                            'Job with job_id {} already exist, but it does '
-                            'not match our expectation: {}'.format(
-                                job_id, existing_job))
-                        raise
-                logging.info(
-                    'Job with job_id {} already exist. Will waiting for it to '
-                    'finish'.format(job_id))
-            else:
-                logging.error('Failed to create CloudML job: {}'.format(e))
-                raise
-        return self._wait_for_job_done(project_id, job_id)
-
-    def _get_job(self, project_id, job_id):
-        """
-        Gets a CloudML job based on the job name.
-
-        :return: CloudML job object if succeed.
-        :rtype: dict
-
-        Raises:
-            apiclient.errors.HttpError: if HTTP error is returned from server
-        """
-        job_name = 'projects/{}/jobs/{}'.format(project_id, job_id)
-        request = self._cloudml.projects().jobs().get(name=job_name)
-        while True:
-            try:
-                return request.execute()
-            except errors.HttpError as e:
-                if e.resp.status == 429:
-                    # polling after 30 seconds when quota failure occurs
-                    time.sleep(30)
-                else:
-                    logging.error('Failed to get CloudML job: {}'.format(e))
-                    raise
-
-    def _wait_for_job_done(self, project_id, job_id, interval=30):
-        """
-        Waits for the Job to reach a terminal state.
-
-        This method will periodically check the job state until the job reach
-        a terminal state.
-
-        Raises:
-            apiclient.errors.HttpError: if HTTP error is returned when getting
-            the job
-        """
-        assert interval > 0
-        while True:
-            job = self._get_job(project_id, job_id)
-            if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
-                return job
-            time.sleep(interval)
-
-    def create_version(self, project_id, model_name, version_spec):
-        """
-        Creates the Version on Cloud ML.
-
-        Returns the operation if the version was created successfully and
-        raises an error otherwise.
-        """
-        parent_name = 'projects/{}/models/{}'.format(project_id, model_name)
-        create_request = self._cloudml.projects().models().versions().create(
-            parent=parent_name, body=version_spec)
-        response = create_request.execute()
-        get_request = self._cloudml.projects().operations().get(
-            name=response['name'])
-
-        return _poll_with_exponential_delay(
-            request=get_request,
-            max_n=9,
-            is_done_func=lambda resp: resp.get('done', False),
-            is_error_func=lambda resp: resp.get('error', None) is not None)
-
-    def set_default_version(self, project_id, model_name, version_name):
-        """
-        Sets a version to be the default. Blocks until finished.
-        """
-        full_version_name = 'projects/{}/models/{}/versions/{}'.format(
-            project_id, model_name, version_name)
-        request = self._cloudml.projects().models().versions().setDefault(
-            name=full_version_name, body={})
-
-        try:
-            response = request.execute()
-            logging.info(
-                'Successfully set version: {} to default'.format(response))
-            return response
-        except errors.HttpError as e:
-            logging.error('Something went wrong: {}'.format(e))
-            raise
-
-    def list_versions(self, project_id, model_name):
-        """
-        Lists all available versions of a model. Blocks until finished.
-        """
-        result = []
-        full_parent_name = 'projects/{}/models/{}'.format(
-            project_id, model_name)
-        request = self._cloudml.projects().models().versions().list(
-            parent=full_parent_name, pageSize=100)
-
-        response = request.execute()
-        next_page_token = response.get('nextPageToken', None)
-        result.extend(response.get('versions', []))
-        while next_page_token is not None:
-            next_request = self._cloudml.projects().models().versions().list(
-                parent=full_parent_name,
-                pageToken=next_page_token,
-                pageSize=100)
-            response = next_request.execute()
-            next_page_token = response.get('nextPageToken', None)
-            result.extend(response.get('versions', []))
-            time.sleep(5)
-        return result
-
-    def delete_version(self, project_id, model_name, version_name):
-        """
-        Deletes the given version of a model. Blocks until finished.
-        """
-        full_name = 'projects/{}/models/{}/versions/{}'.format(
-            project_id, model_name, version_name)
-        delete_request = self._cloudml.projects().models().versions().delete(
-            name=full_name)
-        response = delete_request.execute()
-        get_request = self._cloudml.projects().operations().get(
-            name=response['name'])
-
-        return _poll_with_exponential_delay(
-            request=get_request,
-            max_n=9,
-            is_done_func=lambda resp: resp.get('done', False),
-            is_error_func=lambda resp: resp.get('error', None) is not None)
-
-    def create_model(self, project_id, model):
-        """
-        Create a Model. Blocks until finished.
-        """
-        assert model['name'] is not None and model['name'] is not ''
-        project = 'projects/{}'.format(project_id)
-
-        request = self._cloudml.projects().models().create(
-            parent=project, body=model)
-        return request.execute()
-
-    def get_model(self, project_id, model_name):
-        """
-        Gets a Model. Blocks until finished.
-        """
-        assert model_name is not None and model_name is not ''
-        full_model_name = 'projects/{}/models/{}'.format(
-            project_id, model_name)
-        request = self._cloudml.projects().models().get(name=full_model_name)
-        try:
-            return request.execute()
-        except errors.HttpError as e:
-            if e.resp.status == 404:
-                logging.error('Model was not found: {}'.format(e))
-                return None
-            raise

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/hooks/gcp_mlengine_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py
new file mode 100644
index 0000000..47d9700
--- /dev/null
+++ b/airflow/contrib/hooks/gcp_mlengine_hook.py
@@ -0,0 +1,269 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import random
+import time
+from airflow import settings
+from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
+from apiclient.discovery import build
+from apiclient import errors
+from oauth2client.client import GoogleCredentials
+
+logging.getLogger('GoogleCloudMLEngine').setLevel(settings.LOGGING_LEVEL)
+
+
+def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
+
+    for i in range(0, max_n):
+        try:
+            response = request.execute()
+            if is_error_func(response):
+                raise ValueError(
+                    'The response contained an error: {}'.format(response))
+            elif is_done_func(response):
+                logging.info('Operation is done: {}'.format(response))
+                return response
+            else:
+                time.sleep((2**i) + (random.randint(0, 1000) / 1000))
+        except errors.HttpError as e:
+            if e.resp.status != 429:
+                logging.info(
+                    'Something went wrong. Not retrying: {}'.format(e))
+                raise
+            else:
+                time.sleep((2**i) + (random.randint(0, 1000) / 1000))
+
+
+class MLEngineHook(GoogleCloudBaseHook):
+
+    def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None):
+        super(MLEngineHook, self).__init__(gcp_conn_id, delegate_to)
+        self._mlengine = self.get_conn()
+
+    def get_conn(self):
+        """
+        Returns a Google MLEngine service object.
+        """
+        credentials = GoogleCredentials.get_application_default()
+        return build('ml', 'v1', credentials=credentials)
+
+    def create_job(self, project_id, job, use_existing_job_fn=None):
+        """
+        Launches a MLEngine job and wait for it to reach a terminal state.
+
+        :param project_id: The Google Cloud project id within which MLEngine
+            job will be launched.
+        :type project_id: string
+
+        :param job: MLEngine Job object that should be provided to the MLEngine
+            API, such as:
+            {
+              'jobId': 'my_job_id',
+              'trainingInput': {
+                'scaleTier': 'STANDARD_1',
+                ...
+              }
+            }
+        :type job: dict
+
+        :param use_existing_job_fn: In case that a MLEngine job with the same
+            job_id already exist, this method (if provided) will decide whether
+            we should use this existing job, continue waiting for it to finish
+            and returning the job object. It should accepts a MLEngine job
+            object, and returns a boolean value indicating whether it is OK to
+            reuse the existing job. If 'use_existing_job_fn' is not provided,
+            we by default reuse the existing MLEngine job.
+        :type use_existing_job_fn: function
+
+        :return: The MLEngine job object if the job successfully reach a
+            terminal state (which might be FAILED or CANCELLED state).
+        :rtype: dict
+        """
+        request = self._mlengine.projects().jobs().create(
+            parent='projects/{}'.format(project_id),
+            body=job)
+        job_id = job['jobId']
+
+        try:
+            request.execute()
+        except errors.HttpError as e:
+            # 409 means there is an existing job with the same job ID.
+            if e.resp.status == 409:
+                if use_existing_job_fn is not None:
+                    existing_job = self._get_job(project_id, job_id)
+                    if not use_existing_job_fn(existing_job):
+                        logging.error(
+                            'Job with job_id {} already exist, but it does '
+                            'not match our expectation: {}'.format(
+                                job_id, existing_job))
+                        raise
+                logging.info(
+                    'Job with job_id {} already exist. Will waiting for it to '
+                    'finish'.format(job_id))
+            else:
+                logging.error('Failed to create MLEngine job: {}'.format(e))
+                raise
+        return self._wait_for_job_done(project_id, job_id)
+
+    def _get_job(self, project_id, job_id):
+        """
+        Gets a MLEngine job based on the job name.
+
+        :return: MLEngine job object if succeed.
+        :rtype: dict
+
+        Raises:
+            apiclient.errors.HttpError: if HTTP error is returned from server
+        """
+        job_name = 'projects/{}/jobs/{}'.format(project_id, job_id)
+        request = self._mlengine.projects().jobs().get(name=job_name)
+        while True:
+            try:
+                return request.execute()
+            except errors.HttpError as e:
+                if e.resp.status == 429:
+                    # polling after 30 seconds when quota failure occurs
+                    time.sleep(30)
+                else:
+                    logging.error('Failed to get MLEngine job: {}'.format(e))
+                    raise
+
+    def _wait_for_job_done(self, project_id, job_id, interval=30):
+        """
+        Waits for the Job to reach a terminal state.
+
+        This method will periodically check the job state until the job reach
+        a terminal state.
+
+        Raises:
+            apiclient.errors.HttpError: if HTTP error is returned when getting
+            the job
+        """
+        assert interval > 0
+        while True:
+            job = self._get_job(project_id, job_id)
+            if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
+                return job
+            time.sleep(interval)
+
+    def create_version(self, project_id, model_name, version_spec):
+        """
+        Creates the Version on Google Cloud ML Engine.
+
+        Returns the operation if the version was created successfully and
+        raises an error otherwise.
+        """
+        parent_name = 'projects/{}/models/{}'.format(project_id, model_name)
+        create_request = self._mlengine.projects().models().versions().create(
+            parent=parent_name, body=version_spec)
+        response = create_request.execute()
+        get_request = self._mlengine.projects().operations().get(
+            name=response['name'])
+
+        return _poll_with_exponential_delay(
+            request=get_request,
+            max_n=9,
+            is_done_func=lambda resp: resp.get('done', False),
+            is_error_func=lambda resp: resp.get('error', None) is not None)
+
+    def set_default_version(self, project_id, model_name, version_name):
+        """
+        Sets a version to be the default. Blocks until finished.
+        """
+        full_version_name = 'projects/{}/models/{}/versions/{}'.format(
+            project_id, model_name, version_name)
+        request = self._mlengine.projects().models().versions().setDefault(
+            name=full_version_name, body={})
+
+        try:
+            response = request.execute()
+            logging.info(
+                'Successfully set version: {} to default'.format(response))
+            return response
+        except errors.HttpError as e:
+            logging.error('Something went wrong: {}'.format(e))
+            raise
+
+    def list_versions(self, project_id, model_name):
+        """
+        Lists all available versions of a model. Blocks until finished.
+        """
+        result = []
+        full_parent_name = 'projects/{}/models/{}'.format(
+            project_id, model_name)
+        request = self._mlengine.projects().models().versions().list(
+            parent=full_parent_name, pageSize=100)
+
+        response = request.execute()
+        next_page_token = response.get('nextPageToken', None)
+        result.extend(response.get('versions', []))
+        while next_page_token is not None:
+            next_request = self._mlengine.projects().models().versions().list(
+                parent=full_parent_name,
+                pageToken=next_page_token,
+                pageSize=100)
+            response = next_request.execute()
+            next_page_token = response.get('nextPageToken', None)
+            result.extend(response.get('versions', []))
+            time.sleep(5)
+        return result
+
+    def delete_version(self, project_id, model_name, version_name):
+        """
+        Deletes the given version of a model. Blocks until finished.
+        """
+        full_name = 'projects/{}/models/{}/versions/{}'.format(
+            project_id, model_name, version_name)
+        delete_request = self._mlengine.projects().models().versions().delete(
+            name=full_name)
+        response = delete_request.execute()
+        get_request = self._mlengine.projects().operations().get(
+            name=response['name'])
+
+        return _poll_with_exponential_delay(
+            request=get_request,
+            max_n=9,
+            is_done_func=lambda resp: resp.get('done', False),
+            is_error_func=lambda resp: resp.get('error', None) is not None)
+
+    def create_model(self, project_id, model):
+        """
+        Create a Model. Blocks until finished.
+        """
+        assert model['name'] is not None and model['name'] is not ''
+        project = 'projects/{}'.format(project_id)
+
+        request = self._mlengine.projects().models().create(
+            parent=project, body=model)
+        return request.execute()
+
+    def get_model(self, project_id, model_name):
+        """
+        Gets a Model. Blocks until finished.
+        """
+        assert model_name is not None and model_name is not ''
+        full_model_name = 'projects/{}/models/{}'.format(
+            project_id, model_name)
+        request = self._mlengine.projects().models().get(name=full_model_name)
+        try:
+            return request.execute()
+        except errors.HttpError as e:
+            if e.resp.status == 404:
+                logging.error('Model was not found: {}'.format(e))
+                return None
+            raise

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
deleted file mode 100644
index 6bdd516..0000000
--- a/airflow/contrib/operators/cloudml_operator.py
+++ /dev/null
@@ -1,565 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the 'License'); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import re
-
-from airflow import settings
-from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
-from airflow.exceptions import AirflowException
-from airflow.operators import BaseOperator
-from airflow.utils.decorators import apply_defaults
-from apiclient import errors
-
-
-logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
-
-
-def _create_prediction_input(project_id,
-                             region,
-                             data_format,
-                             input_paths,
-                             output_path,
-                             model_name=None,
-                             version_name=None,
-                             uri=None,
-                             max_worker_count=None,
-                             runtime_version=None):
-    """
-    Create the batch prediction input from the given parameters.
-
-    Args:
-        A subset of arguments documented in __init__ method of class
-        CloudMLBatchPredictionOperator
-
-    Returns:
-        A dictionary representing the predictionInput object as documented
-        in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
-
-    Raises:
-        ValueError: if a unique model/version origin cannot be determined.
-    """
-
-    prediction_input = {
-        'dataFormat': data_format,
-        'inputPaths': input_paths,
-        'outputPath': output_path,
-        'region': region
-    }
-
-    if uri:
-        if model_name or version_name:
-            logging.error(
-                'Ambiguous model origin: Both uri and model/version name are '
-                'provided.')
-            raise ValueError('Ambiguous model origin.')
-        prediction_input['uri'] = uri
-    elif model_name:
-        origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
-        if not version_name:
-            prediction_input['modelName'] = origin_name
-        else:
-            prediction_input['versionName'] = \
-                origin_name + '/versions/{}'.format(version_name)
-    else:
-        logging.error(
-            'Missing model origin: Batch prediction expects a model, '
-            'a model & version combination, or a URI to savedModel.')
-        raise ValueError('Missing model origin.')
-
-    if max_worker_count:
-        prediction_input['maxWorkerCount'] = max_worker_count
-    if runtime_version:
-        prediction_input['runtimeVersion'] = runtime_version
-
-    return prediction_input
-
-
-def _normalize_cloudml_job_id(job_id):
-    """
-    Replaces invalid CloudML job_id characters with '_'.
-
-    This also adds a leading 'z' in case job_id starts with an invalid
-    character.
-
-    Args:
-        job_id: A job_id str that may have invalid characters.
-
-    Returns:
-        A valid job_id representation.
-    """
-    match = re.search(r'\d', job_id)
-    if match and match.start() is 0:
-        job_id = 'z_{}'.format(job_id)
-    return re.sub('[^0-9a-zA-Z]+', '_', job_id)
-
-
-class CloudMLBatchPredictionOperator(BaseOperator):
-    """
-    Start a Cloud ML prediction job.
-
-    NOTE: For model origin, users should consider exactly one from the
-    three options below:
-    1. Populate 'uri' field only, which should be a GCS location that
-    points to a tensorflow savedModel directory.
-    2. Populate 'model_name' field only, which refers to an existing
-    model, and the default version of the model will be used.
-    3. Populate both 'model_name' and 'version_name' fields, which
-    refers to a specific version of a specific model.
-
-    In options 2 and 3, both model and version name should contain the
-    minimal identifier. For instance, call
-        CloudMLBatchPredictionOperator(
-            ...,
-            model_name='my_model',
-            version_name='my_version',
-            ...)
-    if the desired model version is
-    "projects/my_project/models/my_model/versions/my_version".
-
-
-    :param project_id: The Google Cloud project name where the
-        prediction job is submitted.
-    :type project_id: string
-
-    :param job_id: A unique id for the prediction job on Google Cloud
-        ML Engine.
-    :type job_id: string
-
-    :param data_format: The format of the input data.
-        It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided
-        or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"].
-    :type data_format: string
-
-    :param input_paths: A list of GCS paths of input data for batch
-        prediction. Accepting wildcard operator *, but only at the end.
-    :type input_paths: list of string
-
-    :param output_path: The GCS path where the prediction results are
-        written to.
-    :type output_path: string
-
-    :param region: The Google Compute Engine region to run the
-        prediction job in.:
-    :type region: string
-
-    :param model_name: The Google Cloud ML model to use for prediction.
-        If version_name is not provided, the default version of this
-        model will be used.
-        Should not be None if version_name is provided.
-        Should be None if uri is provided.
-    :type model_name: string
-
-    :param version_name: The Google Cloud ML model version to use for
-        prediction.
-        Should be None if uri is provided.
-    :type version_name: string
-
-    :param uri: The GCS path of the saved model to use for prediction.
-        Should be None if model_name is provided.
-        It should be a GCS path pointing to a tensorflow SavedModel.
-    :type uri: string
-
-    :param max_worker_count: The maximum number of workers to be used
-        for parallel processing. Defaults to 10 if not specified.
-    :type max_worker_count: int
-
-    :param runtime_version: The Google Cloud ML runtime version to use
-        for batch prediction.
-    :type runtime_version: string
-
-    :param gcp_conn_id: The connection ID used for connection to Google
-        Cloud Platform.
-    :type gcp_conn_id: string
-
-    :param delegate_to: The account to impersonate, if any.
-        For this to work, the service account making the request must
-        have doamin-wide delegation enabled.
-    :type delegate_to: string
-
-    Raises:
-        ValueError: if a unique model/version origin cannot be determined.
-    """
-
-    template_fields = [
-        "prediction_job_request",
-    ]
-
-    @apply_defaults
-    def __init__(self,
-                 project_id,
-                 job_id,
-                 region,
-                 data_format,
-                 input_paths,
-                 output_path,
-                 model_name=None,
-                 version_name=None,
-                 uri=None,
-                 max_worker_count=None,
-                 runtime_version=None,
-                 gcp_conn_id='google_cloud_default',
-                 delegate_to=None,
-                 *args,
-                 **kwargs):
-        super(CloudMLBatchPredictionOperator, self).__init__(*args, **kwargs)
-
-        self.project_id = project_id
-        self.gcp_conn_id = gcp_conn_id
-        self.delegate_to = delegate_to
-
-        try:
-            prediction_input = _create_prediction_input(
-                project_id, region, data_format, input_paths, output_path,
-                model_name, version_name, uri, max_worker_count,
-                runtime_version)
-        except ValueError as e:
-            logging.error(
-                'Cannot create batch prediction job request due to: {}'
-                .format(str(e)))
-            raise
-
-        self.prediction_job_request = {
-            'jobId': _normalize_cloudml_job_id(job_id),
-            'predictionInput': prediction_input
-        }
-
-    def execute(self, context):
-        hook = CloudMLHook(self.gcp_conn_id, self.delegate_to)
-
-        def check_existing_job(existing_job):
-            return existing_job.get('predictionInput', None) == \
-                self.prediction_job_request['predictionInput']
-        try:
-            finished_prediction_job = hook.create_job(
-                self.project_id,
-                self.prediction_job_request,
-                check_existing_job)
-        except errors.HttpError:
-            raise
-
-        if finished_prediction_job['state'] != 'SUCCEEDED':
-            logging.error(
-                'Batch prediction job failed: %s',
-                str(finished_prediction_job))
-            raise RuntimeError(finished_prediction_job['errorMessage'])
-
-        return finished_prediction_job['predictionOutput']
-
-
-class CloudMLModelOperator(BaseOperator):
-    """
-    Operator for managing a Google Cloud ML model.
-
-    :param model: A dictionary containing the information about the model.
-        If the `operation` is `create`, then the `model` parameter should
-        contain all the information about this model such as `name`.
-
-        If the `operation` is `get`, the `model` parameter
-        should contain the `name` of the model.
-    :type model: dict
-
-    :param project_id: The Google Cloud project name to which CloudML
-        model belongs.
-    :type project_id: string
-
-    :param gcp_conn_id: The connection ID to use when fetching connection info.
-    :type gcp_conn_id: string
-
-    :param operation: The operation to perform. Available operations are:
-        'create': Creates a new model as provided by the `model` parameter.
-        'get': Gets a particular model where the name is specified in `model`.
-
-    :param delegate_to: The account to impersonate, if any.
-        For this to work, the service account making the request must have
-        domain-wide delegation enabled.
-    :type delegate_to: string
-    """
-
-    template_fields = [
-        '_model',
-        '_model_name',
-    ]
-
-    @apply_defaults
-    def __init__(self,
-                 project_id,
-                 model,
-                 gcp_conn_id='google_cloud_default',
-                 operation='create',
-                 delegate_to=None,
-                 *args,
-                 **kwargs):
-        super(CloudMLModelOperator, self).__init__(*args, **kwargs)
-        self._model = model
-        self._operation = operation
-        self._gcp_conn_id = gcp_conn_id
-        self._delegate_to = delegate_to
-        self._project_id = project_id
-
-    def execute(self, context):
-        hook = CloudMLHook(
-            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
-        if self._operation == 'create':
-            hook.create_model(self._project_id, self._model)
-        elif self._operation == 'get':
-            hook.get_model(self._project_id, self._model['name'])
-        else:
-            raise ValueError('Unknown operation: {}'.format(self._operation))
-
-
-class CloudMLVersionOperator(BaseOperator):
-    """
-    Operator for managing a Google Cloud ML version.
-
-    :param model_name: The name of the Google Cloud ML model that the version
-        belongs to.
-    :type model_name: string
-
-    :param project_id: The Google Cloud project name to which CloudML
-        model belongs.
-    :type project_id: string
-
-    :param version: A dictionary containing the information about the version.
-        If the `operation` is `create`, `version` should contain all the
-        information about this version such as name, and deploymentUrl.
-        If the `operation` is `get` or `delete`, the `version` parameter
-        should contain the `name` of the version.
-        If it is None, the only `operation` possible would be `list`.
-    :type version: dict
-
-    :param version_name: A name to use for the version being operated upon. If
-        not None and the `version` argument is None or does not have a value for
-        the `name` key, then this will be populated in the payload for the
-        `name` key.
-    :type version_name: string
-
-    :param gcp_conn_id: The connection ID to use when fetching connection info.
-    :type gcp_conn_id: string
-
-    :param operation: The operation to perform. Available operations are:
-        'create': Creates a new version in the model specified by `model_name`,
-            in which case the `version` parameter should contain all the
-            information to create that version
-            (e.g. `name`, `deploymentUrl`).
-        'get': Gets full information of a particular version in the model
-            specified by `model_name`.
-            The name of the version should be specified in the `version`
-            parameter.
-
-        'list': Lists all available versions of the model specified
-            by `model_name`.
-
-        'delete': Deletes the version specified in `version` parameter from the
-            model specified by `model_name`).
-            The name of the version should be specified in the `version`
-            parameter.
-     :type operation: string
-
-    :param delegate_to: The account to impersonate, if any.
-        For this to work, the service account making the request must have
-        domain-wide delegation enabled.
-    :type delegate_to: string
-    """
-
-    template_fields = [
-        '_model_name',
-        '_version',
-        '_version_name',
-    ]
-
-    @apply_defaults
-    def __init__(self,
-                 model_name,
-                 project_id,
-                 version=None,
-                 version_name=None,
-                 gcp_conn_id='google_cloud_default',
-                 operation='create',
-                 delegate_to=None,
-                 *args,
-                 **kwargs):
-
-        super(CloudMLVersionOperator, self).__init__(*args, **kwargs)
-        self._model_name = model_name
-        self._version = version or {}
-        self._version_name = version_name
-        self._gcp_conn_id = gcp_conn_id
-        self._delegate_to = delegate_to
-        self._project_id = project_id
-        self._operation = operation
-
-    def execute(self, context):
-        if 'name' not in self._version:
-            self._version['name'] = self._version_name
-
-        hook = CloudMLHook(
-            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
-
-        if self._operation == 'create':
-            assert self._version is not None
-            return hook.create_version(self._project_id, self._model_name,
-                                       self._version)
-        elif self._operation == 'set_default':
-            return hook.set_default_version(
-                self._project_id, self._model_name,
-                self._version['name'])
-        elif self._operation == 'list':
-            return hook.list_versions(self._project_id, self._model_name)
-        elif self._operation == 'delete':
-            return hook.delete_version(self._project_id, self._model_name,
-                                       self._version['name'])
-        else:
-            raise ValueError('Unknown operation: {}'.format(self._operation))
-
-
-class CloudMLTrainingOperator(BaseOperator):
-    """
-    Operator for launching a CloudML training job.
-
-    :param project_id: The Google Cloud project name within which CloudML
-        training job should run. This field could be templated.
-    :type project_id: string
-
-    :param job_id: A unique templated id for the submitted Google CloudML
-        training job.
-    :type job_id: string
-
-    :param package_uris: A list of package locations for CloudML training job,
-        which should include the main training program + any additional
-        dependencies.
-    :type package_uris: string
-
-    :param training_python_module: The Python module name to run within CloudML
-        training job after installing 'package_uris' packages.
-    :type training_python_module: string
-
-    :param training_args: A list of templated command line arguments to pass to
-        the CloudML training program.
-    :type training_args: string
-
-    :param region: The Google Compute Engine region to run the CloudML training
-        job in. This field could be templated.
-    :type region: string
-
-    :param scale_tier: Resource tier for CloudML training job.
-    :type scale_tier: string
-
-    :param gcp_conn_id: The connection ID to use when fetching connection info.
-    :type gcp_conn_id: string
-
-    :param delegate_to: The account to impersonate, if any.
-        For this to work, the service account making the request must have
-        domain-wide delegation enabled.
-    :type delegate_to: string
-
-    :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real
-        training job will be launched, but the CloudML training job request
-        will be printed out. In 'CLOUD' mode, a real CloudML training job
-        creation request will be issued.
-    :type mode: string
-    """
-
-    template_fields = [
-        '_project_id',
-        '_job_id',
-        '_package_uris',
-        '_training_python_module',
-        '_training_args',
-        '_region',
-        '_scale_tier',
-    ]
-
-    @apply_defaults
-    def __init__(self,
-                 project_id,
-                 job_id,
-                 package_uris,
-                 training_python_module,
-                 training_args,
-                 region,
-                 scale_tier=None,
-                 gcp_conn_id='google_cloud_default',
-                 delegate_to=None,
-                 mode='PRODUCTION',
-                 *args,
-                 **kwargs):
-        super(CloudMLTrainingOperator, self).__init__(*args, **kwargs)
-        self._project_id = project_id
-        self._job_id = job_id
-        self._package_uris = package_uris
-        self._training_python_module = training_python_module
-        self._training_args = training_args
-        self._region = region
-        self._scale_tier = scale_tier
-        self._gcp_conn_id = gcp_conn_id
-        self._delegate_to = delegate_to
-        self._mode = mode
-
-        if not self._project_id:
-            raise AirflowException('Google Cloud project id is required.')
-        if not self._job_id:
-            raise AirflowException(
-                'An unique job id is required for Google CloudML training '
-                'job.')
-        if not package_uris:
-            raise AirflowException(
-                'At least one python package is required for CloudML '
-                'Training job.')
-        if not training_python_module:
-            raise AirflowException(
-                'Python module name to run after installing required '
-                'packages is required.')
-        if not self._region:
-            raise AirflowException('Google Compute Engine region is required.')
-
-    def execute(self, context):
-        job_id = _normalize_cloudml_job_id(self._job_id)
-        training_request = {
-            'jobId': job_id,
-            'trainingInput': {
-                'scaleTier': self._scale_tier,
-                'packageUris': self._package_uris,
-                'pythonModule': self._training_python_module,
-                'region': self._region,
-                'args': self._training_args,
-            }
-        }
-
-        if self._mode == 'DRY_RUN':
-            logging.info('In dry_run mode.')
-            logging.info(
-                'CloudML Training job request is: {}'.format(training_request))
-            return
-
-        hook = CloudMLHook(
-            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
-
-        # Helper method to check if the existing job's training input is the
-        # same as the request we get here.
-        def check_existing_job(existing_job):
-            return existing_job.get('trainingInput', None) == \
-                training_request['trainingInput']
-        try:
-            finished_training_job = hook.create_job(
-                self._project_id, training_request, check_existing_job)
-        except errors.HttpError:
-            raise
-
-        if finished_training_job['state'] != 'SUCCEEDED':
-            logging.error('CloudML training job failed: {}'.format(
-                str(finished_training_job)))
-            raise RuntimeError(finished_training_job['errorMessage'])

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py
deleted file mode 100644
index 81cd54f..0000000
--- a/airflow/contrib/operators/cloudml_operator_utils.py
+++ /dev/null
@@ -1,245 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the 'License'); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import json
-import os
-import re
-
-import dill
-
-from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
-from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
-from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
-from airflow.exceptions import AirflowException
-from airflow.operators.python_operator import PythonOperator
-from six.moves.urllib.parse import urlsplit
-
-def create_evaluate_ops(task_prefix,
-                        data_format,
-                        input_paths,
-                        prediction_path,
-                        metric_fn_and_keys,
-                        validate_fn,
-                        batch_prediction_job_id=None,
-                        project_id=None,
-                        region=None,
-                        dataflow_options=None,
-                        model_uri=None,
-                        model_name=None,
-                        version_name=None,
-                        dag=None):
-    """
-    Creates Operators needed for model evaluation and returns.
-
-    It gets prediction over inputs via Cloud ML Engine BatchPrediction API by
-    calling CloudMLBatchPredictionOperator, then summarize and validate
-    the result via Cloud Dataflow using DataFlowPythonOperator.
-
-    For details and pricing about Batch prediction, please refer to the website
-    https://cloud.google.com/ml-engine/docs/how-tos/batch-predict
-    and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/
-
-    It returns three chained operators for prediction, summary, and validation,
-    named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation,
-    respectively.
-    (<prefix> should contain only alphanumeric characters or hyphen.)
-
-    The upstream and downstream can be set accordingly like:
-      pred, _, val = create_evaluate_ops(...)
-      pred.set_upstream(upstream_op)
-      ...
-      downstream_op.set_upstream(val)
-
-    Callers will provide two python callables, metric_fn and validate_fn, in
-    order to customize the evaluation behavior as they wish.
-    - metric_fn receives a dictionary per instance derived from json in the
-      batch prediction result. The keys might vary depending on the model.
-      It should return a tuple of metrics.
-    - validation_fn receives a dictionary of the averaged metrics that metric_fn
-      generated over all instances.
-      The key/value of the dictionary matches to what's given by
-      metric_fn_and_keys arg.
-      The dictionary contains an additional metric, 'count' to represent the
-      total number of instances received for evaluation.
-      The function would raise an exception to mark the task as failed, in a
-      case the validation result is not okay to proceed (i.e. to set the trained
-      version as default).
-
-    Typical examples are like this:
-
-    def get_metric_fn_and_keys():
-        import math  # imports should be outside of the metric_fn below.
-        def error_and_squared_error(inst):
-            label = float(inst['input_label'])
-            classes = float(inst['classes'])  # 0 or 1
-            err = abs(classes-label)
-            squared_err = math.pow(classes-label, 2)
-            return (err, squared_err)  # returns a tuple.
-        return error_and_squared_error, ['err', 'mse']  # key order must match.
-
-    def validate_err_and_count(summary):
-        if summary['err'] > 0.2:
-            raise ValueError('Too high err>0.2; summary=%s' % summary)
-        if summary['mse'] > 0.05:
-            raise ValueError('Too high mse>0.05; summary=%s' % summary)
-        if summary['count'] < 1000:
-            raise ValueError('Too few instances<1000; summary=%s' % summary)
-        return summary
-
-    For the details on the other BatchPrediction-related arguments (project_id,
-    job_id, region, data_format, input_paths, prediction_path, model_uri),
-    please refer to CloudMLBatchPredictionOperator too.
-
-    :param task_prefix: a prefix for the tasks. Only alphanumeric characters and
-        hyphen are allowed (no underscores), since this will be used as dataflow
-        job name, which doesn't allow other characters.
-    :type task_prefix: string
-
-    :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
-    :type data_format: string
-
-    :param input_paths: a list of input paths to be sent to BatchPrediction.
-    :type input_paths: list of strings
-
-    :param prediction_path: GCS path to put the prediction results in.
-    :type prediction_path: string
-
-    :param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
-        - metric_fn is a function that accepts a dictionary (for an instance),
-          and returns a tuple of metric(s) that it calculates.
-        - metric_keys is a list of strings to denote the key of each metric.
-    :type metric_fn_and_keys: tuple of a function and a list of strings
-
-    :param validate_fn: a function to validate whether the averaged metric(s) is
-        good enough to push the model.
-    :type validate_fn: function
-
-    :param batch_prediction_job_id: the id to use for the Cloud ML Batch
-        prediction job. Passed directly to the CloudMLBatchPredictionOperator as
-        the job_id argument.
-    :type batch_prediction_job_id: string
-
-    :param project_id: the Google Cloud Platform project id in which to execute
-        Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s
-        `default_args['project_id']` will be used.
-    :type project_id: string
-
-    :param region: the Google Cloud Platform region in which to execute Cloud ML
-        Batch Prediction and Dataflow jobs. If None, then the `dag`'s
-        `default_args['region']` will be used.
-    :type region: string
-
-    :param dataflow_options: options to run Dataflow jobs. If None, then the
-        `dag`'s `default_args['dataflow_default_options']` will be used.
-    :type dataflow_options: dictionary
-
-    :param model_uri: GCS path of the model exported by Tensorflow using
-        tensorflow.estimator.export_savedmodel(). It cannot be used with
-        model_name or version_name below. See CloudMLBatchPredictionOperator for
-        more detail.
-    :type model_uri: string
-
-    :param model_name: Used to indicate a model to use for prediction. Can be
-        used in combination with version_name, but cannot be used together with
-        model_uri. See CloudMLBatchPredictionOperator for more detail. If None,
-        then the `dag`'s `default_args['model_name']` will be used.
-    :type model_name: string
-
-    :param version_name: Used to indicate a model version to use for prediciton,
-        in combination with model_name. Cannot be used together with model_uri.
-        See CloudMLBatchPredictionOperator for more detail. If None, then the
-        `dag`'s `default_args['version_name']` will be used.
-    :type version_name: string
-
-    :param dag: The `DAG` to use for all Operators.
-    :type dag: airflow.DAG
-
-    :returns: a tuple of three operators, (prediction, summary, validation)
-    :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
-                  PythonOperator)
-    """
-
-    # Verify that task_prefix doesn't have any special characters except hyphen
-    # '-', which is the only allowed non-alphanumeric character by Dataflow.
-    if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix):
-        raise AirflowException(
-            "Malformed task_id for DataFlowPythonOperator (only alphanumeric "
-            "and hyphens are allowed but got: " + task_prefix)
-
-    metric_fn, metric_keys = metric_fn_and_keys
-    if not callable(metric_fn):
-        raise AirflowException("`metric_fn` param must be callable.")
-    if not callable(validate_fn):
-        raise AirflowException("`validate_fn` param must be callable.")
-
-    if dag is not None and dag.default_args is not None:
-        default_args = dag.default_args
-        project_id = project_id or default_args.get('project_id')
-        region = region or default_args.get('region')
-        model_name = model_name or default_args.get('model_name')
-        version_name = version_name or default_args.get('version_name')
-        dataflow_options = dataflow_options or \
-            default_args.get('dataflow_default_options')
-
-    evaluate_prediction = CloudMLBatchPredictionOperator(
-        task_id=(task_prefix + "-prediction"),
-        project_id=project_id,
-        job_id=batch_prediction_job_id,
-        region=region,
-        data_format=data_format,
-        input_paths=input_paths,
-        output_path=prediction_path,
-        uri=model_uri,
-        model_name=model_name,
-        version_name=version_name,
-        dag=dag)
-
-    metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True))
-    evaluate_summary = DataFlowPythonOperator(
-        task_id=(task_prefix + "-summary"),
-        py_options=["-m"],
-        py_file="airflow.contrib.operators.cloudml_prediction_summary",
-        dataflow_default_options=dataflow_options,
-        options={
-            "prediction_path": prediction_path,
-            "metric_fn_encoded": metric_fn_encoded,
-            "metric_keys": ','.join(metric_keys)
-        },
-        dag=dag)
-    evaluate_summary.set_upstream(evaluate_prediction)
-
-    def apply_validate_fn(*args, **kwargs):
-        prediction_path = kwargs["templates_dict"]["prediction_path"]
-        scheme, bucket, obj, _, _ = urlsplit(prediction_path)
-        if scheme != "gs" or not bucket or not obj:
-            raise ValueError("Wrong format prediction_path: %s",
-                             prediction_path)
-        summary = os.path.join(obj.strip("/"),
-                               "prediction.summary.json")
-        gcs_hook = GoogleCloudStorageHook()
-        summary = json.loads(gcs_hook.download(bucket, summary))
-        return validate_fn(summary)
-
-    evaluate_validation = PythonOperator(
-        task_id=(task_prefix + "-validation"),
-        python_callable=apply_validate_fn,
-        provide_context=True,
-        templates_dict={"prediction_path": prediction_path},
-        dag=dag)
-    evaluate_validation.set_upstream(evaluate_summary)
-
-    return evaluate_prediction, evaluate_summary, evaluate_validation

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_prediction_summary.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_prediction_summary.py b/airflow/contrib/operators/cloudml_prediction_summary.py
deleted file mode 100644
index 3128dc3..0000000
--- a/airflow/contrib/operators/cloudml_prediction_summary.py
+++ /dev/null
@@ -1,177 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the 'License'); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
-
-It accepts a user function to calculate the metric(s) per instance in
-the prediction results, then aggregates to output as a summary.
-
-Args:
-  --prediction_path:
-      The GCS folder that contains BatchPrediction results, containing
-      prediction.results-NNNNN-of-NNNNN files in the json format.
-      Output will be also stored in this folder, as 'prediction.summary.json'.
-
-  --metric_fn_encoded:
-      An encoded function that calculates and returns a tuple of metric(s)
-      for a given instance (as a dictionary). It should be encoded
-      via base64.b64encode(dill.dumps(fn, recurse=True)).
-
-  --metric_keys:
-      A comma-separated key(s) of the aggregated metric(s) in the summary
-      output. The order and the size of the keys must match to the output
-      of metric_fn.
-      The summary will have an additional key, 'count', to represent the
-      total number of instances, so the keys shouldn't include 'count'.
-
-# Usage example:
-def get_metric_fn():
-    import math  # all imports must be outside of the function to be passed.
-    def metric_fn(inst):
-        label = float(inst["input_label"])
-        classes = float(inst["classes"])
-        prediction = float(inst["scores"][1])
-        log_loss = math.log(1 + math.exp(
-            -(label * 2 - 1) * math.log(prediction / (1 - prediction))))
-        squared_err = (classes-label)**2
-        return (log_loss, squared_err)
-    return metric_fn
-metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
-
-airflow.contrib.operators.DataFlowPythonOperator(
-    task_id="summary-prediction",
-    py_options=["-m"],
-    py_file="airflow.contrib.operators.cloudml_prediction_summary",
-    options={
-        "prediction_path": prediction_path,
-        "metric_fn_encoded": metric_fn_encoded,
-        "metric_keys": "log_loss,mse"
-    },
-    dataflow_default_options={
-        "project": "xxx", "region": "us-east1",
-        "staging_location": "gs://yy", "temp_location": "gs://zz",
-    })
-    >> dag
-
-# When the input file is like the following:
-{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
-{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
-{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
-{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
-
-# The output file will be:
-{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
-
-# To test outside of the dag:
-subprocess.check_call(["python",
-                       "-m",
-                       "airflow.contrib.operators.cloudml_prediction_summary",
-                       "--prediction_path=gs://...",
-                       "--metric_fn_encoded=" + metric_fn_encoded,
-                       "--metric_keys=log_loss,mse",
-                       "--runner=DataflowRunner",
-                       "--staging_location=gs://...",
-                       "--temp_location=gs://...",
-                       ])
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import base64
-import json
-import logging
-import os
-
-import apache_beam as beam
-import dill
-
-
-class JsonCoder(object):
-    def encode(self, x):
-        return json.dumps(x)
-
-    def decode(self, x):
-        return json.loads(x)
-
-
-@beam.ptransform_fn
-def MakeSummary(pcoll, metric_fn, metric_keys):  # pylint: disable=invalid-name
-    return (
-        pcoll
-        | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn)
-        | "PairWith1" >> beam.Map(lambda tup: tup + (1,))
-        | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(
-            *([sum] * (len(metric_keys) + 1))))
-        | "AverageAndMakeDict" >> beam.Map(
-            lambda tup: dict(
-                [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] +
-                [("count", tup[-1])])))
-
-
-def run(argv=None):
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--prediction_path", required=True,
-        help=(
-            "The GCS folder that contains BatchPrediction results, containing "
-            "prediction.results-NNNNN-of-NNNNN files in the json format. "
-            "Output will be also stored in this folder, as a file"
-            "'prediction.summary.json'."))
-    parser.add_argument(
-        "--metric_fn_encoded", required=True,
-        help=(
-            "An encoded function that calculates and returns a tuple of "
-            "metric(s) for a given instance (as a dictionary). It should be "
-            "encoded via base64.b64encode(dill.dumps(fn, recurse=True))."))
-    parser.add_argument(
-        "--metric_keys", required=True,
-        help=(
-            "A comma-separated keys of the aggregated metric(s) in the summary "
-            "output. The order and the size of the keys must match to the "
-            "output of metric_fn. The summary will have an additional key, "
-            "'count', to represent the total number of instances, so this flag "
-            "shouldn't include 'count'."))
-    known_args, pipeline_args = parser.parse_known_args(argv)
-
-    metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded))
-    if not callable(metric_fn):
-        raise ValueError("--metric_fn_encoded must be an encoded callable.")
-    metric_keys = known_args.metric_keys.split(",")
-
-    with beam.Pipeline(
-        options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
-        # This is apache-beam ptransform's convention
-        # pylint: disable=no-value-for-parameter
-        _ = (p
-             | "ReadPredictionResult" >> beam.io.ReadFromText(
-                 os.path.join(known_args.prediction_path,
-                              "prediction.results-*-of-*"),
-                 coder=JsonCoder())
-             | "Summary" >> MakeSummary(metric_fn, metric_keys)
-             | "Write" >> beam.io.WriteToText(
-                 os.path.join(known_args.prediction_path,
-                              "prediction.summary.json"),
-                 shard_name_template='',  # without trailing -NNNNN-of-NNNNN.
-                 coder=JsonCoder()))
-        # pylint: enable=no-value-for-parameter
-
-
-if __name__ == "__main__":
-    logging.getLogger().setLevel(logging.INFO)
-    run()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py
new file mode 100644
index 0000000..7476825
--- /dev/null
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -0,0 +1,564 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from airflow import settings
+from airflow.contrib.hooks.gcp_mlengine_hook import MLEngineHook
+from airflow.exceptions import AirflowException
+from airflow.operators import BaseOperator
+from airflow.utils.decorators import apply_defaults
+from apiclient import errors
+
+
+logging.getLogger('GoogleCloudMLEngine').setLevel(settings.LOGGING_LEVEL)
+
+
+def _create_prediction_input(project_id,
+                             region,
+                             data_format,
+                             input_paths,
+                             output_path,
+                             model_name=None,
+                             version_name=None,
+                             uri=None,
+                             max_worker_count=None,
+                             runtime_version=None):
+    """
+    Create the batch prediction input from the given parameters.
+
+    Args:
+        A subset of arguments documented in __init__ method of class
+        MLEngineBatchPredictionOperator
+
+    Returns:
+        A dictionary representing the predictionInput object as documented
+        in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
+
+    Raises:
+        ValueError: if a unique model/version origin cannot be determined.
+    """
+
+    prediction_input = {
+        'dataFormat': data_format,
+        'inputPaths': input_paths,
+        'outputPath': output_path,
+        'region': region
+    }
+
+    if uri:
+        if model_name or version_name:
+            logging.error(
+                'Ambiguous model origin: Both uri and model/version name are '
+                'provided.')
+            raise ValueError('Ambiguous model origin.')
+        prediction_input['uri'] = uri
+    elif model_name:
+        origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
+        if not version_name:
+            prediction_input['modelName'] = origin_name
+        else:
+            prediction_input['versionName'] = \
+                origin_name + '/versions/{}'.format(version_name)
+    else:
+        logging.error(
+            'Missing model origin: Batch prediction expects a model, '
+            'a model & version combination, or a URI to savedModel.')
+        raise ValueError('Missing model origin.')
+
+    if max_worker_count:
+        prediction_input['maxWorkerCount'] = max_worker_count
+    if runtime_version:
+        prediction_input['runtimeVersion'] = runtime_version
+
+    return prediction_input
+
+
+def _normalize_mlengine_job_id(job_id):
+    """
+    Replaces invalid MLEngine job_id characters with '_'.
+
+    This also adds a leading 'z' in case job_id starts with an invalid
+    character.
+
+    Args:
+        job_id: A job_id str that may have invalid characters.
+
+    Returns:
+        A valid job_id representation.
+    """
+    match = re.search(r'\d', job_id)
+    if match and match.start() is 0:
+        job_id = 'z_{}'.format(job_id)
+    return re.sub('[^0-9a-zA-Z]+', '_', job_id)
+
+
+class MLEngineBatchPredictionOperator(BaseOperator):
+    """
+    Start a Google Cloud ML Engine prediction job.
+
+    NOTE: For model origin, users should consider exactly one from the
+    three options below:
+    1. Populate 'uri' field only, which should be a GCS location that
+    points to a tensorflow savedModel directory.
+    2. Populate 'model_name' field only, which refers to an existing
+    model, and the default version of the model will be used.
+    3. Populate both 'model_name' and 'version_name' fields, which
+    refers to a specific version of a specific model.
+
+    In options 2 and 3, both model and version name should contain the
+    minimal identifier. For instance, call
+        MLEngineBatchPredictionOperator(
+            ...,
+            model_name='my_model',
+            version_name='my_version',
+            ...)
+    if the desired model version is
+    "projects/my_project/models/my_model/versions/my_version".
+
+
+    :param project_id: The Google Cloud project name where the
+        prediction job is submitted.
+    :type project_id: string
+
+    :param job_id: A unique id for the prediction job on Google Cloud
+        ML Engine.
+    :type job_id: string
+
+    :param data_format: The format of the input data.
+        It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided
+        or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"].
+    :type data_format: string
+
+    :param input_paths: A list of GCS paths of input data for batch
+        prediction. Accepting wildcard operator *, but only at the end.
+    :type input_paths: list of string
+
+    :param output_path: The GCS path where the prediction results are
+        written to.
+    :type output_path: string
+
+    :param region: The Google Compute Engine region to run the
+        prediction job in.:
+    :type region: string
+
+    :param model_name: The Google Cloud ML Engine model to use for prediction.
+        If version_name is not provided, the default version of this
+        model will be used.
+        Should not be None if version_name is provided.
+        Should be None if uri is provided.
+    :type model_name: string
+
+    :param version_name: The Google Cloud ML Engine model version to use for
+        prediction.
+        Should be None if uri is provided.
+    :type version_name: string
+
+    :param uri: The GCS path of the saved model to use for prediction.
+        Should be None if model_name is provided.
+        It should be a GCS path pointing to a tensorflow SavedModel.
+    :type uri: string
+
+    :param max_worker_count: The maximum number of workers to be used
+        for parallel processing. Defaults to 10 if not specified.
+    :type max_worker_count: int
+
+    :param runtime_version: The Google Cloud ML Engine runtime version to use
+        for batch prediction.
+    :type runtime_version: string
+
+    :param gcp_conn_id: The connection ID used for connection to Google
+        Cloud Platform.
+    :type gcp_conn_id: string
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must
+        have doamin-wide delegation enabled.
+    :type delegate_to: string
+
+    Raises:
+        ValueError: if a unique model/version origin cannot be determined.
+    """
+
+    template_fields = [
+        "prediction_job_request",
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 job_id,
+                 region,
+                 data_format,
+                 input_paths,
+                 output_path,
+                 model_name=None,
+                 version_name=None,
+                 uri=None,
+                 max_worker_count=None,
+                 runtime_version=None,
+                 gcp_conn_id='google_cloud_default',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+        super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs)
+
+        self.project_id = project_id
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+
+        try:
+            prediction_input = _create_prediction_input(
+                project_id, region, data_format, input_paths, output_path,
+                model_name, version_name, uri, max_worker_count,
+                runtime_version)
+        except ValueError as e:
+            logging.error(
+                'Cannot create batch prediction job request due to: {}'
+                .format(str(e)))
+            raise
+
+        self.prediction_job_request = {
+            'jobId': _normalize_mlengine_job_id(job_id),
+            'predictionInput': prediction_input
+        }
+
+    def execute(self, context):
+        hook = MLEngineHook(self.gcp_conn_id, self.delegate_to)
+
+        def check_existing_job(existing_job):
+            return existing_job.get('predictionInput', None) == \
+                self.prediction_job_request['predictionInput']
+        try:
+            finished_prediction_job = hook.create_job(
+                self.project_id,
+                self.prediction_job_request,
+                check_existing_job)
+        except errors.HttpError:
+            raise
+
+        if finished_prediction_job['state'] != 'SUCCEEDED':
+            logging.error(
+                'Batch prediction job failed: %s',
+                str(finished_prediction_job))
+            raise RuntimeError(finished_prediction_job['errorMessage'])
+
+        return finished_prediction_job['predictionOutput']
+
+
+class MLEngineModelOperator(BaseOperator):
+    """
+    Operator for managing a Google Cloud ML Engine model.
+
+    :param project_id: The Google Cloud project name to which MLEngine
+        model belongs.
+    :type project_id: string
+
+    :param model: A dictionary containing the information about the model.
+        If the `operation` is `create`, then the `model` parameter should
+        contain all the information about this model such as `name`.
+
+        If the `operation` is `get`, the `model` parameter
+        should contain the `name` of the model.
+    :type model: dict
+
+    :param operation: The operation to perform. Available operations are:
+        'create': Creates a new model as provided by the `model` parameter.
+        'get': Gets a particular model where the name is specified in `model`.
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :type gcp_conn_id: string
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: string
+    """
+
+    template_fields = [
+        '_model',
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 model,
+                 operation='create',
+                 gcp_conn_id='google_cloud_default',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+        super(MLEngineModelOperator, self).__init__(*args, **kwargs)
+        self._project_id = project_id
+        self._model = model
+        self._operation = operation
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
+
+    def execute(self, context):
+        hook = MLEngineHook(
+            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+        if self._operation == 'create':
+            return hook.create_model(self._project_id, self._model)
+        elif self._operation == 'get':
+            return hook.get_model(self._project_id, self._model['name'])
+        else:
+            raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
+class MLEngineVersionOperator(BaseOperator):
+    """
+    Operator for managing a Google Cloud ML Engine version.
+
+    :param project_id: The Google Cloud project name to which MLEngine
+        model belongs.
+    :type project_id: string
+
+    :param model_name: The name of the Google Cloud ML Engine model that the version
+        belongs to.
+    :type model_name: string
+
+    :param version_name: A name to use for the version being operated upon. If
+        not None and the `version` argument is None or does not have a value for
+        the `name` key, then this will be populated in the payload for the
+        `name` key.
+    :type version_name: string
+
+    :param version: A dictionary containing the information about the version.
+        If the `operation` is `create`, `version` should contain all the
+        information about this version such as name, and deploymentUrl.
+        If the `operation` is `get` or `delete`, the `version` parameter
+        should contain the `name` of the version.
+        If it is None, the only `operation` possible would be `list`.
+    :type version: dict
+
+    :param operation: The operation to perform. Available operations are:
+        'create': Creates a new version in the model specified by `model_name`,
+            in which case the `version` parameter should contain all the
+            information to create that version
+            (e.g. `name`, `deploymentUrl`).
+        'get': Gets full information of a particular version in the model
+            specified by `model_name`.
+            The name of the version should be specified in the `version`
+            parameter.
+
+        'list': Lists all available versions of the model specified
+            by `model_name`.
+
+        'delete': Deletes the version specified in `version` parameter from the
+            model specified by `model_name`).
+            The name of the version should be specified in the `version`
+            parameter.
+     :type operation: string
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :type gcp_conn_id: string
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: string
+    """
+
+    template_fields = [
+        '_model_name',
+        '_version_name',
+        '_version',
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 model_name,
+                 version_name=None,
+                 version=None,
+                 operation='create',
+                 gcp_conn_id='google_cloud_default',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+
+        super(MLEngineVersionOperator, self).__init__(*args, **kwargs)
+        self._project_id = project_id
+        self._model_name = model_name
+        self._version_name = version_name
+        self._version = version or {}
+        self._operation = operation
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
+
+    def execute(self, context):
+        if 'name' not in self._version:
+            self._version['name'] = self._version_name
+
+        hook = MLEngineHook(
+            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+
+        if self._operation == 'create':
+            assert self._version is not None
+            return hook.create_version(self._project_id, self._model_name,
+                                       self._version)
+        elif self._operation == 'set_default':
+            return hook.set_default_version(
+                self._project_id, self._model_name,
+                self._version['name'])
+        elif self._operation == 'list':
+            return hook.list_versions(self._project_id, self._model_name)
+        elif self._operation == 'delete':
+            return hook.delete_version(self._project_id, self._model_name,
+                                       self._version['name'])
+        else:
+            raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
+class MLEngineTrainingOperator(BaseOperator):
+    """
+    Operator for launching a MLEngine training job.
+
+    :param project_id: The Google Cloud project name within which MLEngine
+        training job should run. This field could be templated.
+    :type project_id: string
+
+    :param job_id: A unique templated id for the submitted Google MLEngine
+        training job.
+    :type job_id: string
+
+    :param package_uris: A list of package locations for MLEngine training job,
+        which should include the main training program + any additional
+        dependencies.
+    :type package_uris: string
+
+    :param training_python_module: The Python module name to run within MLEngine
+        training job after installing 'package_uris' packages.
+    :type training_python_module: string
+
+    :param training_args: A list of templated command line arguments to pass to
+        the MLEngine training program.
+    :type training_args: string
+
+    :param region: The Google Compute Engine region to run the MLEngine training
+        job in. This field could be templated.
+    :type region: string
+
+    :param scale_tier: Resource tier for MLEngine training job.
+    :type scale_tier: string
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :type gcp_conn_id: string
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: string
+
+    :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real
+        training job will be launched, but the MLEngine training job request
+        will be printed out. In 'CLOUD' mode, a real MLEngine training job
+        creation request will be issued.
+    :type mode: string
+    """
+
+    template_fields = [
+        '_project_id',
+        '_job_id',
+        '_package_uris',
+        '_training_python_module',
+        '_training_args',
+        '_region',
+        '_scale_tier',
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 job_id,
+                 package_uris,
+                 training_python_module,
+                 training_args,
+                 region,
+                 scale_tier=None,
+                 gcp_conn_id='google_cloud_default',
+                 delegate_to=None,
+                 mode='PRODUCTION',
+                 *args,
+                 **kwargs):
+        super(MLEngineTrainingOperator, self).__init__(*args, **kwargs)
+        self._project_id = project_id
+        self._job_id = job_id
+        self._package_uris = package_uris
+        self._training_python_module = training_python_module
+        self._training_args = training_args
+        self._region = region
+        self._scale_tier = scale_tier
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
+        self._mode = mode
+
+        if not self._project_id:
+            raise AirflowException('Google Cloud project id is required.')
+        if not self._job_id:
+            raise AirflowException(
+                'An unique job id is required for Google MLEngine training '
+                'job.')
+        if not package_uris:
+            raise AirflowException(
+                'At least one python package is required for MLEngine '
+                'Training job.')
+        if not training_python_module:
+            raise AirflowException(
+                'Python module name to run after installing required '
+                'packages is required.')
+        if not self._region:
+            raise AirflowException('Google Compute Engine region is required.')
+
+    def execute(self, context):
+        job_id = _normalize_mlengine_job_id(self._job_id)
+        training_request = {
+            'jobId': job_id,
+            'trainingInput': {
+                'scaleTier': self._scale_tier,
+                'packageUris': self._package_uris,
+                'pythonModule': self._training_python_module,
+                'region': self._region,
+                'args': self._training_args,
+            }
+        }
+
+        if self._mode == 'DRY_RUN':
+            logging.info('In dry_run mode.')
+            logging.info(
+                'MLEngine Training job request is: {}'.format(training_request))
+            return
+
+        hook = MLEngineHook(
+            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+
+        # Helper method to check if the existing job's training input is the
+        # same as the request we get here.
+        def check_existing_job(existing_job):
+            return existing_job.get('trainingInput', None) == \
+                training_request['trainingInput']
+        try:
+            finished_training_job = hook.create_job(
+                self._project_id, training_request, check_existing_job)
+        except errors.HttpError:
+            raise
+
+        if finished_training_job['state'] != 'SUCCEEDED':
+            logging.error('MLEngine training job failed: {}'.format(
+                str(finished_training_job)))
+            raise RuntimeError(finished_training_job['errorMessage'])