You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/07/31 20:30:41 UTC

incubator-airflow git commit: [AIRFLOW-1359] Use default_args in Cloud ML eval

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 6e2640766 -> 1932ccc88


[AIRFLOW-1359] Use default_args in Cloud ML eval

This change makes the create_evaluate_ops utility
method make use of the default_args parameters of
the DAG when possible. This simplifies the usage
of the create_evaluate_ops method, and improves
the usefulness of a variety of default_args.

To further the usefulness of default_args for
Cloud ML Operators, this change also introduces
version_name to the CloudMLVersionOperator,
allowing model_name and version_name to be
specified across an entire pipeline.

This change also resolves a small TODO by making
the DataFlowPythonOperator's `options` and
`dataflow_default_options` variables templatized.

Closes #2445 from
peterjdolan/eval_ops_arguments_from_default_args


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/1932ccc8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/1932ccc8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/1932ccc8

Branch: refs/heads/master
Commit: 1932ccc881d10d220d1d06efaa477373a08596bb
Parents: 6e26407
Author: Peter Dolan <pe...@google.com>
Authored: Mon Jul 31 13:27:12 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Mon Jul 31 13:27:22 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/operators/cloudml_operator.py   | 16 +++-
 .../contrib/operators/cloudml_operator_utils.py | 84 ++++++++++++--------
 airflow/contrib/operators/dataflow_operator.py  |  2 +
 .../operators/test_cloudml_operator_utils.py    | 40 +++++-----
 4 files changed, 91 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
index 34b2e83..6bdd516 100644
--- a/airflow/contrib/operators/cloudml_operator.py
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -341,6 +341,12 @@ class CloudMLVersionOperator(BaseOperator):
         If it is None, the only `operation` possible would be `list`.
     :type version: dict
 
+    :param version_name: A name to use for the version being operated upon. If
+        not None and the `version` argument is None or does not have a value for
+        the `name` key, then this will be populated in the payload for the
+        `name` key.
+    :type version_name: string
+
     :param gcp_conn_id: The connection ID to use when fetching connection info.
     :type gcp_conn_id: string
 
