You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by tu...@apache.org on 2021/03/28 14:53:22 UTC

[airflow] branch master updated: Override project in dataprocSubmitJobOperator (#14981)

This is an automated email from the ASF dual-hosted git repository.

turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 099c490  Override project in dataprocSubmitJobOperator (#14981)
099c490 is described below

commit 099c490cffae9556e56e141addcb41e9676e0d8f
Author: Sam Wheating <sa...@shopify.com>
AuthorDate: Sun Mar 28 10:53:06 2021 -0400

    Override project in dataprocSubmitJobOperator (#14981)
---
 .../providers/google/cloud/operators/dataproc.py   |  7 +++--
 .../google/cloud/operators/test_dataproc.py        | 32 ++++++++++++++++++++++
 2 files changed, 37 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index bcfb2c7..d578565 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -858,6 +858,9 @@ class DataprocJobBaseOperator(BaseOperator):
     :type job_name: str
     :param cluster_name: The name of the DataProc cluster.
     :type cluster_name: str
+    :param project_id: The ID of the Google Cloud project the cluster belongs to,
+        if not specified the project will be inferred from the provided GCP connection.
+    :type project_id: str
     :param dataproc_properties: Map for the Hive properties. Ideal to put in
         default arguments (templated)
     :type dataproc_properties: dict
@@ -912,6 +915,7 @@ class DataprocJobBaseOperator(BaseOperator):
         *,
         job_name: str = '{{task.task_id}}_{{ds_nodash}}',
         cluster_name: str = "cluster-1",
+        project_id: Optional[str] = None,
         dataproc_properties: Optional[Dict] = None,
         dataproc_jars: Optional[List[str]] = None,
         gcp_conn_id: str = 'google_cloud_default',
@@ -943,9 +947,8 @@ class DataprocJobBaseOperator(BaseOperator):
 
         self.job_error_states = job_error_states if job_error_states is not None else {'ERROR'}
         self.impersonation_chain = impersonation_chain
-
         self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
-        self.project_id = self.hook.project_id
+        self.project_id = self.hook.project_id if project_id is None else project_id
         self.job_template = None
         self.job = None
         self.dataproc_job_id = None
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index e1c712e..e66acb4 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -781,6 +781,12 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
         "labels": {"airflow-version": AIRFLOW_VERSION},
         "spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
     }
+    other_project_job = {
+        "reference": {"project_id": "other-project", "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
+        "placement": {"cluster_name": "cluster-1"},
+        "labels": {"airflow-version": AIRFLOW_VERSION},
+        "spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
+    }
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_deprecation_warning(self, mock_hook):
@@ -815,6 +821,32 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
 
     @mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_override_project_id(self, mock_hook, mock_uuid):
+        mock_uuid.return_value = self.job_id
+        mock_hook.return_value.project_id = GCP_PROJECT
+        mock_hook.return_value.wait_for_job.return_value = None
+        mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id
+
+        op = DataprocSubmitSparkSqlJobOperator(
+            project_id="other-project",
+            task_id=TASK_ID,
+            region=GCP_LOCATION,
+            gcp_conn_id=GCP_CONN_ID,
+            query=self.query,
+            variables=self.variables,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute(context={})
+        mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
+        mock_hook.return_value.submit_job.assert_called_once_with(
+            project_id="other-project", job=self.other_project_job, location=GCP_LOCATION
+        )
+        mock_hook.return_value.wait_for_job.assert_called_once_with(
+            job_id=self.job_id, location=GCP_LOCATION, project_id="other-project"
+        )
+
+    @mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_builder(self, mock_hook, mock_uuid):
         mock_hook.return_value.project_id = GCP_PROJECT
         mock_uuid.return_value = self.job_id