You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/06/04 21:45:22 UTC

[airflow] branch main updated: Remove return statement after yield from triggers class (#31703)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 86b5ba2802 Remove return statement after yield from triggers class (#31703)
86b5ba2802 is described below

commit 86b5ba28026fc6e8b7d868b83080189df9b09306
Author: Pankaj Singh <98...@users.noreply.github.com>
AuthorDate: Mon Jun 5 03:15:05 2023 +0530

    Remove return statement after yield from triggers class (#31703)
    
    * Remove return statement after yield from triggers class
    
    We have couple of trigger class where we yield as well return
    the return statement should not require once we yield
---
 airflow/providers/cncf/kubernetes/triggers/pod.py  |  5 ---
 .../providers/databricks/triggers/databricks.py    |  1 -
 .../providers/google/cloud/triggers/bigquery.py    | 15 -------
 .../google/cloud/triggers/bigquery_dts.py          |  4 --
 .../providers/google/cloud/triggers/dataflow.py    |  4 --
 .../providers/google/cloud/triggers/datafusion.py  |  2 -
 .../providers/google/cloud/triggers/dataproc.py    |  2 -
 airflow/providers/google/cloud/triggers/gcs.py     |  3 --
 .../google/cloud/triggers/kubernetes_engine.py     |  3 --
 .../databricks/triggers/test_databricks.py         | 48 +++++++++++-----------
 .../google/cloud/triggers/test_bigquery.py         | 30 ++++----------
 11 files changed, 33 insertions(+), 84 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py
index 0a7b41b8c0..74a1f8787e 100644
--- a/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ b/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -141,7 +141,6 @@ class KubernetesPodTrigger(BaseTrigger):
                             "message": "All containers inside pod have started successfully.",
                         }
                     )
-                    return
                 elif self.should_wait(pod_phase=pod_status, container_state=container_state):
                     self.log.info("Container is not completed and still working.")
 
@@ -160,7 +159,6 @@ class KubernetesPodTrigger(BaseTrigger):
                                     "message": message,
                                 }
                             )
-                            return
 
                     self.log.info("Sleeping for %s seconds.", self.poll_interval)
                     await asyncio.sleep(self.poll_interval)
@@ -173,7 +171,6 @@ class KubernetesPodTrigger(BaseTrigger):
                             "message": pod.status.message,
                         }
                     )
-                    return
             except CancelledError:
                 # That means that task was marked as failed
                 if self.get_logs:
@@ -196,7 +193,6 @@ class KubernetesPodTrigger(BaseTrigger):
                         "message": "Pod execution was cancelled",
                     }
                 )
-                return
             except Exception as e:
                 self.log.exception("Exception occurred while checking pod phase:")
                 yield TriggerEvent(
@@ -207,7 +203,6 @@ class KubernetesPodTrigger(BaseTrigger):
                         "message": str(e),
                     }
                 )
-                return
 
     def _get_async_hook(self) -> AsyncKubernetesHook:
         if self._hook is None:
diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py
index e5e56cc0ff..c2400e2c97 100644
--- a/airflow/providers/databricks/triggers/databricks.py
+++ b/airflow/providers/databricks/triggers/databricks.py
@@ -89,7 +89,6 @@ class DatabricksExecutionTrigger(BaseTrigger):
                             "run_state": run_state.to_json(),
                         }
                     )
-                    return
                 else:
                     self.log.info(
                         "run-id %s in run state %s. sleeping for %s seconds",
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py
index c7b17af2ed..e63a3df1a4 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -88,14 +88,12 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                             "message": "Job completed",
                         }
                     )
-                    return
                 elif response_from_hook == "pending":
                     self.log.info("Query is still running...")
                     self.log.info("Sleeping for %s seconds.", self.poll_interval)
                     await asyncio.sleep(self.poll_interval)
                 else:
                     yield TriggerEvent({"status": "error", "message": response_from_hook})
-                    return
 
             except Exception as e:
                 self.log.exception("Exception occurred while checking for query completion")
@@ -151,7 +149,6 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
                                 "records": first_record,
                             }
                         )
-                    return
 
                 elif response_from_hook == "pending":
                     self.log.info("Query is still running...")
@@ -209,18 +206,15 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
                             "records": records,
                         }
                     )
-                    return
                 elif response_from_hook == "pending":
                     self.log.info("Query is still running...")
                     self.log.info("Sleeping for %s seconds.", self.poll_interval)
                     await asyncio.sleep(self.poll_interval)
                 else:
                     yield TriggerEvent({"status": "error", "message": response_from_hook})
