You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2018/01/16 17:32:37 UTC
incubator-airflow git commit: [AIRFLOW-2000] Support non-main
dataflow job class
Repository: incubator-airflow
Updated Branches:
refs/heads/master 88130a5d7 -> f6a1c3cf7
[AIRFLOW-2000] Support non-main dataflow job class
Closes #2942 from fenglu-g/master
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/f6a1c3cf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/f6a1c3cf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/f6a1c3cf
Branch: refs/heads/master
Commit: f6a1c3cf7f9716ed0c298e79bc94a9066a7def18
Parents: 88130a5
Author: fenglu-g <fe...@google.com>
Authored: Tue Jan 16 09:32:32 2018 -0800
Committer: Chris Riccomini <cr...@apache.org>
Committed: Tue Jan 16 09:32:32 2018 -0800
----------------------------------------------------------------------
airflow/contrib/hooks/gcp_dataflow_hook.py | 23 +++++-----
airflow/contrib/operators/dataflow_operator.py | 8 +++-
tests/contrib/hooks/test_gcp_dataflow_hook.py | 27 +++++++++++-
.../contrib/operators/test_dataflow_operator.py | 44 +++++++++++++++++++-
4 files changed, 88 insertions(+), 14 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/airflow/contrib/hooks/gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py
index d60b498..72a2225 100644
--- a/airflow/contrib/hooks/gcp_dataflow_hook.py
+++ b/airflow/contrib/hooks/gcp_dataflow_hook.py
@@ -151,23 +151,25 @@ class DataFlowHook(GoogleCloudBaseHook):
http_authorized = self._authorize()
return build('dataflow', 'v1b3', http=http_authorized)
- def _start_dataflow(self, task_id, variables, dataflow,
- name, command_prefix, label_formatter):
+ def _start_dataflow(self, task_id, variables, name,
+ command_prefix, label_formatter):
cmd = command_prefix + self._build_cmd(task_id, variables,
- dataflow, label_formatter)
+ label_formatter)
_Dataflow(cmd).wait_for_done()
_DataflowJob(self.get_conn(), variables['project'],
name, self.poll_sleep).wait_for_done()
- def start_java_dataflow(self, task_id, variables, dataflow):
+ def start_java_dataflow(self, task_id, variables, dataflow, job_class=None):
name = task_id + "-" + str(uuid.uuid1())[:8]
variables['jobName'] = name
def label_formatter(labels_dict):
return ['--labels={}'.format(
json.dumps(labels_dict).replace(' ', ''))]
- self._start_dataflow(task_id, variables, dataflow, name,
- ["java", "-jar"], label_formatter)
+ command_prefix = (["java", "-cp", dataflow, job_class] if job_class
+ else ["java", "-jar", dataflow])
+ self._start_dataflow(task_id, variables, name,
+ command_prefix, label_formatter)
def start_template_dataflow(self, task_id, variables, parameters, dataflow_template):
name = task_id + "-" + str(uuid.uuid1())[:8]
@@ -181,11 +183,12 @@ class DataFlowHook(GoogleCloudBaseHook):
def label_formatter(labels_dict):
return ['--labels={}={}'.format(key, value)
for key, value in labels_dict.items()]
- self._start_dataflow(task_id, variables, dataflow, name,
- ["python"] + py_options, label_formatter)
+ self._start_dataflow(task_id, variables, name,
+ ["python"] + py_options + [dataflow],
+ label_formatter)
- def _build_cmd(self, task_id, variables, dataflow, label_formatter):
- command = [dataflow, "--runner=DataflowRunner"]
+ def _build_cmd(self, task_id, variables, label_formatter):
+ command = ["--runner=DataflowRunner"]
if variables is not None:
for attr, value in variables.items():
if attr == 'labels':
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/airflow/contrib/operators/dataflow_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py
index 915e26c..5c34678 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -73,6 +73,7 @@ class DataFlowJavaOperator(BaseOperator):
gcp_conn_id='google_cloud_default',
delegate_to=None,
poll_sleep=10,
+ job_class=None,
*args,
**kwargs):
"""
@@ -103,6 +104,9 @@ class DataFlowJavaOperator(BaseOperator):
Cloud Platform for the dataflow job status while the job is in the
JOB_STATE_RUNNING state.
:type poll_sleep: int
+ :param job_class: The name of the dataflow job class to be executued, it
+ is often not the main class configured in the dataflow jar file.
+ :type job_class: string
"""
super(DataFlowJavaOperator, self).__init__(*args, **kwargs)
@@ -116,6 +120,7 @@ class DataFlowJavaOperator(BaseOperator):
self.dataflow_default_options = dataflow_default_options
self.options = options
self.poll_sleep = poll_sleep
+ self.job_class = job_class
def execute(self, context):
bucket_helper = GoogleCloudBucketHelper(
@@ -128,7 +133,8 @@ class DataFlowJavaOperator(BaseOperator):
dataflow_options = copy.copy(self.dataflow_default_options)
dataflow_options.update(self.options)
- hook.start_java_dataflow(self.task_id, dataflow_options, self.jar)
+ hook.start_java_dataflow(self.task_id, dataflow_options,
+ self.jar, self.job_class)
class DataflowTemplateOperator(BaseOperator):
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/tests/contrib/hooks/test_gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py
index bf513c8..6089f4b 100644
--- a/tests/contrib/hooks/test_gcp_dataflow_hook.py
+++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py
@@ -37,6 +37,7 @@ PARAMETERS = {
}
PY_FILE = 'apache_beam.examples.wordcount'
JAR_FILE = 'unitest.jar'
+JOB_CLASS = 'com.example.UnitTest'
PY_OPTIONS = ['-m']
DATAFLOW_OPTIONS_PY = {
'project': 'test',
@@ -62,7 +63,7 @@ def mock_init(self, gcp_conn_id, delegate_to=None):
pass
-class DataFlowPythonHookTest(unittest.TestCase):
+class DataFlowHookTest(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
@@ -115,6 +116,30 @@ class DataFlowPythonHookTest(unittest.TestCase):
self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
sorted(EXPECTED_CMD))
+ @mock.patch(DATAFLOW_STRING.format('uuid.uuid1'))
+ @mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
+ @mock.patch(DATAFLOW_STRING.format('_Dataflow'))
+ @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
+ def test_start_java_dataflow_with_job_class(
+ self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
+ mock_uuid.return_value = MOCK_UUID
+ mock_conn.return_value = None
+ dataflow_instance = mock_dataflow.return_value
+ dataflow_instance.wait_for_done.return_value = None
+ dataflowjob_instance = mock_dataflowjob.return_value
+ dataflowjob_instance.wait_for_done.return_value = None
+ self.dataflow_hook.start_java_dataflow(
+ task_id=TASK_ID, variables=DATAFLOW_OPTIONS_JAVA,
+ dataflow=JAR_FILE, job_class=JOB_CLASS)
+ EXPECTED_CMD = ['java', '-cp', JAR_FILE, JOB_CLASS,
+ '--runner=DataflowRunner', '--project=test',
+ '--stagingLocation=gs://test/staging',
+ '--labels={"foo":"bar"}',
+ '--jobName={}-{}'.format(TASK_ID, MOCK_UUID)]
+ self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
+ sorted(EXPECTED_CMD))
+
+
@mock.patch('airflow.contrib.hooks.gcp_dataflow_hook._Dataflow.log')
@mock.patch('subprocess.Popen')
@mock.patch('select.select')
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f6a1c3cf/tests/contrib/operators/test_dataflow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py
index da95d18..88d13d8 100644
--- a/tests/contrib/operators/test_dataflow_operator.py
+++ b/tests/contrib/operators/test_dataflow_operator.py
@@ -16,7 +16,7 @@
import unittest
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator, \
- DataflowTemplateOperator
+ DataFlowJavaOperator, DataflowTemplateOperator
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
from airflow.version import version
@@ -36,8 +36,10 @@ PARAMETERS = {
'output': 'gs://test/output/my_output'
}
PY_FILE = 'gs://my-bucket/my-object.py'
+JAR_FILE = 'example/test.jar'
+JOB_CLASS = 'com.test.NotMain'
PY_OPTIONS = ['-m']
-DEFAULT_OPTIONS_PYTHON = {
+DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
'project': 'test',
'stagingLocation': 'gs://test/staging',
}
@@ -105,6 +107,44 @@ class DataFlowPythonOperatorTest(unittest.TestCase):
self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
+class DataFlowJavaOperatorTest(unittest.TestCase):
+
+ def setUp(self):
+ self.dataflow = DataFlowJavaOperator(
+ task_id=TASK_ID,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ dataflow_default_options=DEFAULT_OPTIONS_JAVA,
+ options=ADDITIONAL_OPTIONS,
+ poll_sleep=POLL_SLEEP)
+
+ def test_init(self):
+ """Test DataflowTemplateOperator instance is properly initialized."""
+ self.assertEqual(self.dataflow.task_id, TASK_ID)
+ self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
+ self.assertEqual(self.dataflow.dataflow_default_options,
+ DEFAULT_OPTIONS_JAVA)
+ self.assertEqual(self.dataflow.job_class, JOB_CLASS)
+ self.assertEqual(self.dataflow.jar, JAR_FILE)
+ self.assertEqual(self.dataflow.options,
+ EXPECTED_ADDITIONAL_OPTIONS)
+
+ @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook')
+ @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper'))
+ def test_exec(self, gcs_hook, dataflow_mock):
+ """Test DataFlowHook is created and the right args are passed to
+ start_java_workflow.
+
+ """
+ start_java_hook = dataflow_mock.return_value.start_java_dataflow
+ gcs_download_hook = gcs_hook.return_value.google_cloud_to_local
+ self.dataflow.execute(None)
+ self.assertTrue(dataflow_mock.called)
+ gcs_download_hook.assert_called_once_with(JAR_FILE)
+ start_java_hook.assert_called_once_with(TASK_ID, mock.ANY,
+ mock.ANY, JOB_CLASS)
+
+
class DataFlowTemplateOperatorTest(unittest.TestCase):
def setUp(self):