You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/27 15:50:43 UTC
[airflow] branch main updated: Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5d4abbd58c Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)
5d4abbd58c is described below
commit 5d4abbd58c33e7dfa8505e307d43420459d3df55
Author: Jarek Potiuk <ja...@polidea.com>
AuthorDate: Wed Jul 27 17:50:18 2022 +0200
Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)
* Deprecate hql parameters and synchronize DBApiHook method APIs
Various providers deriving from DbApi had some variations in some
methods that were derived from the common DbApi Hook. Mostly they
were about extra parameters added and hql parameter used instead of
sql. This prevents from really "common" approach in DbApiHook as
some common sql operators rely on signatures being the same.
This introduced breaking changes in a few providers - but those
breaking changes are easy to fix and most have already been
deprecated.
---
airflow/providers/apache/hive/CHANGELOG.rst | 11 ++
airflow/providers/apache/hive/hooks/hive.py | 45 ++++----
.../providers/apache/hive/operators/hive_stats.py | 2 +-
.../apache/hive/transfers/hive_to_mysql.py | 2 +-
.../apache/hive/transfers/hive_to_samba.py | 2 +-
airflow/providers/apache/pinot/hooks/pinot.py | 8 +-
airflow/providers/common/sql/CHANGELOG.rst | 5 +
airflow/providers/common/sql/hooks/sql.py | 9 +-
airflow/providers/exasol/hooks/exasol.py | 9 +-
airflow/providers/google/cloud/hooks/cloud_sql.py | 61 ++++++-----
airflow/providers/presto/CHANGELOG.rst | 6 ++
airflow/providers/presto/hooks/presto.py | 114 +++-----------------
airflow/providers/trino/CHANGELOG.rst | 6 ++
airflow/providers/trino/hooks/trino.py | 120 ++-------------------
tests/providers/apache/hive/hooks/test_hive.py | 2 +-
.../apache/hive/transfers/test_hive_to_mysql.py | 4 +-
.../apache/hive/transfers/test_hive_to_samba.py | 2 +-
17 files changed, 132 insertions(+), 276 deletions(-)
diff --git a/airflow/providers/apache/hive/CHANGELOG.rst b/airflow/providers/apache/hive/CHANGELOG.rst
index 7463edeae8..318e50d698 100644
--- a/airflow/providers/apache/hive/CHANGELOG.rst
+++ b/airflow/providers/apache/hive/CHANGELOG.rst
@@ -24,6 +24,17 @@
Changelog
---------
+Breaking Changes
+~~~~~~~~~~~~~~~~
+
+* The ``hql`` parameter in ``get_records`` of ``HiveServer2Hook`` has been renamed to sql to match the
+ ``get_records`` DbApiHook signature. If you used it as a positional parameter, this is no change for you,
+ but if you used it as keyword one, you need to rename it.
+* ``hive_conf`` parameter has been renamed to ``parameters`` and it is now second parameter, to match ``get_records``
+ signature from the DbApiHook. You need to rename it if you used it.
+* ``schema`` parameter in ``get_records`` is an optional kwargs extra parameter that you can add, to match
+ the schema of ``get_records`` from DbApiHook.
+
3.1.0
.....
diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py
index 63e4652f20..f5e8c24d65 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -24,7 +24,7 @@ import time
import warnings
from collections import OrderedDict
from tempfile import NamedTemporaryFile, TemporaryDirectory
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
import pandas
import unicodecsv as csv
@@ -857,15 +857,15 @@ class HiveServer2Hook(DbApiHook):
def _get_results(
self,
- hql: Union[str, List[str]],
+ sql: Union[str, List[str]],
schema: str = 'default',
fetch_size: Optional[int] = None,
- hive_conf: Optional[Dict[Any, Any]] = None,
+ hive_conf: Optional[Union[Iterable, Mapping]] = None,
) -> Any:
from pyhive.exc import ProgrammingError
- if isinstance(hql, str):
- hql = [hql]
+ if isinstance(sql, str):
+ sql = [sql]
previous_description = None
with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur:
@@ -882,7 +882,7 @@ class HiveServer2Hook(DbApiHook):
for k, v in env_context.items():
cur.execute(f"set {k}={v}")
- for statement in hql:
+ for statement in sql:
cur.execute(statement)
# we only get results of statements that returns
lowered_statement = statement.lower().strip()
@@ -911,29 +911,29 @@ class HiveServer2Hook(DbApiHook):
def get_results(
self,
- hql: str,
+ sql: Union[str, List[str]],
schema: str = 'default',
fetch_size: Optional[int] = None,
- hive_conf: Optional[Dict[Any, Any]] = None,
+ hive_conf: Optional[Union[Iterable, Mapping]] = None,
) -> Dict[str, Any]:
"""
Get results of the provided hql in target schema.
- :param hql: hql to be executed.
+ :param sql: hql to be executed.
:param schema: target schema, default to 'default'.
:param fetch_size: max size of result to fetch.
:param hive_conf: hive_conf to execute alone with the hql.
:return: results of hql execution, dict with data (list of results) and header
:rtype: dict
"""
- results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
+ results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
results = {'data': list(results_iter), 'header': header}
return results
def to_csv(
self,
- hql: str,
+ sql: str,
csv_filepath: str,
schema: str = 'default',
delimiter: str = ',',
@@ -945,7 +945,7 @@ class HiveServer2Hook(DbApiHook):
"""
Execute hql in target schema and write results to a csv file.
- :param hql: hql to be executed.
+ :param sql: hql to be executed.
:param csv_filepath: filepath of csv to write results into.
:param schema: target schema, default to 'default'.
:param delimiter: delimiter of the csv file, default to ','.
@@ -955,7 +955,7 @@ class HiveServer2Hook(DbApiHook):
:param hive_conf: hive_conf to execute alone with the hql.
"""
- results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
+ results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
message = None
@@ -982,14 +982,14 @@ class HiveServer2Hook(DbApiHook):
self.log.info("Done. Loaded a total of %s rows.", i)
def get_records(
- self, hql: str, schema: str = 'default', hive_conf: Optional[Dict[Any, Any]] = None
+ self, sql: Union[str, List[str]], parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs
) -> Any:
"""
- Get a set of records from a Hive query.
+ Get a set of records from a Hive query. You can optionally pass 'schema' kwarg
+ which specifies target schema and default to 'default'.
- :param hql: hql to be executed.
- :param schema: target schema, default to 'default'.
- :param hive_conf: hive_conf to execute alone with the hql.
+ :param sql: hql to be executed.
+ :param parameters: optional configuration passed to get_results
:return: result of hive execution
:rtype: list
@@ -998,11 +998,12 @@ class HiveServer2Hook(DbApiHook):
>>> len(hh.get_records(sql))
100
"""
- return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data']
+ schema = kwargs['schema'] if 'schema' in kwargs else 'default'
+ return self.get_results(sql, schema=schema, hive_conf=parameters)['data']
def get_pandas_df( # type: ignore
self,
- hql: str,
+ sql: str,
schema: str = 'default',
hive_conf: Optional[Dict[Any, Any]] = None,
**kwargs,
@@ -1010,7 +1011,7 @@ class HiveServer2Hook(DbApiHook):
"""
Get a pandas dataframe from a Hive query
- :param hql: hql to be executed.
+ :param sql: hql to be executed.
:param schema: target schema, default to 'default'.
:param hive_conf: hive_conf to execute alone with the hql.
:param kwargs: (optional) passed into pandas.DataFrame constructor
@@ -1025,6 +1026,6 @@ class HiveServer2Hook(DbApiHook):
:return: pandas.DateFrame
"""
- res = self.get_results(hql, schema=schema, hive_conf=hive_conf)
+ res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pandas.DataFrame(res['data'], columns=[c[0] for c in res['header']], **kwargs)
return df
diff --git a/airflow/providers/apache/hive/operators/hive_stats.py b/airflow/providers/apache/hive/operators/hive_stats.py
index 7cf2002a41..a82040e846 100644
--- a/airflow/providers/apache/hive/operators/hive_stats.py
+++ b/airflow/providers/apache/hive/operators/hive_stats.py
@@ -138,7 +138,7 @@ class HiveStatsCollectionOperator(BaseOperator):
presto = PrestoHook(presto_conn_id=self.presto_conn_id)
self.log.info('Executing SQL check: %s', sql)
- row = presto.get_first(hql=sql)
+ row = presto.get_first(sql)
self.log.info("Record: %s", row)
if not row:
raise AirflowException("The query returned None")
diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
index 65de4cc159..9c01b3162b 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
@@ -111,7 +111,7 @@ class HiveToMySqlOperator(BaseOperator):
mysql = self._call_preoperator()
mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
else:
- hive_results = hive.get_records(self.sql, hive_conf=hive_conf)
+ hive_results = hive.get_records(self.sql, parameters=hive_conf)
mysql = self._call_preoperator()
mysql.insert_rows(table=self.mysql_table, rows=hive_results)
diff --git a/airflow/providers/apache/hive/transfers/hive_to_samba.py b/airflow/providers/apache/hive/transfers/hive_to_samba.py
index c5ab66efa4..63a811c8ce 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_samba.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_samba.py
@@ -68,7 +68,7 @@ class HiveToSambaOperator(BaseOperator):
with NamedTemporaryFile() as tmp_file:
self.log.info("Fetching file from Hive")
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
- hive.to_csv(hql=self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context))
+ hive.to_csv(self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context))
self.log.info("Pushing to samba")
samba = SambaHook(samba_conn_id=self.samba_conn_id)
samba.push_from_local(self.destination_filepath, tmp_file.name)
diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py
index 794646e46d..90b3804d98 100644
--- a/airflow/providers/apache/pinot/hooks/pinot.py
+++ b/airflow/providers/apache/pinot/hooks/pinot.py
@@ -275,7 +275,9 @@ class PinotDbApiHook(DbApiHook):
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
return f'{conn_type}://{host}/{endpoint}'
- def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
+ def get_records(
+ self, sql: Union[str, List[str]], parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs
+ ) -> Any:
"""
Executes the sql and returns a set of records.
@@ -287,7 +289,9 @@ class PinotDbApiHook(DbApiHook):
cur.execute(sql)
return cur.fetchall()
- def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
+ def get_first(
+ self, sql: Union[str, List[str]], parameters: Optional[Union[Iterable, Mapping]] = None
+ ) -> Any:
"""
Executes the sql and returns the first resulting row.
diff --git a/airflow/providers/common/sql/CHANGELOG.rst b/airflow/providers/common/sql/CHANGELOG.rst
index e8cd07d33b..d48dafc25d 100644
--- a/airflow/providers/common/sql/CHANGELOG.rst
+++ b/airflow/providers/common/sql/CHANGELOG.rst
@@ -15,6 +15,11 @@
specific language governing permissions and limitations
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py
index e6687fa938..76d7980850 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -181,7 +181,12 @@ class DbApiHook(BaseHook):
with closing(self.get_conn()) as conn:
yield from psql.read_sql(sql, con=conn, params=parameters, chunksize=chunksize, **kwargs)
- def get_records(self, sql, parameters=None):
+ def get_records(
+ self,
+ sql: Union[str, List[str]],
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ **kwargs: dict,
+ ):
"""
Executes the sql and returns a set of records.
@@ -197,7 +202,7 @@ class DbApiHook(BaseHook):
cur.execute(sql)
return cur.fetchall()
- def get_first(self, sql, parameters=None):
+ def get_first(self, sql: Union[str, List[str]], parameters=None):
"""
Executes the sql and returns the first resulting row.
diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py
index 537f2fcb6d..6d05b5e7b5 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -77,7 +77,12 @@ class ExasolHook(DbApiHook):
df = conn.export_to_pandas(sql, query_params=parameters, **kwargs)
return df
- def get_records(self, sql: str, parameters: Optional[dict] = None) -> List[Union[dict, Tuple[Any, ...]]]:
+ def get_records(
+ self,
+ sql: Union[str, List[str]],
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ **kwargs: dict,
+ ) -> List[Union[dict, Tuple[Any, ...]]]:
"""
Executes the sql and returns a set of records.
@@ -89,7 +94,7 @@ class ExasolHook(DbApiHook):
with closing(conn.execute(sql, parameters)) as cur:
return cur.fetchall()
- def get_first(self, sql: str, parameters: Optional[dict] = None) -> Optional[Any]:
+ def get_first(self, sql: Union[str, List[str]], parameters: Optional[dict] = None) -> Optional[Any]:
"""
Executes the sql and returns the first resulting row.
diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py
index 95f16fe9c1..5db53677a0 100644
--- a/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -426,11 +426,11 @@ class CloudSqlProxyRunner(LoggingMixin):
self.sql_proxy_was_downloaded = False
self.sql_proxy_version = sql_proxy_version
self.download_sql_proxy_dir = None
- self.sql_proxy_process = None # type: Optional[Popen]
+ self.sql_proxy_process: Optional[Popen] = None
self.instance_specification = instance_specification
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
- self.command_line_parameters = [] # type: List[str]
+ self.command_line_parameters: List[str] = []
self.cloud_sql_proxy_socket_directory = self.path_prefix
self.sql_proxy_path = (
sql_proxy_binary_path if sql_proxy_binary_path else self.path_prefix + "_cloud_sql_proxy"
@@ -705,28 +705,28 @@ class CloudSQLDatabaseHook(BaseHook):
self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
self.cloudsql_connection = self.get_connection(self.gcp_cloudsql_conn_id)
self.extras = self.cloudsql_connection.extra_dejson
- self.project_id = self.extras.get('project_id', default_gcp_project_id) # type: Optional[str]
- self.instance = self.extras.get('instance') # type: Optional[str]
- self.database = self.cloudsql_connection.schema # type: Optional[str]
- self.location = self.extras.get('location') # type: Optional[str]
- self.database_type = self.extras.get('database_type') # type: Optional[str]
- self.use_proxy = self._get_bool(self.extras.get('use_proxy', 'False')) # type: bool
- self.use_ssl = self._get_bool(self.extras.get('use_ssl', 'False')) # type: bool
- self.sql_proxy_use_tcp = self._get_bool(self.extras.get('sql_proxy_use_tcp', 'False')) # type: bool
- self.sql_proxy_version = self.extras.get('sql_proxy_version') # type: Optional[str]
- self.sql_proxy_binary_path = self.extras.get('sql_proxy_binary_path') # type: Optional[str]
- self.user = self.cloudsql_connection.login # type: Optional[str]
- self.password = self.cloudsql_connection.password # type: Optional[str]
- self.public_ip = self.cloudsql_connection.host # type: Optional[str]
- self.public_port = self.cloudsql_connection.port # type: Optional[int]
- self.sslcert = self.extras.get('sslcert') # type: Optional[str]
- self.sslkey = self.extras.get('sslkey') # type: Optional[str]
- self.sslrootcert = self.extras.get('sslrootcert') # type: Optional[str]
+ self.project_id = self.extras.get('project_id', default_gcp_project_id)
+ self.instance = self.extras.get('instance')
+ self.database = self.cloudsql_connection.schema
+ self.location = self.extras.get('location')
+ self.database_type = self.extras.get('database_type')
+ self.use_proxy = self._get_bool(self.extras.get('use_proxy', 'False'))
+ self.use_ssl = self._get_bool(self.extras.get('use_ssl', 'False'))
+ self.sql_proxy_use_tcp = self._get_bool(self.extras.get('sql_proxy_use_tcp', 'False'))
+ self.sql_proxy_version = self.extras.get('sql_proxy_version')
+ self.sql_proxy_binary_path = self.extras.get('sql_proxy_binary_path')
+ self.user = self.cloudsql_connection.login
+ self.password = self.cloudsql_connection.password
+ self.public_ip = self.cloudsql_connection.host
+ self.public_port = self.cloudsql_connection.port
+ self.sslcert = self.extras.get('sslcert')
+ self.sslkey = self.extras.get('sslkey')
+ self.sslrootcert = self.extras.get('sslrootcert')
# Port and socket path and db_hook are automatically generated
self.sql_proxy_tcp_port = None
- self.sql_proxy_unique_path = None # type: Optional[str]
- self.db_hook = None # type: Optional[Union[PostgresHook, MySqlHook]]
- self.reserved_tcp_socket = None # type: Optional[socket.socket]
+ self.sql_proxy_unique_path: Optional[str] = None
+ self.db_hook: Optional[Union[PostgresHook, MySqlHook]] = None
+ self.reserved_tcp_socket: Optional[socket.socket] = None
# Generated based on clock + clock sequence. Unique per host (!).
# This is important as different hosts share the database
self.db_conn_id = str(uuid.uuid1())
@@ -828,18 +828,18 @@ class CloudSQLDatabaseHook(BaseHook):
if not self.database_type:
raise ValueError("The database_type should be set")
- database_uris = CONNECTION_URIS[self.database_type] # type: Dict[str, Dict[str, str]]
+ database_uris = CONNECTION_URIS[self.database_type]
ssl_spec = None
socket_path = None
if self.use_proxy:
- proxy_uris = database_uris['proxy'] # type: Dict[str, str]
+ proxy_uris = database_uris['proxy']
if self.sql_proxy_use_tcp:
format_string = proxy_uris['tcp']
else:
format_string = proxy_uris['socket']
socket_path = f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}"
else:
- public_uris = database_uris['public'] # type: Dict[str, str]
+ public_uris = database_uris['public']
if self.use_ssl:
format_string = public_uris['ssl']
ssl_spec = {'cert': self.sslcert, 'key': self.sslkey, 'ca': self.sslrootcert}
@@ -876,7 +876,7 @@ class CloudSQLDatabaseHook(BaseHook):
return connection_uri
def _get_instance_socket_name(self) -> str:
- return self.project_id + ":" + self.location + ":" + self.instance # type: ignore
+ return self.project_id + ":" + self.location + ":" + self.instance
def _get_sqlproxy_instance_specification(self) -> str:
instance_specification = self._get_instance_socket_name()
@@ -921,10 +921,13 @@ class CloudSQLDatabaseHook(BaseHook):
that uses proxy or connects directly to the Google Cloud SQL database.
"""
if self.database_type == 'postgres':
- self.db_hook = PostgresHook(connection=connection, schema=self.database)
+ db_hook: Union[PostgresHook, MySqlHook] = PostgresHook(
+ connection=connection, schema=self.database
+ )
else:
- self.db_hook = MySqlHook(connection=connection, schema=self.database)
- return self.db_hook
+ db_hook = MySqlHook(connection=connection, schema=self.database)
+ self.db_hook = db_hook
+ return db_hook
def cleanup_database_hook(self) -> None:
"""Clean up database hook after it was used."""
diff --git a/airflow/providers/presto/CHANGELOG.rst b/airflow/providers/presto/CHANGELOG.rst
index 46bf396dd7..1d7af0a87e 100644
--- a/airflow/providers/presto/CHANGELOG.rst
+++ b/airflow/providers/presto/CHANGELOG.rst
@@ -24,6 +24,12 @@
Changelog
---------
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+Deprecated ``hql`` parameter has been removed in ``get_records``, ``get_first``, ``get_pandas_df`` and ``run``
+methods of the ``PrestoHook``.
+
3.1.0
.....
diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py
index 709a378a8d..454bccb45c 100644
--- a/airflow/providers/presto/hooks/presto.py
+++ b/airflow/providers/presto/hooks/presto.py
@@ -17,8 +17,7 @@
# under the License.
import json
import os
-import warnings
-from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload
+from typing import Any, Callable, Iterable, List, Mapping, Optional, Union
import prestodb
from prestodb.exceptions import DatabaseError
@@ -142,82 +141,28 @@ class PrestoHook(DbApiHook):
isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT)
- @overload
- def get_records(self, sql: str = "", parameters: Optional[dict] = None):
- """Get a set of records from Presto
-
- :param sql: SQL statement to be executed.
- :param parameters: The parameters to render the SQL query with.
- """
-
- @overload
- def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""):
- """:sphinx-autoapi-skip:"""
-
- def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""):
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
+ def get_records(
+ self,
+ sql: Union[str, List[str]] = "",
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ **kwargs: dict,
+ ):
+ if not isinstance(sql, str):
+ raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!")
try:
return super().get_records(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise PrestoException(e)
- @overload
- def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any:
- """Returns only the first row, regardless of how many rows the query returns.
-
- :param sql: SQL statement to be executed.
- :param parameters: The parameters to render the SQL query with.
- """
-
- @overload
- def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any:
- """:sphinx-autoapi-skip:"""
-
- def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any:
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
+ def get_first(self, sql: Union[str, List[str]] = "", parameters: Optional[dict] = None) -> Any:
+ if not isinstance(sql, str):
+ raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!")
try:
return super().get_first(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise PrestoException(e)
- @overload
def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
- """Get a pandas dataframe from a sql query.
-
- :param sql: SQL statement to be executed.
- :param parameters: The parameters to render the SQL query with.
- """
-
- @overload
- def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs):
- """:sphinx-autoapi-skip:"""
-
- def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs):
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
import pandas
cursor = self.get_cursor()
@@ -234,19 +179,6 @@ class PrestoHook(DbApiHook):
df = pandas.DataFrame(**kwargs)
return df
- @overload
- def run(
- self,
- sql: Union[str, Iterable[str]],
- autocommit: bool = False,
- parameters: Optional[Union[Iterable, Mapping]] = None,
- handler: Optional[Callable] = None,
- split_statements: bool = False,
- return_last: bool = True,
- ) -> Optional[Union[Any, List[Any]]]:
- """Execute the statement against Presto. Can be used to create views."""
-
- @overload
def run(
self,
sql: Union[str, Iterable[str]],
@@ -255,29 +187,7 @@ class PrestoHook(DbApiHook):
handler: Optional[Callable] = None,
split_statements: bool = False,
return_last: bool = True,
- hql: str = "",
) -> Optional[Union[Any, List[Any]]]:
- """:sphinx-autoapi-skip:"""
-
- def run(
- self,
- sql: Union[str, Iterable[str]],
- autocommit: bool = False,
- parameters: Optional[Union[Iterable, Mapping]] = None,
- handler: Optional[Callable] = None,
- split_statements: bool = False,
- return_last: bool = True,
- hql: str = "",
- ) -> Optional[Union[Any, List[Any]]]:
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
return super().run(
sql=sql,
autocommit=autocommit,
diff --git a/airflow/providers/trino/CHANGELOG.rst b/airflow/providers/trino/CHANGELOG.rst
index bfbe975e68..0e59ff69db 100644
--- a/airflow/providers/trino/CHANGELOG.rst
+++ b/airflow/providers/trino/CHANGELOG.rst
@@ -24,6 +24,12 @@
Changelog
---------
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+Deprecated ``hql`` parameter has been removed in ``get_records``, ``get_first``, ``get_pandas_df`` and ``run``
+methods of the ``TrinoHook``.
+
3.1.0
.....
diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py
index d8ac5148de..a29a48500b 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -17,9 +17,8 @@
# under the License.
import json
import os
-import warnings
from contextlib import closing
-from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload
+from typing import Any, Callable, Iterable, List, Mapping, Optional, Union
import trino
from trino.exceptions import DatabaseError
@@ -148,96 +147,32 @@ class TrinoHook(DbApiHook):
isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT)
- @overload
- def get_records(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None):
- """Get a set of records from Trino
-
- :param sql: SQL statement to be executed.
- :param parameters: The parameters to render the SQL query with.
- """
-
- @overload
def get_records(
- self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
- ):
- """:sphinx-autoapi-skip:"""
-
- def get_records(
- self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
+ self,
+ sql: Union[str, List[str]] = "",
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ **kwargs: dict,
):
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
+ if not isinstance(sql, str):
+ raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!")
try:
return super().get_records(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise TrinoException(e)
- @overload
- def get_first(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
- """Returns only the first row, regardless of how many rows the query returns.
-
- :param sql: SQL statement to be executed.
- :param parameters: The parameters to render the SQL query with.
- """
-
- @overload
def get_first(
- self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
+ self, sql: Union[str, List[str]] = "", parameters: Optional[Union[Iterable, Mapping]] = None
) -> Any:
- """:sphinx-autoapi-skip:"""
-
- def get_first(
- self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
- ) -> Any:
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
+ if not isinstance(sql, str):
+ raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!")
try:
return super().get_first(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise TrinoException(e)
- @overload
def get_pandas_df(
self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs
): # type: ignore[override]
- """Get a pandas dataframe from a sql query.
-
- :param sql: SQL statement to be executed.
- :param parameters: The parameters to render the SQL query with.
- """
-
- @overload
- def get_pandas_df(
- self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs
- ): # type: ignore[override]
- """:sphinx-autoapi-skip:"""
-
- def get_pandas_df(
- self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs
- ): # type: ignore[override]
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
import pandas
cursor = self.get_cursor()
@@ -254,19 +189,6 @@ class TrinoHook(DbApiHook):
df = pandas.DataFrame(**kwargs)
return df
- @overload
- def run(
- self,
- sql: Union[str, Iterable[str]],
- autocommit: bool = False,
- parameters: Optional[Union[Iterable, Mapping]] = None,
- handler: Optional[Callable] = None,
- split_statements: bool = False,
- return_last: bool = True,
- ) -> Optional[Union[Any, List[Any]]]:
- """Execute the statement against Trino. Can be used to create views."""
-
- @overload
def run(
self,
sql: Union[str, Iterable[str]],
@@ -275,29 +197,7 @@ class TrinoHook(DbApiHook):
handler: Optional[Callable] = None,
split_statements: bool = False,
return_last: bool = True,
- hql: str = "",
) -> Optional[Union[Any, List[Any]]]:
- """:sphinx-autoapi-skip:"""
-
- def run(
- self,
- sql: Union[str, Iterable[str]],
- autocommit: bool = False,
- parameters: Optional[Union[Iterable, Mapping]] = None,
- handler: Optional[Callable] = None,
- split_statements: bool = False,
- return_last: bool = True,
- hql: str = "",
- ) -> Optional[Union[Any, List[Any]]]:
- """:sphinx-autoapi-skip:"""
- if hql:
- warnings.warn(
- "The hql parameter has been deprecated. You should pass the sql parameter.",
- DeprecationWarning,
- stacklevel=2,
- )
- sql = hql
-
return super().run(
sql=sql,
autocommit=autocommit,
diff --git a/tests/providers/apache/hive/hooks/test_hive.py b/tests/providers/apache/hive/hooks/test_hive.py
index 49e26863be..c1fe5bd223 100644
--- a/tests/providers/apache/hive/hooks/test_hive.py
+++ b/tests/providers/apache/hive/hooks/test_hive.py
@@ -826,7 +826,7 @@ class TestHiveServer2Hook(unittest.TestCase):
)
output = '\n'.join(
- res_tuple[0] for res_tuple in hook.get_results(hql=hql, hive_conf={'key': 'value'})['data']
+ res_tuple[0] for res_tuple in hook.get_results(hql, hive_conf={'key': 'value'})['data']
)
assert 'value' in output
assert 'test_dag_id' in output
diff --git a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
index 29b366e4cf..7e056a17ba 100644
--- a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
+++ b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
@@ -45,7 +45,7 @@ class TestHiveToMySqlTransfer(TestHiveEnvironment):
HiveToMySqlOperator(**self.kwargs).execute(context={})
mock_hive_hook.assert_called_once_with(hiveserver2_conn_id=self.kwargs['hiveserver2_conn_id'])
- mock_hive_hook.return_value.get_records.assert_called_once_with('sql', hive_conf={})
+ mock_hive_hook.return_value.get_records.assert_called_once_with('sql', parameters={})
mock_mysql_hook.assert_called_once_with(mysql_conn_id=self.kwargs['mysql_conn_id'])
mock_mysql_hook.return_value.insert_rows.assert_called_once_with(
table=self.kwargs['mysql_table'], rows=mock_hive_hook.return_value.get_records.return_value
@@ -112,7 +112,7 @@ class TestHiveToMySqlTransfer(TestHiveEnvironment):
hive_conf = context_to_airflow_vars(context)
hive_conf.update(self.kwargs['hive_conf'])
- mock_hive_hook.get_records.assert_called_once_with(self.kwargs['sql'], hive_conf=hive_conf)
+ mock_hive_hook.get_records.assert_called_once_with(self.kwargs['sql'], parameters=hive_conf)
@unittest.skipIf(
'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set"
diff --git a/tests/providers/apache/hive/transfers/test_hive_to_samba.py b/tests/providers/apache/hive/transfers/test_hive_to_samba.py
index 3e13e23536..6620943758 100644
--- a/tests/providers/apache/hive/transfers/test_hive_to_samba.py
+++ b/tests/providers/apache/hive/transfers/test_hive_to_samba.py
@@ -64,7 +64,7 @@ class TestHive2SambaOperator(TestHiveEnvironment):
mock_hive_hook.assert_called_once_with(hiveserver2_conn_id=self.kwargs['hiveserver2_conn_id'])
mock_hive_hook.return_value.to_csv.assert_called_once_with(
- hql=self.kwargs['hql'],
+ self.kwargs['hql'],
csv_filepath=mock_tmp_file.name,
hive_conf=context_to_airflow_vars(context),
)