-                    return
             except Exception as e:
                 self.log.exception("Exception occurred while checking for query completion")
                 yield TriggerEvent({"status": "error", "message": str(e)})
-                return
 
 
 class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
@@ -351,7 +345,6 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
                             "second_row_data": second_job_row,
                         }
                     )
-                    return
                 elif first_job_response_from_hook == "pending" or second_job_response_from_hook == "pending":
                     self.log.info("Query is still running...")
                     self.log.info("Sleeping for %s seconds.", self.poll_interval)
@@ -360,12 +353,10 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
                     yield TriggerEvent(
                         {"status": "error", "message": second_job_response_from_hook, "data": None}
                     )
-                    return
 
             except Exception as e:
                 self.log.exception("Exception occurred while checking for query completion")
                 yield TriggerEvent({"status": "error", "message": str(e)})
-                return
 
 
 class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
@@ -437,19 +428,16 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
                     records = records.pop(0) if records else None
                     hook.value_check(self.sql, self.pass_value, records, self.tolerance)
                     yield TriggerEvent({"status": "success", "message": "Job completed", "records": records})
-                    return
                 elif response_from_hook == "pending":
                     self.log.info("Query is still running...")
                     self.log.info("Sleeping for %s seconds.", self.poll_interval)
                     await asyncio.sleep(self.poll_interval)
                 else:
                     yield TriggerEvent({"status": "error", "message": response_from_hook, "records": None})
-                    return
 
             except Exception as e:
                 self.log.exception("Exception occurred while checking for query completion")
                 yield TriggerEvent({"status": "error", "message": str(e)})
-                return
 
 
 class BigQueryTableExistenceTrigger(BaseTrigger):
@@ -507,12 +495,10 @@ class BigQueryTableExistenceTrigger(BaseTrigger):
                 )
                 if response:
                     yield TriggerEvent({"status": "success", "message": "success"})
-                    return
                 await asyncio.sleep(self.poll_interval)
             except Exception as e:
                 self.log.exception("Exception occurred while checking for Table existence")
                 yield TriggerEvent({"status": "error", "message": str(e)})
-                return
 
     async def _table_exists(
         self, hook: BigQueryTableAsyncHook, dataset: str, table_id: str, project_id: str
@@ -593,7 +579,6 @@ class BigQueryTablePartitionExistenceTrigger(BigQueryTableExistenceTrigger):
                     job_id = None
                 elif status == "error":
                     yield TriggerEvent({"status": "error", "message": status})
-                    return
                 self.log.info("Sleeping for %s seconds.", self.poll_interval)
                 await asyncio.sleep(self.poll_interval)
 
diff --git a/airflow/providers/google/cloud/triggers/bigquery_dts.py b/airflow/providers/google/cloud/triggers/bigquery_dts.py
index 8ab40d99df..354a689445 100644
--- a/airflow/providers/google/cloud/triggers/bigquery_dts.py
+++ b/airflow/providers/google/cloud/triggers/bigquery_dts.py
@@ -105,7 +105,6 @@ class BigQueryDataTransferRunTrigger(BaseTrigger):
                             "config_id": self.config_id,
                         }
                     )
-                    return
 
                 elif state == TransferState.FAILED:
                     self.log.info("Job has failed")
@@ -116,7 +115,6 @@ class BigQueryDataTransferRunTrigger(BaseTrigger):
                             "message": "Job has failed",
                         }
                     )
-                    return
 
                 if state == TransferState.CANCELLED:
                     self.log.info("Job has been cancelled.")
@@ -127,7 +125,6 @@ class BigQueryDataTransferRunTrigger(BaseTrigger):
                             "message": "Job was cancelled",
                         }
                     )
-                    return
 
                 else:
                     self.log.info("Job is still working...")
@@ -141,7 +138,6 @@ class BigQueryDataTransferRunTrigger(BaseTrigger):
                         "message": f"Trigger failed with exception: {str(e)}",
                     }
                 )
-                return
 
     def _get_async_hook(self) -> AsyncBiqQueryDataTransferServiceHook:
         return AsyncBiqQueryDataTransferServiceHook(
diff --git a/airflow/providers/google/cloud/triggers/dataflow.py b/airflow/providers/google/cloud/triggers/dataflow.py
index 5dfdf5106a..bc04ce64b2 100644
--- a/airflow/providers/google/cloud/triggers/dataflow.py
+++ b/airflow/providers/google/cloud/triggers/dataflow.py
@@ -107,7 +107,6 @@ class TemplateJobStartTrigger(BaseTrigger):
                             "message": "Job completed",
                         }
                     )
