You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2018/01/16 17:32:37 UTC

incubator-airflow git commit: [AIRFLOW-2000] Support non-main dataflow job class

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 88130a5d7 -> f6a1c3cf7


[AIRFLOW-2000] Support non-main dataflow job class

Closes #2942 from fenglu-g/master


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/f6a1c3cf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/f6a1c3cf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/f6a1c3cf

Branch: refs/heads/master
Commit: f6a1c3cf7f9716ed0c298e79bc94a9066a7def18
Parents: 88130a5
Author: fenglu-g <fe...@google.com>
Authored: Tue Jan 16 09:32:32 2018 -0800
Committer: Chris Riccomini <cr...@apache.org>
Committed: Tue Jan 16 09:32:32 2018 -0800

----------------------------------------------------------------------
 airflow/contrib/hooks/gcp_dataflow_hook.py      | 23 +++++-----
 airflow/contrib/operators/dataflow_operator.py  |  8 +++-
 tests/contrib/hooks/test_gcp_dataflow_hook.py   | 27 +++++++++++-
 .../contrib/operators/test_dataflow_operator.py | 44 +++++++++++++++++++-
 4 files changed, 88 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/airflow/contrib/hooks/gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py
index d60b498..72a2225 100644
--- a/airflow/contrib/hooks/gcp_dataflow_hook.py
+++ b/airflow/contrib/hooks/gcp_dataflow_hook.py
@@ -151,23 +151,25 @@ class DataFlowHook(GoogleCloudBaseHook):
         http_authorized = self._authorize()
         return build('dataflow', 'v1b3', http=http_authorized)
 
-    def _start_dataflow(self, task_id, variables, dataflow,
-                        name, command_prefix, label_formatter):
+    def _start_dataflow(self, task_id, variables, name,
+                        command_prefix, label_formatter):
         cmd = command_prefix + self._build_cmd(task_id, variables,
-                                               dataflow, label_formatter)
+                                               label_formatter)
         _Dataflow(cmd).wait_for_done()
         _DataflowJob(self.get_conn(), variables['project'],
                      name, self.poll_sleep).wait_for_done()
 
-    def start_java_dataflow(self, task_id, variables, dataflow):
+    def start_java_dataflow(self, task_id, variables, dataflow, job_class=None):
         name = task_id + "-" + str(uuid.uuid1())[:8]
         variables['jobName'] = name
 
         def label_formatter(labels_dict):
             return ['--labels={}'.format(
                     json.dumps(labels_dict).replace(' ', ''))]
-        self._start_dataflow(task_id, variables, dataflow, name,
-                             ["java", "-jar"], label_formatter)
+        command_prefix = (["java", "-cp", dataflow, job_class] if job_class
+                          else ["java", "-jar", dataflow])
+        self._start_dataflow(task_id, variables, name,
+                             command_prefix, label_formatter)
 
     def start_template_dataflow(self, task_id, variables, parameters, dataflow_template):
         name = task_id + "-" + str(uuid.uuid1())[:8]
@@ -181,11 +183,12 @@ class DataFlowHook(GoogleCloudBaseHook):
         def label_formatter(labels_dict):
             return ['--labels={}={}'.format(key, value)
                     for key, value in labels_dict.items()]
-        self._start_dataflow(task_id, variables, dataflow, name,
-                             ["python"] + py_options, label_formatter)
+        self._start_dataflow(task_id, variables, name,
+                             ["python"] + py_options + [dataflow],
+                             label_formatter)
 
-    def _build_cmd(self, task_id, variables, dataflow, label_formatter):
-        command = [dataflow, "--runner=DataflowRunner"]
+    def _build_cmd(self, task_id, variables, label_formatter):
+        command = ["--runner=DataflowRunner"]
         if variables is not None:
             for attr, value in variables.items():
                 if attr == 'labels':

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/airflow/contrib/operators/dataflow_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py
index 915e26c..5c34678 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -73,6 +73,7 @@ class DataFlowJavaOperator(BaseOperator):
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
             poll_sleep=10,
+            job_class=None,
             *args,
             **kwargs):
         """
@@ -103,6 +104,9 @@ class DataFlowJavaOperator(BaseOperator):
             Cloud Platform for the dataflow job status while the job is in the
             JOB_STATE_RUNNING state.
         :type poll_sleep: int
+        :param job_class: The name of the dataflow job class to be executued, it
+        is often not the main class configured in the dataflow jar file.
+        :type job_class: string
         """
         super(DataFlowJavaOperator, self).__init__(*args, **kwargs)
 
@@ -116,6 +120,7 @@ class DataFlowJavaOperator(BaseOperator):
         self.dataflow_default_options = dataflow_default_options
         self.options = options
         self.poll_sleep = poll_sleep
