You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2018/01/03 19:16:47 UTC

incubator-airflow git commit: [AIRFLOW-1953] Add labels to dataflow operators

Repository: incubator-airflow
Updated Branches:
  refs/heads/master b9f4a7437 -> cc9295fe3


[AIRFLOW-1953] Add labels to dataflow operators

Closes #2913 from fenglu-g/master


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/cc9295fe
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/cc9295fe
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/cc9295fe

Branch: refs/heads/master
Commit: cc9295fe37ed6fb1ddfa077ee065ca6e0849a617
Parents: b9f4a74
Author: fenglu-g <fe...@google.com>
Authored: Wed Jan 3 11:16:39 2018 -0800
Committer: Chris Riccomini <cr...@apache.org>
Committed: Wed Jan 3 11:16:39 2018 -0800

----------------------------------------------------------------------
 UPDATING.md                                     |  5 ++
 airflow/contrib/hooks/gcp_dataflow_hook.py      | 36 ++++++++----
 airflow/contrib/operators/dataflow_operator.py  | 13 +++--
 setup.py                                        |  2 +-
 tests/contrib/hooks/test_gcp_dataflow_hook.py   | 60 +++++++++++++++++---
 .../contrib/operators/test_dataflow_operator.py | 14 ++++-
 .../operators/test_mlengine_operator_utils.py   |  3 +
 7 files changed, 107 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/UPDATING.md
----------------------------------------------------------------------
diff --git a/UPDATING.md b/UPDATING.md
index 9c39634..7a801e5 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -14,6 +14,11 @@ celery_result_backend -> result_backend
 ```
 This will result in the same config parameters as Celery 4 and will make it more transparent.
 
+### GCP Dataflow Operators
+Dataflow job labeling is now supported in Dataflow{Java,Python}Operator with a default
+"airflow-version" label, please upgrade your google-cloud-dataflow or apache-beam version
+to 2.2.0 or greater.
+
 ## Airflow 1.9
 
 ### SSH Hook updates, along with new SSH Operator & SFTP Operator

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/airflow/contrib/hooks/gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py
index 1928c3b..7cb7c79 100644
--- a/airflow/contrib/hooks/gcp_dataflow_hook.py
+++ b/airflow/contrib/hooks/gcp_dataflow_hook.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import json
 import select
 import subprocess
 import time
@@ -147,27 +148,40 @@ class DataFlowHook(GoogleCloudBaseHook):
         http_authorized = self._authorize()
         return build('dataflow', 'v1b3', http=http_authorized)
 
-    def _start_dataflow(self, task_id, variables, dataflow, name, command_prefix):
-        cmd = command_prefix + self._build_cmd(task_id, variables, dataflow)
+    def _start_dataflow(self, task_id, variables, dataflow,
+                        name, command_prefix, label_formatter):
+        cmd = command_prefix + self._build_cmd(task_id, variables,
+                                               dataflow, label_formatter)
         _Dataflow(cmd).wait_for_done()
-        _DataflowJob(
-            self.get_conn(), variables['project'], name, self.poll_sleep).wait_for_done()
+        _DataflowJob(self.get_conn(), variables['project'],
+                     name, self.poll_sleep).wait_for_done()
 
     def start_java_dataflow(self, task_id, variables, dataflow):
         name = task_id + "-" + str(uuid.uuid1())[:8]
         variables['jobName'] = name
-        self._start_dataflow(
-            task_id, variables, dataflow, name, ["java", "-jar"])
+
+        def label_formatter(labels_dict):
+            return ['--labels={}'.format(
+                    json.dumps(labels_dict).replace(' ', ''))]
+        self._start_dataflow(task_id, variables, dataflow, name,
+                             ["java", "-jar"], label_formatter)
 
     def start_python_dataflow(self, task_id, variables, dataflow, py_options):
         name = task_id + "-" + str(uuid.uuid1())[:8]
         variables["job_name"] = name
-        self._start_dataflow(
-            task_id, variables, dataflow, name, ["python"] + py_options)
 
-    def _build_cmd(self, task_id, variables, dataflow):
+        def label_formatter(labels_dict):
+            return ['--labels={}={}'.format(key, value)
+                    for key, value in labels_dict.items()]
+        self._start_dataflow(task_id, variables, dataflow, name,
+                             ["python"] + py_options, label_formatter)
+
+    def _build_cmd(self, task_id, variables, dataflow, label_formatter):
         command = [dataflow, "--runner=DataflowRunner"]
         if variables is not None:
-            for attr, value in variables.iteritems():
-                command.append("--" + attr + "=" + value)
+            for attr, value in variables.items():
+                if attr == 'labels':
+                    command += label_formatter(value)
+                else:
+                    command.append("--" + attr + "=" + value)
         return command

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/airflow/contrib/operators/dataflow_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py
index 6fd23f1..01fbd35 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -19,6 +19,7 @@ import uuid
 from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
 from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook
 from airflow.models import BaseOperator
+from airflow.version import version
 from airflow.utils.decorators import apply_defaults
 
 
@@ -52,7 +53,8 @@ class DataFlowJavaOperator(BaseOperator):
             'autoscalingAlgorithm': 'BASIC',
             'maxNumWorkers': '50',
             'start': '{{ds}}',
-            'partitionType': 'DAY'
+            'partitionType': 'DAY',
+            'labels': {'foo' : 'bar'}
         },
         dag=my-dag)
     ```
