You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/08/22 06:44:37 UTC

[airflow] branch main updated: Refactor: Simplify code in Apache/Alibaba providers (#33227)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 32feab4100 Refactor: Simplify code in Apache/Alibaba providers (#33227)
32feab4100 is described below

commit 32feab41006897de182bfa684813be230027aca1
Author: Miroslav Šedivý <67...@users.noreply.github.com>
AuthorDate: Tue Aug 22 06:44:29 2023 +0000

    Refactor: Simplify code in Apache/Alibaba providers (#33227)
---
 .../alibaba/cloud/hooks/analyticdb_spark.py        |  2 +-
 airflow/providers/apache/beam/hooks/beam.py        |  6 +--
 airflow/providers/apache/beam/triggers/beam.py     | 47 ++++++++++------------
 airflow/providers/apache/hive/hooks/hive.py        | 33 ++++++---------
 airflow/providers/apache/livy/hooks/livy.py        | 12 +++---
 airflow/providers/apache/spark/hooks/spark_sql.py  |  2 +-
 6 files changed, 44 insertions(+), 58 deletions(-)

diff --git a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
index 9881ca38ae..e06ee91228 100644
--- a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
+++ b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
@@ -321,7 +321,7 @@ class AnalyticDBSparkHook(BaseHook, LoggingMixin):
         if conf:
             if not isinstance(conf, dict):
                 raise ValueError("'conf' argument must be a dict")
-            if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()):
+            if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()):
                 raise ValueError("'conf' values must be either strings or ints")
         return True
 
diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py
index 762dd2f07b..72dc224626 100644
--- a/airflow/providers/apache/beam/hooks/beam.py
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -104,10 +104,8 @@ def process_fd(
     fd_to_log = {proc.stderr: log.warning, proc.stdout: log.info}
     func_log = fd_to_log[fd]
 
-    while True:
-        line = fd.readline().decode()
-        if not line:
-            return
+    for line in iter(fd.readline, b""):
+        line = line.decode()
         if process_line_callback:
             process_line_callback(line)
         func_log(line.rstrip("\n"))
diff --git a/airflow/providers/apache/beam/triggers/beam.py b/airflow/providers/apache/beam/triggers/beam.py
index 9c29a9fbe2..0d201cd8c9 100644
--- a/airflow/providers/apache/beam/triggers/beam.py
+++ b/airflow/providers/apache/beam/triggers/beam.py
@@ -85,32 +85,29 @@ class BeamPipelineTrigger(BaseTrigger):
     async def run(self) -> AsyncIterator[TriggerEvent]:  # type: ignore[override]
         """Get current pipeline status and yields a TriggerEvent."""
         hook = self._get_async_hook()
-        while True:
-            try:
-                return_code = await hook.start_python_pipeline_async(
-                    variables=self.variables,
-                    py_file=self.py_file,
-                    py_options=self.py_options,
-                    py_interpreter=self.py_interpreter,
-                    py_requirements=self.py_requirements,
-                    py_system_site_packages=self.py_system_site_packages,
+        try:
+            return_code = await hook.start_python_pipeline_async(
+                variables=self.variables,
+                py_file=self.py_file,
+                py_options=self.py_options,
+                py_interpreter=self.py_interpreter,
+                py_requirements=self.py_requirements,
+                py_system_site_packages=self.py_system_site_packages,
+            )
+        except Exception as e:
+            self.log.exception("Exception occurred while checking for pipeline state")
+            yield TriggerEvent({"status": "error", "message": str(e)})
+        else:
+            if return_code == 0:
+                yield TriggerEvent(
+                    {
+                        "status": "success",
+                        "message": "Pipeline has finished SUCCESSFULLY",
+                    }
                 )
-                if return_code == 0:
-                    yield TriggerEvent(
-                        {
-                            "status": "success",
-                            "message": "Pipeline has finished SUCCESSFULLY",
-                        }
-                    )
-                    return
-                else:
-                    yield TriggerEvent({"status": "error", "message": "Operation failed"})
-                    return
-
-            except Exception as e:
-                self.log.exception("Exception occurred while checking for pipeline state")
-                yield TriggerEvent({"status": "error", "message": str(e)})
-                return
+            else:
+                yield TriggerEvent({"status": "error", "message": "Operation failed"})
+        return
 
     def _get_async_hook(self) -> BeamAsyncHook:
         return BeamAsyncHook(runner=self.runner)
diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py
index ea004860b4..7f02619024 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -277,13 +277,11 @@ class HiveCliHook(BaseHook):
                 )
                 self.sub_process = sub_process
                 stdout = ""
-                while True:
-                    line = sub_process.stdout.readline()
-                    if not line:
-                        break
-                    stdout += line.decode("UTF-8")
+                for line in iter(sub_process.stdout.readline, b""):
+                    line = line.decode()
+                    stdout += line
                     if verbose:
-                        self.log.info(line.decode("UTF-8").strip())
+                        self.log.info(line.strip())
                 sub_process.wait()
 
                 if sub_process.returncode:
@@ -704,25 +702,20 @@ class HiveMetastoreHook(BaseHook):
         # Assuming all specs have the same keys.
         if partition_key not in part_specs[0].keys():
             raise AirflowException(f"Provided partition_key {partition_key} is not in part_specs.")
-        is_subset = None
-        if filter_map:
-            is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys()))
-        if filter_map and not is_subset:
+        if filter_map and not set(filter_map).issubset(part_specs[0]):
             raise AirflowException(
                 f"Keys in provided filter_map {', '.join(filter_map.keys())} "
                 f"are not subset of part_spec keys: {', '.join(part_specs[0].keys())}"
             )
 
