You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/09/26 15:56:06 UTC

incubator-airflow git commit: [AIRFLOW-1576] Added region param to Dataproc{*}Operators

Repository: incubator-airflow
Updated Branches:
  refs/heads/master ba0b8f683 -> 7962627a9


[AIRFLOW-1576] Added region param to Dataproc{*}Operators

Closes #2625 from cjqian/1576


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/7962627a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/7962627a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/7962627a

Branch: refs/heads/master
Commit: 7962627a90672d6aa5c0330a7f1ee06e39dc677f
Parents: ba0b8f6
Author: Crystal Qian <cr...@gmail.com>
Authored: Tue Sep 26 08:55:56 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Tue Sep 26 08:55:56 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/gcp_dataproc_hook.py      | 11 ++--
 airflow/contrib/operators/dataproc_operator.py  | 24 +++++--
 .../contrib/operators/test_dataproc_operator.py | 68 +++++++++++++++++++-
 3 files changed, 91 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7962627a/airflow/contrib/hooks/gcp_dataproc_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py b/airflow/contrib/hooks/gcp_dataproc_hook.py
index c964f4c..a1bba0b 100644
--- a/airflow/contrib/hooks/gcp_dataproc_hook.py
+++ b/airflow/contrib/hooks/gcp_dataproc_hook.py
@@ -22,12 +22,13 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 
 
 class _DataProcJob(LoggingMixin):
-    def __init__(self, dataproc_api, project_id, job):
+    def __init__(self, dataproc_api, project_id, job, region='global'):
         self.dataproc_api = dataproc_api
         self.project_id = project_id
+        self.region = region
         self.job = dataproc_api.projects().regions().jobs().submit(
             projectId=self.project_id,
-            region='global',
+            region=self.region,
             body=job).execute()
         self.job_id = self.job['reference']['jobId']
         self.log.info(
@@ -39,7 +40,7 @@ class _DataProcJob(LoggingMixin):
         while True:
             self.job = self.dataproc_api.projects().regions().jobs().get(
                 projectId=self.project_id,
-                region='global',
+                region=self.region,
                 jobId=self.job_id).execute()
             if 'ERROR' == self.job['status']['state']:
                 print(str(self.job))
@@ -153,8 +154,8 @@ class DataProcHook(GoogleCloudBaseHook):
         http_authorized = self._authorize()
         return build('dataproc', 'v1', http=http_authorized)
 
-    def submit(self, project_id, job):
-        submitted = _DataProcJob(self.get_conn(), project_id, job)
+    def submit(self, project_id, job, region='global'):
+        submitted = _DataProcJob(self.get_conn(), project_id, job, region)
         if not submitted.wait_for_done():
             submitted.raise_error("DataProcTask has errors")
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7962627a/airflow/contrib/operators/dataproc_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py
index bdb0335..6ef89ba 100644
--- a/airflow/contrib/operators/dataproc_operator.py
+++ b/airflow/contrib/operators/dataproc_operator.py
@@ -503,6 +503,7 @@ class DataProcHiveOperator(BaseOperator):
             dataproc_hive_jars=None,
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
+            region='global',
             *args,
             **kwargs):
         """
@@ -532,6 +533,8 @@ class DataProcHiveOperator(BaseOperator):
             For this to work, the service account making the request must have domain-wide
             delegation enabled.
         :type delegate_to: string
+        :param region: The specified region where the dataproc cluster is created.
+        :type region: string
         """
         super(DataProcHiveOperator, self).__init__(*args, **kwargs)
         self.gcp_conn_id = gcp_conn_id
@@ -543,6 +546,7 @@ class DataProcHiveOperator(BaseOperator):
         self.dataproc_cluster = dataproc_cluster
         self.dataproc_properties = dataproc_hive_properties
         self.dataproc_jars = dataproc_hive_jars
+        self.region = region
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -559,7 +563,7 @@ class DataProcHiveOperator(BaseOperator):
         job.add_jar_file_uris(self.dataproc_jars)
         job.set_job_name(self.job_name)
 