-                    return
                 elif status == JobState.JOB_STATE_FAILED:
                     yield TriggerEvent(
                         {
@@ -115,7 +114,6 @@ class TemplateJobStartTrigger(BaseTrigger):
                             "message": f"Dataflow job with id {self.job_id} has failed its execution",
                         }
                     )
-                    return
                 elif status == JobState.JOB_STATE_STOPPED:
                     yield TriggerEvent(
                         {
@@ -123,7 +121,6 @@ class TemplateJobStartTrigger(BaseTrigger):
                             "message": f"Dataflow job with id {self.job_id} was stopped",
                         }
                     )
-                    return
                 else:
                     self.log.info("Job is still running...")
                     self.log.info("Current job status is: %s", status)
@@ -132,7 +129,6 @@ class TemplateJobStartTrigger(BaseTrigger):
             except Exception as e:
                 self.log.exception("Exception occurred while checking for job completion.")
                 yield TriggerEvent({"status": "error", "message": str(e)})
-                return
 
     def _get_async_hook(self) -> AsyncDataflowHook:
         return AsyncDataflowHook(
diff --git a/airflow/providers/google/cloud/triggers/datafusion.py b/airflow/providers/google/cloud/triggers/datafusion.py
index 34fa7d0258..982fda00e1 100644
--- a/airflow/providers/google/cloud/triggers/datafusion.py
+++ b/airflow/providers/google/cloud/triggers/datafusion.py
@@ -101,14 +101,12 @@ class DataFusionStartPipelineTrigger(BaseTrigger):
                             "message": "Pipeline is running",
                         }
                     )
-                    return
                 elif response_from_hook == "pending":
                     self.log.info("Pipeline is not still in running state...")
                     self.log.info("Sleeping for %s seconds.", self.poll_interval)
                     await asyncio.sleep(self.poll_interval)
                 else:
                     yield TriggerEvent({"status": "error", "message": response_from_hook})
-                    return
 
             except Exception as e:
                 self.log.exception("Exception occurred while checking for pipeline state")
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py
index d896f1190d..9614e9507b 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -322,7 +322,6 @@ class DataprocWorkflowTrigger(DataprocBaseTrigger):
                                 "message": operation.error.message,
                             }
                         )
-                        return
                     yield TriggerEvent(
                         {
                             "operation_name": operation.name,
@@ -331,7 +330,6 @@ class DataprocWorkflowTrigger(DataprocBaseTrigger):
                             "message": "Operation is successfully ended.",
                         }
                     )
-                    return
                 else:
                     self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
                     await asyncio.sleep(self.polling_interval_seconds)
diff --git a/airflow/providers/google/cloud/triggers/gcs.py b/airflow/providers/google/cloud/triggers/gcs.py
index 34dc163e23..84907a4ced 100644
--- a/airflow/providers/google/cloud/triggers/gcs.py
+++ b/airflow/providers/google/cloud/triggers/gcs.py
@@ -82,7 +82,6 @@ class GCSBlobTrigger(BaseTrigger):
                 await asyncio.sleep(self.poke_interval)
         except Exception as e:
             yield TriggerEvent({"status": "error", "message": str(e)})
-            return
 
     def _get_async_hook(self) -> GCSAsyncHook:
         return GCSAsyncHook(gcp_conn_id=self.google_cloud_conn_id, **self.hook_params)
@@ -266,7 +265,6 @@ class GCSPrefixBlobTrigger(GCSBlobTrigger):
                 await asyncio.sleep(self.poke_interval)
         except Exception as e:
             yield TriggerEvent({"status": "error", "message": str(e)})
-            return
 
     async def _list_blobs_with_prefix(self, hook: GCSAsyncHook, bucket_name: str, prefix: str) -> list[str]:
         """
@@ -369,7 +367,6 @@ class GCSUploadSessionTrigger(GCSPrefixBlobTrigger):
                 await asyncio.sleep(self.poke_interval)
         except Exception as e:
             yield TriggerEvent({"status": "error", "message": str(e)})
-            return
 
     def _get_time(self) -> datetime:
         """
diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index 237a88e352..ff19350cfe 100644
--- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -176,7 +176,6 @@ class GKEOperationTrigger(BaseTrigger):
                             "operation_name": operation.name,
                         }
                     )