-        candidates = [
-            p_dict[partition_key]
-            for p_dict in part_specs
-            if filter_map is None or all(item in p_dict.items() for item in filter_map.items())
-        ]
-
-        if not candidates:
-            return None
-        else:
-            return max(candidates)
+        return max(
+            (
+                p_dict[partition_key]
+                for p_dict in part_specs
+                if filter_map is None or all(item in p_dict.items() for item in filter_map.items())
+            ),
+            default=None,
+        )
 
     def max_partition(
         self,
diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py
index ede3d2eb98..ba2ff1bb13 100644
--- a/airflow/providers/apache/livy/hooks/livy.py
+++ b/airflow/providers/apache/livy/hooks/livy.py
@@ -432,7 +432,7 @@ class LivyHook(HttpHook, LoggingMixin):
         if (
             vals is None
             or not isinstance(vals, (tuple, list))
-            or any(1 for val in vals if not isinstance(val, (str, int, float)))
+            or not all(isinstance(val, (str, int, float)) for val in vals)
         ):
             raise ValueError("List of strings expected")
         return True
@@ -448,7 +448,7 @@ class LivyHook(HttpHook, LoggingMixin):
         if conf:
             if not isinstance(conf, dict):
                 raise ValueError("'conf' argument must be a dict")
-            if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()):
+            if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()):
                 raise ValueError("'conf' values must be either strings or ints")
         return True
 
@@ -542,8 +542,7 @@ class LivyAsyncHook(HttpAsyncHook, LoggingMixin):
             else:
                 return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"}
 
-            attempt_num = 1
-            while True:
+            for attempt_num in range(1, 1 + self.retry_limit):
                 response = await request_func(
                     url,
                     json=data if self.method in ("POST", "PATCH") else None,
@@ -568,7 +567,6 @@ class LivyAsyncHook(HttpAsyncHook, LoggingMixin):
                         # Don't retry.
                         return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"}
 
-                attempt_num += 1
                 await asyncio.sleep(self.retry_delay)
 
     def _generate_base_url(self, conn: Connection) -> str:
@@ -815,7 +813,7 @@ class LivyAsyncHook(HttpAsyncHook, LoggingMixin):
         if (
             vals is None
             or not isinstance(vals, (tuple, list))
-            or any(1 for val in vals if not isinstance(val, (str, int, float)))
+            or not all(isinstance(val, (str, int, float)) for val in vals)
         ):
             raise ValueError("List of strings expected")
         return True
@@ -831,6 +829,6 @@ class LivyAsyncHook(HttpAsyncHook, LoggingMixin):
         if conf:
             if not isinstance(conf, dict):
                 raise ValueError("'conf' argument must be a dict")
-            if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()):
+            if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()):
                 raise ValueError("'conf' values must be either strings or ints")
         return True
diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py
index 6864aa52fe..41dc741ccd 100644
--- a/airflow/providers/apache/spark/hooks/spark_sql.py
+++ b/airflow/providers/apache/spark/hooks/spark_sql.py
@@ -134,7 +134,7 @@ class SparkSqlHook(BaseHook):
             connection_cmd += ["--num-executors", str(self._num_executors)]
         if self._sql:
             sql = self._sql.strip()
-            if sql.endswith(".sql") or sql.endswith(".hql"):
+            if sql.endswith((".sql", ".hql")):
                 connection_cmd += ["-f", sql]
             else:
                 connection_cmd += ["-e", sql]