@@ -372,13 +378,15 @@ class CloudMLVersionOperator(BaseOperator):
     template_fields = [
         '_model_name',
         '_version',
+        '_version_name',
     ]
 
     @apply_defaults
     def __init__(self,
                  model_name,
                  project_id,
-                 version,
+                 version=None,
+                 version_name=None,
                  gcp_conn_id='google_cloud_default',
                  operation='create',
                  delegate_to=None,
@@ -387,13 +395,17 @@ class CloudMLVersionOperator(BaseOperator):
 
         super(CloudMLVersionOperator, self).__init__(*args, **kwargs)
         self._model_name = model_name
-        self._version = version
+        self._version = version or {}
+        self._version_name = version_name
         self._gcp_conn_id = gcp_conn_id
         self._delegate_to = delegate_to
         self._project_id = project_id
         self._operation = operation
 
     def execute(self, context):
+        if 'name' not in self._version:
+            self._version['name'] = self._version_name
+
         hook = CloudMLHook(
             gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/airflow/contrib/operators/cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py
index f4abb32..81cd54f 100644
--- a/airflow/contrib/operators/cloudml_operator_utils.py
+++ b/airflow/contrib/operators/cloudml_operator_utils.py
@@ -18,31 +18,26 @@ import base64
 import json
 import os
 import re
-try:  # python 2
-    from urlparse import urlsplit
-except ImportError:  # python 3
-    from urllib.parse import urlsplit
 
 import dill
 
 from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
 from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
-from airflow.contrib.operators.cloudml_operator import _normalize_cloudml_job_id
 from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
 from airflow.exceptions import AirflowException
 from airflow.operators.python_operator import PythonOperator
-
+from six.moves.urllib.parse import urlsplit
 
 def create_evaluate_ops(task_prefix,
-                        project_id,
-                        job_id,
-                        region,
                         data_format,
                         input_paths,
                         prediction_path,
                         metric_fn_and_keys,
                         validate_fn,
-                        dataflow_options,
+                        batch_prediction_job_id=None,
+                        project_id=None,
+                        region=None,
+                        dataflow_options=None,
                         model_uri=None,
                         model_name=None,
                         version_name=None,
@@ -114,22 +109,6 @@ def create_evaluate_ops(task_prefix,
         job name, which doesn't allow other characters.
     :type task_prefix: string
 
-    :param model_uri: GCS path of the model exported by Tensorflow using
-        tensorflow.estimator.export_savedmodel(). It cannot be used with
-        model_name or version_name below. See CloudMLBatchPredictionOperator for
-        more detail.
-    :type model_uri: string
-
-    :param model_name: Used to indicate a model to use for prediction. Can be
-        used in combination with version_name, but cannot be used together with
-        model_uri. See CloudMLBatchPredictionOperator for more detail.
-    :type model_name: string
-
-    :param version_name: Used to indicate a model version to use for prediciton,
-        in combination with model_name. Cannot be used together with model_uri.
-        See CloudMLBatchPredictionOperator for more detail.
-    :type version_name: string
-
     :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
     :type data_format: string
 
@@ -149,9 +128,46 @@ def create_evaluate_ops(task_prefix,
         good enough to push the model.
     :type validate_fn: function
 
-    :param dataflow_options: options to run Dataflow jobs.
+    :param batch_prediction_job_id: the id to use for the Cloud ML Batch
+        prediction job. Passed directly to the CloudMLBatchPredictionOperator as
+        the job_id argument.
+    :type batch_prediction_job_id: string
+
+    :param project_id: the Google Cloud Platform project id in which to execute
+        Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+        `default_args['project_id']` will be used.
+    :type project_id: string
+
+    :param region: the Google Cloud Platform region in which to execute Cloud ML
+        Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+        `default_args['region']` will be used.
+    :type region: string
+
+    :param dataflow_options: options to run Dataflow jobs. If None, then the
+        `dag`'s `default_args['dataflow_default_options']` will be used.
     :type dataflow_options: dictionary
 
+    :param model_uri: GCS path of the model exported by Tensorflow using
+        tensorflow.estimator.export_savedmodel(). It cannot be used with
+        model_name or version_name below. See CloudMLBatchPredictionOperator for
+        more detail.
+    :type model_uri: string
+
+    :param model_name: Used to indicate a model to use for prediction. Can be
+        used in combination with version_name, but cannot be used together with
+        model_uri. See CloudMLBatchPredictionOperator for more detail. If None,
+        then the `dag`'s `default_args['model_name']` will be used.
+    :type model_name: string
+
+    :param version_name: Used to indicate a model version to use for prediciton,
+        in combination with model_name. Cannot be used together with model_uri.
+        See CloudMLBatchPredictionOperator for more detail. If None, then the
+        `dag`'s `default_args['version_name']` will be used.
+    :type version_name: string
+
+    :param dag: The `DAG` to use for all Operators.
+    :type dag: airflow.DAG
+
     :returns: a tuple of three operators, (prediction, summary, validation)
     :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
                   PythonOperator)
@@ -170,10 +186,19 @@ def create_evaluate_ops(task_prefix,
     if not callable(validate_fn):
         raise AirflowException("`validate_fn` param must be callable.")
 
+    if dag is not None and dag.default_args is not None:
+        default_args = dag.default_args
+        project_id = project_id or default_args.get('project_id')
+        region = region or default_args.get('region')
+        model_name = model_name or default_args.get('model_name')
+        version_name = version_name or default_args.get('version_name')
+        dataflow_options = dataflow_options or \
+            default_args.get('dataflow_default_options')
+
     evaluate_prediction = CloudMLBatchPredictionOperator(
         task_id=(task_prefix + "-prediction"),
         project_id=project_id,
-        job_id=_normalize_cloudml_job_id(job_id),
+        job_id=batch_prediction_job_id,
         region=region,
         data_format=data_format,
         input_paths=input_paths,
@@ -195,9 +220,6 @@ def create_evaluate_ops(task_prefix,
             "metric_keys": ','.join(metric_keys)
         },
         dag=dag)
-    # TODO: "options" is not template_field of DataFlowPythonOperator (not sure
-    # if intended or by mistake); consider fixing in the DataFlowPythonOperator.
-    evaluate_summary.template_fields.append("options")
     evaluate_summary.set_upstream(evaluate_prediction)
 
     def apply_validate_fn(*args, **kwargs):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/airflow/contrib/operators/dataflow_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py
index c1dca24..5cb8cf8 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -123,6 +123,8 @@ class DataFlowJavaOperator(BaseOperator):
 
 class DataFlowPythonOperator(BaseOperator):
 
+    template_fields = ['options', 'dataflow_default_options']
+
     @apply_defaults
     def __init__(
             self,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1932ccc8/tests/contrib/operators/test_cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py
index 91a9f77..b2a5a30 100644
--- a/tests/contrib/operators/test_cloudml_operator_utils.py
+++ b/tests/contrib/operators/test_cloudml_operator_utils.py
@@ -40,6 +40,7 @@ class CreateEvaluateOpsTest(unittest.TestCase):
         'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
         'outputPath': 'gs://legal-bucket/fake-output-path',
         'region': 'us-east1',
+        'versionName': 'projects/test-project/models/test_model/versions/test_version',
     }
     SUCCESS_MESSAGE_MISSING_INPUT = {
         'jobId': 'eval_test_prediction',
@@ -61,30 +62,27 @@ class CreateEvaluateOpsTest(unittest.TestCase):
                 'owner': 'airflow',
                 'start_date': DEFAULT_DATE,
                 'end_date': DEFAULT_DATE,
+                'project_id': 'test-project',
+                'region': 'us-east1',
+                'model_name': 'test_model',
+                'version_name': 'test_version',
             },
             schedule_interval='@daily')
         self.metric_fn = lambda x: (0.1,)
         self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode(
             cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True))
 
-
     def testSuccessfulRun(self):
         input_with_model = self.INPUT_MISSING_ORIGIN.copy()
-        input_with_model['modelName'] = (
-            'projects/test-project/models/test_model')
 
         pred, summary, validate = create_evaluate_ops(
             task_prefix='eval-test',
-            project_id='test-project',
-            job_id='eval-test-prediction',
-            region=input_with_model['region'],
+            batch_prediction_job_id='eval-test-prediction',
             data_format=input_with_model['dataFormat'],
             input_paths=input_with_model['inputPaths'],
             prediction_path=input_with_model['outputPath'],
-            model_name=input_with_model['modelName'].split('/')[-1],
             metric_fn_and_keys=(self.metric_fn, ['err']),
             validate_fn=(lambda x: 'err=%.1f' % x['err']),
-            dataflow_options=None,
             dag=self.dag)
 
         with patch('airflow.contrib.operators.cloudml_operator.'
@@ -100,8 +98,9 @@ class CreateEvaluateOpsTest(unittest.TestCase):
                 'test-project',
                 {
                     'jobId': 'eval_test_prediction',
-                    'predictionInput': input_with_model
-                }, ANY)
+                    'predictionInput': input_with_model,
+                },
+                ANY)
             self.assertEqual(success_message['predictionOutput'], result)
 
         with patch('airflow.contrib.operators.dataflow_operator.'
@@ -133,22 +132,27 @@ class CreateEvaluateOpsTest(unittest.TestCase):
             self.assertEqual('err=0.9', result)
 
     def testFailures(self):
-        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
-        input_with_model['modelName'] = (
-            'projects/test-project/models/test_model')
+        dag = DAG(
+            'test_dag',
+            default_args={
+                'owner': 'airflow',
+                'start_date': DEFAULT_DATE,
+                'end_date': DEFAULT_DATE,
+                'project_id': 'test-project',
+                'region': 'us-east1',
+            },
+            schedule_interval='@daily')
 
+        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
         other_params_but_models = {
             'task_prefix': 'eval-test',
-            'project_id': 'test-project',
-            'job_id': 'eval-test-prediction',
-            'region': input_with_model['region'],
+            'batch_prediction_job_id': 'eval-test-prediction',
             'data_format': input_with_model['dataFormat'],
             'input_paths': input_with_model['inputPaths'],
             'prediction_path': input_with_model['outputPath'],
             'metric_fn_and_keys': (self.metric_fn, ['err']),
             'validate_fn': (lambda x: 'err=%.1f' % x['err']),
-            'dataflow_options': None,
-            'dag': self.dag,
+            'dag': dag,
         }
 
         with self.assertRaisesRegexp(ValueError, 'Missing model origin'):