You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2020/06/15 11:55:30 UTC

[airflow] 01/03: Bug fix for EmrAddStepOperator init with cluster_name error (#9235)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v1-10-stable
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 66382d8d68aa2c675367f007cc013613df54542c
Author: Ahmad Maruf <ah...@verizon.com>
AuthorDate: Fri Jun 12 15:03:51 2020 -0700

    Bug fix for EmrAddStepOperator init with cluster_name error (#9235)
    
    Closes #9127
---
 airflow/contrib/operators/emr_add_steps_operator.py    | 8 +++++---
 tests/contrib/operators/test_emr_add_steps_operator.py | 3 ++-
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/airflow/contrib/operators/emr_add_steps_operator.py b/airflow/contrib/operators/emr_add_steps_operator.py
index 0075b1b..1917752 100644
--- a/airflow/contrib/operators/emr_add_steps_operator.py
+++ b/airflow/contrib/operators/emr_add_steps_operator.py
@@ -66,12 +66,14 @@ class EmrAddStepsOperator(BaseOperator):
         self.steps = steps
 
     def execute(self, context):
-        emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
+        emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
 
-        job_flow_id = self.job_flow_id
+        emr = emr_hook.get_conn()
 
+        job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name,
+                                                                          self.cluster_states)
         if not job_flow_id:
-            job_flow_id = emr.get_cluster_id_by_name(self.job_flow_name, self.cluster_states)
+            raise AirflowException('No cluster found for name: ' + self.job_flow_name)
 
         if self.do_xcom_push:
             context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
diff --git a/tests/contrib/operators/test_emr_add_steps_operator.py b/tests/contrib/operators/test_emr_add_steps_operator.py
index e0cadfe..97ebc42 100644
--- a/tests/contrib/operators/test_emr_add_steps_operator.py
+++ b/tests/contrib/operators/test_emr_add_steps_operator.py
@@ -105,10 +105,11 @@ class TestEmrAddStepsOperator(unittest.TestCase):
         with patch('boto3.session.Session', self.boto3_session_mock):
             self.assertEqual(self.operator.execute(self.mock_context), ['s-2LH3R5GW3A53T'])
 
+    @patch.multiple('airflow.contrib.hooks.emr_hook.EmrHook',
+                    get_cluster_id_by_name=MagicMock(return_value='j-1231231234'))
     def test_init_with_cluster_name(self):
         expected_job_flow_id = 'j-1231231234'
 
-        self.emr_client_mock.get_cluster_id_by_name.return_value = expected_job_flow_id
         self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
 
         with patch('boto3.session.Session', self.boto3_session_mock):