You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/15 16:06:46 UTC

[airflow] 10/28: [AIRFLOW-4734] Upsert functionality for PostgresHook.insert_rows() (#8625)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ecd860f03e5ecc2fe4f910139ead6fc949fe9e1b
Author: William Tran <wi...@pager.com>
AuthorDate: Thu Apr 30 05:16:18 2020 -0400

    [AIRFLOW-4734] Upsert functionality for PostgresHook.insert_rows() (#8625)
    
    PostgresHook's parent class, DbApiHook, implements upsert in its insert_rows() method
    with the replace=True flag. However, the underlying generated SQL is specific to MySQL's
    "REPLACE INTO" syntax and is not applicable to PostgreSQL.
    
    This pulls out the sql generation code for insert/upsert out in to a method that is then
    overridden in the PostgreSQL subclass to generate the "INSERT ... ON CONFLICT DO
    UPDATE" syntax ("new" since Postgres 9.5)
    
    (cherry picked from commit a28c66f23d373cd0f8bfc765a515f21d4b66a0e9)
---
 airflow/contrib/hooks/bigquery_hook.py |  2 +-
 airflow/hooks/dbapi_hook.py            | 54 ++++++++++++++++++++++---------
 airflow/hooks/postgres_hook.py         | 54 +++++++++++++++++++++++++++++++
 tests/hooks/test_postgres_hook.py      | 59 ++++++++++++++++++++++++++++++++++
 4 files changed, 153 insertions(+), 16 deletions(-)

diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index 930d212..07a2ab8 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -85,7 +85,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
         return build(
             'bigquery', 'v2', http=http_authorized, cache_discovery=False)
 
-    def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
+    def insert_rows(self, table, rows, target_fields=None, commit_every=1000, **kwargs):
         """
         Insertion is currently unsupported. Theoretically, you could use
         BigQuery's streaming API to insert rows into a table, but this hasn't
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index 218ff83..ac54881 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -211,8 +211,43 @@ class DbApiHook(BaseHook):
         """
         return self.get_conn().cursor()
 
+    @staticmethod
+    def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
+        """
+        Static helper method that generate the INSERT SQL statement.
+        The REPLACE variant is specific to MySQL syntax.
+
+        :param table: Name of the target table
+        :type table: str
+        :param values: The row to insert into the table
+        :type values: tuple of cell values
+        :param target_fields: The names of the columns to fill in the table
+        :type target_fields: iterable of strings
+        :param replace: Whether to replace instead of insert
+        :type replace: bool
+        :return: The generated INSERT or REPLACE SQL statement
+        :rtype: str
+        """
+        placeholders = ["%s", ] * len(values)
+
+        if target_fields:
+            target_fields = ", ".join(target_fields)
+            target_fields = "({})".format(target_fields)
+        else:
+            target_fields = ''
+
+        if not replace:
+            sql = "INSERT INTO "
+        else:
+            sql = "REPLACE INTO "
+        sql += "{0} {1} VALUES ({2})".format(
+            table,
+            target_fields,
+            ",".join(placeholders))
+        return sql
+
     def insert_rows(self, table, rows, target_fields=None, commit_every=1000,
-                    replace=False):
+                    replace=False, **kwargs):
         """
         A generic way to insert a set of tuples into a table,
         a new transaction is created every commit_every rows
@@ -229,11 +264,6 @@ class DbApiHook(BaseHook):
         :param replace: Whether to replace instead of insert
         :type replace: bool
         """
-        if target_fields:
-            target_fields = ", ".join(target_fields)
-            target_fields = "({})".format(target_fields)
-        else:
-            target_fields = ''
         i = 0
         with closing(self.get_conn()) as conn:
             if self.supports_autocommit:
@@ -247,15 +277,9 @@ class DbApiHook(BaseHook):
                     for cell in row:
                         lst.append(self._serialize_cell(cell, conn))
                     values = tuple(lst)
-                    placeholders = ["%s", ] * len(values)
-                    if not replace:
-                        sql = "INSERT INTO "
-                    else:
-                        sql = "REPLACE INTO "
-                    sql += "{0} {1} VALUES ({2})".format(
-                        table,
-                        target_fields,
-                        ",".join(placeholders))
+                    sql = self._generate_insert_sql(
+                        table, values, target_fields, replace, **kwargs
+                    )
                     cur.execute(sql, values)
                     if commit_every and i % commit_every == 0:
                         conn.commit()
diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py
index a6d6523..4c7a324 100644
--- a/airflow/hooks/postgres_hook.py
+++ b/airflow/hooks/postgres_hook.py
@@ -177,3 +177,57 @@ class PostgresHook(DbApiHook):
             client = aws_hook.get_client_type('rds')
             token = client.generate_db_auth_token(conn.host, port, conn.login)
         return login, token, port
+
+    @staticmethod
+    def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
+        """
+        Static helper method that generate the INSERT SQL statement.
+        The REPLACE variant is specific to MySQL syntax.
+
+        :param table: Name of the target table
+        :type table: str
+        :param values: The row to insert into the table
+        :type values: tuple of cell values
+        :param target_fields: The names of the columns to fill in the table
+        :type target_fields: iterable of strings
+        :param replace: Whether to replace instead of insert
+        :type replace: bool
+        :param replace_index: the column or list of column names to act as
+            index for the ON CONFLICT clause
+        :type replace_index: str or list
+        :return: The generated INSERT or REPLACE SQL statement
+        :rtype: str
+        """
+        placeholders = ["%s", ] * len(values)
+        replace_index = kwargs.get("replace_index", None)
+
+        if target_fields:
+            target_fields_fragment = ", ".join(target_fields)
+            target_fields_fragment = "({})".format(target_fields_fragment)
+        else:
+            target_fields_fragment = ''
+
+        sql = "INSERT INTO {0} {1} VALUES ({2})".format(
+            table,
+            target_fields_fragment,
+            ",".join(placeholders))
+
+        if replace:
+            if target_fields is None:
+                raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names")
+            if replace_index is None:
+                raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index")
+            if isinstance(replace_index, str):
+                replace_index = [replace_index]
+            replace_index_set = set(replace_index)
+
+            replace_target = [
+                "{0} = excluded.{0}".format(col)
+                for col in target_fields
+                if col not in replace_index_set
+            ]
+            sql += " ON CONFLICT ({0}) DO UPDATE SET {1}".format(
+                ", ".join(replace_index),
+                ", ".join(replace_target),
+            )
+        return sql
diff --git a/tests/hooks/test_postgres_hook.py b/tests/hooks/test_postgres_hook.py
index f706d56..061dca0 100644
--- a/tests/hooks/test_postgres_hook.py
+++ b/tests/hooks/test_postgres_hook.py
@@ -183,3 +183,62 @@ class TestPostgresHook(unittest.TestCase):
                     results = [line.rstrip().decode("utf-8") for line in f.readlines()]
 
         self.assertEqual(sorted(input_data), sorted(results))
+
+    @pytest.mark.backend("postgres")
+    def test_insert_rows(self):
+        table = "table"
+        rows = [("hello",),
+                ("world",)]
+
+        self.db_hook.insert_rows(table, rows)
+
+        assert self.conn.close.call_count == 1
+        assert self.cur.close.call_count == 1
+
+        commit_count = 2  # The first and last commit
+        self.assertEqual(commit_count, self.conn.commit.call_count)
+
+        sql = "INSERT INTO {}  VALUES (%s)".format(table)
+        for row in rows:
+            self.cur.execute.assert_any_call(sql, row)
+
+    @pytest.mark.backend("postgres")
+    def test_insert_rows_replace(self):
+        table = "table"
+        rows = [(1, "hello",),
+                (2, "world",)]
+        fields = ("id", "value")
+
+        self.db_hook.insert_rows(
+            table, rows, fields, replace=True, replace_index=fields[0])
+
+        assert self.conn.close.call_count == 1
+        assert self.cur.close.call_count == 1
+
+        commit_count = 2  # The first and last commit
+        self.assertEqual(commit_count, self.conn.commit.call_count)
+
+        sql = "INSERT INTO {0} ({1}, {2}) VALUES (%s,%s) " \
+              "ON CONFLICT ({1}) DO UPDATE SET {2} = excluded.{2}".format(
+                  table, fields[0], fields[1])
+        for row in rows:
+            self.cur.execute.assert_any_call(sql, row)
+
+    @pytest.mark.xfail
+    @pytest.mark.backend("postgres")
+    def test_insert_rows_replace_missing_target_field_arg(self):
+        table = "table"
+        rows = [(1, "hello",),
+                (2, "world",)]
+        fields = ("id", "value")
+        self.db_hook.insert_rows(
+            table, rows, replace=True, replace_index=fields[0])
+
+    @pytest.mark.xfail
+    @pytest.mark.backend("postgres")
+    def test_insert_rows_replace_missing_replace_index_arg(self):
+        table = "table"
+        rows = [(1, "hello",),
+                (2, "world",)]
+        fields = ("id", "value")
+        self.db_hook.insert_rows(table, rows, fields, replace=True)