@@ -97,7 +99,7 @@ class DataFlowJavaOperator(BaseOperator):
             For this to work, the service account making the request must have
             domain-wide delegation enabled.
         :type delegate_to: string
-        :param poll_sleep: The time in seconds to sleep between polling Google 
+        :param poll_sleep: The time in seconds to sleep between polling Google
             Cloud Platform for the dataflow job status while the job is in the
             JOB_STATE_RUNNING state.
         :type poll_sleep: int
@@ -106,7 +108,8 @@ class DataFlowJavaOperator(BaseOperator):
 
         dataflow_default_options = dataflow_default_options or {}
         options = options or {}
-
+        options.setdefault('labels', {}).update(
+            {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')})
         self.gcp_conn_id = gcp_conn_id
         self.delegate_to = delegate_to
         self.jar = jar
@@ -171,7 +174,7 @@ class DataFlowPythonOperator(BaseOperator):
             For this to work, the service account making the request must have
             domain-wide  delegation enabled.
         :type delegate_to: string
-        :param poll_sleep: The time in seconds to sleep between polling Google 
+        :param poll_sleep: The time in seconds to sleep between polling Google
             Cloud Platform for the dataflow job status while the job is in the
             JOB_STATE_RUNNING state.
         :type poll_sleep: int
@@ -182,6 +185,8 @@ class DataFlowPythonOperator(BaseOperator):
         self.py_options = py_options or []
         self.dataflow_default_options = dataflow_default_options or {}
         self.options = options or {}
+        self.options.setdefault('labels', {}).update(
+            {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')})
         self.gcp_conn_id = gcp_conn_id
         self.delegate_to = delegate_to
         self.poll_sleep = poll_sleep

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 84da6f1..a63ce79 100644
--- a/setup.py
+++ b/setup.py
@@ -123,7 +123,7 @@ gcp_api = [
     'google-api-python-client>=1.5.0, <1.6.0',
     'oauth2client>=2.0.2, <2.1.0',
     'PyOpenSSL',
-    'google-cloud-dataflow',
+    'google-cloud-dataflow>=2.2.0',
     'pandas-gbq'
 ]
 hdfs = ['snakebite>=2.7.8']

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/tests/contrib/hooks/test_gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py
index 1ab5a99..a37b153 100644
--- a/tests/contrib/hooks/test_gcp_dataflow_hook.py
+++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py
@@ -31,13 +31,21 @@ except ImportError:
 
 TASK_ID = 'test-python-dataflow'
 PY_FILE = 'apache_beam.examples.wordcount'
+JAR_FILE = 'unitest.jar'
 PY_OPTIONS = ['-m']
-OPTIONS = {
+DATAFLOW_OPTIONS_PY = {
     'project': 'test',
-    'staging_location': 'gs://test/staging'
+    'staging_location': 'gs://test/staging',
+    'labels': {'foo': 'bar'}
+}
+DATAFLOW_OPTIONS_JAVA = {
+    'project': 'test',
+    'stagingLocation': 'gs://test/staging',
+    'labels': {'foo': 'bar'}
 }
 BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
 DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}'
+MOCK_UUID = '12345678'
 
 
 def mock_init(self, gcp_conn_id, delegate_to=None):
@@ -51,13 +59,51 @@ class DataFlowHookTest(unittest.TestCase):
                         new=mock_init):
             self.dataflow_hook = DataFlowHook(gcp_conn_id='test')
 
-    @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow'))
-    def test_start_python_dataflow(self, internal_dataflow_mock):
+    @mock.patch(DATAFLOW_STRING.format('uuid.uuid1'))
+    @mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
+    @mock.patch(DATAFLOW_STRING.format('_Dataflow'))
+    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
+    def test_start_python_dataflow(self, mock_conn,
+                                   mock_dataflow, mock_dataflowjob, mock_uuid):
+        mock_uuid.return_value = MOCK_UUID
+        mock_conn.return_value = None
+        dataflow_instance = mock_dataflow.return_value
+        dataflow_instance.wait_for_done.return_value = None
+        dataflowjob_instance = mock_dataflowjob.return_value
+        dataflowjob_instance.wait_for_done.return_value = None
         self.dataflow_hook.start_python_dataflow(
-            task_id=TASK_ID, variables=OPTIONS,
+            task_id=TASK_ID, variables=DATAFLOW_OPTIONS_PY,
             dataflow=PY_FILE, py_options=PY_OPTIONS)
-        internal_dataflow_mock.assert_called_once_with(
-            TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS)
+        EXPECTED_CMD = ['python', '-m', PY_FILE,
+                        '--runner=DataflowRunner', '--project=test',
+                        '--labels=foo=bar',
+                        '--staging_location=gs://test/staging',
+                        '--job_name={}-{}'.format(TASK_ID, MOCK_UUID)]
+        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
+                             sorted(EXPECTED_CMD))
+
+    @mock.patch(DATAFLOW_STRING.format('uuid.uuid1'))
+    @mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
+    @mock.patch(DATAFLOW_STRING.format('_Dataflow'))
+    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
+    def test_start_java_dataflow(self, mock_conn,
+                                 mock_dataflow, mock_dataflowjob, mock_uuid):
+        mock_uuid.return_value = MOCK_UUID
+        mock_conn.return_value = None
+        dataflow_instance = mock_dataflow.return_value
+        dataflow_instance.wait_for_done.return_value = None
+        dataflowjob_instance = mock_dataflowjob.return_value
+        dataflowjob_instance.wait_for_done.return_value = None
+        self.dataflow_hook.start_java_dataflow(
+            task_id=TASK_ID, variables=DATAFLOW_OPTIONS_JAVA,
+            dataflow=JAR_FILE)
+        EXPECTED_CMD = ['java', '-jar', JAR_FILE,
+                        '--runner=DataflowRunner', '--project=test',
+                        '--stagingLocation=gs://test/staging',
+                        '--labels={"foo":"bar"}',
+                        '--jobName={}-{}'.format(TASK_ID, MOCK_UUID)]
+        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
+                             sorted(EXPECTED_CMD))
 
     @mock.patch('airflow.contrib.hooks.gcp_dataflow_hook._Dataflow.log')
     @mock.patch('subprocess.Popen')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/tests/contrib/operators/test_dataflow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py
index 77fc1f6..5b07051 100644
--- a/tests/contrib/operators/test_dataflow_operator.py
+++ b/tests/contrib/operators/test_dataflow_operator.py
@@ -16,6 +16,7 @@
 import unittest
 
 from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
+from airflow.version import version
 
 try:
     from unittest import mock
@@ -34,7 +35,13 @@ DEFAULT_OPTIONS = {
     'stagingLocation': 'gs://test/staging'
 }
 ADDITIONAL_OPTIONS = {
-    'output': 'gs://test/output'
+    'output': 'gs://test/output',
+    'labels': {'foo': 'bar'}
+}
+TEST_VERSION = 'v{}'.format(version.replace('.', '-').replace('+', '-'))
+EXPECTED_ADDITIONAL_OPTIONS = {
+    'output': 'gs://test/output',
+    'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}
 }
 POLL_SLEEP = 30
 GCS_HOOK_STRING = 'airflow.contrib.operators.dataflow_operator.{}'
@@ -60,7 +67,7 @@ class DataFlowPythonOperatorTest(unittest.TestCase):
         self.assertEqual(self.dataflow.dataflow_default_options,
                          DEFAULT_OPTIONS)
         self.assertEqual(self.dataflow.options,
-                         ADDITIONAL_OPTIONS)
+                         EXPECTED_ADDITIONAL_OPTIONS)
 
     @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook')
     @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper'))