+        self.job_class = job_class
 
     def execute(self, context):
         bucket_helper = GoogleCloudBucketHelper(
@@ -128,7 +133,8 @@ class DataFlowJavaOperator(BaseOperator):
         dataflow_options = copy.copy(self.dataflow_default_options)
         dataflow_options.update(self.options)
 
-        hook.start_java_dataflow(self.task_id, dataflow_options, self.jar)
+        hook.start_java_dataflow(self.task_id, dataflow_options,
+                                 self.jar, self.job_class)
 
 
 class DataflowTemplateOperator(BaseOperator):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/tests/contrib/hooks/test_gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py
index bf513c8..6089f4b 100644
--- a/tests/contrib/hooks/test_gcp_dataflow_hook.py
+++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py
@@ -37,6 +37,7 @@ PARAMETERS = {
 }
 PY_FILE = 'apache_beam.examples.wordcount'
 JAR_FILE = 'unitest.jar'
+JOB_CLASS = 'com.example.UnitTest'
 PY_OPTIONS = ['-m']
 DATAFLOW_OPTIONS_PY = {
     'project': 'test',
@@ -62,7 +63,7 @@ def mock_init(self, gcp_conn_id, delegate_to=None):
     pass
 
 
-class DataFlowPythonHookTest(unittest.TestCase):
+class DataFlowHookTest(unittest.TestCase):
 
     def setUp(self):
         with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
@@ -115,6 +116,30 @@ class DataFlowPythonHookTest(unittest.TestCase):
         self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                              sorted(EXPECTED_CMD))
 
+    @mock.patch(DATAFLOW_STRING.format('uuid.uuid1'))
+    @mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
+    @mock.patch(DATAFLOW_STRING.format('_Dataflow'))
+    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
+    def test_start_java_dataflow_with_job_class(
+            self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
+        mock_uuid.return_value = MOCK_UUID
+        mock_conn.return_value = None
+        dataflow_instance = mock_dataflow.return_value
+        dataflow_instance.wait_for_done.return_value = None
+        dataflowjob_instance = mock_dataflowjob.return_value
+        dataflowjob_instance.wait_for_done.return_value = None
+        self.dataflow_hook.start_java_dataflow(
+            task_id=TASK_ID, variables=DATAFLOW_OPTIONS_JAVA,
+            dataflow=JAR_FILE, job_class=JOB_CLASS)
+        EXPECTED_CMD = ['java', '-cp', JAR_FILE, JOB_CLASS,
+                        '--runner=DataflowRunner', '--project=test',
+                        '--stagingLocation=gs://test/staging',
+                        '--labels={"foo":"bar"}',
+                        '--jobName={}-{}'.format(TASK_ID, MOCK_UUID)]
+        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
+                             sorted(EXPECTED_CMD))
+
+
     @mock.patch('airflow.contrib.hooks.gcp_dataflow_hook._Dataflow.log')
     @mock.patch('subprocess.Popen')
     @mock.patch('select.select')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/tests/contrib/operators/test_dataflow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py
index da95d18..88d13d8 100644
--- a/tests/contrib/operators/test_dataflow_operator.py
+++ b/tests/contrib/operators/test_dataflow_operator.py
@@ -16,7 +16,7 @@
 import unittest
 
 from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator, \
-    DataflowTemplateOperator
+    DataFlowJavaOperator, DataflowTemplateOperator
 from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
 from airflow.version import version
 
@@ -36,8 +36,10 @@ PARAMETERS = {
     'output': 'gs://test/output/my_output'
 }
 PY_FILE = 'gs://my-bucket/my-object.py'
+JAR_FILE = 'example/test.jar'
+JOB_CLASS = 'com.test.NotMain'
 PY_OPTIONS = ['-m']
-DEFAULT_OPTIONS_PYTHON = {
+DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
     'project': 'test',
     'stagingLocation': 'gs://test/staging',
 }
@@ -105,6 +107,44 @@ class DataFlowPythonOperatorTest(unittest.TestCase):
         self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
 
 
+class DataFlowJavaOperatorTest(unittest.TestCase):
+
+    def setUp(self):
+        self.dataflow = DataFlowJavaOperator(
+            task_id=TASK_ID,
+            jar=JAR_FILE,
+            job_class=JOB_CLASS,
+            dataflow_default_options=DEFAULT_OPTIONS_JAVA,
+            options=ADDITIONAL_OPTIONS,
+            poll_sleep=POLL_SLEEP)
+
+    def test_init(self):
+        """Test DataflowTemplateOperator instance is properly initialized."""
+        self.assertEqual(self.dataflow.task_id, TASK_ID)
+        self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
+        self.assertEqual(self.dataflow.dataflow_default_options,
+                         DEFAULT_OPTIONS_JAVA)
+        self.assertEqual(self.dataflow.job_class, JOB_CLASS)
+        self.assertEqual(self.dataflow.jar, JAR_FILE)
+        self.assertEqual(self.dataflow.options,
+                         EXPECTED_ADDITIONAL_OPTIONS)
+
+    @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook')
+    @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper'))
+    def test_exec(self, gcs_hook, dataflow_mock):
+        """Test DataFlowHook is created and the right args are passed to
+        start_java_workflow.
+
+        """
+        start_java_hook = dataflow_mock.return_value.start_java_dataflow
+        gcs_download_hook = gcs_hook.return_value.google_cloud_to_local
+        self.dataflow.execute(None)
+        self.assertTrue(dataflow_mock.called)
+        gcs_download_hook.assert_called_once_with(JAR_FILE)
+        start_java_hook.assert_called_once_with(TASK_ID, mock.ANY,
+                                                mock.ANY, JOB_CLASS)
+
+
 class DataFlowTemplateOperatorTest(unittest.TestCase):
 
     def setUp(self):