You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/01/04 19:52:06 UTC
[incubator-superset] 05/22: Split up tests/db_engine_test.py (#8449)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch 0.35
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
commit ddeb43677f8d95255e78f7a371b5c0d14661e2a7
Author: Will Barrett <wi...@preset.io>
AuthorDate: Thu Oct 24 20:46:45 2019 -0700
Split up tests/db_engine_test.py (#8449)
* Split up db_engine_specs_test.py into a number of targeted files
* Remove db_engine_specs_test.py
* isort
---
tests/db_engine_specs/base_engine_spec_tests.py | 204 ++++++
tests/db_engine_specs/base_tests.py | 28 +
tests/db_engine_specs/bigquery_tests.py | 39 ++
tests/db_engine_specs/hive_tests.py | 152 +++++
tests/db_engine_specs/mssql_tests.py | 71 +++
tests/db_engine_specs/mysql_tests.py | 30 +
tests/db_engine_specs/oracle_tests.py | 36 ++
tests/db_engine_specs/pinot_tests.py | 33 +
tests/db_engine_specs/postgres_tests.py | 72 +++
tests/db_engine_specs/presto_tests.py | 343 ++++++++++
tests/db_engine_specs_test.py | 810 ------------------------
11 files changed, 1008 insertions(+), 810 deletions(-)
diff --git a/tests/db_engine_specs/base_engine_spec_tests.py b/tests/db_engine_specs/base_engine_spec_tests.py
new file mode 100644
index 0000000..13f7b67
--- /dev/null
+++ b/tests/db_engine_specs/base_engine_spec_tests.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import mock
+
+from superset import app
+from superset.db_engine_specs import engines
+from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains
+from superset.db_engine_specs.sqlite import SqliteEngineSpec
+from superset.utils.core import get_example_database
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class DbEngineSpecsTests(DbEngineSpecTestCase):
+ def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
+ q0 = "select * from table"
+ q1 = "select * from mytable limit 10"
+ q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
+ q3 = "select * from (select * from my_subquery limit 10);"
+ q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
+ q5 = "select * from mytable limit 20, 10"
+ q6 = "select * from mytable limit 10 offset 20"
+ q7 = "select * from mytable limit"
+ q8 = "select * from mytable limit 10.0"
+ q9 = "select * from mytable limit x"
+ q10 = "select * from mytable limit 20, x"
+ q11 = "select * from mytable limit x offset 20"
+
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
+ self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
+
+ def test_wrapped_semi_tabs(self):
+ self.sql_limit_regex(
+ "SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
+ )
+
+ def test_simple_limit_query(self):
+ self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
+
+ def test_modify_limit_query(self):
+ self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
+
+ def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
+ self.sql_limit_regex(
+ "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
+ "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
+ )
+
+ def test_limit_with_expr(self):
+ self.sql_limit_regex(
+ """
+ SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 99990""",
+ """SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 1000""",
+ )
+
+ def test_limit_expr_and_semicolon(self):
+ self.sql_limit_regex(
+ """
+ SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 99990 ;""",
+ """SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 1000""",
+ )
+
+ def test_get_datatype(self):
+ self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
+
+ def test_limit_with_implicit_offset(self):
+ self.sql_limit_regex(
+ """
+ SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 99990, 999999""",
+ """SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 99990, 1000""",
+ )
+
+ def test_limit_with_explicit_offset(self):
+ self.sql_limit_regex(
+ """
+ SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 99990
+ OFFSET 999999""",
+ """SELECT
+ 'LIMIT 777' AS a
+ , b
+ FROM
+ table
+ LIMIT 1000
+ OFFSET 999999""",
+ )
+
+ def test_limit_with_non_token_limit(self):
+ self.sql_limit_regex(
+ """SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
+ )
+
+ def test_time_grain_blacklist(self):
+ with app.app_context():
+ app.config["TIME_GRAIN_BLACKLIST"] = ["PT1M"]
+ time_grain_functions = SqliteEngineSpec.get_time_grain_functions()
+ self.assertNotIn("PT1M", time_grain_functions)
+
+ def test_time_grain_addons(self):
+ with app.app_context():
+ app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
+ app.config["TIME_GRAIN_ADDON_FUNCTIONS"] = {
+ "sqlite": {"PTXM": "ABC({col})"}
+ }
+ time_grains = SqliteEngineSpec.get_time_grains()
+ time_grain_addon = time_grains[-1]
+ self.assertEqual("PTXM", time_grain_addon.duration)
+ self.assertEqual("x seconds", time_grain_addon.label)
+
+ def test_engine_time_grain_validity(self):
+ time_grains = set(builtin_time_grains.keys())
+ # loop over all subclasses of BaseEngineSpec
+ for engine in engines.values():
+ if engine is not BaseEngineSpec:
+ # make sure time grain functions have been defined
+ self.assertGreater(len(engine.get_time_grain_functions()), 0)
+ # make sure all defined time grains are supported
+ defined_grains = {grain.duration for grain in engine.get_time_grains()}
+ intersection = time_grains.intersection(defined_grains)
+ self.assertSetEqual(defined_grains, intersection, engine)
+
+ def test_get_table_names(self):
+ inspector = mock.Mock()
+ inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
+ inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
+
+ """ Make sure base engine spec removes schema name from table name
+ ie. when try_remove_schema_from_table_name == True. """
+ base_result_expected = ["table", "table_2"]
+ base_result = BaseEngineSpec.get_table_names(
+ database=mock.ANY, schema="schema", inspector=inspector
+ )
+ self.assertListEqual(base_result_expected, base_result)
+
+ def test_column_datatype_to_string(self):
+ example_db = get_example_database()
+ sqla_table = example_db.get_table("energy_usage")
+ dialect = example_db.get_dialect()
+ col_names = [
+ example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
+ for c in sqla_table.columns
+ ]
+ if example_db.backend == "postgresql":
+ expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
+ else:
+ expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
+ self.assertEqual(col_names, expected)
diff --git a/tests/db_engine_specs/base_tests.py b/tests/db_engine_specs/base_tests.py
new file mode 100644
index 0000000..812e6b8
--- /dev/null
+++ b/tests/db_engine_specs/base_tests.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.db_engine_specs.mysql import MySQLEngineSpec
+from superset.models.core import Database
+from tests.base_tests import SupersetTestCase
+
+
+class DbEngineSpecTestCase(SupersetTestCase):
+ def sql_limit_regex(
+ self, sql, expected_sql, engine_spec_class=MySQLEngineSpec, limit=1000
+ ):
+ main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
+ limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
+ self.assertEqual(expected_sql, limited)
diff --git a/tests/db_engine_specs/bigquery_tests.py b/tests/db_engine_specs/bigquery_tests.py
new file mode 100644
index 0000000..ec23e86
--- /dev/null
+++ b/tests/db_engine_specs/bigquery_tests.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from sqlalchemy import column
+
+from superset.db_engine_specs.bigquery import BigQueryEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class BigQueryTestCase(DbEngineSpecTestCase):
+ def test_bigquery_sqla_column_label(self):
+ label = BigQueryEngineSpec.make_label_compatible(column("Col").name)
+ label_expected = "Col"
+ self.assertEqual(label, label_expected)
+
+ label = BigQueryEngineSpec.make_label_compatible(column("SUM(x)").name)
+ label_expected = "SUM_x__5f110"
+ self.assertEqual(label, label_expected)
+
+ label = BigQueryEngineSpec.make_label_compatible(column("SUM[x]").name)
+ label_expected = "SUM_x__7ebe1"
+ self.assertEqual(label, label_expected)
+
+ label = BigQueryEngineSpec.make_label_compatible(column("12345_col").name)
+ label_expected = "_12345_col_8d390"
+ self.assertEqual(label, label_expected)
diff --git a/tests/db_engine_specs/hive_tests.py b/tests/db_engine_specs/hive_tests.py
new file mode 100644
index 0000000..94a474d
--- /dev/null
+++ b/tests/db_engine_specs/hive_tests.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import mock
+
+from superset.db_engine_specs.hive import HiveEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class HiveTests(DbEngineSpecTestCase):
+ def test_0_progress(self):
+ log = """
+ 17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
+ 17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
+ """.split(
+ "\n"
+ )
+ self.assertEqual(0, HiveEngineSpec.progress(log))
+
+ def test_number_of_jobs_progress(self):
+ log = """
+ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
+ """.split(
+ "\n"
+ )
+ self.assertEqual(0, HiveEngineSpec.progress(log))
+
+ def test_job_1_launched_progress(self):
+ log = """
+ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
+ 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
+ """.split(
+ "\n"
+ )
+ self.assertEqual(0, HiveEngineSpec.progress(log))
+
+ def test_job_1_launched_stage_1(self):
+ log = """
+ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
+ 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
+ """.split(
+ "\n"
+ )
+ self.assertEqual(0, HiveEngineSpec.progress(log))
+
+ def test_job_1_launched_stage_1_map_40_progress(
+ self
+ ): # pylint: disable=invalid-name
+ log = """
+ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
+ 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
+ """.split(
+ "\n"
+ )
+ self.assertEqual(10, HiveEngineSpec.progress(log))
+
+ def test_job_1_launched_stage_1_map_80_reduce_40_progress(
+ self
+ ): # pylint: disable=invalid-name
+ log = """
+ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
+ 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
+ """.split(
+ "\n"
+ )
+ self.assertEqual(30, HiveEngineSpec.progress(log))
+
+ def test_job_1_launched_stage_2_stages_progress(
+ self
+ ): # pylint: disable=invalid-name
+ log = """
+ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
+ 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
+ """.split(
+ "\n"
+ )
+ self.assertEqual(12, HiveEngineSpec.progress(log))
+
+ def test_job_2_launched_stage_2_stages_progress(
+ self
+ ): # pylint: disable=invalid-name
+ log = """
+ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
+ 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
+ 17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
+ 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
+ """.split(
+ "\n"
+ )
+ self.assertEqual(60, HiveEngineSpec.progress(log))
+
+ def test_hive_error_msg(self):
+ msg = (
+ '{...} errorMessage="Error while compiling statement: FAILED: '
+ "SemanticException [Error 10001]: Line 4"
+ ":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
+ "sqlState='42S02', errorCode=10001)){...}"
+ )
+ self.assertEqual(
+ (
+ "hive error: Error while compiling statement: FAILED: "
+ "SemanticException [Error 10001]: Line 4:5 "
+ "Table not found 'fact_ridesfdslakj'"
+ ),
+ HiveEngineSpec.extract_error_message(Exception(msg)),
+ )
+
+ e = Exception("Some string that doesn't match the regex")
+ self.assertEqual(f"hive error: {e}", HiveEngineSpec.extract_error_message(e))
+
+ msg = (
+ "errorCode=10001, "
+ 'errorMessage="Error while compiling statement"), operationHandle'
+ '=None)"'
+ )
+ self.assertEqual(
+ ("hive error: Error while compiling statement"),
+ HiveEngineSpec.extract_error_message(Exception(msg)),
+ )
+
+ def test_hive_get_view_names_return_empty_list(
+ self
+ ): # pylint: disable=invalid-name
+ self.assertEqual(
+ [], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
+ )
diff --git a/tests/db_engine_specs/mssql_tests.py b/tests/db_engine_specs/mssql_tests.py
new file mode 100644
index 0000000..989fa8c
--- /dev/null
+++ b/tests/db_engine_specs/mssql_tests.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from sqlalchemy import column, table
+from sqlalchemy.dialects import mssql
+from sqlalchemy.sql import select
+from sqlalchemy.types import String, UnicodeText
+
+from superset.db_engine_specs.mssql import MssqlEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class MssqlEngineSpecTest(DbEngineSpecTestCase):
+ def test_mssql_column_types(self):
+ def assert_type(type_string, type_expected):
+ type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
+ if type_expected is None:
+ self.assertIsNone(type_assigned)
+ else:
+ self.assertIsInstance(type_assigned, type_expected)
+
+ assert_type("INT", None)
+ assert_type("STRING", String)
+ assert_type("CHAR(10)", String)
+ assert_type("VARCHAR(10)", String)
+ assert_type("TEXT", String)
+ assert_type("NCHAR(10)", UnicodeText)
+ assert_type("NVARCHAR(10)", UnicodeText)
+ assert_type("NTEXT", UnicodeText)
+
+ def test_where_clause_n_prefix(self):
+ dialect = mssql.dialect()
+ spec = MssqlEngineSpec
+ str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)"))
+ unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT"))
+ tbl = table("tbl")
+ sel = (
+ select([str_col, unicode_col])
+ .select_from(tbl)
+ .where(str_col == "abc")
+ .where(unicode_col == "abc")
+ )
+
+ query = str(
+ sel.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
+ )
+ query_expected = (
+ "SELECT col, unicode_col \n"
+ "FROM tbl \n"
+ "WHERE col = 'abc' AND unicode_col = N'abc'"
+ )
+ self.assertEqual(query, query_expected)
+
+ def test_time_exp_mixd_case_col_1y(self):
+ col = column("MixedCase")
+ expr = MssqlEngineSpec.get_timestamp_expr(col, None, "P1Y")
+ result = str(expr.compile(None, dialect=mssql.dialect()))
+ self.assertEqual(result, "DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)")
diff --git a/tests/db_engine_specs/mysql_tests.py b/tests/db_engine_specs/mysql_tests.py
new file mode 100644
index 0000000..22205a8
--- /dev/null
+++ b/tests/db_engine_specs/mysql_tests.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+
+from superset.db_engine_specs.mysql import MySQLEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class MySQLEngineSpecsTestCase(DbEngineSpecTestCase):
+ @unittest.skipUnless(
+ DbEngineSpecTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
+ )
+ def test_get_datatype_mysql(self):
+ """Tests related to datatype mapping for MySQL"""
+ self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
+ self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
diff --git a/tests/db_engine_specs/oracle_tests.py b/tests/db_engine_specs/oracle_tests.py
new file mode 100644
index 0000000..285f616
--- /dev/null
+++ b/tests/db_engine_specs/oracle_tests.py
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from sqlalchemy import column
+from sqlalchemy.dialects import oracle
+
+from superset.db_engine_specs.oracle import OracleEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class OracleTestCase(DbEngineSpecTestCase):
+ def test_oracle_sqla_column_name_length_exceeded(self):
+ col = column("This_Is_32_Character_Column_Name")
+ label = OracleEngineSpec.make_label_compatible(col.name)
+ self.assertEqual(label.quote, True)
+ label_expected = "3b26974078683be078219674eeb8f5"
+ self.assertEqual(label, label_expected)
+
+ def test_oracle_time_expression_reserved_keyword_1m_grain(self):
+ col = column("decimal")
+ expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
+ result = str(expr.compile(dialect=oracle.dialect()))
+ self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
diff --git a/tests/db_engine_specs/pinot_tests.py b/tests/db_engine_specs/pinot_tests.py
new file mode 100644
index 0000000..a96e9c1
--- /dev/null
+++ b/tests/db_engine_specs/pinot_tests.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from sqlalchemy import column
+
+from superset.db_engine_specs.pinot import PinotEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class PinotTestCase(DbEngineSpecTestCase):
+ """ Tests pertaining to our Pinot database support """
+
+ def test_pinot_time_expression_sec_one_1m_grain(self):
+ col = column("tstamp")
+ expr = PinotEngineSpec.get_timestamp_expr(col, "epoch_s", "P1M")
+ result = str(expr.compile())
+ self.assertEqual(
+ result,
+ 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")',
+ ) # noqa
diff --git a/tests/db_engine_specs/postgres_tests.py b/tests/db_engine_specs/postgres_tests.py
new file mode 100644
index 0000000..3204c53
--- /dev/null
+++ b/tests/db_engine_specs/postgres_tests.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import mock
+
+from sqlalchemy import column, literal_column
+from sqlalchemy.dialects import postgresql
+
+from superset.db_engine_specs.postgres import PostgresEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class PostgresTests(DbEngineSpecTestCase):
+ def test_get_table_names(self):
+ """ Make sure postgres doesn't try to remove schema name from table name
+ ie. when try_remove_schema_from_table_name == False. """
+ inspector = mock.Mock()
+ inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
+ inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
+
+ pg_result_expected = ["schema.table", "table_2", "table_3"]
+ pg_result = PostgresEngineSpec.get_table_names(
+ database=mock.ANY, schema="schema", inspector=inspector
+ )
+ self.assertListEqual(pg_result_expected, pg_result)
+
+ def test_time_exp_literal_no_grain(self):
+ col = literal_column("COALESCE(a, b)")
+ expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
+ result = str(expr.compile(None, dialect=postgresql.dialect()))
+ self.assertEqual(result, "COALESCE(a, b)")
+
+ def test_time_exp_literal_1y_grain(self):
+ col = literal_column("COALESCE(a, b)")
+ expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
+ result = str(expr.compile(None, dialect=postgresql.dialect()))
+ self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
+
+ def test_time_ex_lowr_col_no_grain(self):
+ col = column("lower_case")
+ expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
+ result = str(expr.compile(None, dialect=postgresql.dialect()))
+ self.assertEqual(result, "lower_case")
+
+ def test_time_exp_lowr_col_sec_1y(self):
+ col = column("lower_case")
+ expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
+ result = str(expr.compile(None, dialect=postgresql.dialect()))
+ self.assertEqual(
+ result,
+ "DATE_TRUNC('year', "
+ "(timestamp 'epoch' + lower_case * interval '1 second'))",
+ )
+
+ def test_time_exp_mixd_case_col_1y(self):
+ col = column("MixedCase")
+ expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
+ result = str(expr.compile(None, dialect=postgresql.dialect()))
+ self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py
new file mode 100644
index 0000000..b727310
--- /dev/null
+++ b/tests/db_engine_specs/presto_tests.py
@@ -0,0 +1,343 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import mock, skipUnless
+
+import pandas as pd
+from sqlalchemy.engine.result import RowProxy
+from sqlalchemy.sql import select
+
+from superset.db_engine_specs.presto import PrestoEngineSpec
+from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
+
+
+class PrestoTests(DbEngineSpecTestCase):
+ @skipUnless(
+ DbEngineSpecTestCase.is_module_installed("pyhive"), "pyhive not installed"
+ )
+ def test_get_datatype_presto(self):
+ self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
+
+ def test_presto_get_view_names_return_empty_list(
+ self
+ ): # pylint: disable=invalid-name
+ self.assertEqual(
+ [], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
+ )
+
+ def verify_presto_column(self, column, expected_results):
+ inspector = mock.Mock()
+ inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
+ keymap = {
+ "Column": (None, None, 0),
+ "Type": (None, None, 1),
+ "Null": (None, None, 2),
+ }
+ row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap)
+ inspector.bind.execute = mock.Mock(return_value=[row])
+ results = PrestoEngineSpec.get_columns(inspector, "", "")
+ self.assertEqual(len(expected_results), len(results))
+ for expected_result, result in zip(expected_results, results):
+ self.assertEqual(expected_result[0], result["name"])
+ self.assertEqual(expected_result[1], str(result["type"]))
+
+ def test_presto_get_column(self):
+ presto_column = ("column_name", "boolean", "")
+ expected_results = [("column_name", "BOOLEAN")]
+ self.verify_presto_column(presto_column, expected_results)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_get_simple_row_column(self):
+ presto_column = ("column_name", "row(nested_obj double)", "")
+ expected_results = [("column_name", "ROW"), ("column_name.nested_obj", "FLOAT")]
+ self.verify_presto_column(presto_column, expected_results)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_get_simple_row_column_with_name_containing_whitespace(self):
+ presto_column = ("column name", "row(nested_obj double)", "")
+ expected_results = [("column name", "ROW"), ("column name.nested_obj", "FLOAT")]
+ self.verify_presto_column(presto_column, expected_results)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_get_simple_row_column_with_tricky_nested_field_name(self):
+ presto_column = ("column_name", 'row("Field Name(Tricky, Name)" double)', "")
+ expected_results = [
+ ("column_name", "ROW"),
+ ('column_name."Field Name(Tricky, Name)"', "FLOAT"),
+ ]
+ self.verify_presto_column(presto_column, expected_results)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_get_simple_array_column(self):
+ presto_column = ("column_name", "array(double)", "")
+ expected_results = [("column_name", "ARRAY")]
+ self.verify_presto_column(presto_column, expected_results)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_get_row_within_array_within_row_column(self):
+ presto_column = (
+ "column_name",
+ "row(nested_array array(row(nested_row double)), nested_obj double)",
+ "",
+ )
+ expected_results = [
+ ("column_name", "ROW"),
+ ("column_name.nested_array", "ARRAY"),
+ ("column_name.nested_array.nested_row", "FLOAT"),
+ ("column_name.nested_obj", "FLOAT"),
+ ]
+ self.verify_presto_column(presto_column, expected_results)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_get_array_within_row_within_array_column(self):
+ presto_column = (
+ "column_name",
+ "array(row(nested_array array(double), nested_obj double))",
+ "",
+ )
+ expected_results = [
+ ("column_name", "ARRAY"),
+ ("column_name.nested_array", "ARRAY"),
+ ("column_name.nested_obj", "FLOAT"),
+ ]
+ self.verify_presto_column(presto_column, expected_results)
+
+ def test_presto_get_fields(self):
+ cols = [
+ {"name": "column"},
+ {"name": "column.nested_obj"},
+ {"name": 'column."quoted.nested obj"'},
+ ]
+ actual_results = PrestoEngineSpec._get_fields(cols)
+ expected_results = [
+ {"name": '"column"', "label": "column"},
+ {"name": '"column"."nested_obj"', "label": "column.nested_obj"},
+ {
+ "name": '"column"."quoted.nested obj"',
+ "label": 'column."quoted.nested obj"',
+ },
+ ]
+ for actual_result, expected_result in zip(actual_results, expected_results):
+ self.assertEqual(actual_result.element.name, expected_result["name"])
+ self.assertEqual(actual_result.name, expected_result["label"])
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_expand_data_with_simple_structural_columns(self):
+ cols = [
+ {"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
+ {"name": "array_column", "type": "ARRAY(BIGINT)"},
+ ]
+ data = [
+ {"row_column": ["a"], "array_column": [1, 2, 3]},
+ {"row_column": ["b"], "array_column": [4, 5, 6]},
+ ]
+ actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
+ cols, data
+ )
+ expected_cols = [
+ {"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
+ {"name": "row_column.nested_obj", "type": "VARCHAR"},
+ {"name": "array_column", "type": "ARRAY(BIGINT)"},
+ ]
+
+ expected_data = [
+ {"array_column": 1, "row_column": ["a"], "row_column.nested_obj": "a"},
+ {"array_column": 2, "row_column": "", "row_column.nested_obj": ""},
+ {"array_column": 3, "row_column": "", "row_column.nested_obj": ""},
+ {"array_column": 4, "row_column": ["b"], "row_column.nested_obj": "b"},
+ {"array_column": 5, "row_column": "", "row_column.nested_obj": ""},
+ {"array_column": 6, "row_column": "", "row_column.nested_obj": ""},
+ ]
+
+ expected_expanded_cols = [{"name": "row_column.nested_obj", "type": "VARCHAR"}]
+ self.assertEqual(actual_cols, expected_cols)
+ self.assertEqual(actual_data, expected_data)
+ self.assertEqual(actual_expanded_cols, expected_expanded_cols)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_expand_data_with_complex_row_columns(self):
+ cols = [
+ {
+ "name": "row_column",
+ "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
+ }
+ ]
+ data = [{"row_column": ["a1", ["a2"]]}, {"row_column": ["b1", ["b2"]]}]
+ actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
+ cols, data
+ )
+ expected_cols = [
+ {
+ "name": "row_column",
+ "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
+ },
+ {"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
+ {"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
+ {"name": "row_column.nested_obj1", "type": "VARCHAR"},
+ ]
+ expected_data = [
+ {
+ "row_column": ["a1", ["a2"]],
+ "row_column.nested_obj1": "a1",
+ "row_column.nested_row": ["a2"],
+ "row_column.nested_row.nested_obj2": "a2",
+ },
+ {
+ "row_column": ["b1", ["b2"]],
+ "row_column.nested_obj1": "b1",
+ "row_column.nested_row": ["b2"],
+ "row_column.nested_row.nested_obj2": "b2",
+ },
+ ]
+
+ expected_expanded_cols = [
+ {"name": "row_column.nested_obj1", "type": "VARCHAR"},
+ {"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
+ {"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
+ ]
+ self.assertEqual(actual_cols, expected_cols)
+ self.assertEqual(actual_data, expected_data)
+ self.assertEqual(actual_expanded_cols, expected_expanded_cols)
+
+ @mock.patch.dict(
+ "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
+ )
+ def test_presto_expand_data_with_complex_array_columns(self):
+ cols = [
+ {"name": "int_column", "type": "BIGINT"},
+ {
+ "name": "array_column",
+ "type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
+ },
+ ]
+ data = [
+ {"int_column": 1, "array_column": [[[["a"], ["b"]]], [[["c"], ["d"]]]]},
+ {"int_column": 2, "array_column": [[[["e"], ["f"]]], [[["g"], ["h"]]]]},
+ ]
+ actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
+ cols, data
+ )
+ expected_cols = [
+ {"name": "int_column", "type": "BIGINT"},
+ {
+ "name": "array_column",
+ "type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
+ },
+ {
+ "name": "array_column.nested_array",
+ "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
+ },
+ {"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
+ ]
+ expected_data = [
+ {
+ "array_column": [[["a"], ["b"]]],
+ "array_column.nested_array": ["a"],
+ "array_column.nested_array.nested_obj": "a",
+ "int_column": 1,
+ },
+ {
+ "array_column": "",
+ "array_column.nested_array": ["b"],
+ "array_column.nested_array.nested_obj": "b",
+ "int_column": "",
+ },
+ {
+ "array_column": [[["c"], ["d"]]],
+ "array_column.nested_array": ["c"],
+ "array_column.nested_array.nested_obj": "c",
+ "int_column": "",
+ },
+ {
+ "array_column": "",
+ "array_column.nested_array": ["d"],
+ "array_column.nested_array.nested_obj": "d",
+ "int_column": "",
+ },
+ {
+ "array_column": [[["e"], ["f"]]],
+ "array_column.nested_array": ["e"],
+ "array_column.nested_array.nested_obj": "e",
+ "int_column": 2,
+ },
+ {
+ "array_column": "",
+ "array_column.nested_array": ["f"],
+ "array_column.nested_array.nested_obj": "f",
+ "int_column": "",
+ },
+ {
+ "array_column": [[["g"], ["h"]]],
+ "array_column.nested_array": ["g"],
+ "array_column.nested_array.nested_obj": "g",
+ "int_column": "",
+ },
+ {
+ "array_column": "",
+ "array_column.nested_array": ["h"],
+ "array_column.nested_array.nested_obj": "h",
+ "int_column": "",
+ },
+ ]
+ expected_expanded_cols = [
+ {
+ "name": "array_column.nested_array",
+ "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
+ },
+ {"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
+ ]
+ self.assertEqual(actual_cols, expected_cols)
+ self.assertEqual(actual_data, expected_data)
+ self.assertEqual(actual_expanded_cols, expected_expanded_cols)
+
+ def test_presto_extra_table_metadata(self):
+ db = mock.Mock()
+ db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
+ db.get_extra = mock.Mock(return_value={})
+ df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
+ db.get_df = mock.Mock(return_value=df)
+ PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
+ result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
+ self.assertEqual({"ds": "01-01-19", "hour": 1}, result["partitions"]["latest"])
+
+ def test_presto_where_latest_partition(self):
+ db = mock.Mock()
+ db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
+ db.get_extra = mock.Mock(return_value={})
+ df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
+ db.get_df = mock.Mock(return_value=df)
+ columns = [{"name": "ds"}, {"name": "hour"}]
+ result = PrestoEngineSpec.where_latest_partition(
+ "test_table", "test_schema", db, select(), columns
+ )
+ query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
+ self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
deleted file mode 100644
index 619ae4f..0000000
--- a/tests/db_engine_specs_test.py
+++ /dev/null
@@ -1,810 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import unittest
-from unittest import mock
-
-import pandas as pd
-from sqlalchemy import column, literal_column, table
-from sqlalchemy.dialects import mssql, oracle, postgresql
-from sqlalchemy.engine.result import RowProxy
-from sqlalchemy.sql import select
-from sqlalchemy.types import String, UnicodeText
-
-from superset import app
-from superset.db_engine_specs import engines
-from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains
-from superset.db_engine_specs.bigquery import BigQueryEngineSpec
-from superset.db_engine_specs.hive import HiveEngineSpec
-from superset.db_engine_specs.mssql import MssqlEngineSpec
-from superset.db_engine_specs.mysql import MySQLEngineSpec
-from superset.db_engine_specs.oracle import OracleEngineSpec
-from superset.db_engine_specs.pinot import PinotEngineSpec
-from superset.db_engine_specs.postgres import PostgresEngineSpec
-from superset.db_engine_specs.presto import PrestoEngineSpec
-from superset.db_engine_specs.sqlite import SqliteEngineSpec
-from superset.models.core import Database
-from superset.utils.core import get_example_database
-
-from .base_tests import SupersetTestCase
-
-
-class DbEngineSpecsTestCase(SupersetTestCase):
- def test_0_progress(self):
- log = """
- 17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
- 17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
- """.split(
- "\n"
- )
- self.assertEqual(0, HiveEngineSpec.progress(log))
-
- def test_number_of_jobs_progress(self):
- log = """
- 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
- """.split(
- "\n"
- )
- self.assertEqual(0, HiveEngineSpec.progress(log))
-
- def test_job_1_launched_progress(self):
- log = """
- 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
- 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
- """.split(
- "\n"
- )
- self.assertEqual(0, HiveEngineSpec.progress(log))
-
- def test_job_1_launched_stage_1_0_progress(self):
- log = """
- 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
- 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
- """.split(
- "\n"
- )
- self.assertEqual(0, HiveEngineSpec.progress(log))
-
- def test_job_1_launched_stage_1_map_40_progress(self):
- log = """
- 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
- 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
- """.split(
- "\n"
- )
- self.assertEqual(10, HiveEngineSpec.progress(log))
-
- def test_job_1_launched_stage_1_map_80_reduce_40_progress(self):
- log = """
- 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
- 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
- """.split(
- "\n"
- )
- self.assertEqual(30, HiveEngineSpec.progress(log))
-
- def test_job_1_launched_stage_2_stages_progress(self):
- log = """
- 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
- 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
- """.split(
- "\n"
- )
- self.assertEqual(12, HiveEngineSpec.progress(log))
-
- def test_job_2_launched_stage_2_stages_progress(self):
- log = """
- 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
- 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
- 17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
- 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
- """.split(
- "\n"
- )
- self.assertEqual(60, HiveEngineSpec.progress(log))
-
- def test_hive_error_msg(self):
- msg = (
- '{...} errorMessage="Error while compiling statement: FAILED: '
- "SemanticException [Error 10001]: Line 4"
- ":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
- "sqlState='42S02', errorCode=10001)){...}"
- )
- self.assertEqual(
- (
- "hive error: Error while compiling statement: FAILED: "
- "SemanticException [Error 10001]: Line 4:5 "
- "Table not found 'fact_ridesfdslakj'"
- ),
- HiveEngineSpec.extract_error_message(Exception(msg)),
- )
-
- e = Exception("Some string that doesn't match the regex")
- self.assertEqual(f"hive error: {e}", HiveEngineSpec.extract_error_message(e))
-
- msg = (
- "errorCode=10001, "
- 'errorMessage="Error while compiling statement"), operationHandle'
- '=None)"'
- )
- self.assertEqual(
- ("hive error: Error while compiling statement"),
- HiveEngineSpec.extract_error_message(Exception(msg)),
- )
-
- def get_generic_database(self):
- return Database(database_name="test_database", sqlalchemy_uri="sqlite://")
-
- def sql_limit_regex(
- self, sql, expected_sql, engine_spec_class=MySQLEngineSpec, limit=1000
- ):
- main = self.get_generic_database()
- limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
- self.assertEqual(expected_sql, limited)
-
- def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec):
- q0 = "select * from table"
- q1 = "select * from mytable limit 10"
- q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
- q3 = "select * from (select * from my_subquery limit 10);"
- q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
- q5 = "select * from mytable limit 20, 10"
- q6 = "select * from mytable limit 10 offset 20"
- q7 = "select * from mytable limit"
- q8 = "select * from mytable limit 10.0"
- q9 = "select * from mytable limit x"
- q10 = "select * from mytable limit 20, x"
- q11 = "select * from mytable limit x offset 20"
-
- self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
-
- def test_wrapped_query(self):
- self.sql_limit_regex(
- "SELECT * FROM a",
- "SELECT * \nFROM (SELECT * FROM a) AS inner_qry\n LIMIT 1000 OFFSET 0",
- MssqlEngineSpec,
- )
-
- @unittest.skipUnless(
- SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
- )
- def test_wrapped_semi_tabs(self):
- self.sql_limit_regex(
- "SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
- )
-
- def test_simple_limit_query(self):
- self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
-
- def test_modify_limit_query(self):
- self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
-
- def test_limit_query_with_limit_subquery(self):
- self.sql_limit_regex(
- "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
- "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
- )
-
- def test_limit_with_expr(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 1000""",
- )
-
- def test_limit_expr_and_semicolon(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990 ;""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 1000""",
- )
-
- @unittest.skipUnless(
- SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
- )
- def test_get_datatype_mysql(self):
- self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
- self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
-
- @unittest.skipUnless(
- SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
- )
- def test_get_datatype_presto(self):
- self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
-
- def test_get_datatype(self):
- self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
-
- def test_limit_with_implicit_offset(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990, 999999""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990, 1000""",
- )
-
- def test_limit_with_explicit_offset(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990
- OFFSET 999999""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 1000
- OFFSET 999999""",
- )
-
- def test_limit_with_non_token_limit(self):
- self.sql_limit_regex(
- """SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
- )
-
- def test_time_grain_blacklist(self):
- with app.app_context():
- app.config["TIME_GRAIN_BLACKLIST"] = ["PT1M"]
- time_grain_functions = SqliteEngineSpec.get_time_grain_functions()
- self.assertNotIn("PT1M", time_grain_functions)
-
- def test_time_grain_addons(self):
- with app.app_context():
- app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
- app.config["TIME_GRAIN_ADDON_FUNCTIONS"] = {
- "sqlite": {"PTXM": "ABC({col})"}
- }
- time_grains = SqliteEngineSpec.get_time_grains()
- time_grain_addon = time_grains[-1]
- self.assertEqual("PTXM", time_grain_addon.duration)
- self.assertEqual("x seconds", time_grain_addon.label)
-
- def test_engine_time_grain_validity(self):
- time_grains = set(builtin_time_grains.keys())
- # loop over all subclasses of BaseEngineSpec
- for engine in engines.values():
- if engine is not BaseEngineSpec:
- # make sure time grain functions have been defined
- self.assertGreater(len(engine.get_time_grain_functions()), 0)
- # make sure all defined time grains are supported
- defined_grains = {grain.duration for grain in engine.get_time_grains()}
- intersection = time_grains.intersection(defined_grains)
- self.assertSetEqual(defined_grains, intersection, engine)
-
- def test_presto_get_view_names_return_empty_list(self):
- self.assertEqual(
- [], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
- )
-
- def verify_presto_column(self, column, expected_results):
- inspector = mock.Mock()
- inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
- keymap = {
- "Column": (None, None, 0),
- "Type": (None, None, 1),
- "Null": (None, None, 2),
- }
- row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap)
- inspector.bind.execute = mock.Mock(return_value=[row])
- results = PrestoEngineSpec.get_columns(inspector, "", "")
- self.assertEqual(len(expected_results), len(results))
- for expected_result, result in zip(expected_results, results):
- self.assertEqual(expected_result[0], result["name"])
- self.assertEqual(expected_result[1], str(result["type"]))
-
- def test_presto_get_column(self):
- presto_column = ("column_name", "boolean", "")
- expected_results = [("column_name", "BOOLEAN")]
- self.verify_presto_column(presto_column, expected_results)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_get_simple_row_column(self):
- presto_column = ("column_name", "row(nested_obj double)", "")
- expected_results = [("column_name", "ROW"), ("column_name.nested_obj", "FLOAT")]
- self.verify_presto_column(presto_column, expected_results)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_get_simple_row_column_with_name_containing_whitespace(self):
- presto_column = ("column name", "row(nested_obj double)", "")
- expected_results = [("column name", "ROW"), ("column name.nested_obj", "FLOAT")]
- self.verify_presto_column(presto_column, expected_results)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_get_simple_row_column_with_tricky_nested_field_name(self):
- presto_column = ("column_name", 'row("Field Name(Tricky, Name)" double)', "")
- expected_results = [
- ("column_name", "ROW"),
- ('column_name."Field Name(Tricky, Name)"', "FLOAT"),
- ]
- self.verify_presto_column(presto_column, expected_results)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_get_simple_array_column(self):
- presto_column = ("column_name", "array(double)", "")
- expected_results = [("column_name", "ARRAY")]
- self.verify_presto_column(presto_column, expected_results)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_get_row_within_array_within_row_column(self):
- presto_column = (
- "column_name",
- "row(nested_array array(row(nested_row double)), nested_obj double)",
- "",
- )
- expected_results = [
- ("column_name", "ROW"),
- ("column_name.nested_array", "ARRAY"),
- ("column_name.nested_array.nested_row", "FLOAT"),
- ("column_name.nested_obj", "FLOAT"),
- ]
- self.verify_presto_column(presto_column, expected_results)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_get_array_within_row_within_array_column(self):
- presto_column = (
- "column_name",
- "array(row(nested_array array(double), nested_obj double))",
- "",
- )
- expected_results = [
- ("column_name", "ARRAY"),
- ("column_name.nested_array", "ARRAY"),
- ("column_name.nested_obj", "FLOAT"),
- ]
- self.verify_presto_column(presto_column, expected_results)
-
- def test_presto_get_fields(self):
- cols = [
- {"name": "column"},
- {"name": "column.nested_obj"},
- {"name": 'column."quoted.nested obj"'},
- ]
- actual_results = PrestoEngineSpec._get_fields(cols)
- expected_results = [
- {"name": '"column"', "label": "column"},
- {"name": '"column"."nested_obj"', "label": "column.nested_obj"},
- {
- "name": '"column"."quoted.nested obj"',
- "label": 'column."quoted.nested obj"',
- },
- ]
- for actual_result, expected_result in zip(actual_results, expected_results):
- self.assertEqual(actual_result.element.name, expected_result["name"])
- self.assertEqual(actual_result.name, expected_result["label"])
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_expand_data_with_simple_structural_columns(self):
- cols = [
- {"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
- {"name": "array_column", "type": "ARRAY(BIGINT)"},
- ]
- data = [
- {"row_column": ["a"], "array_column": [1, 2, 3]},
- {"row_column": ["b"], "array_column": [4, 5, 6]},
- ]
- actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
- cols, data
- )
- expected_cols = [
- {"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
- {"name": "row_column.nested_obj", "type": "VARCHAR"},
- {"name": "array_column", "type": "ARRAY(BIGINT)"},
- ]
-
- expected_data = [
- {"array_column": 1, "row_column": ["a"], "row_column.nested_obj": "a"},
- {"array_column": 2, "row_column": "", "row_column.nested_obj": ""},
- {"array_column": 3, "row_column": "", "row_column.nested_obj": ""},
- {"array_column": 4, "row_column": ["b"], "row_column.nested_obj": "b"},
- {"array_column": 5, "row_column": "", "row_column.nested_obj": ""},
- {"array_column": 6, "row_column": "", "row_column.nested_obj": ""},
- ]
-
- expected_expanded_cols = [{"name": "row_column.nested_obj", "type": "VARCHAR"}]
- self.assertEqual(actual_cols, expected_cols)
- self.assertEqual(actual_data, expected_data)
- self.assertEqual(actual_expanded_cols, expected_expanded_cols)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_expand_data_with_complex_row_columns(self):
- cols = [
- {
- "name": "row_column",
- "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
- }
- ]
- data = [{"row_column": ["a1", ["a2"]]}, {"row_column": ["b1", ["b2"]]}]
- actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
- cols, data
- )
- expected_cols = [
- {
- "name": "row_column",
- "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
- },
- {"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
- {"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
- {"name": "row_column.nested_obj1", "type": "VARCHAR"},
- ]
- expected_data = [
- {
- "row_column": ["a1", ["a2"]],
- "row_column.nested_obj1": "a1",
- "row_column.nested_row": ["a2"],
- "row_column.nested_row.nested_obj2": "a2",
- },
- {
- "row_column": ["b1", ["b2"]],
- "row_column.nested_obj1": "b1",
- "row_column.nested_row": ["b2"],
- "row_column.nested_row.nested_obj2": "b2",
- },
- ]
-
- expected_expanded_cols = [
- {"name": "row_column.nested_obj1", "type": "VARCHAR"},
- {"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
- {"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
- ]
- self.assertEqual(actual_cols, expected_cols)
- self.assertEqual(actual_data, expected_data)
- self.assertEqual(actual_expanded_cols, expected_expanded_cols)
-
- @mock.patch.dict(
- "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
- )
- def test_presto_expand_data_with_complex_array_columns(self):
- cols = [
- {"name": "int_column", "type": "BIGINT"},
- {
- "name": "array_column",
- "type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
- },
- ]
- data = [
- {"int_column": 1, "array_column": [[[["a"], ["b"]]], [[["c"], ["d"]]]]},
- {"int_column": 2, "array_column": [[[["e"], ["f"]]], [[["g"], ["h"]]]]},
- ]
- actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
- cols, data
- )
- expected_cols = [
- {"name": "int_column", "type": "BIGINT"},
- {
- "name": "array_column",
- "type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
- },
- {
- "name": "array_column.nested_array",
- "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
- },
- {"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
- ]
- expected_data = [
- {
- "array_column": [[["a"], ["b"]]],
- "array_column.nested_array": ["a"],
- "array_column.nested_array.nested_obj": "a",
- "int_column": 1,
- },
- {
- "array_column": "",
- "array_column.nested_array": ["b"],
- "array_column.nested_array.nested_obj": "b",
- "int_column": "",
- },
- {
- "array_column": [[["c"], ["d"]]],
- "array_column.nested_array": ["c"],
- "array_column.nested_array.nested_obj": "c",
- "int_column": "",
- },
- {
- "array_column": "",
- "array_column.nested_array": ["d"],
- "array_column.nested_array.nested_obj": "d",
- "int_column": "",
- },
- {
- "array_column": [[["e"], ["f"]]],
- "array_column.nested_array": ["e"],
- "array_column.nested_array.nested_obj": "e",
- "int_column": 2,
- },
- {
- "array_column": "",
- "array_column.nested_array": ["f"],
- "array_column.nested_array.nested_obj": "f",
- "int_column": "",
- },
- {
- "array_column": [[["g"], ["h"]]],
- "array_column.nested_array": ["g"],
- "array_column.nested_array.nested_obj": "g",
- "int_column": "",
- },
- {
- "array_column": "",
- "array_column.nested_array": ["h"],
- "array_column.nested_array.nested_obj": "h",
- "int_column": "",
- },
- ]
- expected_expanded_cols = [
- {
- "name": "array_column.nested_array",
- "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
- },
- {"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
- ]
- self.assertEqual(actual_cols, expected_cols)
- self.assertEqual(actual_data, expected_data)
- self.assertEqual(actual_expanded_cols, expected_expanded_cols)
-
- def test_presto_extra_table_metadata(self):
- db = mock.Mock()
- db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
- db.get_extra = mock.Mock(return_value={})
- df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
- db.get_df = mock.Mock(return_value=df)
- PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
- result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
- self.assertEqual({"ds": "01-01-19", "hour": 1}, result["partitions"]["latest"])
-
- def test_presto_where_latest_partition(self):
- db = mock.Mock()
- db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
- db.get_extra = mock.Mock(return_value={})
- df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
- db.get_df = mock.Mock(return_value=df)
- columns = [{"name": "ds"}, {"name": "hour"}]
- result = PrestoEngineSpec.where_latest_partition(
- "test_table", "test_schema", db, select(), columns
- )
- query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
- self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
-
- def test_hive_get_view_names_return_empty_list(self):
- self.assertEqual(
- [], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
- )
-
- def test_bigquery_sqla_column_label(self):
- label = BigQueryEngineSpec.make_label_compatible(column("Col").name)
- label_expected = "Col"
- self.assertEqual(label, label_expected)
-
- label = BigQueryEngineSpec.make_label_compatible(column("SUM(x)").name)
- label_expected = "SUM_x__5f110"
- self.assertEqual(label, label_expected)
-
- label = BigQueryEngineSpec.make_label_compatible(column("SUM[x]").name)
- label_expected = "SUM_x__7ebe1"
- self.assertEqual(label, label_expected)
-
- label = BigQueryEngineSpec.make_label_compatible(column("12345_col").name)
- label_expected = "_12345_col_8d390"
- self.assertEqual(label, label_expected)
-
- def test_oracle_sqla_column_name_length_exceeded(self):
- col = column("This_Is_32_Character_Column_Name")
- label = OracleEngineSpec.make_label_compatible(col.name)
- self.assertEqual(label.quote, True)
- label_expected = "3b26974078683be078219674eeb8f5"
- self.assertEqual(label, label_expected)
-
- def test_mssql_column_types(self):
- def assert_type(type_string, type_expected):
- type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
- if type_expected is None:
- self.assertIsNone(type_assigned)
- else:
- self.assertIsInstance(type_assigned, type_expected)
-
- assert_type("INT", None)
- assert_type("STRING", String)
- assert_type("CHAR(10)", String)
- assert_type("VARCHAR(10)", String)
- assert_type("TEXT", String)
- assert_type("NCHAR(10)", UnicodeText)
- assert_type("NVARCHAR(10)", UnicodeText)
- assert_type("NTEXT", UnicodeText)
-
- def test_mssql_where_clause_n_prefix(self):
- dialect = mssql.dialect()
- spec = MssqlEngineSpec
- str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)"))
- unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT"))
- tbl = table("tbl")
- sel = (
- select([str_col, unicode_col])
- .select_from(tbl)
- .where(str_col == "abc")
- .where(unicode_col == "abc")
- )
-
- query = str(
- sel.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
- )
- query_expected = (
- "SELECT col, unicode_col \n"
- "FROM tbl \n"
- "WHERE col = 'abc' AND unicode_col = N'abc'"
- )
- self.assertEqual(query, query_expected)
-
- def test_get_table_names(self):
- inspector = mock.Mock()
- inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
- inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
-
- """ Make sure base engine spec removes schema name from table name
- ie. when try_remove_schema_from_table_name == True. """
- base_result_expected = ["table", "table_2"]
- base_result = BaseEngineSpec.get_table_names(
- database=mock.ANY, schema="schema", inspector=inspector
- )
- self.assertListEqual(base_result_expected, base_result)
-
- """ Make sure postgres doesn't try to remove schema name from table name
- ie. when try_remove_schema_from_table_name == False. """
- pg_result_expected = ["schema.table", "table_2", "table_3"]
- pg_result = PostgresEngineSpec.get_table_names(
- database=mock.ANY, schema="schema", inspector=inspector
- )
- self.assertListEqual(pg_result_expected, pg_result)
-
- def test_pg_time_expression_literal_no_grain(self):
- col = literal_column("COALESCE(a, b)")
- expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
- result = str(expr.compile(dialect=postgresql.dialect()))
- self.assertEqual(result, "COALESCE(a, b)")
-
- def test_pg_time_expression_literal_1y_grain(self):
- col = literal_column("COALESCE(a, b)")
- expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
- result = str(expr.compile(dialect=postgresql.dialect()))
- self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
-
- def test_pg_time_expression_lower_column_no_grain(self):
- col = column("lower_case")
- expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
- result = str(expr.compile(dialect=postgresql.dialect()))
- self.assertEqual(result, "lower_case")
-
- def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
- col = column("lower_case")
- expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
- result = str(expr.compile(dialect=postgresql.dialect()))
- self.assertEqual(
- result,
- "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))",
- )
-
- def test_pg_time_expression_mixed_case_column_1y_grain(self):
- col = column("MixedCase")
- expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
- result = str(expr.compile(dialect=postgresql.dialect()))
- self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
-
- def test_mssql_time_expression_mixed_case_column_1y_grain(self):
- col = column("MixedCase")
- expr = MssqlEngineSpec.get_timestamp_expr(col, None, "P1Y")
- result = str(expr.compile(dialect=mssql.dialect()))
- self.assertEqual(result, "DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)")
-
- def test_oracle_time_expression_reserved_keyword_1m_grain(self):
- col = column("decimal")
- expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
- result = str(expr.compile(dialect=oracle.dialect()))
- self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
-
- def test_pinot_time_expression_sec_1m_grain(self):
- col = column("tstamp")
- expr = PinotEngineSpec.get_timestamp_expr(col, "epoch_s", "P1M")
- result = str(expr.compile())
- self.assertEqual(
- result,
- 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")',
- )
-
- def test_column_datatype_to_string(self):
- example_db = get_example_database()
- sqla_table = example_db.get_table("energy_usage")
- dialect = example_db.get_dialect()
- col_names = [
- example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
- for c in sqla_table.columns
- ]
- if example_db.backend == "postgresql":
- expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
- else:
- expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
- self.assertEqual(col_names, expected)