You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/01/05 16:51:03 UTC

[airflow] branch main updated: Refactor vertica_to_mysql to make it more 'mypy' friendly (#20618)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 919ff45  Refactor vertica_to_mysql to make it more 'mypy' friendly (#20618)
919ff45 is described below

commit 919ff4567d86a09fb069dcfd84885b496229eea9
Author: Jarek Potiuk <ja...@potiuk.com>
AuthorDate: Wed Jan 5 17:50:31 2022 +0100

    Refactor vertica_to_mysql to make it more 'mypy' friendly (#20618)
    
    Part of #19891
    
    MyPy was confused by the logic in this method (and so humans could
    be) because there were some implicit relations between bulk_load
    and tmpfle. This refector makes the bulk_load and non-bulk load
    separate (extracting common parts) and more obvious.
    
    Thanks MyPy for flagging this one.
---
 .../providers/mysql/transfers/vertica_to_mysql.py  | 93 ++++++++++++----------
 1 file changed, 51 insertions(+), 42 deletions(-)

diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py b/airflow/providers/mysql/transfers/vertica_to_mysql.py
index 0c37d40..e273e59 100644
--- a/airflow/providers/mysql/transfers/vertica_to_mysql.py
+++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py
@@ -94,63 +94,72 @@ class VerticaToMySqlOperator(BaseOperator):
         vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
         mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
 
-        tmpfile = None
-        result = None
+        if self.bulk_load:
+            self._bulk_load_transfer(mysql, vertica)
+        else:
+            self._non_bulk_load_transfer(mysql, vertica)
 
-        selected_columns = []
+        if self.mysql_postoperator:
+            self.log.info("Running MySQL postoperator...")
+            mysql.run(self.mysql_postoperator)
 
-        count = 0
+        self.log.info("Done")
+
+    def _non_bulk_load_transfer(self, mysql, vertica):
         with closing(vertica.get_conn()) as conn:
             with closing(conn.cursor()) as cursor:
                 cursor.execute(self.sql)
                 selected_columns = [d.name for d in cursor.description]
+                self.log.info("Selecting rows from Vertica...")
+                self.log.info(self.sql)
 
-                if self.bulk_load:
-                    with NamedTemporaryFile("w") as tmpfile:
-                        self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name)
-                        self.log.info(self.sql)
+                result = cursor.fetchall()
+                count = len(result)
 
-                        csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8')
-                        for row in cursor.iterate():
-                            csv_writer.writerow(row)
-                            count += 1
+                self.log.info("Selected rows from Vertica %s", count)
+        self._run_preoperator(mysql)
+        try:
+            self.log.info("Inserting rows into MySQL...")
+            mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns)
+            self.log.info("Inserted rows into MySQL %s", count)
+        except (MySQLdb.Error, MySQLdb.Warning):
+            self.log.info("Inserted rows into MySQL 0")
+            raise
 
-                        tmpfile.flush()
-                else:
-                    self.log.info("Selecting rows from Vertica...")
+    def _bulk_load_transfer(self, mysql, vertica):
+        count = 0
+        with closing(vertica.get_conn()) as conn:
+            with closing(conn.cursor()) as cursor:
+                cursor.execute(self.sql)
+                selected_columns = [d.name for d in cursor.description]
+                with NamedTemporaryFile("w") as tmpfile:
+                    self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name)
                     self.log.info(self.sql)
 
-                    result = cursor.fetchall()
-                    count = len(result)
-
-                self.log.info("Selected rows from Vertica %s", count)
-
-        if self.mysql_preoperator:
-            self.log.info("Running MySQL preoperator...")
-            mysql.run(self.mysql_preoperator)
+                    csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8')
+                    for row in cursor.iterate():
+                        csv_writer.writerow(row)
+                        count += 1
 
+                    tmpfile.flush()
+        self._run_preoperator(mysql)
         try:
-            if self.bulk_load:
-                self.log.info("Bulk inserting rows into MySQL...")
-                with closing(mysql.get_conn()) as conn:
-                    with closing(conn.cursor()) as cursor:
-                        cursor.execute(
-                            f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
-                            f"INTO TABLE {self.mysql_table} "
-                            f"LINES TERMINATED BY '\r\n' ({', '.join(selected_columns)})"
-                        )
-                        conn.commit()
-                tmpfile.close()
-            else:
-                self.log.info("Inserting rows into MySQL...")
-                mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns)
+            self.log.info("Bulk inserting rows into MySQL...")
+            with closing(mysql.get_conn()) as conn:
+                with closing(conn.cursor()) as cursor:
+                    cursor.execute(
+                        f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
+                        f"INTO TABLE {self.mysql_table} "
+                        f"LINES TERMINATED BY '\r\n' ({', '.join(selected_columns)})"
+                    )
+                    conn.commit()
+            tmpfile.close()
             self.log.info("Inserted rows into MySQL %s", count)
         except (MySQLdb.Error, MySQLdb.Warning):
             self.log.info("Inserted rows into MySQL 0")
             raise
 
-        if self.mysql_postoperator:
-            self.log.info("Running MySQL postoperator...")
-            mysql.run(self.mysql_postoperator)
-
-        self.log.info("Done")
+    def _run_preoperator(self, mysql):
+        if self.mysql_preoperator:
+            self.log.info("Running MySQL preoperator...")
+            mysql.run(self.mysql_preoperator)