-        hook.submit(hook.project_id, job.build())
+        hook.submit(hook.project_id, job.build(), self.region)
 
 
 class DataProcSparkSqlOperator(BaseOperator):
@@ -663,6 +667,7 @@ class DataProcSparkOperator(BaseOperator):
             dataproc_spark_jars=None,
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
+            region='global',
             *args,
             **kwargs):
         """
@@ -699,6 +704,8 @@ class DataProcSparkOperator(BaseOperator):
             For this to work, the service account making the request must have domain-wide
             delegation enabled.
         :type delegate_to: string
+        :param region: The specified region where the dataproc cluster is created.
+        :type region: string
         """
         super(DataProcSparkOperator, self).__init__(*args, **kwargs)
         self.gcp_conn_id = gcp_conn_id
@@ -712,6 +719,7 @@ class DataProcSparkOperator(BaseOperator):
         self.dataproc_cluster = dataproc_cluster
         self.dataproc_properties = dataproc_spark_properties
         self.dataproc_jars = dataproc_spark_jars
+        self.region = region
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -726,7 +734,7 @@ class DataProcSparkOperator(BaseOperator):
         job.add_file_uris(self.files)
         job.set_job_name(self.job_name)
 
-        hook.submit(hook.project_id, job.build())
+        hook.submit(hook.project_id, job.build(), self.region)
 
 
 class DataProcHadoopOperator(BaseOperator):
@@ -751,6 +759,7 @@ class DataProcHadoopOperator(BaseOperator):
             dataproc_hadoop_jars=None,
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
+            region='global',
             *args,
             **kwargs):
         """
@@ -787,6 +796,8 @@ class DataProcHadoopOperator(BaseOperator):
             For this to work, the service account making the request must have domain-wide
             delegation enabled.
         :type delegate_to: string
+        :param region: The specified region where the dataproc cluster is created.
+        :type region: string
         """
         super(DataProcHadoopOperator, self).__init__(*args, **kwargs)
         self.gcp_conn_id = gcp_conn_id
@@ -800,6 +811,7 @@ class DataProcHadoopOperator(BaseOperator):
         self.dataproc_cluster = dataproc_cluster
         self.dataproc_properties = dataproc_hadoop_properties
         self.dataproc_jars = dataproc_hadoop_jars
+        self.region = region
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -814,7 +826,7 @@ class DataProcHadoopOperator(BaseOperator):
         job.add_file_uris(self.files)
         job.set_job_name(self.job_name)
 
-        hook.submit(hook.project_id, job.build())
+        hook.submit(hook.project_id, job.build(), self.region)
 
 
 class DataProcPySparkOperator(BaseOperator):
@@ -839,6 +851,7 @@ class DataProcPySparkOperator(BaseOperator):
             dataproc_pyspark_jars=None,
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
+            region='global',
             *args,
             **kwargs):
         """
@@ -875,6 +888,8 @@ class DataProcPySparkOperator(BaseOperator):
             For this to work, the service account making the request must have
             domain-wide delegation enabled.
         :type delegate_to: string
+        :param region: The specified region where the dataproc cluster is created.
+        :type region: string
          """
         super(DataProcPySparkOperator, self).__init__(*args, **kwargs)
         self.gcp_conn_id = gcp_conn_id
@@ -888,6 +903,7 @@ class DataProcPySparkOperator(BaseOperator):
         self.dataproc_cluster = dataproc_cluster
         self.dataproc_properties = dataproc_pyspark_properties
         self.dataproc_jars = dataproc_pyspark_jars
+        self.region = region
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -903,4 +919,4 @@ class DataProcPySparkOperator(BaseOperator):
         job.add_python_file_uris(self.pyfiles)
         job.set_job_name(self.job_name)
 
