You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/07/06 18:46:17 UTC
incubator-airflow git commit: [AIRFLOW-1271] Add Google CloudML
Training Operator
Repository: incubator-airflow
Updated Branches:
refs/heads/master d231dce37 -> 0fc45045a
[AIRFLOW-1271] Add Google CloudML Training Operator
Closes #2408 from leomzhong/cloudml_training
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/0fc45045
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/0fc45045
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/0fc45045
Branch: refs/heads/master
Commit: 0fc45045a27a0b1867410613d6c0edba820e3abf
Parents: d231dce
Author: Ming Zhong <le...@gmail.com>
Authored: Thu Jul 6 11:46:13 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Thu Jul 6 11:46:13 2017 -0700
----------------------------------------------------------------------
airflow/contrib/hooks/gcp_cloudml_hook.py | 82 ++++-----
airflow/contrib/operators/cloudml_operator.py | 148 ++++++++++++++-
tests/contrib/hooks/test_gcp_cloudml_hook.py | 111 +++++++++++-
.../contrib/operators/test_cloudml_operator.py | 179 ++++++++++++++-----
4 files changed, 428 insertions(+), 92 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py
index 3af8508..6f634b2 100644
--- a/airflow/contrib/hooks/gcp_cloudml_hook.py
+++ b/airflow/contrib/hooks/gcp_cloudml_hook.py
@@ -62,30 +62,37 @@ class CloudMLHook(GoogleCloudBaseHook):
credentials = GoogleCredentials.get_application_default()
return build('ml', 'v1', credentials=credentials)
- def create_job(self, project_name, job):
+ def create_job(self, project_name, job, use_existing_job_fn=None):
"""
- Creates and executes a CloudML job.
-
- Returns the job object if the job was created and finished
- successfully, or raises an error otherwise.
-
- Raises:
- apiclient.errors.HttpError: if the job cannot be created
- successfully
-
- project_name is the name of the project to use, such as
- 'my-project'
-
- job is the complete Cloud ML Job object that should be provided to the
- Cloud ML API, such as
-
- {
- 'jobId': 'my_job_id',
- 'trainingInput': {
- 'scaleTier': 'STANDARD_1',
- ...
- }
- }
+ Launches a CloudML job and wait for it to reach a terminal state.
+
+ :param project_name: The Google Cloud project name within which CloudML
+ job will be launched.
+ :type project_name: string
+
+ :param job: CloudML Job object that should be provided to the CloudML
+ API, such as:
+ {
+ 'jobId': 'my_job_id',
+ 'trainingInput': {
+ 'scaleTier': 'STANDARD_1',
+ ...
+ }
+ }
+ :type job: dict
+
+ :param use_existing_job_fn: In case that a CloudML job with the same
+ job_id already exist, this method (if provided) will decide whether
+ we should use this existing job, continue waiting for it to finish
+ and returning the job object. It should accepts a CloudML job
+ object, and returns a boolean value indicating whether it is OK to
+ reuse the existing job. If 'use_existing_job_fn' is not provided,
+ we by default reuse the existing CloudML job.
+ :type use_existing_job_fn: function
+
+ :return: The CloudML job object if the job successfully reach a
+ terminal state (which might be FAILED or CANCELLED state).
+ :rtype: dict
"""
request = self._cloudml.projects().jobs().create(
parent='projects/{}'.format(project_name),
@@ -94,29 +101,24 @@ class CloudMLHook(GoogleCloudBaseHook):
try:
request.execute()
- return self._wait_for_job_done(project_name, job_id)
except errors.HttpError as e:
+ # 409 means there is an existing job with the same job ID.
if e.resp.status == 409:
- existing_job = self._get_job(project_name, job_id)
+ if use_existing_job_fn is not None:
+ existing_job = self._get_job(project_name, job_id)
+ if not use_existing_job_fn(existing_job):
+ logging.error(
+ 'Job with job_id {} already exist, but it does '
+ 'not match our expectation: {}'.format(
+ job_id, existing_job))
+ raise
logging.info(
- 'Job with job_id {} already exist: {}.'.format(
- job_id,
- existing_job))
-
- if existing_job.get('predictionInput', None) == \
- job['predictionInput']:
- return self._wait_for_job_done(project_name, job_id)
- else:
- logging.error(
- 'Job with job_id {} already exists, but the '
- 'predictionInput mismatch: {}'
- .format(job_id, existing_job))
- raise ValueError(
- 'Found a existing job with job_id {}, but with '
- 'different predictionInput.'.format(job_id))
+ 'Job with job_id {} already exist. Will waiting for it to '
+ 'finish'.format(job_id))
else:
logging.error('Failed to create CloudML job: {}'.format(e))
raise
+ return self._wait_for_job_done(project_name, job_id)
def _get_job(self, project_name, job_id):
"""
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
index 871cc73..3ad6f5a 100644
--- a/airflow/contrib/operators/cloudml_operator.py
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -18,8 +18,9 @@ import logging
import re
from airflow import settings
-from airflow.operators import BaseOperator
from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
+from airflow.exceptions import AirflowException
+from airflow.operators import BaseOperator
from airflow.utils.decorators import apply_defaults
from apiclient import errors
@@ -239,10 +240,14 @@ class CloudMLBatchPredictionOperator(BaseOperator):
def execute(self, context):
hook = CloudMLHook(self.gcp_conn_id, self.delegate_to)
+ def check_existing_job(existing_job):
+ return existing_job.get('predictionInput', None) == \
+ self.prediction_job_request['predictionInput']
try:
finished_prediction_job = hook.create_job(
self.project_id,
- self.prediction_job_request)
+ self.prediction_job_request,
+ check_existing_job)
except errors.HttpError:
raise
@@ -406,3 +411,142 @@ class CloudMLVersionOperator(BaseOperator):
self._version['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
+class CloudMLTrainingOperator(BaseOperator):
+ """
+ Operator for launching a CloudML training job.
+
+ :param project_name: The Google Cloud project name within which CloudML
+ training job should run. This field could be templated.
+ :type project_name: string
+
+ :param job_id: A unique templated id for the submitted Google CloudML
+ training job.
+ :type job_id: string
+
+ :param package_uris: A list of package locations for CloudML training job,
+ which should include the main training program + any additional
+ dependencies.
+ :type package_uris: string
+
+ :param training_python_module: The Python module name to run within CloudML
+ training job after installing 'package_uris' packages.
+ :type training_python_module: string
+
+ :param training_args: A list of templated command line arguments to pass to
+ the CloudML training program.
+ :type training_args: string
+
+ :param region: The Google Compute Engine region to run the CloudML training
+ job in. This field could be templated.
+ :type region: string
+
+ :param scale_tier: Resource tier for CloudML training job.
+ :type scale_tier: string
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :type gcp_conn_id: string
+
+ :param delegate_to: The account to impersonate, if any.
+ For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: string
+
+ :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real
+ training job will be launched, but the CloudML training job request
+ will be printed out. In 'CLOUD' mode, a real CloudML training job
+ creation request will be issued.
+ :type mode: string
+ """
+
+ template_fields = [
+ '_project_name',
+ '_job_id',
+ '_package_uris',
+ '_training_python_module',
+ '_training_args',
+ '_region',
+ '_scale_tier',
+ ]
+
+ @apply_defaults
+ def __init__(self,
+ project_name,
+ job_id,
+ package_uris,
+ training_python_module,
+ training_args,
+ region,
+ scale_tier=None,
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ mode='PRODUCTION',
+ *args,
+ **kwargs):
+ super(CloudMLTrainingOperator, self).__init__(*args, **kwargs)
+ self._project_name = project_name
+ self._job_id = job_id
+ self._package_uris = package_uris
+ self._training_python_module = training_python_module
+ self._training_args = training_args
+ self._region = region
+ self._scale_tier = scale_tier
+ self._gcp_conn_id = gcp_conn_id
+ self._delegate_to = delegate_to
+ self._mode = mode
+
+ if not self._project_name:
+ raise AirflowException('Google Cloud project name is required.')
+ if not self._job_id:
+ raise AirflowException(
+ 'An unique job id is required for Google CloudML training '
+ 'job.')
+ if not package_uris:
+ raise AirflowException(
+ 'At least one python package is required for CloudML '
+ 'Training job.')
+ if not training_python_module:
+ raise AirflowException(
+ 'Python module name to run after installing required '
+ 'packages is required.')
+ if not self._region:
+ raise AirflowException('Google Compute Engine region is required.')
+
+ def execute(self, context):
+ job_id = _normalize_cloudml_job_id(self._job_id)
+ training_request = {
+ 'jobId': job_id,
+ 'trainingInput': {
+ 'scaleTier': self._scale_tier,
+ 'packageUris': self._package_uris,
+ 'pythonModule': self._training_python_module,
+ 'region': self._region,
+ 'args': self._training_args,
+ }
+ }
+
+ if self._mode == 'DRY_RUN':
+ logging.info('In dry_run mode.')
+ logging.info(
+ 'CloudML Training job request is: {}'.format(training_request))
+ return
+
+ hook = CloudMLHook(
+ gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+
+ # Helper method to check if the existing job's training input is the
+ # same as the request we get here.
+ def check_existing_job(existing_job):
+ return existing_job.get('trainingInput', None) == \
+ training_request['trainingInput']
+ try:
+ finished_training_job = hook.create_job(
+ self._project_name, training_request, check_existing_job)
+ except errors.HttpError:
+ raise
+
+ if finished_training_job['state'] != 'SUCCEEDED':
+ logging.error('CloudML training job failed: {}'.format(
+ str(finished_training_job)))
+ raise RuntimeError(finished_training_job['errorMessage'])
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py
index e34e05f..53aba41 100644
--- a/tests/contrib/hooks/test_gcp_cloudml_hook.py
+++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py
@@ -20,6 +20,7 @@ except ImportError: # python 3
from urllib.parse import urlparse, parse_qsl
from airflow.contrib.hooks import gcp_cloudml_hook as hook
+from apiclient import errors
from apiclient.discovery import build
from apiclient.http import HttpMockSequence
from oauth2client.contrib.gce import HttpAccessTokenRefreshError
@@ -137,8 +138,8 @@ class TestCloudMLHook(unittest.TestCase):
expected_requests = [
('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, model_name, version), 'POST',
- '{}'),
+ self._SERVICE_URI_PREFIX, project, model_name, version),
+ 'POST', '{}'),
]
with _TestCloudMLHook(
@@ -175,7 +176,8 @@ class TestCloudMLHook(unittest.TestCase):
self._SERVICE_URI_PREFIX, project, model_name), 'GET',
None),
] + [
- ('{}projects/{}/models/{}/versions?alt=json&pageToken={}&pageSize=100'.format(
+ ('{}projects/{}/models/{}/versions?alt=json&pageToken={}'
+ '&pageSize=100'.format(
self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
None) for ix in range(len(versions) - 1)
]
@@ -303,6 +305,109 @@ class TestCloudMLHook(unittest.TestCase):
project_name=project, job=my_job)
self.assertEquals(create_job_response, my_job)
+ @_SKIP_IF
+ def test_create_cloudml_job_reuse_existing_job_by_default(self):
+ project = 'test-project'
+ job_id = 'test-job-id'
+ my_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ }
+ response_body = json.dumps(my_job)
+ job_already_exist_response = ({'status': '409'}, json.dumps({}))
+ succeeded_response = ({'status': '200'}, response_body)
+
+ create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+ ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ]
+ responses = [job_already_exist_response, succeeded_response]
+
+ # By default, 'create_job' reuse the existing job.
+ with _TestCloudMLHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ create_job_response = cml_hook.create_job(
+ project_name=project, job=my_job)
+ self.assertEquals(create_job_response, my_job)
+
+ @_SKIP_IF
+ def test_create_cloudml_job_check_existing_job(self):
+ project = 'test-project'
+ job_id = 'test-job-id'
+ my_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ 'someInput': {
+ 'input': 'someInput'
+ }
+ }
+ different_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ 'someInput': {
+ 'input': 'someDifferentInput'
+ }
+ }
+
+ my_job_response_body = json.dumps(my_job)
+ different_job_response_body = json.dumps(different_job)
+ job_already_exist_response = ({'status': '409'}, json.dumps({}))
+ different_job_response = ({'status': '200'},
+ different_job_response_body)
+
+ create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body)
+ ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ]
+
+ # Returns a different job (with different 'someInput' field) will
+ # cause 'create_job' request to fail.
+ responses = [job_already_exist_response, different_job_response]
+
+ def check_input(existing_job):
+ return existing_job.get('someInput', None) == \
+ my_job['someInput']
+ with _TestCloudMLHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ with self.assertRaises(errors.HttpError):
+ cml_hook.create_job(
+ project_name=project, job=my_job,
+ use_existing_job_fn=check_input)
+
+ my_job_response = ({'status': '200'}, my_job_response_body)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ask_if_done_request,
+ ]
+ responses = [
+ job_already_exist_response,
+ my_job_response,
+ my_job_response]
+ with _TestCloudMLHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ create_job_response = cml_hook.create_job(
+ project_name=project, job=my_job,
+ use_existing_job_fn=check_input)
+ self.assertEquals(create_job_response, my_job)
+
if __name__ == '__main__':
unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/tests/contrib/operators/test_cloudml_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py
index b76a0c6..dc8c204 100644
--- a/tests/contrib/operators/test_cloudml_operator.py
+++ b/tests/contrib/operators/test_cloudml_operator.py
@@ -26,41 +26,41 @@ import unittest
from airflow import configuration, DAG
from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
+from airflow.contrib.operators.cloudml_operator import CloudMLTrainingOperator
+from mock import ANY
from mock import patch
DEFAULT_DATE = datetime.datetime(2017, 6, 6)
-INPUT_MISSING_ORIGIN = {
- 'dataFormat': 'TEXT',
- 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
- 'outputPath': 'gs://legal-bucket/fake-output-path',
- 'region': 'us-east1',
-}
-
-SUCCESS_MESSAGE_MISSING_INPUT = {
- 'jobId': 'test_prediction',
- 'predictionOutput': {
- 'outputPath': 'gs://fake-output-path',
- 'predictionCount': 5000,
- 'errorCount': 0,
- 'nodeHours': 2.78
- },
- 'state': 'SUCCEEDED'
-}
-
-DEFAULT_ARGS = {
- 'project_id': 'test-project',
- 'job_id': 'test_prediction',
- 'region': 'us-east1',
- 'data_format': 'TEXT',
- 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
- 'output_path': 'gs://12_legal_bucket_underscore_number/legal-output-path',
- 'task_id': 'test-prediction'
-}
-
class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
+ INPUT_MISSING_ORIGIN = {
+ 'dataFormat': 'TEXT',
+ 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+ 'outputPath': 'gs://legal-bucket/fake-output-path',
+ 'region': 'us-east1',
+ }
+ SUCCESS_MESSAGE_MISSING_INPUT = {
+ 'jobId': 'test_prediction',
+ 'predictionOutput': {
+ 'outputPath': 'gs://fake-output-path',
+ 'predictionCount': 5000,
+ 'errorCount': 0,
+ 'nodeHours': 2.78
+ },
+ 'state': 'SUCCEEDED'
+ }
+ BATCH_PREDICTION_DEFAULT_ARGS = {
+ 'project_id': 'test-project',
+ 'job_id': 'test_prediction',
+ 'region': 'us-east1',
+ 'data_format': 'TEXT',
+ 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
+ 'output_path':
+ 'gs://12_legal_bucket_underscore_number/legal-output-path',
+ 'task_id': 'test-prediction'
+ }
def setUp(self):
super(CloudMLBatchPredictionOperatorTest, self).setUp()
@@ -78,10 +78,10 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
as mock_hook:
- input_with_model = INPUT_MISSING_ORIGIN.copy()
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
input_with_model['modelName'] = \
'projects/test-project/models/test_model'
- success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
success_message['predictionInput'] = input_with_model
hook_instance = mock_hook.return_value
@@ -104,12 +104,12 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
prediction_output = prediction_task.execute(None)
mock_hook.assert_called_with('google_cloud_default', None)
- hook_instance.create_job.assert_called_with(
+ hook_instance.create_job.assert_called_once_with(
'test-project',
{
'jobId': 'test_prediction',
'predictionInput': input_with_model
- })
+ }, ANY)
self.assertEquals(
success_message['predictionOutput'],
prediction_output)
@@ -118,10 +118,10 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
as mock_hook:
- input_with_version = INPUT_MISSING_ORIGIN.copy()
+ input_with_version = self.INPUT_MISSING_ORIGIN.copy()
input_with_version['versionName'] = \
'projects/test-project/models/test_model/versions/test_version'
- success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
success_message['predictionInput'] = input_with_version
hook_instance = mock_hook.return_value
@@ -132,8 +132,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
hook_instance.create_job.return_value = success_message
prediction_task = CloudMLBatchPredictionOperator(
- job_id='test_prediction',
- project_id='test-project',
+ job_id='test_prediction', project_id='test-project',
region=input_with_version['region'],
data_format=input_with_version['dataFormat'],
input_paths=input_with_version['inputPaths'],
@@ -150,7 +149,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
{
'jobId': 'test_prediction',
'predictionInput': input_with_version
- })
+ }, ANY)
self.assertEquals(
success_message['predictionOutput'],
prediction_output)
@@ -159,9 +158,9 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
as mock_hook:
- input_with_uri = INPUT_MISSING_ORIGIN.copy()
+ input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
- success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
success_message['predictionInput'] = input_with_uri
hook_instance = mock_hook.return_value
@@ -189,14 +188,14 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
{
'jobId': 'test_prediction',
'predictionInput': input_with_uri
- })
+ }, ANY)
self.assertEquals(
success_message['predictionOutput'],
prediction_output)
def testInvalidModelOrigin(self):
# Test that both uri and model is given
- task_args = DEFAULT_ARGS.copy()
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['uri'] = 'gs://fake-uri/saved_model'
task_args['model_name'] = 'fake_model'
with self.assertRaises(ValueError) as context:
@@ -204,7 +203,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
self.assertEquals('Ambiguous model origin.', str(context.exception))
# Test that both uri and model/version is given
- task_args = DEFAULT_ARGS.copy()
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['uri'] = 'gs://fake-uri/saved_model'
task_args['model_name'] = 'fake_model'
task_args['version_name'] = 'fake_version'
@@ -213,7 +212,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
self.assertEquals('Ambiguous model origin.', str(context.exception))
# Test that a version is given without a model
- task_args = DEFAULT_ARGS.copy()
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['version_name'] = 'bare_version'
with self.assertRaises(ValueError) as context:
CloudMLBatchPredictionOperator(**task_args).execute(None)
@@ -222,7 +221,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
str(context.exception))
# Test that none of uri, model, model/version is given
- task_args = DEFAULT_ARGS.copy()
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
with self.assertRaises(ValueError) as context:
CloudMLBatchPredictionOperator(**task_args).execute(None)
self.assertEquals(
@@ -234,7 +233,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
as mock_hook:
- input_with_model = INPUT_MISSING_ORIGIN.copy()
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
input_with_model['modelName'] = \
'projects/experimental/models/test_model'
@@ -263,7 +262,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
{
'jobId': 'test_prediction',
'predictionInput': input_with_model
- })
+ }, ANY)
self.assertEquals(http_error_code, context.exception.resp.status)
@@ -275,7 +274,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
'state': 'FAILED',
'errorMessage': 'A failure message'
}
- task_args = DEFAULT_ARGS.copy()
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['uri'] = 'a uri'
with self.assertRaises(RuntimeError) as context:
@@ -284,5 +283,91 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
self.assertEquals('A failure message', str(context.exception))
+class CloudMLTrainingOperatorTest(unittest.TestCase):
+ TRAINING_DEFAULT_ARGS = {
+ 'project_name': 'test-project',
+ 'job_id': 'test_training',
+ 'package_uris': ['gs://some-bucket/package1'],
+ 'training_python_module': 'trainer',
+ 'training_args': '--some_arg=\'aaa\'',
+ 'region': 'us-east1',
+ 'scale_tier': 'STANDARD_1',
+ 'task_id': 'test-training'
+ }
+ TRAINING_INPUT = {
+ 'jobId': 'test_training',
+ 'trainingInput': {
+ 'scaleTier': 'STANDARD_1',
+ 'packageUris': ['gs://some-bucket/package1'],
+ 'pythonModule': 'trainer',
+ 'args': '--some_arg=\'aaa\'',
+ 'region': 'us-east1'
+ }
+ }
+
+ def testSuccessCreateTrainingJob(self):
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+ success_response = self.TRAINING_INPUT.copy()
+ success_response['state'] = 'SUCCEEDED'
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.return_value = success_response
+
+ training_op = CloudMLTrainingOperator(**self.TRAINING_DEFAULT_ARGS)
+ training_op.execute(None)
+
+ mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
+ delegate_to=None)
+ # Make sure only 'create_job' is invoked on hook instance
+ self.assertEquals(len(hook_instance.mock_calls), 1)
+ hook_instance.create_job.assert_called_with(
+ 'test-project', self.TRAINING_INPUT, ANY)
+
+ def testHttpError(self):
+ http_error_code = 403
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': http_error_code
+ }), content=b'Forbidden')
+
+ with self.assertRaises(errors.HttpError) as context:
+ training_op = CloudMLTrainingOperator(
+ **self.TRAINING_DEFAULT_ARGS)
+ training_op.execute(None)
+
+ mock_hook.assert_called_with(
+ gcp_conn_id='google_cloud_default', delegate_to=None)
+ # Make sure only 'create_job' is invoked on hook instance
+ self.assertEquals(len(hook_instance.mock_calls), 1)
+ hook_instance.create_job.assert_called_with(
+ 'test-project', self.TRAINING_INPUT, ANY)
+ self.assertEquals(http_error_code, context.exception.resp.status)
+
+ def testFailedJobError(self):
+ with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+ as mock_hook:
+ failure_response = self.TRAINING_INPUT.copy()
+ failure_response['state'] = 'FAILED'
+ failure_response['errorMessage'] = 'A failure message'
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.return_value = failure_response
+
+ with self.assertRaises(RuntimeError) as context:
+ training_op = CloudMLTrainingOperator(
+ **self.TRAINING_DEFAULT_ARGS)
+ training_op.execute(None)
+
+ mock_hook.assert_called_with(
+ gcp_conn_id='google_cloud_default', delegate_to=None)
+ # Make sure only 'create_job' is invoked on hook instance
+ self.assertEquals(len(hook_instance.mock_calls), 1)
+ hook_instance.create_job.assert_called_with(
+ 'test-project', self.TRAINING_INPUT, ANY)
+ self.assertEquals('A failure message', str(context.exception))
+
+
if __name__ == '__main__':
unittest.main()