You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by tu...@apache.org on 2021/03/28 14:53:22 UTC
[airflow] branch master updated: Override project in
dataprocSubmitJobOperator (#14981)
This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 099c490 Override project in dataprocSubmitJobOperator (#14981)
099c490 is described below
commit 099c490cffae9556e56e141addcb41e9676e0d8f
Author: Sam Wheating <sa...@shopify.com>
AuthorDate: Sun Mar 28 10:53:06 2021 -0400
Override project in dataprocSubmitJobOperator (#14981)
---
.../providers/google/cloud/operators/dataproc.py | 7 +++--
.../google/cloud/operators/test_dataproc.py | 32 ++++++++++++++++++++++
2 files changed, 37 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index bcfb2c7..d578565 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -858,6 +858,9 @@ class DataprocJobBaseOperator(BaseOperator):
:type job_name: str
:param cluster_name: The name of the DataProc cluster.
:type cluster_name: str
+ :param project_id: The ID of the Google Cloud project the cluster belongs to,
+ if not specified the project will be inferred from the provided GCP connection.
+ :type project_id: str
:param dataproc_properties: Map for the Hive properties. Ideal to put in
default arguments (templated)
:type dataproc_properties: dict
@@ -912,6 +915,7 @@ class DataprocJobBaseOperator(BaseOperator):
*,
job_name: str = '{{task.task_id}}_{{ds_nodash}}',
cluster_name: str = "cluster-1",
+ project_id: Optional[str] = None,
dataproc_properties: Optional[Dict] = None,
dataproc_jars: Optional[List[str]] = None,
gcp_conn_id: str = 'google_cloud_default',
@@ -943,9 +947,8 @@ class DataprocJobBaseOperator(BaseOperator):
self.job_error_states = job_error_states if job_error_states is not None else {'ERROR'}
self.impersonation_chain = impersonation_chain
-
self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
- self.project_id = self.hook.project_id
+ self.project_id = self.hook.project_id if project_id is None else project_id
self.job_template = None
self.job = None
self.dataproc_job_id = None
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index e1c712e..e66acb4 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -781,6 +781,12 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
}
+ other_project_job = {
+ "reference": {"project_id": "other-project", "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
+ "placement": {"cluster_name": "cluster-1"},
+ "labels": {"airflow-version": AIRFLOW_VERSION},
+ "spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
+ }
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
@@ -815,6 +821,32 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute_override_project_id(self, mock_hook, mock_uuid):
+ mock_uuid.return_value = self.job_id
+ mock_hook.return_value.project_id = GCP_PROJECT
+ mock_hook.return_value.wait_for_job.return_value = None
+ mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id
+
+ op = DataprocSubmitSparkSqlJobOperator(
+ project_id="other-project",
+ task_id=TASK_ID,
+ region=GCP_LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ query=self.query,
+ variables=self.variables,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.submit_job.assert_called_once_with(
+ project_id="other-project", job=self.other_project_job, location=GCP_LOCATION
+ )
+ mock_hook.return_value.wait_for_job.assert_called_once_with(
+ job_id=self.job_id, location=GCP_LOCATION, project_id="other-project"
+ )
+
+ @mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_builder(self, mock_hook, mock_uuid):
mock_hook.return_value.project_id = GCP_PROJECT
mock_uuid.return_value = self.job_id