@@ -76,7 +83,8 @@ class DataFlowPythonOperatorTest(unittest.TestCase):
         expected_options = {
             'project': 'test',
             'staging_location': 'gs://test/staging',
-            'output': 'gs://test/output'
+            'output': 'gs://test/output',
+            'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}
         }
         gcs_download_hook.assert_called_once_with(PY_FILE)
         start_python_hook.assert_called_once_with(TASK_ID, expected_options,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/tests/contrib/operators/test_mlengine_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_mlengine_operator_utils.py b/tests/contrib/operators/test_mlengine_operator_utils.py
index 80ab01a..c8f6fb5 100644
--- a/tests/contrib/operators/test_mlengine_operator_utils.py
+++ b/tests/contrib/operators/test_mlengine_operator_utils.py
@@ -26,11 +26,13 @@ from airflow import configuration, DAG
 from airflow.contrib.operators import mlengine_operator_utils
 from airflow.contrib.operators.mlengine_operator_utils import create_evaluate_ops
 from airflow.exceptions import AirflowException
+from airflow.version import version
 
 from mock import ANY
 from mock import patch
 
 DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+TEST_VERSION = 'v{}'.format(version.replace('.', '-').replace('+', '-'))
 
 
 class CreateEvaluateOpsTest(unittest.TestCase):
@@ -115,6 +117,7 @@ class CreateEvaluateOpsTest(unittest.TestCase):
                 'eval-test-summary',
                 {
                     'prediction_path': 'gs://legal-bucket/fake-output-path',
+                    'labels': {'airflow-version': TEST_VERSION},
                     'metric_keys': 'err',
                     'metric_fn_encoded': self.metric_fn_encoded,
                 },