You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/14 19:46:00 UTC
[airflow] branch v1-10-test updated: [AIRFLOW-4734] Upsert
functionality for PostgresHook.insert_rows() (#8625)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-test by this push:
new cafeb81 [AIRFLOW-4734] Upsert functionality for PostgresHook.insert_rows() (#8625)
cafeb81 is described below
commit cafeb81d60a269cc7fb48bd177bbbe46833ea79f
Author: William Tran <wi...@pager.com>
AuthorDate: Thu Apr 30 05:16:18 2020 -0400
[AIRFLOW-4734] Upsert functionality for PostgresHook.insert_rows() (#8625)
PostgresHook's parent class, DbApiHook, implements upsert in its insert_rows() method
with the replace=True flag. However, the underlying generated SQL is specific to MySQL's
"REPLACE INTO" syntax and is not applicable to PostgreSQL.
This pulls out the sql generation code for insert/upsert out in to a method that is then
overridden in the PostgreSQL subclass to generate the "INSERT ... ON CONFLICT DO
UPDATE" syntax ("new" since Postgres 9.5)
(cherry picked from commit a28c66f23d373cd0f8bfc765a515f21d4b66a0e9)
---
airflow/contrib/hooks/bigquery_hook.py | 2 +-
airflow/hooks/dbapi_hook.py | 54 ++++++++++++++++++++++---------
airflow/hooks/postgres_hook.py | 54 +++++++++++++++++++++++++++++++
tests/hooks/test_postgres_hook.py | 59 ++++++++++++++++++++++++++++++++++
4 files changed, 153 insertions(+), 16 deletions(-)
diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index 930d212..07a2ab8 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -85,7 +85,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
return build(
'bigquery', 'v2', http=http_authorized, cache_discovery=False)
- def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
+ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, **kwargs):
"""
Insertion is currently unsupported. Theoretically, you could use
BigQuery's streaming API to insert rows into a table, but this hasn't
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index 218ff83..ac54881 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -211,8 +211,43 @@ class DbApiHook(BaseHook):
"""
return self.get_conn().cursor()
+ @staticmethod
+ def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
+ """
+ Static helper method that generate the INSERT SQL statement.
+ The REPLACE variant is specific to MySQL syntax.
+
+ :param table: Name of the target table
+ :type table: str
+ :param values: The row to insert into the table
+ :type values: tuple of cell values
+ :param target_fields: The names of the columns to fill in the table
+ :type target_fields: iterable of strings
+ :param replace: Whether to replace instead of insert
+ :type replace: bool
+ :return: The generated INSERT or REPLACE SQL statement
+ :rtype: str
+ """
+ placeholders = ["%s", ] * len(values)
+
+ if target_fields:
+ target_fields = ", ".join(target_fields)
+ target_fields = "({})".format(target_fields)
+ else:
+ target_fields = ''
+
+ if not replace:
+ sql = "INSERT INTO "
+ else:
+ sql = "REPLACE INTO "
+ sql += "{0} {1} VALUES ({2})".format(
+ table,
+ target_fields,
+ ",".join(placeholders))
+ return sql
+
def insert_rows(self, table, rows, target_fields=None, commit_every=1000,
- replace=False):
+ replace=False, **kwargs):
"""
A generic way to insert a set of tuples into a table,
a new transaction is created every commit_every rows
@@ -229,11 +264,6 @@ class DbApiHook(BaseHook):
:param replace: Whether to replace instead of insert
:type replace: bool
"""
- if target_fields:
- target_fields = ", ".join(target_fields)
- target_fields = "({})".format(target_fields)
- else:
- target_fields = ''
i = 0
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
@@ -247,15 +277,9 @@ class DbApiHook(BaseHook):
for cell in row:
lst.append(self._serialize_cell(cell, conn))
values = tuple(lst)
- placeholders = ["%s", ] * len(values)
- if not replace:
- sql = "INSERT INTO "
- else:
- sql = "REPLACE INTO "
- sql += "{0} {1} VALUES ({2})".format(
- table,
- target_fields,
- ",".join(placeholders))
+ sql = self._generate_insert_sql(
+ table, values, target_fields, replace, **kwargs
+ )
cur.execute(sql, values)
if commit_every and i % commit_every == 0:
conn.commit()
diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py
index a6d6523..4c7a324 100644
--- a/airflow/hooks/postgres_hook.py
+++ b/airflow/hooks/postgres_hook.py
@@ -177,3 +177,57 @@ class PostgresHook(DbApiHook):
client = aws_hook.get_client_type('rds')
token = client.generate_db_auth_token(conn.host, port, conn.login)
return login, token, port
+
+ @staticmethod
+ def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
+ """
+ Static helper method that generate the INSERT SQL statement.
+ The REPLACE variant is specific to MySQL syntax.
+
+ :param table: Name of the target table
+ :type table: str
+ :param values: The row to insert into the table
+ :type values: tuple of cell values
+ :param target_fields: The names of the columns to fill in the table
+ :type target_fields: iterable of strings
+ :param replace: Whether to replace instead of insert
+ :type replace: bool
+ :param replace_index: the column or list of column names to act as
+ index for the ON CONFLICT clause
+ :type replace_index: str or list
+ :return: The generated INSERT or REPLACE SQL statement
+ :rtype: str
+ """
+ placeholders = ["%s", ] * len(values)
+ replace_index = kwargs.get("replace_index", None)
+
+ if target_fields:
+ target_fields_fragment = ", ".join(target_fields)
+ target_fields_fragment = "({})".format(target_fields_fragment)
+ else:
+ target_fields_fragment = ''
+
+ sql = "INSERT INTO {0} {1} VALUES ({2})".format(
+ table,
+ target_fields_fragment,
+ ",".join(placeholders))
+
+ if replace:
+ if target_fields is None:
+ raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names")
+ if replace_index is None:
+ raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index")
+ if isinstance(replace_index, str):
+ replace_index = [replace_index]
+ replace_index_set = set(replace_index)
+
+ replace_target = [
+ "{0} = excluded.{0}".format(col)
+ for col in target_fields
+ if col not in replace_index_set
+ ]
+ sql += " ON CONFLICT ({0}) DO UPDATE SET {1}".format(
+ ", ".join(replace_index),
+ ", ".join(replace_target),
+ )
+ return sql
diff --git a/tests/hooks/test_postgres_hook.py b/tests/hooks/test_postgres_hook.py
index f706d56..061dca0 100644
--- a/tests/hooks/test_postgres_hook.py
+++ b/tests/hooks/test_postgres_hook.py
@@ -183,3 +183,62 @@ class TestPostgresHook(unittest.TestCase):
results = [line.rstrip().decode("utf-8") for line in f.readlines()]
self.assertEqual(sorted(input_data), sorted(results))
+
+ @pytest.mark.backend("postgres")
+ def test_insert_rows(self):
+ table = "table"
+ rows = [("hello",),
+ ("world",)]
+
+ self.db_hook.insert_rows(table, rows)
+
+ assert self.conn.close.call_count == 1
+ assert self.cur.close.call_count == 1
+
+ commit_count = 2 # The first and last commit
+ self.assertEqual(commit_count, self.conn.commit.call_count)
+
+ sql = "INSERT INTO {} VALUES (%s)".format(table)
+ for row in rows:
+ self.cur.execute.assert_any_call(sql, row)
+
+ @pytest.mark.backend("postgres")
+ def test_insert_rows_replace(self):
+ table = "table"
+ rows = [(1, "hello",),
+ (2, "world",)]
+ fields = ("id", "value")
+
+ self.db_hook.insert_rows(
+ table, rows, fields, replace=True, replace_index=fields[0])
+
+ assert self.conn.close.call_count == 1
+ assert self.cur.close.call_count == 1
+
+ commit_count = 2 # The first and last commit
+ self.assertEqual(commit_count, self.conn.commit.call_count)
+
+ sql = "INSERT INTO {0} ({1}, {2}) VALUES (%s,%s) " \
+ "ON CONFLICT ({1}) DO UPDATE SET {2} = excluded.{2}".format(
+ table, fields[0], fields[1])
+ for row in rows:
+ self.cur.execute.assert_any_call(sql, row)
+
+ @pytest.mark.xfail
+ @pytest.mark.backend("postgres")
+ def test_insert_rows_replace_missing_target_field_arg(self):
+ table = "table"
+ rows = [(1, "hello",),
+ (2, "world",)]
+ fields = ("id", "value")
+ self.db_hook.insert_rows(
+ table, rows, replace=True, replace_index=fields[0])
+
+ @pytest.mark.xfail
+ @pytest.mark.backend("postgres")
+ def test_insert_rows_replace_missing_replace_index_arg(self):
+ table = "table"
+ rows = [(1, "hello",),
+ (2, "world",)]
+ fields = ("id", "value")
+ self.db_hook.insert_rows(table, rows, fields, replace=True)