You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/07/14 00:07:04 UTC
incubator-airflow git commit: [AIRFLOW-1359] Add Google CloudML utils
for model evaluation
Repository: incubator-airflow
Updated Branches:
refs/heads/master e88ecff6a -> 194d1d6e5
[AIRFLOW-1359] Add Google CloudML utils for model evaluation
Closes #2407 from yk5/evaluate
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/194d1d6e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/194d1d6e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/194d1d6e
Branch: refs/heads/master
Commit: 194d1d6e5b89918f22267ae6a86455a0acc771df
Parents: e88ecff
Author: Younghee Kwon <yo...@google.com>
Authored: Thu Jul 13 17:06:06 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Thu Jul 13 17:06:56 2017 -0700
----------------------------------------------------------------------
.../contrib/operators/cloudml_operator_utils.py | 223 +++++++++++++++++++
.../operators/cloudml_prediction_summary.py | 177 +++++++++++++++
.../operators/test_cloudml_operator_utils.py | 179 +++++++++++++++
3 files changed, 579 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/airflow/contrib/operators/cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py
new file mode 100644
index 0000000..f4abb32
--- /dev/null
+++ b/airflow/contrib/operators/cloudml_operator_utils.py
@@ -0,0 +1,223 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import json
+import os
+import re
+try: # python 2
+ from urlparse import urlsplit
+except ImportError: # python 3
+ from urllib.parse import urlsplit
+
+import dill
+
+from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
+from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
+from airflow.contrib.operators.cloudml_operator import _normalize_cloudml_job_id
+from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
+from airflow.exceptions import AirflowException
+from airflow.operators.python_operator import PythonOperator
+
+
+def create_evaluate_ops(task_prefix,
+ project_id,
+ job_id,
+ region,
+ data_format,
+ input_paths,
+ prediction_path,
+ metric_fn_and_keys,
+ validate_fn,
+ dataflow_options,
+ model_uri=None,
+ model_name=None,
+ version_name=None,
+ dag=None):
+ """
+ Creates Operators needed for model evaluation and returns.
+
+ It gets prediction over inputs via Cloud ML Engine BatchPrediction API by
+ calling CloudMLBatchPredictionOperator, then summarize and validate
+ the result via Cloud Dataflow using DataFlowPythonOperator.
+
+ For details and pricing about Batch prediction, please refer to the website
+ https://cloud.google.com/ml-engine/docs/how-tos/batch-predict
+ and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/
+
+ It returns three chained operators for prediction, summary, and validation,
+ named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation,
+ respectively.
+ (<prefix> should contain only alphanumeric characters or hyphen.)
+
+ The upstream and downstream can be set accordingly like:
+ pred, _, val = create_evaluate_ops(...)
+ pred.set_upstream(upstream_op)
+ ...
+ downstream_op.set_upstream(val)
+
+ Callers will provide two python callables, metric_fn and validate_fn, in
+ order to customize the evaluation behavior as they wish.
+ - metric_fn receives a dictionary per instance derived from json in the
+ batch prediction result. The keys might vary depending on the model.
+ It should return a tuple of metrics.
+ - validation_fn receives a dictionary of the averaged metrics that metric_fn
+ generated over all instances.
+ The key/value of the dictionary matches to what's given by
+ metric_fn_and_keys arg.
+ The dictionary contains an additional metric, 'count' to represent the
+ total number of instances received for evaluation.
+ The function would raise an exception to mark the task as failed, in a
+ case the validation result is not okay to proceed (i.e. to set the trained
+ version as default).
+
+ Typical examples are like this:
+
+ def get_metric_fn_and_keys():
+ import math # imports should be outside of the metric_fn below.
+ def error_and_squared_error(inst):
+ label = float(inst['input_label'])
+ classes = float(inst['classes']) # 0 or 1
+ err = abs(classes-label)
+ squared_err = math.pow(classes-label, 2)
+ return (err, squared_err) # returns a tuple.
+ return error_and_squared_error, ['err', 'mse'] # key order must match.
+
+ def validate_err_and_count(summary):
+ if summary['err'] > 0.2:
+ raise ValueError('Too high err>0.2; summary=%s' % summary)
+ if summary['mse'] > 0.05:
+ raise ValueError('Too high mse>0.05; summary=%s' % summary)
+ if summary['count'] < 1000:
+ raise ValueError('Too few instances<1000; summary=%s' % summary)
+ return summary
+
+ For the details on the other BatchPrediction-related arguments (project_id,
+ job_id, region, data_format, input_paths, prediction_path, model_uri),
+ please refer to CloudMLBatchPredictionOperator too.
+
+ :param task_prefix: a prefix for the tasks. Only alphanumeric characters and
+ hyphen are allowed (no underscores), since this will be used as dataflow
+ job name, which doesn't allow other characters.
+ :type task_prefix: string
+
+ :param model_uri: GCS path of the model exported by Tensorflow using
+ tensorflow.estimator.export_savedmodel(). It cannot be used with
+ model_name or version_name below. See CloudMLBatchPredictionOperator for
+ more detail.
+ :type model_uri: string
+
+ :param model_name: Used to indicate a model to use for prediction. Can be
+ used in combination with version_name, but cannot be used together with
+ model_uri. See CloudMLBatchPredictionOperator for more detail.
+ :type model_name: string
+
+ :param version_name: Used to indicate a model version to use for prediciton,
+ in combination with model_name. Cannot be used together with model_uri.
+ See CloudMLBatchPredictionOperator for more detail.
+ :type version_name: string
+
+ :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
+ :type data_format: string
+
+ :param input_paths: a list of input paths to be sent to BatchPrediction.
+ :type input_paths: list of strings
+
+ :param prediction_path: GCS path to put the prediction results in.
+ :type prediction_path: string
+
+ :param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
+ - metric_fn is a function that accepts a dictionary (for an instance),
+ and returns a tuple of metric(s) that it calculates.
+ - metric_keys is a list of strings to denote the key of each metric.
+ :type metric_fn_and_keys: tuple of a function and a list of strings
+
+ :param validate_fn: a function to validate whether the averaged metric(s) is
+ good enough to push the model.
+ :type validate_fn: function
+
+ :param dataflow_options: options to run Dataflow jobs.
+ :type dataflow_options: dictionary
+
+ :returns: a tuple of three operators, (prediction, summary, validation)
+ :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
+ PythonOperator)
+ """
+
+ # Verify that task_prefix doesn't have any special characters except hyphen
+ # '-', which is the only allowed non-alphanumeric character by Dataflow.
+ if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix):
+ raise AirflowException(
+ "Malformed task_id for DataFlowPythonOperator (only alphanumeric "
+ "and hyphens are allowed but got: " + task_prefix)
+
+ metric_fn, metric_keys = metric_fn_and_keys
+ if not callable(metric_fn):
+ raise AirflowException("`metric_fn` param must be callable.")
+ if not callable(validate_fn):
+ raise AirflowException("`validate_fn` param must be callable.")
+
+ evaluate_prediction = CloudMLBatchPredictionOperator(
+ task_id=(task_prefix + "-prediction"),
+ project_id=project_id,
+ job_id=_normalize_cloudml_job_id(job_id),
+ region=region,
+ data_format=data_format,
+ input_paths=input_paths,
+ output_path=prediction_path,
+ uri=model_uri,
+ model_name=model_name,
+ version_name=version_name,
+ dag=dag)
+
+ metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True))
+ evaluate_summary = DataFlowPythonOperator(
+ task_id=(task_prefix + "-summary"),
+ py_options=["-m"],
+ py_file="airflow.contrib.operators.cloudml_prediction_summary",
+ dataflow_default_options=dataflow_options,
+ options={
+ "prediction_path": prediction_path,
+ "metric_fn_encoded": metric_fn_encoded,
+ "metric_keys": ','.join(metric_keys)
+ },
+ dag=dag)
+ # TODO: "options" is not template_field of DataFlowPythonOperator (not sure
+ # if intended or by mistake); consider fixing in the DataFlowPythonOperator.
+ evaluate_summary.template_fields.append("options")
+ evaluate_summary.set_upstream(evaluate_prediction)
+
+ def apply_validate_fn(*args, **kwargs):
+ prediction_path = kwargs["templates_dict"]["prediction_path"]
+ scheme, bucket, obj, _, _ = urlsplit(prediction_path)
+ if scheme != "gs" or not bucket or not obj:
+ raise ValueError("Wrong format prediction_path: %s",
+ prediction_path)
+ summary = os.path.join(obj.strip("/"),
+ "prediction.summary.json")
+ gcs_hook = GoogleCloudStorageHook()
+ summary = json.loads(gcs_hook.download(bucket, summary))
+ return validate_fn(summary)
+
+ evaluate_validation = PythonOperator(
+ task_id=(task_prefix + "-validation"),
+ python_callable=apply_validate_fn,
+ provide_context=True,
+ templates_dict={"prediction_path": prediction_path},
+ dag=dag)
+ evaluate_validation.set_upstream(evaluate_summary)
+
+ return evaluate_prediction, evaluate_summary, evaluate_validation
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/airflow/contrib/operators/cloudml_prediction_summary.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_prediction_summary.py b/airflow/contrib/operators/cloudml_prediction_summary.py
new file mode 100644
index 0000000..3128dc3
--- /dev/null
+++ b/airflow/contrib/operators/cloudml_prediction_summary.py
@@ -0,0 +1,177 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
+
+It accepts a user function to calculate the metric(s) per instance in
+the prediction results, then aggregates to output as a summary.
+
+Args:
+ --prediction_path:
+ The GCS folder that contains BatchPrediction results, containing
+ prediction.results-NNNNN-of-NNNNN files in the json format.
+ Output will be also stored in this folder, as 'prediction.summary.json'.
+
+ --metric_fn_encoded:
+ An encoded function that calculates and returns a tuple of metric(s)
+ for a given instance (as a dictionary). It should be encoded
+ via base64.b64encode(dill.dumps(fn, recurse=True)).
+
+ --metric_keys:
+ A comma-separated key(s) of the aggregated metric(s) in the summary
+ output. The order and the size of the keys must match to the output
+ of metric_fn.
+ The summary will have an additional key, 'count', to represent the
+ total number of instances, so the keys shouldn't include 'count'.
+
+# Usage example:
+def get_metric_fn():
+ import math # all imports must be outside of the function to be passed.
+ def metric_fn(inst):
+ label = float(inst["input_label"])
+ classes = float(inst["classes"])
+ prediction = float(inst["scores"][1])
+ log_loss = math.log(1 + math.exp(
+ -(label * 2 - 1) * math.log(prediction / (1 - prediction))))
+ squared_err = (classes-label)**2
+ return (log_loss, squared_err)
+ return metric_fn
+metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
+
+airflow.contrib.operators.DataFlowPythonOperator(
+ task_id="summary-prediction",
+ py_options=["-m"],
+ py_file="airflow.contrib.operators.cloudml_prediction_summary",
+ options={
+ "prediction_path": prediction_path,
+ "metric_fn_encoded": metric_fn_encoded,
+ "metric_keys": "log_loss,mse"
+ },
+ dataflow_default_options={
+ "project": "xxx", "region": "us-east1",
+ "staging_location": "gs://yy", "temp_location": "gs://zz",
+ })
+ >> dag
+
+# When the input file is like the following:
+{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
+{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
+{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
+{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
+
+# The output file will be:
+{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
+
+# To test outside of the dag:
+subprocess.check_call(["python",
+ "-m",
+ "airflow.contrib.operators.cloudml_prediction_summary",
+ "--prediction_path=gs://...",
+ "--metric_fn_encoded=" + metric_fn_encoded,
+ "--metric_keys=log_loss,mse",
+ "--runner=DataflowRunner",
+ "--staging_location=gs://...",
+ "--temp_location=gs://...",
+ ])
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import base64
+import json
+import logging
+import os
+
+import apache_beam as beam
+import dill
+
+
+class JsonCoder(object):
+ def encode(self, x):
+ return json.dumps(x)
+
+ def decode(self, x):
+ return json.loads(x)
+
+
+@beam.ptransform_fn
+def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name
+ return (
+ pcoll
+ | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn)
+ | "PairWith1" >> beam.Map(lambda tup: tup + (1,))
+ | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(
+ *([sum] * (len(metric_keys) + 1))))
+ | "AverageAndMakeDict" >> beam.Map(
+ lambda tup: dict(
+ [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] +
+ [("count", tup[-1])])))
+
+
+def run(argv=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--prediction_path", required=True,
+ help=(
+ "The GCS folder that contains BatchPrediction results, containing "
+ "prediction.results-NNNNN-of-NNNNN files in the json format. "
+ "Output will be also stored in this folder, as a file"
+ "'prediction.summary.json'."))
+ parser.add_argument(
+ "--metric_fn_encoded", required=True,
+ help=(
+ "An encoded function that calculates and returns a tuple of "
+ "metric(s) for a given instance (as a dictionary). It should be "
+ "encoded via base64.b64encode(dill.dumps(fn, recurse=True))."))
+ parser.add_argument(
+ "--metric_keys", required=True,
+ help=(
+ "A comma-separated keys of the aggregated metric(s) in the summary "
+ "output. The order and the size of the keys must match to the "
+ "output of metric_fn. The summary will have an additional key, "
+ "'count', to represent the total number of instances, so this flag "
+ "shouldn't include 'count'."))
+ known_args, pipeline_args = parser.parse_known_args(argv)
+
+ metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded))
+ if not callable(metric_fn):
+ raise ValueError("--metric_fn_encoded must be an encoded callable.")
+ metric_keys = known_args.metric_keys.split(",")
+
+ with beam.Pipeline(
+ options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
+ # This is apache-beam ptransform's convention
+ # pylint: disable=no-value-for-parameter
+ _ = (p
+ | "ReadPredictionResult" >> beam.io.ReadFromText(
+ os.path.join(known_args.prediction_path,
+ "prediction.results-*-of-*"),
+ coder=JsonCoder())
+ | "Summary" >> MakeSummary(metric_fn, metric_keys)
+ | "Write" >> beam.io.WriteToText(
+ os.path.join(known_args.prediction_path,
+ "prediction.summary.json"),
+ shard_name_template='', # without trailing -NNNNN-of-NNNNN.
+ coder=JsonCoder()))
+ # pylint: enable=no-value-for-parameter
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ run()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/tests/contrib/operators/test_cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py
new file mode 100644
index 0000000..91a9f77
--- /dev/null
+++ b/tests/contrib/operators/test_cloudml_operator_utils.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+import unittest
+
+from airflow import configuration, DAG
+from airflow.contrib.operators import cloudml_operator_utils
+from airflow.contrib.operators.cloudml_operator_utils import create_evaluate_ops
+from airflow.exceptions import AirflowException
+
+from mock import ANY
+from mock import patch
+
+DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+
+
+class CreateEvaluateOpsTest(unittest.TestCase):
+
+ INPUT_MISSING_ORIGIN = {
+ 'dataFormat': 'TEXT',
+ 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+ 'outputPath': 'gs://legal-bucket/fake-output-path',
+ 'region': 'us-east1',
+ }
+ SUCCESS_MESSAGE_MISSING_INPUT = {
+ 'jobId': 'eval_test_prediction',
+ 'predictionOutput': {
+ 'outputPath': 'gs://fake-output-path',
+ 'predictionCount': 5000,
+ 'errorCount': 0,
+ 'nodeHours': 2.78
+ },
+ 'state': 'SUCCEEDED'
+ }
+
+ def setUp(self):
+ super(CreateEvaluateOpsTest, self).setUp()
+ configuration.load_test_config()
+ self.dag = DAG(
+ 'test_dag',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'end_date': DEFAULT_DATE,
+ },
+ schedule_interval='@daily')
+ self.metric_fn = lambda x: (0.1,)
+ self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode(
+ cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True))
+
+
+ def testSuccessfulRun(self):
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+ input_with_model['modelName'] = (
+ 'projects/test-project/models/test_model')
+
+ pred, summary, validate = create_evaluate_ops(
+ task_prefix='eval-test',
+ project_id='test-project',
+ job_id='eval-test-prediction',
+ region=input_with_model['region'],
+ data_format=input_with_model['dataFormat'],
+ input_paths=input_with_model['inputPaths'],
+ prediction_path=input_with_model['outputPath'],
+ model_name=input_with_model['modelName'].split('/')[-1],
+ metric_fn_and_keys=(self.metric_fn, ['err']),
+ validate_fn=(lambda x: 'err=%.1f' % x['err']),
+ dataflow_options=None,
+ dag=self.dag)
+
+ with patch('airflow.contrib.operators.cloudml_operator.'
+ 'CloudMLHook') as mock_cloudml_hook:
+
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_model
+ hook_instance = mock_cloudml_hook.return_value
+ hook_instance.create_job.return_value = success_message
+ result = pred.execute(None)
+ mock_cloudml_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_once_with(
+ 'test-project',
+ {
+ 'jobId': 'eval_test_prediction',
+ 'predictionInput': input_with_model
+ }, ANY)
+ self.assertEqual(success_message['predictionOutput'], result)
+
+ with patch('airflow.contrib.operators.dataflow_operator.'
+ 'DataFlowHook') as mock_dataflow_hook:
+
+ hook_instance = mock_dataflow_hook.return_value
+ hook_instance.start_python_dataflow.return_value = None
+ summary.execute(None)
+ mock_dataflow_hook.assert_called_with(
+ gcp_conn_id='google_cloud_default', delegate_to=None)
+ hook_instance.start_python_dataflow.assert_called_once_with(
+ 'eval-test-summary',
+ {
+ 'prediction_path': 'gs://legal-bucket/fake-output-path',
+ 'metric_keys': 'err',
+ 'metric_fn_encoded': self.metric_fn_encoded,
+ },
+ 'airflow.contrib.operators.cloudml_prediction_summary',
+ ['-m'])
+
+ with patch('airflow.contrib.operators.cloudml_operator_utils.'
+ 'GoogleCloudStorageHook') as mock_gcs_hook:
+
+ hook_instance = mock_gcs_hook.return_value
+ hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
+ result = validate.execute({})
+ hook_instance.download.assert_called_once_with(
+ 'legal-bucket', 'fake-output-path/prediction.summary.json')
+ self.assertEqual('err=0.9', result)
+
+ def testFailures(self):
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+ input_with_model['modelName'] = (
+ 'projects/test-project/models/test_model')
+
+ other_params_but_models = {
+ 'task_prefix': 'eval-test',
+ 'project_id': 'test-project',
+ 'job_id': 'eval-test-prediction',
+ 'region': input_with_model['region'],
+ 'data_format': input_with_model['dataFormat'],
+ 'input_paths': input_with_model['inputPaths'],
+ 'prediction_path': input_with_model['outputPath'],
+ 'metric_fn_and_keys': (self.metric_fn, ['err']),
+ 'validate_fn': (lambda x: 'err=%.1f' % x['err']),
+ 'dataflow_options': None,
+ 'dag': self.dag,
+ }
+
+ with self.assertRaisesRegexp(ValueError, 'Missing model origin'):
+ _ = create_evaluate_ops(**other_params_but_models)
+
+ with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+ _ = create_evaluate_ops(model_uri='abc', model_name='cde',
+ **other_params_but_models)
+
+ with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+ _ = create_evaluate_ops(model_uri='abc', version_name='vvv',
+ **other_params_but_models)
+
+ with self.assertRaisesRegexp(AirflowException,
+ '`metric_fn` param must be callable'):
+ params = other_params_but_models.copy()
+ params['metric_fn_and_keys'] = (None, ['abc'])
+ _ = create_evaluate_ops(model_uri='gs://blah', **params)
+
+ with self.assertRaisesRegexp(AirflowException,
+ '`validate_fn` param must be callable'):
+ params = other_params_but_models.copy()
+ params['validate_fn'] = None
+ _ = create_evaluate_ops(model_uri='gs://blah', **params)
+
+
+if __name__ == '__main__':
+ unittest.main()