-        hook.submit(hook.project_id, job.build())
+        hook.submit(hook.project_id, job.build(), self.region)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7962627a/tests/contrib/operators/test_dataproc_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_dataproc_operator.py b/tests/contrib/operators/test_dataproc_operator.py
index 7ce6199..d206fba 100644
--- a/tests/contrib/operators/test_dataproc_operator.py
+++ b/tests/contrib/operators/test_dataproc_operator.py
@@ -20,14 +20,25 @@ import unittest
 from airflow import DAG
 from airflow.contrib.operators.dataproc_operator import DataprocClusterCreateOperator
 from airflow.contrib.operators.dataproc_operator import DataprocClusterDeleteOperator
+from airflow.contrib.operators.dataproc_operator import DataProcHadoopOperator
+from airflow.contrib.operators.dataproc_operator import DataProcHiveOperator
+from airflow.contrib.operators.dataproc_operator import DataProcPySparkOperator
+from airflow.contrib.operators.dataproc_operator import DataProcSparkOperator
 from airflow.version import version
 
 from copy import deepcopy
 
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
 from mock import Mock
 from mock import patch
 
-
 TASK_ID = 'test-dataproc-operator'
 CLUSTER_NAME = 'test-cluster-name'
 PROJECT_ID = 'test-project-id'
@@ -47,9 +58,11 @@ SERVICE_ACCOUNT_SCOPES = [
     'https://www.googleapis.com/auth/bigtable.data'
 ]
 DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+REGION = 'test-region'
+MAIN_URI = 'test-uri'
 
 class DataprocClusterCreateOperatorTest(unittest.TestCase):
-    # Unitest for the DataprocClusterCreateOperator
+    # Unit test for the DataprocClusterCreateOperator
     def setUp(self):
         # instantiate two different test cases with different labels.
         self.labels = [LABEL1, LABEL2]
@@ -158,7 +171,7 @@ class DataprocClusterCreateOperatorTest(unittest.TestCase):
                 mock_info.assert_called_with('Creating cluster: %s', u'smoke-cluster-testnodash')
 
 class DataprocClusterDeleteOperatorTest(unittest.TestCase):
-    # Unitest for the DataprocClusterDeleteOperator
+    # Unit test for the DataprocClusterDeleteOperator
     def setUp(self):
         self.mock_execute = Mock()
         self.mock_execute.execute = Mock(return_value={'done' : True})
@@ -213,3 +226,52 @@ class DataprocClusterDeleteOperatorTest(unittest.TestCase):
                 with self.assertRaises(TypeError) as _:
                     dataproc_task.execute(None)
                 mock_info.assert_called_with('Deleting cluster: %s', u'smoke-cluster-testnodash')
+
+class DataProcHadoopOperatorTest(unittest.TestCase):
+    # Unit test for the DataProcHadoopOperator
+    def test_hook_correct_region(self):
+       with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') as mock_hook:
+            dataproc_task = DataProcHadoopOperator(
+                task_id=TASK_ID,
+                region=REGION
+            )
+
+            dataproc_task.execute(None)
+            mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION)
+
+class DataProcHiveOperatorTest(unittest.TestCase):
+    # Unit test for the DataProcHiveOperator
+    def test_hook_correct_region(self):
+       with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') as mock_hook:
+            dataproc_task = DataProcHiveOperator(
+                task_id=TASK_ID,
+                region=REGION
+            )
+
+            dataproc_task.execute(None)
+            mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION)
+
+class DataProcPySparkOperatorTest(unittest.TestCase):
+    # Unit test for the DataProcPySparkOperator
+    def test_hook_correct_region(self):
+       with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') as mock_hook:
+            dataproc_task = DataProcPySparkOperator(
+                task_id=TASK_ID,
+                main=MAIN_URI,
+                region=REGION
+            )
+
+            dataproc_task.execute(None)
+            mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION)
+
+class DataProcSparkOperatorTest(unittest.TestCase):
+    # Unit test for the DataProcSparkOperator
+    def test_hook_correct_region(self):
+       with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') as mock_hook:
+            dataproc_task = DataProcSparkOperator(
+                task_id=TASK_ID,
+                region=REGION
+            )
+
+            dataproc_task.execute(None)
+            mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION)