You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/06/29 15:54:42 UTC
incubator-airflow git commit: [AIRFLOW-1272] Google Cloud ML Batch
Prediction Operator
Repository: incubator-airflow
Updated Branches:
refs/heads/master e414844b7 -> e92d6bf72
[AIRFLOW-1272] Google Cloud ML Batch Prediction Operator
Closes #2390 from
jiwang576/GCP_CML_batch_prediction
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/e92d6bf7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/e92d6bf7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/e92d6bf7
Branch: refs/heads/master
Commit: e92d6bf72a6dbc192bdbb2f8c0266a81be9e2676
Parents: e414844
Author: Ji Wang <ji...@berkeley.edu>
Authored: Thu Jun 29 08:54:30 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Thu Jun 29 08:54:30 2017 -0700
----------------------------------------------------------------------
airflow/contrib/hooks/gcp_cloudml_hook.py | 116 +++++-
airflow/contrib/operators/cloudml_operator.py | 356 +++++++++++++++----
tests/contrib/hooks/test_gcp_cloudml_hook.py | 79 +++-
.../contrib/operators/test_cloudml_operator.py | 288 +++++++++++++++
4 files changed, 755 insertions(+), 84 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py
index e722b2a..3af8508 100644
--- a/airflow/contrib/hooks/gcp_cloudml_hook.py
+++ b/airflow/contrib/hooks/gcp_cloudml_hook.py
@@ -33,7 +33,8 @@ def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
try:
response = request.execute()
if is_error_func(response):
- raise ValueError('The response contained an error: {}'.format(response))
+ raise ValueError(
+ 'The response contained an error: {}'.format(response))
elif is_done_func(response):
logging.info('Operation is done: {}'.format(response))
return response
@@ -41,8 +42,9 @@ def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
time.sleep((2**i) + (random.randint(0, 1000) / 1000))
except errors.HttpError as e:
if e.resp.status != 429:
- logging.info('Something went wrong. Not retrying: {}'.format(e))
- raise e
+ logging.info(
+ 'Something went wrong. Not retrying: {}'.format(e))
+ raise
else:
time.sleep((2**i) + (random.randint(0, 1000) / 1000))
@@ -60,12 +62,109 @@ class CloudMLHook(GoogleCloudBaseHook):
credentials = GoogleCredentials.get_application_default()
return build('ml', 'v1', credentials=credentials)
+ def create_job(self, project_name, job):
+ """
+ Creates and executes a CloudML job.
+
+ Returns the job object if the job was created and finished
+ successfully, or raises an error otherwise.
+
+ Raises:
+ apiclient.errors.HttpError: if the job cannot be created
+ successfully
+
+ project_name is the name of the project to use, such as
+ 'my-project'
+
+ job is the complete Cloud ML Job object that should be provided to the
+ Cloud ML API, such as
+
+ {
+ 'jobId': 'my_job_id',
+ 'trainingInput': {
+ 'scaleTier': 'STANDARD_1',
+ ...
+ }
+ }
+ """
+ request = self._cloudml.projects().jobs().create(
+ parent='projects/{}'.format(project_name),
+ body=job)
+ job_id = job['jobId']
+
+ try:
+ request.execute()
+ return self._wait_for_job_done(project_name, job_id)
+ except errors.HttpError as e:
+ if e.resp.status == 409:
+ existing_job = self._get_job(project_name, job_id)
+ logging.info(
+ 'Job with job_id {} already exist: {}.'.format(
+ job_id,
+ existing_job))
+
+ if existing_job.get('predictionInput', None) == \
+ job['predictionInput']:
+ return self._wait_for_job_done(project_name, job_id)
+ else:
+ logging.error(
+ 'Job with job_id {} already exists, but the '
+ 'predictionInput mismatch: {}'
+ .format(job_id, existing_job))
+ raise ValueError(
+ 'Found a existing job with job_id {}, but with '
+ 'different predictionInput.'.format(job_id))
+ else:
+ logging.error('Failed to create CloudML job: {}'.format(e))
+ raise
+
+ def _get_job(self, project_name, job_id):
+ """
+ Gets a CloudML job based on the job name.
+
+ :return: CloudML job object if succeed.
+ :rtype: dict
+
+ Raises:
+ apiclient.errors.HttpError: if HTTP error is returned from server
+ """
+ job_name = 'projects/{}/jobs/{}'.format(project_name, job_id)
+ request = self._cloudml.projects().jobs().get(name=job_name)
+ while True:
+ try:
+ return request.execute()
+ except errors.HttpError as e:
+ if e.resp.status == 429:
+ # polling after 30 seconds when quota failure occurs
+ time.sleep(30)
+ else:
+ logging.error('Failed to get CloudML job: {}'.format(e))
+ raise
+
+ def _wait_for_job_done(self, project_name, job_id, interval=30):
+ """
+ Waits for the Job to reach a terminal state.
+
+ This method will periodically check the job state until the job reach
+ a terminal state.
+
+ Raises:
+ apiclient.errors.HttpError: if HTTP error is returned when getting
+ the job
+ """
+ assert interval > 0
+ while True:
+ job = self._get_job(project_name, job_id)
+ if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
+ return job
+ time.sleep(interval)
+
def create_version(self, project_name, model_name, version_spec):
"""
Creates the Version on Cloud ML.
- Returns the operation if the version was created successfully and raises
- an error otherwise.
+ Returns the operation if the version was created successfully and
+ raises an error otherwise.
"""
parent_name = 'projects/{}/models/{}'.format(project_name, model_name)
create_request = self._cloudml.projects().models().versions().create(
@@ -91,11 +190,12 @@ class CloudMLHook(GoogleCloudBaseHook):
try:
response = request.execute()
- logging.info('Successfully set version: {} to default'.format(response))
+ logging.info(
+ 'Successfully set version: {} to default'.format(response))
return response
except errors.HttpError as e:
logging.error('Something went wrong: {}'.format(e))
- raise e
+ raise
def list_versions(self, project_name, model_name):
"""
@@ -164,4 +264,4 @@ class CloudMLHook(GoogleCloudBaseHook):
if e.resp.status == 404:
logging.error('Model was not found: {}'.format(e))
return None
- raise e
+ raise
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
index b0b6e91..871cc73 100644
--- a/airflow/contrib/operators/cloudml_operator.py
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -14,16 +14,307 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+import re
+
from airflow import settings
-from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
from airflow.operators import BaseOperator
+from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
from airflow.utils.decorators import apply_defaults
+from apiclient import errors
+
logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
+def _create_prediction_input(project_id,
+ region,
+ data_format,
+ input_paths,
+ output_path,
+ model_name=None,
+ version_name=None,
+ uri=None,
+ max_worker_count=None,
+ runtime_version=None):
+ """
+ Create the batch prediction input from the given parameters.
+
+ Args:
+ A subset of arguments documented in __init__ method of class
+ CloudMLBatchPredictionOperator
+
+ Returns:
+ A dictionary representing the predictionInput object as documented
+ in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
+
+ Raises:
+ ValueError: if a unique model/version origin cannot be determined.
+ """
+
+ prediction_input = {
+ 'dataFormat': data_format,
+ 'inputPaths': input_paths,
+ 'outputPath': output_path,
+ 'region': region
+ }
+
+ if uri:
+ if model_name or version_name:
+ logging.error(
+ 'Ambiguous model origin: Both uri and model/version name are '
+ 'provided.')
+ raise ValueError('Ambiguous model origin.')
+ prediction_input['uri'] = uri
+ elif model_name:
+ origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
+ if not version_name:
+ prediction_input['modelName'] = origin_name
+ else:
+ prediction_input['versionName'] = \
+ origin_name + '/versions/{}'.format(version_name)
+ else:
+ logging.error(
+ 'Missing model origin: Batch prediction expects a model, '
+ 'a model & version combination, or a URI to savedModel.')
+ raise ValueError('Missing model origin.')
+
+ if max_worker_count:
+ prediction_input['maxWorkerCount'] = max_worker_count
+ if runtime_version:
+ prediction_input['runtimeVersion'] = runtime_version
+
+ return prediction_input
+
+
+def _normalize_cloudml_job_id(job_id):
+ """
+ Replaces invalid CloudML job_id characters with '_'.
+
+ This also adds a leading 'z' in case job_id starts with an invalid
+ character.
+
+ Args:
+ job_id: A job_id str that may have invalid characters.
+
+ Returns:
+ A valid job_id representation.
+ """
+ match = re.search(r'\d', job_id)
+ if match and match.start() is 0:
+ job_id = 'z_{}'.format(job_id)
+ return re.sub('[^0-9a-zA-Z]+', '_', job_id)
+
+
+class CloudMLBatchPredictionOperator(BaseOperator):
+ """
+ Start a Cloud ML prediction job.
+
+ NOTE: For model origin, users should consider exactly one from the
+ three options below:
+ 1. Populate 'uri' field only, which should be a GCS location that
+ points to a tensorflow savedModel directory.
+ 2. Populate 'model_name' field only, which refers to an existing
+ model, and the default version of the model will be used.
+ 3. Populate both 'model_name' and 'version_name' fields, which
+ refers to a specific version of a specific model.
+
+ In options 2 and 3, both model and version name should contain the
+ minimal identifier. For instance, call
+ CloudMLBatchPredictionOperator(
+ ...,
+ model_name='my_model',
+ version_name='my_version',
+ ...)
+ if the desired model version is
+ "projects/my_project/models/my_model/versions/my_version".
+
+
+ :param project_id: The Google Cloud project name where the
+ prediction job is submitted.
+ :type project_id: string
+
+ :param job_id: A unique id for the prediction job on Google Cloud
+ ML Engine.
+ :type job_id: string
+
+ :param data_format: The format of the input data.
+ It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided
+ or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"].
+ :type data_format: string
+
+ :param input_paths: A list of GCS paths of input data for batch
+ prediction. Accepting wildcard operator *, but only at the end.
+ :type input_paths: list of string
+
+ :param output_path: The GCS path where the prediction results are
+ written to.
+ :type output_path: string
+
+ :param region: The Google Compute Engine region to run the
+ prediction job in.:
+ :type region: string
+
+ :param model_name: The Google Cloud ML model to use for prediction.
+ If version_name is not provided, the default version of this
+ model will be used.
+ Should not be None if version_name is provided.
+ Should be None if uri is provided.
+ :type model_name: string
+
+ :param version_name: The Google Cloud ML model version to use for
+ prediction.
+ Should be None if uri is provided.
+ :type version_name: string
+
+ :param uri: The GCS path of the saved model to use for prediction.
+ Should be None if model_name is provided.
+ It should be a GCS path pointing to a tensorflow SavedModel.
+ :type uri: string
+
+ :param max_worker_count: The maximum number of workers to be used
+ for parallel processing. Defaults to 10 if not specified.
+ :type max_worker_count: int
+
+ :param runtime_version: The Google Cloud ML runtime version to use
+ for batch prediction.
+ :type runtime_version: string
+
+ :param gcp_conn_id: The connection ID used for connection to Google
+ Cloud Platform.
+ :type gcp_conn_id: string
+
+ :param delegate_to: The account to impersonate, if any.
+ For this to work, the service account making the request must
+ have doamin-wide delegation enabled.
+ :type delegate_to: string
+
+ Raises:
+ ValueError: if a unique model/version origin cannot be determined.
+ """
+
+ template_fields = [
+ "prediction_job_request",
+ ]
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ job_id,
+ region,
+ data_format,
+ input_paths,
+ output_path,
+ model_name=None,
+ version_name=None,
+ uri=None,
+ max_worker_count=None,
+ runtime_version=None,
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ *args,
+ **kwargs):
+ super(CloudMLBatchPredictionOperator, self).__init__(*args, **kwargs)
+
+ self.project_id = project_id
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+
+ try:
+ prediction_input = _create_prediction_input(
+ project_id, region, data_format, input_paths, output_path,
+ model_name, version_name, uri, max_worker_count,
+ runtime_version)
+ except ValueError as e:
+ logging.error(
+ 'Cannot create batch prediction job request due to: {}'
+ .format(str(e)))
+ raise
+
+ self.prediction_job_request = {
+ 'jobId': _normalize_cloudml_job_id(job_id),
+ 'predictionInput': prediction_input
+ }
+
+ def execute(self, context):
+ hook = CloudMLHook(self.gcp_conn_id, self.delegate_to)
+
+ try:
+ finished_prediction_job = hook.create_job(
+ self.project_id,
+ self.prediction_job_request)
+ except errors.HttpError:
+ raise
+
+ if finished_prediction_job['state'] != 'SUCCEEDED':
+ logging.error(
+ 'Batch prediction job failed: %s',
+ str(finished_prediction_job))
+ raise RuntimeError(finished_prediction_job['errorMessage'])
+
+ return finished_prediction_job['predictionOutput']
+
+
+class CloudMLModelOperator(BaseOperator):
+ """
+ Operator for managing a Google Cloud ML model.
+
+ :param model: A dictionary containing the information about the model.
+ If the `operation` is `create`, then the `model` parameter should
+ contain all the information about this model such as `name`.
+
+ If the `operation` is `get`, the `model` parameter
+ should contain the `name` of the model.
+ :type model: dict
+
+ :param project_name: The Google Cloud project name to which CloudML
+ model belongs.
+ :type project_name: string
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :type gcp_conn_id: string
+
+ :param operation: The operation to perform. Available operations are:
+ 'create': Creates a new model as provided by the `model` parameter.
+ 'get': Gets a particular model where the name is specified in `model`.
+
+ :param delegate_to: The account to impersonate, if any.
+ For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: string
+ """
+
+ template_fields = [
+ '_model',
+ ]
+
+ @apply_defaults
+ def __init__(self,
+ model,
+ project_name,
+ gcp_conn_id='google_cloud_default',
+ operation='create',
+ delegate_to=None,
+ *args,
+ **kwargs):
+ super(CloudMLModelOperator, self).__init__(*args, **kwargs)
+ self._model = model
+ self._operation = operation
+ self._gcp_conn_id = gcp_conn_id
+ self._delegate_to = delegate_to
+ self._project_name = project_name
+
+ def execute(self, context):
+ hook = CloudMLHook(
+ gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+ if self._operation == 'create':
+ hook.create_model(self._project_name, self._model)
+ elif self._operation == 'get':
+ hook.get_model(self._project_name, self._model['name'])
+ else:
+ raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
class CloudMLVersionOperator(BaseOperator):
"""
Operator for managing a Google Cloud ML version.
@@ -72,7 +363,6 @@ class CloudMLVersionOperator(BaseOperator):
:type delegate_to: string
"""
-
template_fields = [
'_model_name',
'_version',
@@ -116,63 +406,3 @@ class CloudMLVersionOperator(BaseOperator):
self._version['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
-
-
-class CloudMLModelOperator(BaseOperator):
- """
- Operator for managing a Google Cloud ML model.
-
- :param model: A dictionary containing the information about the model.
- If the `operation` is `create`, then the `model` parameter should
- contain all the information about this model such as `name`.
-
- If the `operation` is `get`, the `model` parameter
- should contain the `name` of the model.
- :type model: dict
-
- :param project_name: The Google Cloud project name to which CloudML
- model belongs.
- :type project_name: string
-
- :param gcp_conn_id: The connection ID to use when fetching connection info.
- :type gcp_conn_id: string
-
- :param operation: The operation to perform. Available operations are:
- 'create': Creates a new model as provided by the `model` parameter.
- 'get': Gets a particular model where the name is specified in `model`.
-
- :param delegate_to: The account to impersonate, if any.
- For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :type delegate_to: string
- """
-
- template_fields = [
- '_model',
- ]
-
- @apply_defaults
- def __init__(self,
- model,
- project_name,
- gcp_conn_id='google_cloud_default',
- operation='create',
- delegate_to=None,
- *args,
- **kwargs):
- super(CloudMLModelOperator, self).__init__(*args, **kwargs)
- self._model = model
- self._operation = operation
- self._gcp_conn_id = gcp_conn_id
- self._delegate_to = delegate_to
- self._project_name = project_name
-
- def execute(self, context):
- hook = CloudMLHook(
- gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
- if self._operation == 'create':
- hook.create_model(self._project_name, self._model)
- elif self._operation == 'get':
- hook.get_model(self._project_name, self._model['name'])
- else:
- raise ValueError('Unknown operation: {}'.format(self._operation))
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py
index aa50e69..e34e05f 100644
--- a/tests/contrib/hooks/test_gcp_cloudml_hook.py
+++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py
@@ -13,9 +13,10 @@
import json
import mock
import unittest
-try: # python 2
+
+try: # python 2
from urlparse import urlparse, parse_qsl
-except ImportError: #python 3
+except ImportError: # python 3
from urllib.parse import urlparse, parse_qsl
from airflow.contrib.hooks import gcp_cloudml_hook as hook
@@ -49,12 +50,16 @@ class _TestCloudMLHook(object):
self._test_cls = test_cls
self._responses = responses
self._expected_requests = [
- self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in expected_requests]
+ self._normalize_requests_for_comparison(x[0], x[1], x[2])
+ for x in expected_requests]
self._actual_requests = []
def _normalize_requests_for_comparison(self, uri, http_method, body):
parts = urlparse(uri)
- return (parts._replace(query=set(parse_qsl(parts.query))), http_method, body)
+ return (
+ parts._replace(query=set(parse_qsl(parts.query))),
+ http_method,
+ body)
def __enter__(self):
http = HttpMockSequence(self._responses)
@@ -76,7 +81,9 @@ class _TestCloudMLHook(object):
if any(args):
return None
self._test_cls.assertEquals(
- [self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in self._actual_requests], self._expected_requests)
+ [self._normalize_requests_for_comparison(x[0], x[1], x[2])
+ for x in self._actual_requests],
+ self._expected_requests)
class TestCloudMLHook(unittest.TestCase):
@@ -86,6 +93,7 @@ class TestCloudMLHook(unittest.TestCase):
_SKIP_IF = unittest.skipIf(not cml_available,
'CloudML is not available to run tests')
+
_SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/'
@_SKIP_IF
@@ -112,7 +120,8 @@ class TestCloudMLHook(unittest.TestCase):
responses=[succeeded_response] * 2,
expected_requests=expected_requests) as cml_hook:
create_version_response = cml_hook.create_version(
- project_name=project, model_name=model_name, version_spec=version)
+ project_name=project, model_name=model_name,
+ version_spec=version)
self.assertEquals(create_version_response, response_body)
@_SKIP_IF
@@ -137,7 +146,8 @@ class TestCloudMLHook(unittest.TestCase):
responses=[succeeded_response],
expected_requests=expected_requests) as cml_hook:
set_default_version_response = cml_hook.set_default_version(
- project_name=project, model_name=model_name, version_name=version)
+ project_name=project, model_name=model_name,
+ version_name=version)
self.assertEquals(set_default_version_response, response_body)
@_SKIP_IF
@@ -150,8 +160,12 @@ class TestCloudMLHook(unittest.TestCase):
# This test returns the versions one at a time.
versions = ['ver_{}'.format(ix) for ix in range(3)]
- response_bodies = [{'name': operation_name, 'nextPageToken': ix, 'versions': [
- ver]} for ix, ver in enumerate(versions)]
+ response_bodies = [
+ {
+ 'name': operation_name,
+ 'nextPageToken': ix,
+ 'versions': [ver]
+ } for ix, ver in enumerate(versions)]
response_bodies[-1].pop('nextPageToken')
responses = [({'status': '200'}, json.dumps(body))
for body in response_bodies]
@@ -190,9 +204,11 @@ class TestCloudMLHook(unittest.TestCase):
{'status': '200'}, json.dumps(done_response_body))
expected_requests = [
- ('{}projects/{}/models/{}/versions/{}?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, model_name, version), 'DELETE',
- None),
+ (
+ '{}projects/{}/models/{}/versions/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, model_name, version),
+ 'DELETE',
+ None),
('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
'GET', None),
]
@@ -202,7 +218,8 @@ class TestCloudMLHook(unittest.TestCase):
responses=[not_done_response, succeeded_response],
expected_requests=expected_requests) as cml_hook:
delete_version_response = cml_hook.delete_version(
- project_name=project, model_name=model_name, version_name=version)
+ project_name=project, model_name=model_name,
+ version_name=version)
self.assertEquals(delete_version_response, done_response_body)
@_SKIP_IF
@@ -250,6 +267,42 @@ class TestCloudMLHook(unittest.TestCase):
project_name=project, model_name=model_name)
self.assertEquals(get_model_response, response_body)
+ @_SKIP_IF
+ def test_create_cloudml_job(self):
+ project = 'test-project'
+ job_id = 'test-job-id'
+ my_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ }
+ response_body = json.dumps(my_job)
+ succeeded_response = ({'status': '200'}, response_body)
+ queued_response = ({'status': '200'}, json.dumps({
+ 'jobId': job_id,
+ 'state': 'QUEUED',
+ }))
+
+ create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+ ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ask_if_done_request,
+ ]
+ responses = [succeeded_response,
+ queued_response, succeeded_response]
+
+ with _TestCloudMLHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ create_job_response = cml_hook.create_job(
+ project_name=project, job=my_job)
+ self.assertEquals(create_job_response, my_job)
+
if __name__ == '__main__':
unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e92d6bf7/tests/contrib/operators/test_cloudml_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py
new file mode 100644
index 0000000..b76a0c6
--- /dev/null
+++ b/tests/contrib/operators/test_cloudml_operator.py
@@ -0,0 +1,288 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+from apiclient import errors
+import httplib2
+import unittest
+
+from airflow import configuration, DAG
+from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
+
+from mock import patch
+
+DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+
+INPUT_MISSING_ORIGIN = {
+ 'dataFormat': 'TEXT',
+ 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+ 'outputPath': 'gs://legal-bucket/fake-output-path',
+ 'region': 'us-east1',
+}
+
+SUCCESS_MESSAGE_MISSING_INPUT = {
+ 'jobId': 'test_prediction',
+ 'predictionOutput': {
+ 'outputPath': 'gs://fake-output-path',
+ 'predictionCount': 5000,
+ 'errorCount': 0,
+ 'nodeHours': 2.78
+ },
+ 'state': 'SUCCEEDED'
+}
+
+DEFAULT_ARGS = {
+ 'project_id': 'test-project',
+ 'job_id': 'test_prediction',
+ 'region': 'us-east1',
+ 'data_format': 'TEXT',
+ 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
+ 'output_path': 'gs://12_legal_bucket_underscore_number/legal-output-path',
+ 'task_id': 'test-prediction'
+}
+
+
+class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
+
+ def setUp(self):
+ super(CloudMLBatchPredictionOperatorTest, self).setUp()
+ configuration.load_test_config()
+ self.dag = DAG(
+ 'test_dag',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'end_date': DEFAULT_DATE,
+ },
+ schedule_interval='@daily')
+
+ def testSuccessWithModel(self):
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+
+ input_with_model = INPUT_MISSING_ORIGIN.copy()
+ input_with_model['modelName'] = \
+ 'projects/test-project/models/test_model'
+ success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_model
+
+ hook_instance = mock_hook.return_value
+ hook_instance.get_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': 404
+ }), content=b'some bytes')
+ hook_instance.create_job.return_value = success_message
+
+ prediction_task = CloudMLBatchPredictionOperator(
+ job_id='test_prediction',
+ project_id='test-project',
+ region=input_with_model['region'],
+ data_format=input_with_model['dataFormat'],
+ input_paths=input_with_model['inputPaths'],
+ output_path=input_with_model['outputPath'],
+ model_name=input_with_model['modelName'].split('/')[-1],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_output = prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_model
+ })
+ self.assertEquals(
+ success_message['predictionOutput'],
+ prediction_output)
+
+ def testSuccessWithVersion(self):
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+
+ input_with_version = INPUT_MISSING_ORIGIN.copy()
+ input_with_version['versionName'] = \
+ 'projects/test-project/models/test_model/versions/test_version'
+ success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_version
+
+ hook_instance = mock_hook.return_value
+ hook_instance.get_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': 404
+ }), content=b'some bytes')
+ hook_instance.create_job.return_value = success_message
+
+ prediction_task = CloudMLBatchPredictionOperator(
+ job_id='test_prediction',
+ project_id='test-project',
+ region=input_with_version['region'],
+ data_format=input_with_version['dataFormat'],
+ input_paths=input_with_version['inputPaths'],
+ output_path=input_with_version['outputPath'],
+ model_name=input_with_version['versionName'].split('/')[-3],
+ version_name=input_with_version['versionName'].split('/')[-1],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_output = prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_version
+ })
+ self.assertEquals(
+ success_message['predictionOutput'],
+ prediction_output)
+
+ def testSuccessWithURI(self):
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+
+ input_with_uri = INPUT_MISSING_ORIGIN.copy()
+ input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
+ success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_uri
+
+ hook_instance = mock_hook.return_value
+ hook_instance.get_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': 404
+ }), content=b'some bytes')
+ hook_instance.create_job.return_value = success_message
+
+ prediction_task = CloudMLBatchPredictionOperator(
+ job_id='test_prediction',
+ project_id='test-project',
+ region=input_with_uri['region'],
+ data_format=input_with_uri['dataFormat'],
+ input_paths=input_with_uri['inputPaths'],
+ output_path=input_with_uri['outputPath'],
+ uri=input_with_uri['uri'],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_output = prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_uri
+ })
+ self.assertEquals(
+ success_message['predictionOutput'],
+ prediction_output)
+
+ def testInvalidModelOrigin(self):
+ # Test that both uri and model is given
+ task_args = DEFAULT_ARGS.copy()
+ task_args['uri'] = 'gs://fake-uri/saved_model'
+ task_args['model_name'] = 'fake_model'
+ with self.assertRaises(ValueError) as context:
+ CloudMLBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals('Ambiguous model origin.', str(context.exception))
+
+ # Test that both uri and model/version is given
+ task_args = DEFAULT_ARGS.copy()
+ task_args['uri'] = 'gs://fake-uri/saved_model'
+ task_args['model_name'] = 'fake_model'
+ task_args['version_name'] = 'fake_version'
+ with self.assertRaises(ValueError) as context:
+ CloudMLBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals('Ambiguous model origin.', str(context.exception))
+
+ # Test that a version is given without a model
+ task_args = DEFAULT_ARGS.copy()
+ task_args['version_name'] = 'bare_version'
+ with self.assertRaises(ValueError) as context:
+ CloudMLBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals(
+ 'Missing model origin.',
+ str(context.exception))
+
+ # Test that none of uri, model, model/version is given
+ task_args = DEFAULT_ARGS.copy()
+ with self.assertRaises(ValueError) as context:
+ CloudMLBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals(
+ 'Missing model origin.',
+ str(context.exception))
+
+ def testHttpError(self):
+ http_error_code = 403
+
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+ input_with_model = INPUT_MISSING_ORIGIN.copy()
+ input_with_model['modelName'] = \
+ 'projects/experimental/models/test_model'
+
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': http_error_code
+ }), content=b'Forbidden')
+
+ with self.assertRaises(errors.HttpError) as context:
+ prediction_task = CloudMLBatchPredictionOperator(
+ job_id='test_prediction',
+ project_id='test-project',
+ region=input_with_model['region'],
+ data_format=input_with_model['dataFormat'],
+ input_paths=input_with_model['inputPaths'],
+ output_path=input_with_model['outputPath'],
+ model_name=input_with_model['modelName'].split('/')[-1],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_model
+ })
+
+ self.assertEquals(http_error_code, context.exception.resp.status)
+
+ def testFailedJobError(self):
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.return_value = {
+ 'state': 'FAILED',
+ 'errorMessage': 'A failure message'
+ }
+ task_args = DEFAULT_ARGS.copy()
+ task_args['uri'] = 'a uri'
+
+ with self.assertRaises(RuntimeError) as context:
+ CloudMLBatchPredictionOperator(**task_args).execute(None)
+
+ self.assertEquals('A failure message', str(context.exception))
+
+
+if __name__ == '__main__':
+ unittest.main()