You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/07/31 20:30:41 UTC
incubator-airflow git commit: [AIRFLOW-1359] Use default_args in
Cloud ML eval
Repository: incubator-airflow
Updated Branches:
refs/heads/master 6e2640766 -> 1932ccc88
[AIRFLOW-1359] Use default_args in Cloud ML eval
This change makes the create_evaluate_ops utility
method make use of the default_args parameters of
the DAG when possible. This simplifies the usage
of the create_evaluate_ops method, and improves
the usefulness of a variety of default_args.
To further the usefulness of default_args for
Cloud ML Operators, this change also introduces
version_name to the CloudMLVersionOperator,
allowing model_name and version_name to be
specified across an entire pipeline.
This change also resolves a small TODO by making
the DataFlowPythonOperator's `options` and
`dataflow_default_options` variables templatized.
Closes #2445 from
peterjdolan/eval_ops_arguments_from_default_args
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/1932ccc8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/1932ccc8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/1932ccc8
Branch: refs/heads/master
Commit: 1932ccc881d10d220d1d06efaa477373a08596bb
Parents: 6e26407
Author: Peter Dolan <pe...@google.com>
Authored: Mon Jul 31 13:27:12 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Mon Jul 31 13:27:22 2017 -0700
----------------------------------------------------------------------
airflow/contrib/operators/cloudml_operator.py | 16 +++-
.../contrib/operators/cloudml_operator_utils.py | 84 ++++++++++++--------
airflow/contrib/operators/dataflow_operator.py | 2 +
.../operators/test_cloudml_operator_utils.py | 40 +++++-----
4 files changed, 91 insertions(+), 51 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
index 34b2e83..6bdd516 100644
--- a/airflow/contrib/operators/cloudml_operator.py
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -341,6 +341,12 @@ class CloudMLVersionOperator(BaseOperator):
If it is None, the only `operation` possible would be `list`.
:type version: dict
+ :param version_name: A name to use for the version being operated upon. If
+ not None and the `version` argument is None or does not have a value for
+ the `name` key, then this will be populated in the payload for the
+ `name` key.
+ :type version_name: string
+
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: string
@@ -372,13 +378,15 @@ class CloudMLVersionOperator(BaseOperator):
template_fields = [
'_model_name',
'_version',
+ '_version_name',
]
@apply_defaults
def __init__(self,
model_name,
project_id,
- version,
+ version=None,
+ version_name=None,
gcp_conn_id='google_cloud_default',
operation='create',
delegate_to=None,
@@ -387,13 +395,17 @@ class CloudMLVersionOperator(BaseOperator):
super(CloudMLVersionOperator, self).__init__(*args, **kwargs)
self._model_name = model_name
- self._version = version
+ self._version = version or {}
+ self._version_name = version_name
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
self._project_id = project_id
self._operation = operation
def execute(self, context):
+ if 'name' not in self._version:
+ self._version['name'] = self._version_name
+
hook = CloudMLHook(
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/airflow/contrib/operators/cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py
index f4abb32..81cd54f 100644
--- a/airflow/contrib/operators/cloudml_operator_utils.py
+++ b/airflow/contrib/operators/cloudml_operator_utils.py
@@ -18,31 +18,26 @@ import base64
import json
import os
import re
-try: # python 2
- from urlparse import urlsplit
-except ImportError: # python 3
- from urllib.parse import urlsplit
import dill
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
-from airflow.contrib.operators.cloudml_operator import _normalize_cloudml_job_id
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
from airflow.exceptions import AirflowException
from airflow.operators.python_operator import PythonOperator
-
+from six.moves.urllib.parse import urlsplit
def create_evaluate_ops(task_prefix,
- project_id,
- job_id,
- region,
data_format,
input_paths,
prediction_path,
metric_fn_and_keys,
validate_fn,
- dataflow_options,
+ batch_prediction_job_id=None,
+ project_id=None,
+ region=None,
+ dataflow_options=None,
model_uri=None,
model_name=None,
version_name=None,
@@ -114,22 +109,6 @@ def create_evaluate_ops(task_prefix,
job name, which doesn't allow other characters.
:type task_prefix: string
- :param model_uri: GCS path of the model exported by Tensorflow using
- tensorflow.estimator.export_savedmodel(). It cannot be used with
- model_name or version_name below. See CloudMLBatchPredictionOperator for
- more detail.
- :type model_uri: string
-
- :param model_name: Used to indicate a model to use for prediction. Can be
- used in combination with version_name, but cannot be used together with
- model_uri. See CloudMLBatchPredictionOperator for more detail.
- :type model_name: string
-
- :param version_name: Used to indicate a model version to use for prediciton,
- in combination with model_name. Cannot be used together with model_uri.
- See CloudMLBatchPredictionOperator for more detail.
- :type version_name: string
-
:param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
:type data_format: string
@@ -149,9 +128,46 @@ def create_evaluate_ops(task_prefix,
good enough to push the model.
:type validate_fn: function
- :param dataflow_options: options to run Dataflow jobs.
+ :param batch_prediction_job_id: the id to use for the Cloud ML Batch
+ prediction job. Passed directly to the CloudMLBatchPredictionOperator as
+ the job_id argument.
+ :type batch_prediction_job_id: string
+
+ :param project_id: the Google Cloud Platform project id in which to execute
+ Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+ `default_args['project_id']` will be used.
+ :type project_id: string
+
+ :param region: the Google Cloud Platform region in which to execute Cloud ML
+ Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+ `default_args['region']` will be used.
+ :type region: string
+
+ :param dataflow_options: options to run Dataflow jobs. If None, then the
+ `dag`'s `default_args['dataflow_default_options']` will be used.
:type dataflow_options: dictionary
+ :param model_uri: GCS path of the model exported by Tensorflow using
+ tensorflow.estimator.export_savedmodel(). It cannot be used with
+ model_name or version_name below. See CloudMLBatchPredictionOperator for
+ more detail.
+ :type model_uri: string
+
+ :param model_name: Used to indicate a model to use for prediction. Can be
+ used in combination with version_name, but cannot be used together with
+ model_uri. See CloudMLBatchPredictionOperator for more detail. If None,
+ then the `dag`'s `default_args['model_name']` will be used.
+ :type model_name: string
+
+ :param version_name: Used to indicate a model version to use for prediciton,
+ in combination with model_name. Cannot be used together with model_uri.
+ See CloudMLBatchPredictionOperator for more detail. If None, then the
+ `dag`'s `default_args['version_name']` will be used.
+ :type version_name: string
+
+ :param dag: The `DAG` to use for all Operators.
+ :type dag: airflow.DAG
+
:returns: a tuple of three operators, (prediction, summary, validation)
:rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
PythonOperator)
@@ -170,10 +186,19 @@ def create_evaluate_ops(task_prefix,
if not callable(validate_fn):
raise AirflowException("`validate_fn` param must be callable.")
+ if dag is not None and dag.default_args is not None:
+ default_args = dag.default_args
+ project_id = project_id or default_args.get('project_id')
+ region = region or default_args.get('region')
+ model_name = model_name or default_args.get('model_name')
+ version_name = version_name or default_args.get('version_name')
+ dataflow_options = dataflow_options or \
+ default_args.get('dataflow_default_options')
+
evaluate_prediction = CloudMLBatchPredictionOperator(
task_id=(task_prefix + "-prediction"),
project_id=project_id,
- job_id=_normalize_cloudml_job_id(job_id),
+ job_id=batch_prediction_job_id,
region=region,
data_format=data_format,
input_paths=input_paths,
@@ -195,9 +220,6 @@ def create_evaluate_ops(task_prefix,
"metric_keys": ','.join(metric_keys)
},
dag=dag)
- # TODO: "options" is not template_field of DataFlowPythonOperator (not sure
- # if intended or by mistake); consider fixing in the DataFlowPythonOperator.
- evaluate_summary.template_fields.append("options")
evaluate_summary.set_upstream(evaluate_prediction)
def apply_validate_fn(*args, **kwargs):
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/airflow/contrib/operators/dataflow_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py
index c1dca24..5cb8cf8 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -123,6 +123,8 @@ class DataFlowJavaOperator(BaseOperator):
class DataFlowPythonOperator(BaseOperator):
+ template_fields = ['options', 'dataflow_default_options']
+
@apply_defaults
def __init__(
self,
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/tests/contrib/operators/test_cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py
index 91a9f77..b2a5a30 100644
--- a/tests/contrib/operators/test_cloudml_operator_utils.py
+++ b/tests/contrib/operators/test_cloudml_operator_utils.py
@@ -40,6 +40,7 @@ class CreateEvaluateOpsTest(unittest.TestCase):
'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
'outputPath': 'gs://legal-bucket/fake-output-path',
'region': 'us-east1',
+ 'versionName': 'projects/test-project/models/test_model/versions/test_version',
}
SUCCESS_MESSAGE_MISSING_INPUT = {
'jobId': 'eval_test_prediction',
@@ -61,30 +62,27 @@ class CreateEvaluateOpsTest(unittest.TestCase):
'owner': 'airflow',
'start_date': DEFAULT_DATE,
'end_date': DEFAULT_DATE,
+ 'project_id': 'test-project',
+ 'region': 'us-east1',
+ 'model_name': 'test_model',
+ 'version_name': 'test_version',
},
schedule_interval='@daily')
self.metric_fn = lambda x: (0.1,)
self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode(
cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True))
-
def testSuccessfulRun(self):
input_with_model = self.INPUT_MISSING_ORIGIN.copy()
- input_with_model['modelName'] = (
- 'projects/test-project/models/test_model')
pred, summary, validate = create_evaluate_ops(
task_prefix='eval-test',
- project_id='test-project',
- job_id='eval-test-prediction',
- region=input_with_model['region'],
+ batch_prediction_job_id='eval-test-prediction',
data_format=input_with_model['dataFormat'],
input_paths=input_with_model['inputPaths'],
prediction_path=input_with_model['outputPath'],
- model_name=input_with_model['modelName'].split('/')[-1],
metric_fn_and_keys=(self.metric_fn, ['err']),
validate_fn=(lambda x: 'err=%.1f' % x['err']),
- dataflow_options=None,
dag=self.dag)
with patch('airflow.contrib.operators.cloudml_operator.'
@@ -100,8 +98,9 @@ class CreateEvaluateOpsTest(unittest.TestCase):
'test-project',
{
'jobId': 'eval_test_prediction',
- 'predictionInput': input_with_model
- }, ANY)
+ 'predictionInput': input_with_model,
+ },
+ ANY)
self.assertEqual(success_message['predictionOutput'], result)
with patch('airflow.contrib.operators.dataflow_operator.'
@@ -133,22 +132,27 @@ class CreateEvaluateOpsTest(unittest.TestCase):
self.assertEqual('err=0.9', result)
def testFailures(self):
- input_with_model = self.INPUT_MISSING_ORIGIN.copy()
- input_with_model['modelName'] = (
- 'projects/test-project/models/test_model')
+ dag = DAG(
+ 'test_dag',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'end_date': DEFAULT_DATE,
+ 'project_id': 'test-project',
+ 'region': 'us-east1',
+ },
+ schedule_interval='@daily')
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
other_params_but_models = {
'task_prefix': 'eval-test',
- 'project_id': 'test-project',
- 'job_id': 'eval-test-prediction',
- 'region': input_with_model['region'],
+ 'batch_prediction_job_id': 'eval-test-prediction',
'data_format': input_with_model['dataFormat'],
'input_paths': input_with_model['inputPaths'],
'prediction_path': input_with_model['outputPath'],
'metric_fn_and_keys': (self.metric_fn, ['err']),
'validate_fn': (lambda x: 'err=%.1f' % x['err']),
- 'dataflow_options': None,
- 'dag': self.dag,
+ 'dag': dag,
}
with self.assertRaisesRegexp(ValueError, 'Missing model origin'):