-                    return
 
                 elif status == Operation.Status.RUNNING or status == Operation.Status.PENDING:
                     self.log.info("Operation is still running.")
@@ -190,7 +189,6 @@ class GKEOperationTrigger(BaseTrigger):
                             "message": f"Operation has failed with status: {operation.status}",
                         }
                     )
-                    return
             except Exception as e:
                 self.log.exception("Exception occurred while checking operation status")
                 yield TriggerEvent(
@@ -199,7 +197,6 @@ class GKEOperationTrigger(BaseTrigger):
                         "message": str(e),
                     }
                 )
-                return
 
     def _get_hook(self) -> GKEAsyncHook:
         if self._hook is None:
diff --git a/tests/providers/databricks/triggers/test_databricks.py b/tests/providers/databricks/triggers/test_databricks.py
index 1ac52a3e9d..53f67eb251 100644
--- a/tests/providers/databricks/triggers/test_databricks.py
+++ b/tests/providers/databricks/triggers/test_databricks.py
@@ -107,17 +107,17 @@ class TestDatabricksExecutionTrigger:
             result_state="SUCCESS",
         )
 
-        trigger_event = self.trigger.run()
-        async for event in trigger_event:
-            assert event == TriggerEvent(
-                {
-                    "run_id": RUN_ID,
-                    "run_state": RunState(
-                        life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS"
-                    ).to_json(),
-                    "run_page_url": RUN_PAGE_URL,
-                }
-            )
+        generator = self.trigger.run()
+        actual = await generator.asend(None)
+        assert actual == TriggerEvent(
+            {
+                "run_id": RUN_ID,
+                "run_state": RunState(
+                    life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS"
+                ).to_json(),
+                "run_page_url": RUN_PAGE_URL,
+            }
+        )
 
     @pytest.mark.asyncio
     @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep")
@@ -137,16 +137,16 @@ class TestDatabricksExecutionTrigger:
             ),
         ]
 
-        trigger_event = self.trigger.run()
-        async for event in trigger_event:
-            assert event == TriggerEvent(
-                {
-                    "run_id": RUN_ID,
-                    "run_state": RunState(
-                        life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS"
-                    ).to_json(),
-                    "run_page_url": RUN_PAGE_URL,
-                }
-            )
-            mock_sleep.assert_called_once()
-            mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS)
+        generator = self.trigger.run()
+        actual = await generator.asend(None)
+        assert actual == TriggerEvent(
+            {
+                "run_id": RUN_ID,
+                "run_state": RunState(
+                    life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS"
+                ).to_json(),
+                "run_page_url": RUN_PAGE_URL,
+            }
+        )
+        mock_sleep.assert_called_once()
+        mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS)
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py
index 079323a27f..b1acd7b530 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -528,15 +528,10 @@ class TestBigQueryIntervalCheckTrigger:
         mock_job_status.side_effect = Exception("Test exception")
         caplog.set_level(logging.DEBUG)
 
-        # trigger event is yielded so it creates a generator object
-        # so i have used async for to get all the values and added it to task
-        task = [i async for i in interval_check_trigger.run()]
-        # since we use return as soon as we yield the trigger event
-        # at any given point there should be one trigger event returned to the task
-        # so we validate for length of task to be 1
+        generator = interval_check_trigger.run()
+        actual = await generator.asend(None)
 
-        assert len(task) == 1
-        assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
 
 
 class TestBigQueryValueCheckTrigger:
@@ -627,16 +622,9 @@ class TestBigQueryValueCheckTrigger:
             job_id=TEST_JOB_ID,
             project_id=TEST_GCP_PROJECT_ID,
         )
-
-        # trigger event is yielded so it creates a generator object
-        # so i have used async for to get all the values and added it to task
-        task = [i async for i in trigger.run()]
-        # since we use return as soon as we yield the trigger event
-        # at any given point there should be one trigger event returned to the task
-        # so we validate for length of task to be 1
-
-        assert len(task) == 1
-        assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
 
 
 class TestBigQueryTableExistenceTrigger:
@@ -693,9 +681,9 @@ class TestBigQueryTableExistenceTrigger:
         """Test BigQueryTableExistenceTrigger throws exception if any error."""
         mock_table_exists.side_effect = AsyncMock(side_effect=Exception("Test exception"))
 
-        task = [i async for i in table_existence_trigger.run()]
-        assert len(task) == 1
-        assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+        generator = table_existence_trigger.run()
+        actual = await generator.asend(None)
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
 
     @pytest.mark.asyncio
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryTableAsyncHook.get_table_client")