You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/09/06 16:51:29 UTC
[2/3] incubator-airflow git commit: [AIRFLOW-1567][Airflow-1567]
Renamed cloudml hook and operator to mlengine
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_operator_utils.py b/airflow/contrib/operators/mlengine_operator_utils.py
new file mode 100644
index 0000000..5fda6ae
--- /dev/null
+++ b/airflow/contrib/operators/mlengine_operator_utils.py
@@ -0,0 +1,245 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import json
+import os
+import re
+
+import dill
+
+from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
+from airflow.contrib.operators.mlengine_operator import MLEngineBatchPredictionOperator
+from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
+from airflow.exceptions import AirflowException
+from airflow.operators.python_operator import PythonOperator
+from six.moves.urllib.parse import urlsplit
+
+def create_evaluate_ops(task_prefix,
+ data_format,
+ input_paths,
+ prediction_path,
+ metric_fn_and_keys,
+ validate_fn,
+ batch_prediction_job_id=None,
+ project_id=None,
+ region=None,
+ dataflow_options=None,
+ model_uri=None,
+ model_name=None,
+ version_name=None,
+ dag=None):
+ """
+ Creates Operators needed for model evaluation and returns.
+
+ It gets prediction over inputs via Cloud ML Engine BatchPrediction API by
+ calling MLEngineBatchPredictionOperator, then summarize and validate
+ the result via Cloud Dataflow using DataFlowPythonOperator.
+
+ For details and pricing about Batch prediction, please refer to the website
+ https://cloud.google.com/ml-engine/docs/how-tos/batch-predict
+ and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/
+
+ It returns three chained operators for prediction, summary, and validation,
+ named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation,
+ respectively.
+ (<prefix> should contain only alphanumeric characters or hyphen.)
+
+ The upstream and downstream can be set accordingly like:
+ pred, _, val = create_evaluate_ops(...)
+ pred.set_upstream(upstream_op)
+ ...
+ downstream_op.set_upstream(val)
+
+ Callers will provide two python callables, metric_fn and validate_fn, in
+ order to customize the evaluation behavior as they wish.
+ - metric_fn receives a dictionary per instance derived from json in the
+ batch prediction result. The keys might vary depending on the model.
+ It should return a tuple of metrics.
+ - validation_fn receives a dictionary of the averaged metrics that metric_fn
+ generated over all instances.
+ The key/value of the dictionary matches to what's given by
+ metric_fn_and_keys arg.
+ The dictionary contains an additional metric, 'count' to represent the
+ total number of instances received for evaluation.
+ The function would raise an exception to mark the task as failed, in a
+ case the validation result is not okay to proceed (i.e. to set the trained
+ version as default).
+
+ Typical examples are like this:
+
+ def get_metric_fn_and_keys():
+ import math # imports should be outside of the metric_fn below.
+ def error_and_squared_error(inst):
+ label = float(inst['input_label'])
+ classes = float(inst['classes']) # 0 or 1
+ err = abs(classes-label)
+ squared_err = math.pow(classes-label, 2)
+ return (err, squared_err) # returns a tuple.
+ return error_and_squared_error, ['err', 'mse'] # key order must match.
+
+ def validate_err_and_count(summary):
+ if summary['err'] > 0.2:
+ raise ValueError('Too high err>0.2; summary=%s' % summary)
+ if summary['mse'] > 0.05:
+ raise ValueError('Too high mse>0.05; summary=%s' % summary)
+ if summary['count'] < 1000:
+ raise ValueError('Too few instances<1000; summary=%s' % summary)
+ return summary
+
+ For the details on the other BatchPrediction-related arguments (project_id,
+ job_id, region, data_format, input_paths, prediction_path, model_uri),
+ please refer to MLEngineBatchPredictionOperator too.
+
+ :param task_prefix: a prefix for the tasks. Only alphanumeric characters and
+ hyphen are allowed (no underscores), since this will be used as dataflow
+ job name, which doesn't allow other characters.
+ :type task_prefix: string
+
+ :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
+ :type data_format: string
+
+ :param input_paths: a list of input paths to be sent to BatchPrediction.
+ :type input_paths: list of strings
+
+ :param prediction_path: GCS path to put the prediction results in.
+ :type prediction_path: string
+
+ :param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
+ - metric_fn is a function that accepts a dictionary (for an instance),
+ and returns a tuple of metric(s) that it calculates.
+ - metric_keys is a list of strings to denote the key of each metric.
+ :type metric_fn_and_keys: tuple of a function and a list of strings
+
+ :param validate_fn: a function to validate whether the averaged metric(s) is
+ good enough to push the model.
+ :type validate_fn: function
+
+ :param batch_prediction_job_id: the id to use for the Cloud ML Batch
+ prediction job. Passed directly to the MLEngineBatchPredictionOperator as
+ the job_id argument.
+ :type batch_prediction_job_id: string
+
+ :param project_id: the Google Cloud Platform project id in which to execute
+ Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+ `default_args['project_id']` will be used.
+ :type project_id: string
+
+ :param region: the Google Cloud Platform region in which to execute Cloud ML
+ Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+ `default_args['region']` will be used.
+ :type region: string
+
+ :param dataflow_options: options to run Dataflow jobs. If None, then the
+ `dag`'s `default_args['dataflow_default_options']` will be used.
+ :type dataflow_options: dictionary
+
+ :param model_uri: GCS path of the model exported by Tensorflow using
+ tensorflow.estimator.export_savedmodel(). It cannot be used with
+ model_name or version_name below. See MLEngineBatchPredictionOperator for
+ more detail.
+ :type model_uri: string
+
+ :param model_name: Used to indicate a model to use for prediction. Can be
+ used in combination with version_name, but cannot be used together with
+ model_uri. See MLEngineBatchPredictionOperator for more detail. If None,
+ then the `dag`'s `default_args['model_name']` will be used.
+ :type model_name: string
+
+ :param version_name: Used to indicate a model version to use for prediciton,
+ in combination with model_name. Cannot be used together with model_uri.
+ See MLEngineBatchPredictionOperator for more detail. If None, then the
+ `dag`'s `default_args['version_name']` will be used.
+ :type version_name: string
+
+ :param dag: The `DAG` to use for all Operators.
+ :type dag: airflow.DAG
+
+ :returns: a tuple of three operators, (prediction, summary, validation)
+ :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
+ PythonOperator)
+ """
+
+ # Verify that task_prefix doesn't have any special characters except hyphen
+ # '-', which is the only allowed non-alphanumeric character by Dataflow.
+ if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix):
+ raise AirflowException(
+ "Malformed task_id for DataFlowPythonOperator (only alphanumeric "
+ "and hyphens are allowed but got: " + task_prefix)
+
+ metric_fn, metric_keys = metric_fn_and_keys
+ if not callable(metric_fn):
+ raise AirflowException("`metric_fn` param must be callable.")
+ if not callable(validate_fn):
+ raise AirflowException("`validate_fn` param must be callable.")
+
+ if dag is not None and dag.default_args is not None:
+ default_args = dag.default_args
+ project_id = project_id or default_args.get('project_id')
+ region = region or default_args.get('region')
+ model_name = model_name or default_args.get('model_name')
+ version_name = version_name or default_args.get('version_name')
+ dataflow_options = dataflow_options or \
+ default_args.get('dataflow_default_options')
+
+ evaluate_prediction = MLEngineBatchPredictionOperator(
+ task_id=(task_prefix + "-prediction"),
+ project_id=project_id,
+ job_id=batch_prediction_job_id,
+ region=region,
+ data_format=data_format,
+ input_paths=input_paths,
+ output_path=prediction_path,
+ uri=model_uri,
+ model_name=model_name,
+ version_name=version_name,
+ dag=dag)
+
+ metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True))
+ evaluate_summary = DataFlowPythonOperator(
+ task_id=(task_prefix + "-summary"),
+ py_options=["-m"],
+ py_file="airflow.contrib.operators.mlengine_prediction_summary",
+ dataflow_default_options=dataflow_options,
+ options={
+ "prediction_path": prediction_path,
+ "metric_fn_encoded": metric_fn_encoded,
+ "metric_keys": ','.join(metric_keys)
+ },
+ dag=dag)
+ evaluate_summary.set_upstream(evaluate_prediction)
+
+ def apply_validate_fn(*args, **kwargs):
+ prediction_path = kwargs["templates_dict"]["prediction_path"]
+ scheme, bucket, obj, _, _ = urlsplit(prediction_path)
+ if scheme != "gs" or not bucket or not obj:
+ raise ValueError("Wrong format prediction_path: %s",
+ prediction_path)
+ summary = os.path.join(obj.strip("/"),
+ "prediction.summary.json")
+ gcs_hook = GoogleCloudStorageHook()
+ summary = json.loads(gcs_hook.download(bucket, summary))
+ return validate_fn(summary)
+
+ evaluate_validation = PythonOperator(
+ task_id=(task_prefix + "-validation"),
+ python_callable=apply_validate_fn,
+ provide_context=True,
+ templates_dict={"prediction_path": prediction_path},
+ dag=dag)
+ evaluate_validation.set_upstream(evaluate_summary)
+
+ return evaluate_prediction, evaluate_summary, evaluate_validation
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_prediction_summary.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_prediction_summary.py b/airflow/contrib/operators/mlengine_prediction_summary.py
new file mode 100644
index 0000000..1f4d540
--- /dev/null
+++ b/airflow/contrib/operators/mlengine_prediction_summary.py
@@ -0,0 +1,177 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
+
+It accepts a user function to calculate the metric(s) per instance in
+the prediction results, then aggregates to output as a summary.
+
+Args:
+ --prediction_path:
+ The GCS folder that contains BatchPrediction results, containing
+ prediction.results-NNNNN-of-NNNNN files in the json format.
+ Output will be also stored in this folder, as 'prediction.summary.json'.
+
+ --metric_fn_encoded:
+ An encoded function that calculates and returns a tuple of metric(s)
+ for a given instance (as a dictionary). It should be encoded
+ via base64.b64encode(dill.dumps(fn, recurse=True)).
+
+ --metric_keys:
+ A comma-separated key(s) of the aggregated metric(s) in the summary
+ output. The order and the size of the keys must match to the output
+ of metric_fn.
+ The summary will have an additional key, 'count', to represent the
+ total number of instances, so the keys shouldn't include 'count'.
+
+# Usage example:
+def get_metric_fn():
+ import math # all imports must be outside of the function to be passed.
+ def metric_fn(inst):
+ label = float(inst["input_label"])
+ classes = float(inst["classes"])
+ prediction = float(inst["scores"][1])
+ log_loss = math.log(1 + math.exp(
+ -(label * 2 - 1) * math.log(prediction / (1 - prediction))))
+ squared_err = (classes-label)**2
+ return (log_loss, squared_err)
+ return metric_fn
+metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
+
+airflow.contrib.operators.DataFlowPythonOperator(
+ task_id="summary-prediction",
+ py_options=["-m"],
+ py_file="airflow.contrib.operators.mlengine_prediction_summary",
+ options={
+ "prediction_path": prediction_path,
+ "metric_fn_encoded": metric_fn_encoded,
+ "metric_keys": "log_loss,mse"
+ },
+ dataflow_default_options={
+ "project": "xxx", "region": "us-east1",
+ "staging_location": "gs://yy", "temp_location": "gs://zz",
+ })
+ >> dag
+
+# When the input file is like the following:
+{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
+{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
+{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
+{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
+
+# The output file will be:
+{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
+
+# To test outside of the dag:
+subprocess.check_call(["python",
+ "-m",
+ "airflow.contrib.operators.mlengine_prediction_summary",
+ "--prediction_path=gs://...",
+ "--metric_fn_encoded=" + metric_fn_encoded,
+ "--metric_keys=log_loss,mse",
+ "--runner=DataflowRunner",
+ "--staging_location=gs://...",
+ "--temp_location=gs://...",
+ ])
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import base64
+import json
+import logging
+import os
+
+import apache_beam as beam
+import dill
+
+
+class JsonCoder(object):
+ def encode(self, x):
+ return json.dumps(x)
+
+ def decode(self, x):
+ return json.loads(x)
+
+
+@beam.ptransform_fn
+def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name
+ return (
+ pcoll
+ | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn)
+ | "PairWith1" >> beam.Map(lambda tup: tup + (1,))
+ | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(
+ *([sum] * (len(metric_keys) + 1))))
+ | "AverageAndMakeDict" >> beam.Map(
+ lambda tup: dict(
+ [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] +
+ [("count", tup[-1])])))
+
+
+def run(argv=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--prediction_path", required=True,
+ help=(
+ "The GCS folder that contains BatchPrediction results, containing "
+ "prediction.results-NNNNN-of-NNNNN files in the json format. "
+ "Output will be also stored in this folder, as a file"
+ "'prediction.summary.json'."))
+ parser.add_argument(
+ "--metric_fn_encoded", required=True,
+ help=(
+ "An encoded function that calculates and returns a tuple of "
+ "metric(s) for a given instance (as a dictionary). It should be "
+ "encoded via base64.b64encode(dill.dumps(fn, recurse=True))."))
+ parser.add_argument(
+ "--metric_keys", required=True,
+ help=(
+ "A comma-separated keys of the aggregated metric(s) in the summary "
+ "output. The order and the size of the keys must match to the "
+ "output of metric_fn. The summary will have an additional key, "
+ "'count', to represent the total number of instances, so this flag "
+ "shouldn't include 'count'."))
+ known_args, pipeline_args = parser.parse_known_args(argv)
+
+ metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded))
+ if not callable(metric_fn):
+ raise ValueError("--metric_fn_encoded must be an encoded callable.")
+ metric_keys = known_args.metric_keys.split(",")
+
+ with beam.Pipeline(
+ options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
+ # This is apache-beam ptransform's convention
+ # pylint: disable=no-value-for-parameter
+ _ = (p
+ | "ReadPredictionResult" >> beam.io.ReadFromText(
+ os.path.join(known_args.prediction_path,
+ "prediction.results-*-of-*"),
+ coder=JsonCoder())
+ | "Summary" >> MakeSummary(metric_fn, metric_keys)
+ | "Write" >> beam.io.WriteToText(
+ os.path.join(known_args.prediction_path,
+ "prediction.summary.json"),
+ shard_name_template='', # without trailing -NNNNN-of-NNNNN.
+ coder=JsonCoder()))
+ # pylint: enable=no-value-for-parameter
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ run()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py
deleted file mode 100644
index f56018d..0000000
--- a/tests/contrib/hooks/test_gcp_cloudml_hook.py
+++ /dev/null
@@ -1,413 +0,0 @@
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import mock
-import unittest
-
-try: # python 2
- from urlparse import urlparse, parse_qsl
-except ImportError: # python 3
- from urllib.parse import urlparse, parse_qsl
-
-from airflow.contrib.hooks import gcp_cloudml_hook as hook
-from apiclient import errors
-from apiclient.discovery import build
-from apiclient.http import HttpMockSequence
-from oauth2client.contrib.gce import HttpAccessTokenRefreshError
-
-cml_available = True
-try:
- hook.CloudMLHook().get_conn()
-except HttpAccessTokenRefreshError:
- cml_available = False
-
-
-class _TestCloudMLHook(object):
-
- def __init__(self, test_cls, responses, expected_requests):
- """
- Init method.
-
- Usage example:
- with _TestCloudMLHook(self, responses, expected_requests) as hook:
- self.run_my_test(hook)
-
- Args:
- test_cls: The caller's instance used for test communication.
- responses: A list of (dict_response, response_content) tuples.
- expected_requests: A list of (uri, http_method, body) tuples.
- """
-
- self._test_cls = test_cls
- self._responses = responses
- self._expected_requests = [
- self._normalize_requests_for_comparison(x[0], x[1], x[2])
- for x in expected_requests]
- self._actual_requests = []
-
- def _normalize_requests_for_comparison(self, uri, http_method, body):
- parts = urlparse(uri)
- return (
- parts._replace(query=set(parse_qsl(parts.query))),
- http_method,
- body)
-
- def __enter__(self):
- http = HttpMockSequence(self._responses)
- native_request_method = http.request
-
- # Collecting requests to validate at __exit__.
- def _request_wrapper(*args, **kwargs):
- self._actual_requests.append(args + (kwargs['body'],))
- return native_request_method(*args, **kwargs)
-
- http.request = _request_wrapper
- service_mock = build('ml', 'v1', http=http)
- with mock.patch.object(
- hook.CloudMLHook, 'get_conn', return_value=service_mock):
- return hook.CloudMLHook()
-
- def __exit__(self, *args):
- # Propogating exceptions here since assert will silence them.
- if any(args):
- return None
- self._test_cls.assertEquals(
- [self._normalize_requests_for_comparison(x[0], x[1], x[2])
- for x in self._actual_requests],
- self._expected_requests)
-
-
-class TestCloudMLHook(unittest.TestCase):
-
- def setUp(self):
- pass
-
- _SKIP_IF = unittest.skipIf(not cml_available,
- 'CloudML is not available to run tests')
-
- _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/'
-
- @_SKIP_IF
- def test_create_version(self):
- project = 'test-project'
- model_name = 'test-model'
- version = 'test-version'
- operation_name = 'projects/{}/operations/test-operation'.format(
- project)
-
- response_body = {'name': operation_name, 'done': True}
- succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
- expected_requests = [
- ('{}projects/{}/models/{}/versions?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, model_name), 'POST',
- '"{}"'.format(version)),
- ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
- 'GET', None),
- ]
-
- with _TestCloudMLHook(
- self,
- responses=[succeeded_response] * 2,
- expected_requests=expected_requests) as cml_hook:
- create_version_response = cml_hook.create_version(
- project_id=project, model_name=model_name,
- version_spec=version)
- self.assertEquals(create_version_response, response_body)
-
- @_SKIP_IF
- def test_set_default_version(self):
- project = 'test-project'
- model_name = 'test-model'
- version = 'test-version'
- operation_name = 'projects/{}/operations/test-operation'.format(
- project)
-
- response_body = {'name': operation_name, 'done': True}
- succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
- expected_requests = [
- ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, model_name, version),
- 'POST', '{}'),
- ]
-
- with _TestCloudMLHook(
- self,
- responses=[succeeded_response],
- expected_requests=expected_requests) as cml_hook:
- set_default_version_response = cml_hook.set_default_version(
- project_id=project, model_name=model_name,
- version_name=version)
- self.assertEquals(set_default_version_response, response_body)
-
- @_SKIP_IF
- def test_list_versions(self):
- project = 'test-project'
- model_name = 'test-model'
- operation_name = 'projects/{}/operations/test-operation'.format(
- project)
-
- # This test returns the versions one at a time.
- versions = ['ver_{}'.format(ix) for ix in range(3)]
-
- response_bodies = [
- {
- 'name': operation_name,
- 'nextPageToken': ix,
- 'versions': [ver]
- } for ix, ver in enumerate(versions)]
- response_bodies[-1].pop('nextPageToken')
- responses = [({'status': '200'}, json.dumps(body))
- for body in response_bodies]
-
- expected_requests = [
- ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format(
- self._SERVICE_URI_PREFIX, project, model_name), 'GET',
- None),
- ] + [
- ('{}projects/{}/models/{}/versions?alt=json&pageToken={}'
- '&pageSize=100'.format(
- self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
- None) for ix in range(len(versions) - 1)
- ]
-
- with _TestCloudMLHook(
- self,
- responses=responses,
- expected_requests=expected_requests) as cml_hook:
- list_versions_response = cml_hook.list_versions(
- project_id=project, model_name=model_name)
- self.assertEquals(list_versions_response, versions)
-
- @_SKIP_IF
- def test_delete_version(self):
- project = 'test-project'
- model_name = 'test-model'
- version = 'test-version'
- operation_name = 'projects/{}/operations/test-operation'.format(
- project)
-
- not_done_response_body = {'name': operation_name, 'done': False}
- done_response_body = {'name': operation_name, 'done': True}
- not_done_response = (
- {'status': '200'}, json.dumps(not_done_response_body))
- succeeded_response = (
- {'status': '200'}, json.dumps(done_response_body))
-
- expected_requests = [
- (
- '{}projects/{}/models/{}/versions/{}?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, model_name, version),
- 'DELETE',
- None),
- ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
- 'GET', None),
- ]
-
- with _TestCloudMLHook(
- self,
- responses=[not_done_response, succeeded_response],
- expected_requests=expected_requests) as cml_hook:
- delete_version_response = cml_hook.delete_version(
- project_id=project, model_name=model_name,
- version_name=version)
- self.assertEquals(delete_version_response, done_response_body)
-
- @_SKIP_IF
- def test_create_model(self):
- project = 'test-project'
- model_name = 'test-model'
- model = {
- 'name': model_name,
- }
- response_body = {}
- succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
- expected_requests = [
- ('{}projects/{}/models?alt=json'.format(
- self._SERVICE_URI_PREFIX, project), 'POST',
- json.dumps(model)),
- ]
-
- with _TestCloudMLHook(
- self,
- responses=[succeeded_response],
- expected_requests=expected_requests) as cml_hook:
- create_model_response = cml_hook.create_model(
- project_id=project, model=model)
- self.assertEquals(create_model_response, response_body)
-
- @_SKIP_IF
- def test_get_model(self):
- project = 'test-project'
- model_name = 'test-model'
- response_body = {'model': model_name}
- succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
- expected_requests = [
- ('{}projects/{}/models/{}?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, model_name), 'GET',
- None),
- ]
-
- with _TestCloudMLHook(
- self,
- responses=[succeeded_response],
- expected_requests=expected_requests) as cml_hook:
- get_model_response = cml_hook.get_model(
- project_id=project, model_name=model_name)
- self.assertEquals(get_model_response, response_body)
-
- @_SKIP_IF
- def test_create_cloudml_job(self):
- project = 'test-project'
- job_id = 'test-job-id'
- my_job = {
- 'jobId': job_id,
- 'foo': 4815162342,
- 'state': 'SUCCEEDED',
- }
- response_body = json.dumps(my_job)
- succeeded_response = ({'status': '200'}, response_body)
- queued_response = ({'status': '200'}, json.dumps({
- 'jobId': job_id,
- 'state': 'QUEUED',
- }))
-
- create_job_request = ('{}projects/{}/jobs?alt=json'.format(
- self._SERVICE_URI_PREFIX, project), 'POST', response_body)
- ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
- expected_requests = [
- create_job_request,
- ask_if_done_request,
- ask_if_done_request,
- ]
- responses = [succeeded_response,
- queued_response, succeeded_response]
-
- with _TestCloudMLHook(
- self,
- responses=responses,
- expected_requests=expected_requests) as cml_hook:
- create_job_response = cml_hook.create_job(
- project_id=project, job=my_job)
- self.assertEquals(create_job_response, my_job)
-
- @_SKIP_IF
- def test_create_cloudml_job_reuse_existing_job_by_default(self):
- project = 'test-project'
- job_id = 'test-job-id'
- my_job = {
- 'jobId': job_id,
- 'foo': 4815162342,
- 'state': 'SUCCEEDED',
- }
- response_body = json.dumps(my_job)
- job_already_exist_response = ({'status': '409'}, json.dumps({}))
- succeeded_response = ({'status': '200'}, response_body)
-
- create_job_request = ('{}projects/{}/jobs?alt=json'.format(
- self._SERVICE_URI_PREFIX, project), 'POST', response_body)
- ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
- expected_requests = [
- create_job_request,
- ask_if_done_request,
- ]
- responses = [job_already_exist_response, succeeded_response]
-
- # By default, 'create_job' reuse the existing job.
- with _TestCloudMLHook(
- self,
- responses=responses,
- expected_requests=expected_requests) as cml_hook:
- create_job_response = cml_hook.create_job(
- project_id=project, job=my_job)
- self.assertEquals(create_job_response, my_job)
-
- @_SKIP_IF
- def test_create_cloudml_job_check_existing_job(self):
- project = 'test-project'
- job_id = 'test-job-id'
- my_job = {
- 'jobId': job_id,
- 'foo': 4815162342,
- 'state': 'SUCCEEDED',
- 'someInput': {
- 'input': 'someInput'
- }
- }
- different_job = {
- 'jobId': job_id,
- 'foo': 4815162342,
- 'state': 'SUCCEEDED',
- 'someInput': {
- 'input': 'someDifferentInput'
- }
- }
-
- my_job_response_body = json.dumps(my_job)
- different_job_response_body = json.dumps(different_job)
- job_already_exist_response = ({'status': '409'}, json.dumps({}))
- different_job_response = ({'status': '200'},
- different_job_response_body)
-
- create_job_request = ('{}projects/{}/jobs?alt=json'.format(
- self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body)
- ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
- self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
- expected_requests = [
- create_job_request,
- ask_if_done_request,
- ]
-
- # Returns a different job (with different 'someInput' field) will
- # cause 'create_job' request to fail.
- responses = [job_already_exist_response, different_job_response]
-
- def check_input(existing_job):
- return existing_job.get('someInput', None) == \
- my_job['someInput']
- with _TestCloudMLHook(
- self,
- responses=responses,
- expected_requests=expected_requests) as cml_hook:
- with self.assertRaises(errors.HttpError):
- cml_hook.create_job(
- project_id=project, job=my_job,
- use_existing_job_fn=check_input)
-
- my_job_response = ({'status': '200'}, my_job_response_body)
- expected_requests = [
- create_job_request,
- ask_if_done_request,
- ask_if_done_request,
- ]
- responses = [
- job_already_exist_response,
- my_job_response,
- my_job_response]
- with _TestCloudMLHook(
- self,
- responses=responses,
- expected_requests=expected_requests) as cml_hook:
- create_job_response = cml_hook.create_job(
- project_id=project, job=my_job,
- use_existing_job_fn=check_input)
- self.assertEquals(create_job_response, my_job)
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/hooks/test_gcp_mlengine_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_mlengine_hook.py b/tests/contrib/hooks/test_gcp_mlengine_hook.py
new file mode 100644
index 0000000..372d47c
--- /dev/null
+++ b/tests/contrib/hooks/test_gcp_mlengine_hook.py
@@ -0,0 +1,413 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import mock
+import unittest
+
+try: # python 2
+ from urlparse import urlparse, parse_qsl
+except ImportError: # python 3
+ from urllib.parse import urlparse, parse_qsl
+
+from airflow.contrib.hooks import gcp_mlengine_hook as hook
+from apiclient import errors
+from apiclient.discovery import build
+from apiclient.http import HttpMockSequence
+from oauth2client.contrib.gce import HttpAccessTokenRefreshError
+
+cml_available = True
+try:
+ hook.MLEngineHook().get_conn()
+except HttpAccessTokenRefreshError:
+ cml_available = False
+
+
+class _TestMLEngineHook(object):
+
+ def __init__(self, test_cls, responses, expected_requests):
+ """
+ Init method.
+
+ Usage example:
+ with _TestMLEngineHook(self, responses, expected_requests) as hook:
+ self.run_my_test(hook)
+
+ Args:
+ test_cls: The caller's instance used for test communication.
+ responses: A list of (dict_response, response_content) tuples.
+ expected_requests: A list of (uri, http_method, body) tuples.
+ """
+
+ self._test_cls = test_cls
+ self._responses = responses
+ self._expected_requests = [
+ self._normalize_requests_for_comparison(x[0], x[1], x[2])
+ for x in expected_requests]
+ self._actual_requests = []
+
+ def _normalize_requests_for_comparison(self, uri, http_method, body):
+ parts = urlparse(uri)
+ return (
+ parts._replace(query=set(parse_qsl(parts.query))),
+ http_method,
+ body)
+
+ def __enter__(self):
+ http = HttpMockSequence(self._responses)
+ native_request_method = http.request
+
+ # Collecting requests to validate at __exit__.
+ def _request_wrapper(*args, **kwargs):
+ self._actual_requests.append(args + (kwargs['body'],))
+ return native_request_method(*args, **kwargs)
+
+ http.request = _request_wrapper
+ service_mock = build('ml', 'v1', http=http)
+ with mock.patch.object(
+ hook.MLEngineHook, 'get_conn', return_value=service_mock):
+ return hook.MLEngineHook()
+
+ def __exit__(self, *args):
+ # Propogating exceptions here since assert will silence them.
+ if any(args):
+ return None
+ self._test_cls.assertEquals(
+ [self._normalize_requests_for_comparison(x[0], x[1], x[2])
+ for x in self._actual_requests],
+ self._expected_requests)
+
+
+class TestMLEngineHook(unittest.TestCase):
+
+ def setUp(self):
+ pass
+
+ _SKIP_IF = unittest.skipIf(not cml_available,
+ 'MLEngine is not available to run tests')
+
+ _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/'
+
+ @_SKIP_IF
+ def test_create_version(self):
+ project = 'test-project'
+ model_name = 'test-model'
+ version = 'test-version'
+ operation_name = 'projects/{}/operations/test-operation'.format(
+ project)
+
+ response_body = {'name': operation_name, 'done': True}
+ succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+ expected_requests = [
+ ('{}projects/{}/models/{}/versions?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, model_name), 'POST',
+ '"{}"'.format(version)),
+ ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
+ 'GET', None),
+ ]
+
+ with _TestMLEngineHook(
+ self,
+ responses=[succeeded_response] * 2,
+ expected_requests=expected_requests) as cml_hook:
+ create_version_response = cml_hook.create_version(
+ project_id=project, model_name=model_name,
+ version_spec=version)
+ self.assertEquals(create_version_response, response_body)
+
+ @_SKIP_IF
+ def test_set_default_version(self):
+ project = 'test-project'
+ model_name = 'test-model'
+ version = 'test-version'
+ operation_name = 'projects/{}/operations/test-operation'.format(
+ project)
+
+ response_body = {'name': operation_name, 'done': True}
+ succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+ expected_requests = [
+ ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, model_name, version),
+ 'POST', '{}'),
+ ]
+
+ with _TestMLEngineHook(
+ self,
+ responses=[succeeded_response],
+ expected_requests=expected_requests) as cml_hook:
+ set_default_version_response = cml_hook.set_default_version(
+ project_id=project, model_name=model_name,
+ version_name=version)
+ self.assertEquals(set_default_version_response, response_body)
+
+ @_SKIP_IF
+ def test_list_versions(self):
+ project = 'test-project'
+ model_name = 'test-model'
+ operation_name = 'projects/{}/operations/test-operation'.format(
+ project)
+
+ # This test returns the versions one at a time.
+ versions = ['ver_{}'.format(ix) for ix in range(3)]
+
+ response_bodies = [
+ {
+ 'name': operation_name,
+ 'nextPageToken': ix,
+ 'versions': [ver]
+ } for ix, ver in enumerate(versions)]
+ response_bodies[-1].pop('nextPageToken')
+ responses = [({'status': '200'}, json.dumps(body))
+ for body in response_bodies]
+
+ expected_requests = [
+ ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format(
+ self._SERVICE_URI_PREFIX, project, model_name), 'GET',
+ None),
+ ] + [
+ ('{}projects/{}/models/{}/versions?alt=json&pageToken={}'
+ '&pageSize=100'.format(
+ self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
+ None) for ix in range(len(versions) - 1)
+ ]
+
+ with _TestMLEngineHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ list_versions_response = cml_hook.list_versions(
+ project_id=project, model_name=model_name)
+ self.assertEquals(list_versions_response, versions)
+
+ @_SKIP_IF
+ def test_delete_version(self):
+ project = 'test-project'
+ model_name = 'test-model'
+ version = 'test-version'
+ operation_name = 'projects/{}/operations/test-operation'.format(
+ project)
+
+ not_done_response_body = {'name': operation_name, 'done': False}
+ done_response_body = {'name': operation_name, 'done': True}
+ not_done_response = (
+ {'status': '200'}, json.dumps(not_done_response_body))
+ succeeded_response = (
+ {'status': '200'}, json.dumps(done_response_body))
+
+ expected_requests = [
+ (
+ '{}projects/{}/models/{}/versions/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, model_name, version),
+ 'DELETE',
+ None),
+ ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
+ 'GET', None),
+ ]
+
+ with _TestMLEngineHook(
+ self,
+ responses=[not_done_response, succeeded_response],
+ expected_requests=expected_requests) as cml_hook:
+ delete_version_response = cml_hook.delete_version(
+ project_id=project, model_name=model_name,
+ version_name=version)
+ self.assertEquals(delete_version_response, done_response_body)
+
+ @_SKIP_IF
+ def test_create_model(self):
+ project = 'test-project'
+ model_name = 'test-model'
+ model = {
+ 'name': model_name,
+ }
+ response_body = {}
+ succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+ expected_requests = [
+ ('{}projects/{}/models?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project), 'POST',
+ json.dumps(model)),
+ ]
+
+ with _TestMLEngineHook(
+ self,
+ responses=[succeeded_response],
+ expected_requests=expected_requests) as cml_hook:
+ create_model_response = cml_hook.create_model(
+ project_id=project, model=model)
+ self.assertEquals(create_model_response, response_body)
+
+ @_SKIP_IF
+ def test_get_model(self):
+ project = 'test-project'
+ model_name = 'test-model'
+ response_body = {'model': model_name}
+ succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+ expected_requests = [
+ ('{}projects/{}/models/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, model_name), 'GET',
+ None),
+ ]
+
+ with _TestMLEngineHook(
+ self,
+ responses=[succeeded_response],
+ expected_requests=expected_requests) as cml_hook:
+ get_model_response = cml_hook.get_model(
+ project_id=project, model_name=model_name)
+ self.assertEquals(get_model_response, response_body)
+
+ @_SKIP_IF
+ def test_create_mlengine_job(self):
+ project = 'test-project'
+ job_id = 'test-job-id'
+ my_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ }
+ response_body = json.dumps(my_job)
+ succeeded_response = ({'status': '200'}, response_body)
+ queued_response = ({'status': '200'}, json.dumps({
+ 'jobId': job_id,
+ 'state': 'QUEUED',
+ }))
+
+ create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+ ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ask_if_done_request,
+ ]
+ responses = [succeeded_response,
+ queued_response, succeeded_response]
+
+ with _TestMLEngineHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ create_job_response = cml_hook.create_job(
+ project_id=project, job=my_job)
+ self.assertEquals(create_job_response, my_job)
+
+ @_SKIP_IF
+ def test_create_mlengine_job_reuse_existing_job_by_default(self):
+ project = 'test-project'
+ job_id = 'test-job-id'
+ my_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ }
+ response_body = json.dumps(my_job)
+ job_already_exist_response = ({'status': '409'}, json.dumps({}))
+ succeeded_response = ({'status': '200'}, response_body)
+
+ create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+ ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ]
+ responses = [job_already_exist_response, succeeded_response]
+
+ # By default, 'create_job' reuse the existing job.
+ with _TestMLEngineHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ create_job_response = cml_hook.create_job(
+ project_id=project, job=my_job)
+ self.assertEquals(create_job_response, my_job)
+
+ @_SKIP_IF
+ def test_create_mlengine_job_check_existing_job(self):
+ project = 'test-project'
+ job_id = 'test-job-id'
+ my_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ 'someInput': {
+ 'input': 'someInput'
+ }
+ }
+ different_job = {
+ 'jobId': job_id,
+ 'foo': 4815162342,
+ 'state': 'SUCCEEDED',
+ 'someInput': {
+ 'input': 'someDifferentInput'
+ }
+ }
+
+ my_job_response_body = json.dumps(my_job)
+ different_job_response_body = json.dumps(different_job)
+ job_already_exist_response = ({'status': '409'}, json.dumps({}))
+ different_job_response = ({'status': '200'},
+ different_job_response_body)
+
+ create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body)
+ ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+ self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ]
+
+ # Returns a different job (with different 'someInput' field) will
+ # cause 'create_job' request to fail.
+ responses = [job_already_exist_response, different_job_response]
+
+ def check_input(existing_job):
+ return existing_job.get('someInput', None) == \
+ my_job['someInput']
+ with _TestMLEngineHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ with self.assertRaises(errors.HttpError):
+ cml_hook.create_job(
+ project_id=project, job=my_job,
+ use_existing_job_fn=check_input)
+
+ my_job_response = ({'status': '200'}, my_job_response_body)
+ expected_requests = [
+ create_job_request,
+ ask_if_done_request,
+ ask_if_done_request,
+ ]
+ responses = [
+ job_already_exist_response,
+ my_job_response,
+ my_job_response]
+ with _TestMLEngineHook(
+ self,
+ responses=responses,
+ expected_requests=expected_requests) as cml_hook:
+ create_job_response = cml_hook.create_job(
+ project_id=project, job=my_job,
+ use_existing_job_fn=check_input)
+ self.assertEquals(create_job_response, my_job)
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_cloudml_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py
deleted file mode 100644
index dc2366e..0000000
--- a/tests/contrib/operators/test_cloudml_operator.py
+++ /dev/null
@@ -1,373 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import datetime
-from apiclient import errors
-import httplib2
-import unittest
-
-from airflow import configuration, DAG
-from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
-from airflow.contrib.operators.cloudml_operator import CloudMLTrainingOperator
-
-from mock import ANY
-from mock import patch
-
-DEFAULT_DATE = datetime.datetime(2017, 6, 6)
-
-
-class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
- INPUT_MISSING_ORIGIN = {
- 'dataFormat': 'TEXT',
- 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
- 'outputPath': 'gs://legal-bucket/fake-output-path',
- 'region': 'us-east1',
- }
- SUCCESS_MESSAGE_MISSING_INPUT = {
- 'jobId': 'test_prediction',
- 'predictionOutput': {
- 'outputPath': 'gs://fake-output-path',
- 'predictionCount': 5000,
- 'errorCount': 0,
- 'nodeHours': 2.78
- },
- 'state': 'SUCCEEDED'
- }
- BATCH_PREDICTION_DEFAULT_ARGS = {
- 'project_id': 'test-project',
- 'job_id': 'test_prediction',
- 'region': 'us-east1',
- 'data_format': 'TEXT',
- 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
- 'output_path':
- 'gs://12_legal_bucket_underscore_number/legal-output-path',
- 'task_id': 'test-prediction'
- }
-
- def setUp(self):
- super(CloudMLBatchPredictionOperatorTest, self).setUp()
- configuration.load_test_config()
- self.dag = DAG(
- 'test_dag',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- 'end_date': DEFAULT_DATE,
- },
- schedule_interval='@daily')
-
- def testSuccessWithModel(self):
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
-
- input_with_model = self.INPUT_MISSING_ORIGIN.copy()
- input_with_model['modelName'] = \
- 'projects/test-project/models/test_model'
- success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
- success_message['predictionInput'] = input_with_model
-
- hook_instance = mock_hook.return_value
- hook_instance.get_job.side_effect = errors.HttpError(
- resp=httplib2.Response({
- 'status': 404
- }), content=b'some bytes')
- hook_instance.create_job.return_value = success_message
-
- prediction_task = CloudMLBatchPredictionOperator(
- job_id='test_prediction',
- project_id='test-project',
- region=input_with_model['region'],
- data_format=input_with_model['dataFormat'],
- input_paths=input_with_model['inputPaths'],
- output_path=input_with_model['outputPath'],
- model_name=input_with_model['modelName'].split('/')[-1],
- dag=self.dag,
- task_id='test-prediction')
- prediction_output = prediction_task.execute(None)
-
- mock_hook.assert_called_with('google_cloud_default', None)
- hook_instance.create_job.assert_called_once_with(
- 'test-project',
- {
- 'jobId': 'test_prediction',
- 'predictionInput': input_with_model
- }, ANY)
- self.assertEquals(
- success_message['predictionOutput'],
- prediction_output)
-
- def testSuccessWithVersion(self):
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
-
- input_with_version = self.INPUT_MISSING_ORIGIN.copy()
- input_with_version['versionName'] = \
- 'projects/test-project/models/test_model/versions/test_version'
- success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
- success_message['predictionInput'] = input_with_version
-
- hook_instance = mock_hook.return_value
- hook_instance.get_job.side_effect = errors.HttpError(
- resp=httplib2.Response({
- 'status': 404
- }), content=b'some bytes')
- hook_instance.create_job.return_value = success_message
-
- prediction_task = CloudMLBatchPredictionOperator(
- job_id='test_prediction', project_id='test-project',
- region=input_with_version['region'],
- data_format=input_with_version['dataFormat'],
- input_paths=input_with_version['inputPaths'],
- output_path=input_with_version['outputPath'],
- model_name=input_with_version['versionName'].split('/')[-3],
- version_name=input_with_version['versionName'].split('/')[-1],
- dag=self.dag,
- task_id='test-prediction')
- prediction_output = prediction_task.execute(None)
-
- mock_hook.assert_called_with('google_cloud_default', None)
- hook_instance.create_job.assert_called_with(
- 'test-project',
- {
- 'jobId': 'test_prediction',
- 'predictionInput': input_with_version
- }, ANY)
- self.assertEquals(
- success_message['predictionOutput'],
- prediction_output)
-
- def testSuccessWithURI(self):
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
-
- input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
- input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
- success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
- success_message['predictionInput'] = input_with_uri
-
- hook_instance = mock_hook.return_value
- hook_instance.get_job.side_effect = errors.HttpError(
- resp=httplib2.Response({
- 'status': 404
- }), content=b'some bytes')
- hook_instance.create_job.return_value = success_message
-
- prediction_task = CloudMLBatchPredictionOperator(
- job_id='test_prediction',
- project_id='test-project',
- region=input_with_uri['region'],
- data_format=input_with_uri['dataFormat'],
- input_paths=input_with_uri['inputPaths'],
- output_path=input_with_uri['outputPath'],
- uri=input_with_uri['uri'],
- dag=self.dag,
- task_id='test-prediction')
- prediction_output = prediction_task.execute(None)
-
- mock_hook.assert_called_with('google_cloud_default', None)
- hook_instance.create_job.assert_called_with(
- 'test-project',
- {
- 'jobId': 'test_prediction',
- 'predictionInput': input_with_uri
- }, ANY)
- self.assertEquals(
- success_message['predictionOutput'],
- prediction_output)
-
- def testInvalidModelOrigin(self):
- # Test that both uri and model is given
- task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
- task_args['uri'] = 'gs://fake-uri/saved_model'
- task_args['model_name'] = 'fake_model'
- with self.assertRaises(ValueError) as context:
- CloudMLBatchPredictionOperator(**task_args).execute(None)
- self.assertEquals('Ambiguous model origin.', str(context.exception))
-
- # Test that both uri and model/version is given
- task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
- task_args['uri'] = 'gs://fake-uri/saved_model'
- task_args['model_name'] = 'fake_model'
- task_args['version_name'] = 'fake_version'
- with self.assertRaises(ValueError) as context:
- CloudMLBatchPredictionOperator(**task_args).execute(None)
- self.assertEquals('Ambiguous model origin.', str(context.exception))
-
- # Test that a version is given without a model
- task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
- task_args['version_name'] = 'bare_version'
- with self.assertRaises(ValueError) as context:
- CloudMLBatchPredictionOperator(**task_args).execute(None)
- self.assertEquals(
- 'Missing model origin.',
- str(context.exception))
-
- # Test that none of uri, model, model/version is given
- task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
- with self.assertRaises(ValueError) as context:
- CloudMLBatchPredictionOperator(**task_args).execute(None)
- self.assertEquals(
- 'Missing model origin.',
- str(context.exception))
-
- def testHttpError(self):
- http_error_code = 403
-
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
- input_with_model = self.INPUT_MISSING_ORIGIN.copy()
- input_with_model['modelName'] = \
- 'projects/experimental/models/test_model'
-
- hook_instance = mock_hook.return_value
- hook_instance.create_job.side_effect = errors.HttpError(
- resp=httplib2.Response({
- 'status': http_error_code
- }), content=b'Forbidden')
-
- with self.assertRaises(errors.HttpError) as context:
- prediction_task = CloudMLBatchPredictionOperator(
- job_id='test_prediction',
- project_id='test-project',
- region=input_with_model['region'],
- data_format=input_with_model['dataFormat'],
- input_paths=input_with_model['inputPaths'],
- output_path=input_with_model['outputPath'],
- model_name=input_with_model['modelName'].split('/')[-1],
- dag=self.dag,
- task_id='test-prediction')
- prediction_task.execute(None)
-
- mock_hook.assert_called_with('google_cloud_default', None)
- hook_instance.create_job.assert_called_with(
- 'test-project',
- {
- 'jobId': 'test_prediction',
- 'predictionInput': input_with_model
- }, ANY)
-
- self.assertEquals(http_error_code, context.exception.resp.status)
-
- def testFailedJobError(self):
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
- hook_instance = mock_hook.return_value
- hook_instance.create_job.return_value = {
- 'state': 'FAILED',
- 'errorMessage': 'A failure message'
- }
- task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
- task_args['uri'] = 'a uri'
-
- with self.assertRaises(RuntimeError) as context:
- CloudMLBatchPredictionOperator(**task_args).execute(None)
-
- self.assertEquals('A failure message', str(context.exception))
-
-
-class CloudMLTrainingOperatorTest(unittest.TestCase):
- TRAINING_DEFAULT_ARGS = {
- 'project_id': 'test-project',
- 'job_id': 'test_training',
- 'package_uris': ['gs://some-bucket/package1'],
- 'training_python_module': 'trainer',
- 'training_args': '--some_arg=\'aaa\'',
- 'region': 'us-east1',
- 'scale_tier': 'STANDARD_1',
- 'task_id': 'test-training'
- }
- TRAINING_INPUT = {
- 'jobId': 'test_training',
- 'trainingInput': {
- 'scaleTier': 'STANDARD_1',
- 'packageUris': ['gs://some-bucket/package1'],
- 'pythonModule': 'trainer',
- 'args': '--some_arg=\'aaa\'',
- 'region': 'us-east1'
- }
- }
-
- def testSuccessCreateTrainingJob(self):
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
- success_response = self.TRAINING_INPUT.copy()
- success_response['state'] = 'SUCCEEDED'
- hook_instance = mock_hook.return_value
- hook_instance.create_job.return_value = success_response
-
- training_op = CloudMLTrainingOperator(**self.TRAINING_DEFAULT_ARGS)
- training_op.execute(None)
-
- mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
- delegate_to=None)
- # Make sure only 'create_job' is invoked on hook instance
- self.assertEquals(len(hook_instance.mock_calls), 1)
- hook_instance.create_job.assert_called_with(
- 'test-project', self.TRAINING_INPUT, ANY)
-
- def testHttpError(self):
- http_error_code = 403
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
- hook_instance = mock_hook.return_value
- hook_instance.create_job.side_effect = errors.HttpError(
- resp=httplib2.Response({
- 'status': http_error_code
- }), content=b'Forbidden')
-
- with self.assertRaises(errors.HttpError) as context:
- training_op = CloudMLTrainingOperator(
- **self.TRAINING_DEFAULT_ARGS)
- training_op.execute(None)
-
- mock_hook.assert_called_with(
- gcp_conn_id='google_cloud_default', delegate_to=None)
- # Make sure only 'create_job' is invoked on hook instance
- self.assertEquals(len(hook_instance.mock_calls), 1)
- hook_instance.create_job.assert_called_with(
- 'test-project', self.TRAINING_INPUT, ANY)
- self.assertEquals(http_error_code, context.exception.resp.status)
-
- def testFailedJobError(self):
- with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
- as mock_hook:
- failure_response = self.TRAINING_INPUT.copy()
- failure_response['state'] = 'FAILED'
- failure_response['errorMessage'] = 'A failure message'
- hook_instance = mock_hook.return_value
- hook_instance.create_job.return_value = failure_response
-
- with self.assertRaises(RuntimeError) as context:
- training_op = CloudMLTrainingOperator(
- **self.TRAINING_DEFAULT_ARGS)
- training_op.execute(None)
-
- mock_hook.assert_called_with(
- gcp_conn_id='google_cloud_default', delegate_to=None)
- # Make sure only 'create_job' is invoked on hook instance
- self.assertEquals(len(hook_instance.mock_calls), 1)
- hook_instance.create_job.assert_called_with(
- 'test-project', self.TRAINING_INPUT, ANY)
- self.assertEquals('A failure message', str(context.exception))
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py
deleted file mode 100644
index b2a5a30..0000000
--- a/tests/contrib/operators/test_cloudml_operator_utils.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import datetime
-import unittest
-
-from airflow import configuration, DAG
-from airflow.contrib.operators import cloudml_operator_utils
-from airflow.contrib.operators.cloudml_operator_utils import create_evaluate_ops
-from airflow.exceptions import AirflowException
-
-from mock import ANY
-from mock import patch
-
-DEFAULT_DATE = datetime.datetime(2017, 6, 6)
-
-
-class CreateEvaluateOpsTest(unittest.TestCase):
-
- INPUT_MISSING_ORIGIN = {
- 'dataFormat': 'TEXT',
- 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
- 'outputPath': 'gs://legal-bucket/fake-output-path',
- 'region': 'us-east1',
- 'versionName': 'projects/test-project/models/test_model/versions/test_version',
- }
- SUCCESS_MESSAGE_MISSING_INPUT = {
- 'jobId': 'eval_test_prediction',
- 'predictionOutput': {
- 'outputPath': 'gs://fake-output-path',
- 'predictionCount': 5000,
- 'errorCount': 0,
- 'nodeHours': 2.78
- },
- 'state': 'SUCCEEDED'
- }
-
- def setUp(self):
- super(CreateEvaluateOpsTest, self).setUp()
- configuration.load_test_config()
- self.dag = DAG(
- 'test_dag',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- 'end_date': DEFAULT_DATE,
- 'project_id': 'test-project',
- 'region': 'us-east1',
- 'model_name': 'test_model',
- 'version_name': 'test_version',
- },
- schedule_interval='@daily')
- self.metric_fn = lambda x: (0.1,)
- self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode(
- cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True))
-
- def testSuccessfulRun(self):
- input_with_model = self.INPUT_MISSING_ORIGIN.copy()
-
- pred, summary, validate = create_evaluate_ops(
- task_prefix='eval-test',
- batch_prediction_job_id='eval-test-prediction',
- data_format=input_with_model['dataFormat'],
- input_paths=input_with_model['inputPaths'],
- prediction_path=input_with_model['outputPath'],
- metric_fn_and_keys=(self.metric_fn, ['err']),
- validate_fn=(lambda x: 'err=%.1f' % x['err']),
- dag=self.dag)
-
- with patch('airflow.contrib.operators.cloudml_operator.'
- 'CloudMLHook') as mock_cloudml_hook:
-
- success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
- success_message['predictionInput'] = input_with_model
- hook_instance = mock_cloudml_hook.return_value
- hook_instance.create_job.return_value = success_message
- result = pred.execute(None)
- mock_cloudml_hook.assert_called_with('google_cloud_default', None)
- hook_instance.create_job.assert_called_once_with(
- 'test-project',
- {
- 'jobId': 'eval_test_prediction',
- 'predictionInput': input_with_model,
- },
- ANY)
- self.assertEqual(success_message['predictionOutput'], result)
-
- with patch('airflow.contrib.operators.dataflow_operator.'
- 'DataFlowHook') as mock_dataflow_hook:
-
- hook_instance = mock_dataflow_hook.return_value
- hook_instance.start_python_dataflow.return_value = None
- summary.execute(None)
- mock_dataflow_hook.assert_called_with(
- gcp_conn_id='google_cloud_default', delegate_to=None)
- hook_instance.start_python_dataflow.assert_called_once_with(
- 'eval-test-summary',
- {
- 'prediction_path': 'gs://legal-bucket/fake-output-path',
- 'metric_keys': 'err',
- 'metric_fn_encoded': self.metric_fn_encoded,
- },
- 'airflow.contrib.operators.cloudml_prediction_summary',
- ['-m'])
-
- with patch('airflow.contrib.operators.cloudml_operator_utils.'
- 'GoogleCloudStorageHook') as mock_gcs_hook:
-
- hook_instance = mock_gcs_hook.return_value
- hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
- result = validate.execute({})
- hook_instance.download.assert_called_once_with(
- 'legal-bucket', 'fake-output-path/prediction.summary.json')
- self.assertEqual('err=0.9', result)
-
- def testFailures(self):
- dag = DAG(
- 'test_dag',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- 'end_date': DEFAULT_DATE,
- 'project_id': 'test-project',
- 'region': 'us-east1',
- },
- schedule_interval='@daily')
-
- input_with_model = self.INPUT_MISSING_ORIGIN.copy()
- other_params_but_models = {
- 'task_prefix': 'eval-test',
- 'batch_prediction_job_id': 'eval-test-prediction',
- 'data_format': input_with_model['dataFormat'],
- 'input_paths': input_with_model['inputPaths'],
- 'prediction_path': input_with_model['outputPath'],
- 'metric_fn_and_keys': (self.metric_fn, ['err']),
- 'validate_fn': (lambda x: 'err=%.1f' % x['err']),
- 'dag': dag,
- }
-
- with self.assertRaisesRegexp(ValueError, 'Missing model origin'):
- _ = create_evaluate_ops(**other_params_but_models)
-
- with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
- _ = create_evaluate_ops(model_uri='abc', model_name='cde',
- **other_params_but_models)
-
- with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
- _ = create_evaluate_ops(model_uri='abc', version_name='vvv',
- **other_params_but_models)
-
- with self.assertRaisesRegexp(AirflowException,
- '`metric_fn` param must be callable'):
- params = other_params_but_models.copy()
- params['metric_fn_and_keys'] = (None, ['abc'])
- _ = create_evaluate_ops(model_uri='gs://blah', **params)
-
- with self.assertRaisesRegexp(AirflowException,
- '`validate_fn` param must be callable'):
- params = other_params_but_models.copy()
- params['validate_fn'] = None
- _ = create_evaluate_ops(model_uri='gs://blah', **params)
-
-
-if __name__ == '__main__':
- unittest.main()