You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/06/29 15:54:42 UTC

incubator-airflow git commit: [AIRFLOW-1272] Google Cloud ML Batch Prediction Operator

Repository: incubator-airflow
Updated Branches:
  refs/heads/master e414844b7 -> e92d6bf72


[AIRFLOW-1272] Google Cloud ML Batch Prediction Operator

Closes #2390 from
jiwang576/GCP_CML_batch_prediction


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/e92d6bf7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/e92d6bf7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/e92d6bf7

Branch: refs/heads/master
Commit: e92d6bf72a6dbc192bdbb2f8c0266a81be9e2676
Parents: e414844
Author: Ji Wang <ji...@berkeley.edu>
Authored: Thu Jun 29 08:54:30 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Thu Jun 29 08:54:30 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/gcp_cloudml_hook.py       | 116 +++++-
 airflow/contrib/operators/cloudml_operator.py   | 356 +++++++++++++++----
 tests/contrib/hooks/test_gcp_cloudml_hook.py    |  79 +++-
 .../contrib/operators/test_cloudml_operator.py  | 288 +++++++++++++++
 4 files changed, 755 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py
index e722b2a..3af8508 100644
--- a/airflow/contrib/hooks/gcp_cloudml_hook.py
+++ b/airflow/contrib/hooks/gcp_cloudml_hook.py
@@ -33,7 +33,8 @@ def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
         try:
             response = request.execute()
             if is_error_func(response):
-                raise ValueError('The response contained an error: {}'.format(response))
+                raise ValueError(
+                    'The response contained an error: {}'.format(response))
             elif is_done_func(response):
                 logging.info('Operation is done: {}'.format(response))
                 return response
@@ -41,8 +42,9 @@ def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
                 time.sleep((2**i) + (random.randint(0, 1000) / 1000))
         except errors.HttpError as e:
             if e.resp.status != 429:
-                logging.info('Something went wrong. Not retrying: {}'.format(e))
-                raise e
+                logging.info(
+                    'Something went wrong. Not retrying: {}'.format(e))
+                raise
             else:
                 time.sleep((2**i) + (random.randint(0, 1000) / 1000))
 
@@ -60,12 +62,109 @@ class CloudMLHook(GoogleCloudBaseHook):
         credentials = GoogleCredentials.get_application_default()
         return build('ml', 'v1', credentials=credentials)
 
+    def create_job(self, project_name, job):
+        """
+        Creates and executes a CloudML job.
+
+        Returns the job object if the job was created and finished
+        successfully, or raises an error otherwise.
+
+        Raises:
+            apiclient.errors.HttpError: if the job cannot be created
+            successfully
+
+        project_name is the name of the project to use, such as
+        'my-project'
+
+        job is the complete Cloud ML Job object that should be provided to the
+        Cloud ML API, such as
+
+        {
+          'jobId': 'my_job_id',
+          'trainingInput': {
+            'scaleTier': 'STANDARD_1',
+            ...
+          }
+        }
+        """
+        request = self._cloudml.projects().jobs().create(
+            parent='projects/{}'.format(project_name),
+            body=job)
+        job_id = job['jobId']
+
+        try:
+            request.execute()
+            return self._wait_for_job_done(project_name, job_id)
+        except errors.HttpError as e:
+            if e.resp.status == 409:
+                existing_job = self._get_job(project_name, job_id)
+                logging.info(
+                    'Job with job_id {} already exist: {}.'.format(
+                        job_id,
+                        existing_job))
+
+                if existing_job.get('predictionInput', None) == \
+                        job['predictionInput']:
+                    return self._wait_for_job_done(project_name, job_id)
+                else:
+                    logging.error(
+                        'Job with job_id {} already exists, but the '
+                        'predictionInput mismatch: {}'
+                        .format(job_id, existing_job))
+                    raise ValueError(
+                        'Found a existing job with job_id {}, but with '
+                        'different predictionInput.'.format(job_id))
+            else:
+                logging.error('Failed to create CloudML job: {}'.format(e))
+                raise
+
+    def _get_job(self, project_name, job_id):
+        """
+        Gets a CloudML job based on the job name.
+
+        :return: CloudML job object if succeed.
+        :rtype: dict
+
+        Raises:
+            apiclient.errors.HttpError: if HTTP error is returned from server
+        """
+        job_name = 'projects/{}/jobs/{}'.format(project_name, job_id)
+        request = self._cloudml.projects().jobs().get(name=job_name)
+        while True:
+            try:
+                return request.execute()
+            except errors.HttpError as e:
+                if e.resp.status == 429:
+                    # polling after 30 seconds when quota failure occurs
+                    time.sleep(30)
+                else:
+                    logging.error('Failed to get CloudML job: {}'.format(e))
+                    raise
+
+    def _wait_for_job_done(self, project_name, job_id, interval=30):
+        """
+        Waits for the Job to reach a terminal state.
+
+        This method will periodically check the job state until the job reach
+        a terminal state.
+
+        Raises:
+            apiclient.errors.HttpError: if HTTP error is returned when getting
+            the job
+        """
+        assert interval > 0
+        while True:
+            job = self._get_job(project_name, job_id)
+            if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
+                return job
+            time.sleep(interval)
+
     def create_version(self, project_name, model_name, version_spec):
         """
         Creates the Version on Cloud ML.
 
-        Returns the operation if the version was created successfully and raises
-        an error otherwise.
+        Returns the operation if the version was created successfully and
+        raises an error otherwise.
         """
         parent_name = 'projects/{}/models/{}'.format(project_name, model_name)
         create_request = self._cloudml.projects().models().versions().create(
@@ -91,11 +190,12 @@ class CloudMLHook(GoogleCloudBaseHook):
 
         try:
             response = request.execute()
-            logging.info('Successfully set version: {} to default'.format(response))
+            logging.info(
+                'Successfully set version: {} to default'.format(response))
             return response
         except errors.HttpError as e:
             logging.error('Something went wrong: {}'.format(e))
-            raise e
+            raise
 
     def list_versions(self, project_name, model_name):
         """
@@ -164,4 +264,4 @@ class CloudMLHook(GoogleCloudBaseHook):
             if e.resp.status == 404:
                 logging.error('Model was not found: {}'.format(e))
                 return None
-            raise e
+            raise

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
index b0b6e91..871cc73 100644
--- a/airflow/contrib/operators/cloudml_operator.py
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -14,16 +14,307 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import logging
+import re
+
 from airflow import settings
-from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
 from airflow.operators import BaseOperator
+from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
 from airflow.utils.decorators import apply_defaults
+from apiclient import errors
+
 
 logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
 
 
+def _create_prediction_input(project_id,
+                             region,
+                             data_format,
+                             input_paths,
+                             output_path,
+                             model_name=None,
+                             version_name=None,
+                             uri=None,
+                             max_worker_count=None,
+                             runtime_version=None):
+    """
+    Create the batch prediction input from the given parameters.
+
+    Args:
+        A subset of arguments documented in __init__ method of class
+        CloudMLBatchPredictionOperator
+
+    Returns:
+        A dictionary representing the predictionInput object as documented
+        in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
+
+    Raises:
+        ValueError: if a unique model/version origin cannot be determined.
+    """
+
+    prediction_input = {
+        'dataFormat': data_format,
+        'inputPaths': input_paths,
+        'outputPath': output_path,
+        'region': region
+    }
+
+    if uri:
+        if model_name or version_name:
+            logging.error(
+                'Ambiguous model origin: Both uri and model/version name are '
+                'provided.')
+            raise ValueError('Ambiguous model origin.')
+        prediction_input['uri'] = uri
+    elif model_name:
+        origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
+        if not version_name:
+            prediction_input['modelName'] = origin_name
+        else:
+            prediction_input['versionName'] = \
+                origin_name + '/versions/{}'.format(version_name)
+    else:
+        logging.error(
+            'Missing model origin: Batch prediction expects a model, '
+            'a model & version combination, or a URI to savedModel.')
+        raise ValueError('Missing model origin.')
+
+    if max_worker_count:
+        prediction_input['maxWorkerCount'] = max_worker_count
+    if runtime_version:
+        prediction_input['runtimeVersion'] = runtime_version
+
+    return prediction_input
+
+
+def _normalize_cloudml_job_id(job_id):
+    """
+    Replaces invalid CloudML job_id characters with '_'.
+
+    This also adds a leading 'z' in case job_id starts with an invalid
+    character.
+
+    Args:
+        job_id: A job_id str that may have invalid characters.
+
+    Returns:
+        A valid job_id representation.
+    """
+    match = re.search(r'\d', job_id)
+    if match and match.start() is 0:
+        job_id = 'z_{}'.format(job_id)
+    return re.sub('[^0-9a-zA-Z]+', '_', job_id)
+
+
+class CloudMLBatchPredictionOperator(BaseOperator):
+    """
+    Start a Cloud ML prediction job.
+
+    NOTE: For model origin, users should consider exactly one from the
+    three options below:
+    1. Populate 'uri' field only, which should be a GCS location that
+    points to a tensorflow savedModel directory.
+    2. Populate 'model_name' field only, which refers to an existing
+    model, and the default version of the model will be used.
+    3. Populate both 'model_name' and 'version_name' fields, which
+    refers to a specific version of a specific model.
+
+    In options 2 and 3, both model and version name should contain the
+    minimal identifier. For instance, call
+        CloudMLBatchPredictionOperator(
+            ...,
+            model_name='my_model',
+            version_name='my_version',
+            ...)
+    if the desired model version is
+    "projects/my_project/models/my_model/versions/my_version".
+
+
+    :param project_id: The Google Cloud project name where the
+        prediction job is submitted.
+    :type project_id: string
+
+    :param job_id: A unique id for the prediction job on Google Cloud
+        ML Engine.
+    :type job_id: string
+
+    :param data_format: The format of the input data.
+        It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided
+        or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"].
+    :type data_format: string
+
+    :param input_paths: A list of GCS paths of input data for batch
+        prediction. Accepting wildcard operator *, but only at the end.
+    :type input_paths: list of string
+
+    :param output_path: The GCS path where the prediction results are
+        written to.
+    :type output_path: string
+
+    :param region: The Google Compute Engine region to run the
+        prediction job in.:
+    :type region: string
+
+    :param model_name: The Google Cloud ML model to use for prediction.
+        If version_name is not provided, the default version of this
+        model will be used.
+        Should not be None if version_name is provided.
+        Should be None if uri is provided.
+    :type model_name: string
+
+    :param version_name: The Google Cloud ML model version to use for
+        prediction.
+        Should be None if uri is provided.
+    :type version_name: string
+
+    :param uri: The GCS path of the saved model to use for prediction.
+        Should be None if model_name is provided.
+        It should be a GCS path pointing to a tensorflow SavedModel.
+    :type uri: string
+
+    :param max_worker_count: The maximum number of workers to be used
+        for parallel processing. Defaults to 10 if not specified.
+    :type max_worker_count: int
+
+    :param runtime_version: The Google Cloud ML runtime version to use
+        for batch prediction.
+    :type runtime_version: string
+
+    :param gcp_conn_id: The connection ID used for connection to Google
+        Cloud Platform.
+    :type gcp_conn_id: string
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must
+        have doamin-wide delegation enabled.
+    :type delegate_to: string
+
+    Raises:
+        ValueError: if a unique model/version origin cannot be determined.
+    """
+
+    template_fields = [
+        "prediction_job_request",
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 job_id,
+                 region,
+                 data_format,
+                 input_paths,
+                 output_path,
+                 model_name=None,
+                 version_name=None,
+                 uri=None,
+                 max_worker_count=None,
+                 runtime_version=None,
+                 gcp_conn_id='google_cloud_default',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+        super(CloudMLBatchPredictionOperator, self).__init__(*args, **kwargs)
+
+        self.project_id = project_id
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+
+        try:
+            prediction_input = _create_prediction_input(
+                project_id, region, data_format, input_paths, output_path,
+                model_name, version_name, uri, max_worker_count,
+                runtime_version)
+        except ValueError as e:
+            logging.error(
+                'Cannot create batch prediction job request due to: {}'
+                .format(str(e)))
+            raise
+
+        self.prediction_job_request = {
+            'jobId': _normalize_cloudml_job_id(job_id),
+            'predictionInput': prediction_input
+        }
+
+    def execute(self, context):
+        hook = CloudMLHook(self.gcp_conn_id, self.delegate_to)
+
+        try:
+            finished_prediction_job = hook.create_job(
+                self.project_id,
+                self.prediction_job_request)
+        except errors.HttpError:
+            raise
+
+        if finished_prediction_job['state'] != 'SUCCEEDED':
+            logging.error(
+                'Batch prediction job failed: %s',
+                str(finished_prediction_job))
+            raise RuntimeError(finished_prediction_job['errorMessage'])
+
+        return finished_prediction_job['predictionOutput']
+
+
+class CloudMLModelOperator(BaseOperator):
+    """
+    Operator for managing a Google Cloud ML model.
+
+    :param model: A dictionary containing the information about the model.
+        If the `operation` is `create`, then the `model` parameter should
+        contain all the information about this model such as `name`.
+
+        If the `operation` is `get`, the `model` parameter
+        should contain the `name` of the model.
+    :type model: dict
+
+    :param project_name: The Google Cloud project name to which CloudML
+        model belongs.
+    :type project_name: string
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :type gcp_conn_id: string
+
+    :param operation: The operation to perform. Available operations are:
+        'create': Creates a new model as provided by the `model` parameter.
+        'get': Gets a particular model where the name is specified in `model`.
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: string
+    """
+
+    template_fields = [
+        '_model',
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 model,
+                 project_name,
+                 gcp_conn_id='google_cloud_default',
+                 operation='create',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+        super(CloudMLModelOperator, self).__init__(*args, **kwargs)
+        self._model = model
+        self._operation = operation
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
+        self._project_name = project_name
+
+    def execute(self, context):
+        hook = CloudMLHook(
+            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+        if self._operation == 'create':
+            hook.create_model(self._project_name, self._model)
+        elif self._operation == 'get':
+            hook.get_model(self._project_name, self._model['name'])
+        else:
+            raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
 class CloudMLVersionOperator(BaseOperator):
     """
     Operator for managing a Google Cloud ML version.
@@ -72,7 +363,6 @@ class CloudMLVersionOperator(BaseOperator):
     :type delegate_to: string
     """
 
-
     template_fields = [
         '_model_name',
         '_version',
@@ -116,63 +406,3 @@ class CloudMLVersionOperator(BaseOperator):
                                        self._version['name'])
         else:
             raise ValueError('Unknown operation: {}'.format(self._operation))
-
-
-class CloudMLModelOperator(BaseOperator):
-    """
-    Operator for managing a Google Cloud ML model.
-
-    :param model: A dictionary containing the information about the model.
-        If the `operation` is `create`, then the `model` parameter should
-        contain all the information about this model such as `name`.
-
-        If the `operation` is `get`, the `model` parameter
-        should contain the `name` of the model.
-    :type model: dict
-
-    :param project_name: The Google Cloud project name to which CloudML
-        model belongs.
-    :type project_name: string
-
-    :param gcp_conn_id: The connection ID to use when fetching connection info.
-    :type gcp_conn_id: string
-
-    :param operation: The operation to perform. Available operations are:
-        'create': Creates a new model as provided by the `model` parameter.
-        'get': Gets a particular model where the name is specified in `model`.
-
-    :param delegate_to: The account to impersonate, if any.
-        For this to work, the service account making the request must have
-        domain-wide delegation enabled.
-    :type delegate_to: string
-    """
-
-    template_fields = [
-        '_model',
-    ]
-
-    @apply_defaults
-    def __init__(self,
-                 model,
-                 project_name,
-                 gcp_conn_id='google_cloud_default',
-                 operation='create',
-                 delegate_to=None,
-                 *args,
-                 **kwargs):
-        super(CloudMLModelOperator, self).__init__(*args, **kwargs)
-        self._model = model
-        self._operation = operation
-        self._gcp_conn_id = gcp_conn_id
-        self._delegate_to = delegate_to
-        self._project_name = project_name
-
-    def execute(self, context):
-        hook = CloudMLHook(
-            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
-        if self._operation == 'create':
-            hook.create_model(self._project_name, self._model)
-        elif self._operation == 'get':
-            hook.get_model(self._project_name, self._model['name'])
-        else:
-            raise ValueError('Unknown operation: {}'.format(self._operation))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py
index aa50e69..e34e05f 100644
--- a/tests/contrib/hooks/test_gcp_cloudml_hook.py
+++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py
@@ -13,9 +13,10 @@
 import json
 import mock
 import unittest
-try: # python 2
+
+try:  # python 2
     from urlparse import urlparse, parse_qsl
-except ImportError: #python 3
+except ImportError:  # python 3
     from urllib.parse import urlparse, parse_qsl
 
 from airflow.contrib.hooks import gcp_cloudml_hook as hook
@@ -49,12 +50,16 @@ class _TestCloudMLHook(object):
         self._test_cls = test_cls
         self._responses = responses
         self._expected_requests = [
-            self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in expected_requests]
+            self._normalize_requests_for_comparison(x[0], x[1], x[2])
+            for x in expected_requests]
         self._actual_requests = []
 
     def _normalize_requests_for_comparison(self, uri, http_method, body):
         parts = urlparse(uri)
-        return (parts._replace(query=set(parse_qsl(parts.query))), http_method, body)
+        return (
+            parts._replace(query=set(parse_qsl(parts.query))),
+            http_method,
+            body)
 
     def __enter__(self):
         http = HttpMockSequence(self._responses)
@@ -76,7 +81,9 @@ class _TestCloudMLHook(object):
         if any(args):
             return None
         self._test_cls.assertEquals(
-            [self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in self._actual_requests], self._expected_requests)
+            [self._normalize_requests_for_comparison(x[0], x[1], x[2])
+                for x in self._actual_requests],
+            self._expected_requests)
 
 
 class TestCloudMLHook(unittest.TestCase):
@@ -86,6 +93,7 @@ class TestCloudMLHook(unittest.TestCase):
 
     _SKIP_IF = unittest.skipIf(not cml_available,
                                'CloudML is not available to run tests')
+
     _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/'
 
     @_SKIP_IF
@@ -112,7 +120,8 @@ class TestCloudMLHook(unittest.TestCase):
                 responses=[succeeded_response] * 2,
                 expected_requests=expected_requests) as cml_hook:
             create_version_response = cml_hook.create_version(
-                project_name=project, model_name=model_name, version_spec=version)
+                project_name=project, model_name=model_name,
+                version_spec=version)
             self.assertEquals(create_version_response, response_body)
 
     @_SKIP_IF
@@ -137,7 +146,8 @@ class TestCloudMLHook(unittest.TestCase):
                 responses=[succeeded_response],
                 expected_requests=expected_requests) as cml_hook:
             set_default_version_response = cml_hook.set_default_version(
-                project_name=project, model_name=model_name, version_name=version)
+                project_name=project, model_name=model_name,
+                version_name=version)
             self.assertEquals(set_default_version_response, response_body)
 
     @_SKIP_IF
@@ -150,8 +160,12 @@ class TestCloudMLHook(unittest.TestCase):
         # This test returns the versions one at a time.
         versions = ['ver_{}'.format(ix) for ix in range(3)]
 
-        response_bodies = [{'name': operation_name, 'nextPageToken': ix, 'versions': [
-            ver]} for ix, ver in enumerate(versions)]
+        response_bodies = [
+            {
+                'name': operation_name,
+                'nextPageToken': ix,
+                'versions': [ver]
+            } for ix, ver in enumerate(versions)]
         response_bodies[-1].pop('nextPageToken')
         responses = [({'status': '200'}, json.dumps(body))
                      for body in response_bodies]
@@ -190,9 +204,11 @@ class TestCloudMLHook(unittest.TestCase):
             {'status': '200'}, json.dumps(done_response_body))
 
         expected_requests = [
-            ('{}projects/{}/models/{}/versions/{}?alt=json'.format(
-                self._SERVICE_URI_PREFIX, project, model_name, version), 'DELETE',
-             None),
+            (
+                '{}projects/{}/models/{}/versions/{}?alt=json'.format(
+                    self._SERVICE_URI_PREFIX, project, model_name, version),
+                'DELETE',
+                None),
             ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
              'GET', None),
         ]
@@ -202,7 +218,8 @@ class TestCloudMLHook(unittest.TestCase):
                 responses=[not_done_response, succeeded_response],
                 expected_requests=expected_requests) as cml_hook:
             delete_version_response = cml_hook.delete_version(
-                project_name=project, model_name=model_name, version_name=version)
+                project_name=project, model_name=model_name,
+                version_name=version)
             self.assertEquals(delete_version_response, done_response_body)
 
     @_SKIP_IF
@@ -250,6 +267,42 @@ class TestCloudMLHook(unittest.TestCase):
                 project_name=project, model_name=model_name)
             self.assertEquals(get_model_response, response_body)
 
+    @_SKIP_IF
+    def test_create_cloudml_job(self):
+        project = 'test-project'
+        job_id = 'test-job-id'
+        my_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+        }
+        response_body = json.dumps(my_job)
+        succeeded_response = ({'status': '200'}, response_body)
+        queued_response = ({'status': '200'}, json.dumps({
+            'jobId': job_id,
+            'state': 'QUEUED',
+        }))
+
+        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+            ask_if_done_request,
+        ]
+        responses = [succeeded_response,
+                     queued_response, succeeded_response]
+
+        with _TestCloudMLHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            create_job_response = cml_hook.create_job(
+                project_name=project, job=my_job)
+            self.assertEquals(create_job_response, my_job)
+
 
 if __name__ == '__main__':
     unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/tests/contrib/operators/test_cloudml_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py
new file mode 100644
index 0000000..b76a0c6
--- /dev/null
+++ b/tests/contrib/operators/test_cloudml_operator.py
@@ -0,0 +1,288 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+from apiclient import errors
+import httplib2
+import unittest
+
+from airflow import configuration, DAG
+from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
+
+from mock import patch
+
+DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+
+INPUT_MISSING_ORIGIN = {
+    'dataFormat': 'TEXT',
+    'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+    'outputPath': 'gs://legal-bucket/fake-output-path',
+    'region': 'us-east1',
+}
+
+SUCCESS_MESSAGE_MISSING_INPUT = {
+    'jobId': 'test_prediction',
+    'predictionOutput': {
+        'outputPath': 'gs://fake-output-path',
+        'predictionCount': 5000,
+        'errorCount': 0,
+        'nodeHours': 2.78
+    },
+    'state': 'SUCCEEDED'
+}
+
+DEFAULT_ARGS = {
+    'project_id': 'test-project',
+    'job_id': 'test_prediction',
+    'region': 'us-east1',
+    'data_format': 'TEXT',
+    'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
+    'output_path': 'gs://12_legal_bucket_underscore_number/legal-output-path',
+    'task_id': 'test-prediction'
+}
+
+
+class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
+
+    def setUp(self):
+        super(CloudMLBatchPredictionOperatorTest, self).setUp()
+        configuration.load_test_config()
+        self.dag = DAG(
+            'test_dag',
+            default_args={
+                'owner': 'airflow',
+                'start_date': DEFAULT_DATE,
+                'end_date': DEFAULT_DATE,
+            },
+            schedule_interval='@daily')
+
+    def testSuccessWithModel(self):
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+
+            input_with_model = INPUT_MISSING_ORIGIN.copy()
+            input_with_model['modelName'] = \
+                'projects/test-project/models/test_model'
+            success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+            success_message['predictionInput'] = input_with_model
+
+            hook_instance = mock_hook.return_value
+            hook_instance.get_job.side_effect = errors.HttpError(
+                resp=httplib2.Response({
+                    'status': 404
+                }), content=b'some bytes')
+            hook_instance.create_job.return_value = success_message
+
+            prediction_task = CloudMLBatchPredictionOperator(
+                job_id='test_prediction',
+                project_id='test-project',
+                region=input_with_model['region'],
+                data_format=input_with_model['dataFormat'],
+                input_paths=input_with_model['inputPaths'],
+                output_path=input_with_model['outputPath'],
+                model_name=input_with_model['modelName'].split('/')[-1],
+                dag=self.dag,
+                task_id='test-prediction')
+            prediction_output = prediction_task.execute(None)
+
+            mock_hook.assert_called_with('google_cloud_default', None)
+            hook_instance.create_job.assert_called_with(
+                'test-project',
+                {
+                    'jobId': 'test_prediction',
+                    'predictionInput': input_with_model
+                })
+            self.assertEquals(
+                success_message['predictionOutput'],
+                prediction_output)
+
+    def testSuccessWithVersion(self):
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+
+            input_with_version = INPUT_MISSING_ORIGIN.copy()
+            input_with_version['versionName'] = \
+                'projects/test-project/models/test_model/versions/test_version'
+            success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+            success_message['predictionInput'] = input_with_version
+
+            hook_instance = mock_hook.return_value
+            hook_instance.get_job.side_effect = errors.HttpError(
+                resp=httplib2.Response({
+                    'status': 404
+                }), content=b'some bytes')
+            hook_instance.create_job.return_value = success_message
+
+            prediction_task = CloudMLBatchPredictionOperator(
+                job_id='test_prediction',
+                project_id='test-project',
+                region=input_with_version['region'],
+                data_format=input_with_version['dataFormat'],
+                input_paths=input_with_version['inputPaths'],
+                output_path=input_with_version['outputPath'],
+                model_name=input_with_version['versionName'].split('/')[-3],
+                version_name=input_with_version['versionName'].split('/')[-1],
+                dag=self.dag,
+                task_id='test-prediction')
+            prediction_output = prediction_task.execute(None)
+
+            mock_hook.assert_called_with('google_cloud_default', None)
+            hook_instance.create_job.assert_called_with(
+                'test-project',
+                {
+                    'jobId': 'test_prediction',
+                    'predictionInput': input_with_version
+                })
+            self.assertEquals(
+                success_message['predictionOutput'],
+                prediction_output)
+
+    def testSuccessWithURI(self):
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+
+            input_with_uri = INPUT_MISSING_ORIGIN.copy()
+            input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
+            success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+            success_message['predictionInput'] = input_with_uri
+
+            hook_instance = mock_hook.return_value
+            hook_instance.get_job.side_effect = errors.HttpError(
+                resp=httplib2.Response({
+                    'status': 404
+                }), content=b'some bytes')
+            hook_instance.create_job.return_value = success_message
+
+            prediction_task = CloudMLBatchPredictionOperator(
+                job_id='test_prediction',
+                project_id='test-project',
+                region=input_with_uri['region'],
+                data_format=input_with_uri['dataFormat'],
+                input_paths=input_with_uri['inputPaths'],
+                output_path=input_with_uri['outputPath'],
+                uri=input_with_uri['uri'],
+                dag=self.dag,
+                task_id='test-prediction')
+            prediction_output = prediction_task.execute(None)
+
+            mock_hook.assert_called_with('google_cloud_default', None)
+            hook_instance.create_job.assert_called_with(
+                'test-project',
+                {
+                    'jobId': 'test_prediction',
+                    'predictionInput': input_with_uri
+                })
+            self.assertEquals(
+                success_message['predictionOutput'],
+                prediction_output)
+
+    def testInvalidModelOrigin(self):
+        # Test that both uri and model is given
+        task_args = DEFAULT_ARGS.copy()
+        task_args['uri'] = 'gs://fake-uri/saved_model'
+        task_args['model_name'] = 'fake_model'
+        with self.assertRaises(ValueError) as context:
+            CloudMLBatchPredictionOperator(**task_args).execute(None)
+        self.assertEquals('Ambiguous model origin.', str(context.exception))
+
+        # Test that both uri and model/version is given
+        task_args = DEFAULT_ARGS.copy()
+        task_args['uri'] = 'gs://fake-uri/saved_model'
+        task_args['model_name'] = 'fake_model'
+        task_args['version_name'] = 'fake_version'
+        with self.assertRaises(ValueError) as context:
+            CloudMLBatchPredictionOperator(**task_args).execute(None)
+        self.assertEquals('Ambiguous model origin.', str(context.exception))
+
+        # Test that a version is given without a model
+        task_args = DEFAULT_ARGS.copy()
+        task_args['version_name'] = 'bare_version'
+        with self.assertRaises(ValueError) as context:
+            CloudMLBatchPredictionOperator(**task_args).execute(None)
+        self.assertEquals(
+            'Missing model origin.',
+            str(context.exception))
+
+        # Test that none of uri, model, model/version is given
+        task_args = DEFAULT_ARGS.copy()
+        with self.assertRaises(ValueError) as context:
+            CloudMLBatchPredictionOperator(**task_args).execute(None)
+        self.assertEquals(
+            'Missing model origin.',
+            str(context.exception))
+
+    def testHttpError(self):
+        http_error_code = 403
+
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+            input_with_model = INPUT_MISSING_ORIGIN.copy()
+            input_with_model['modelName'] = \
+                'projects/experimental/models/test_model'
+
+            hook_instance = mock_hook.return_value
+            hook_instance.create_job.side_effect = errors.HttpError(
+                resp=httplib2.Response({
+                    'status': http_error_code
+                }), content=b'Forbidden')
+
+            with self.assertRaises(errors.HttpError) as context:
+                prediction_task = CloudMLBatchPredictionOperator(
+                    job_id='test_prediction',
+                    project_id='test-project',
+                    region=input_with_model['region'],
+                    data_format=input_with_model['dataFormat'],
+                    input_paths=input_with_model['inputPaths'],
+                    output_path=input_with_model['outputPath'],
+                    model_name=input_with_model['modelName'].split('/')[-1],
+                    dag=self.dag,
+                    task_id='test-prediction')
+                prediction_task.execute(None)
+
+                mock_hook.assert_called_with('google_cloud_default', None)
+                hook_instance.create_job.assert_called_with(
+                    'test-project',
+                    {
+                        'jobId': 'test_prediction',
+                        'predictionInput': input_with_model
+                    })
+
+            self.assertEquals(http_error_code, context.exception.resp.status)
+
+    def testFailedJobError(self):
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+            hook_instance = mock_hook.return_value
+            hook_instance.create_job.return_value = {
+                'state': 'FAILED',
+                'errorMessage': 'A failure message'
+            }
+            task_args = DEFAULT_ARGS.copy()
+            task_args['uri'] = 'a uri'
+
+            with self.assertRaises(RuntimeError) as context:
+                CloudMLBatchPredictionOperator(**task_args).execute(None)
+
+            self.assertEquals('A failure message', str(context.exception))
+
+
+if __name__ == '__main__':
+    unittest.main()