You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2023/06/01 19:01:25 UTC
[superset] branch master updated: chore(pre-commit): Add pyupgrade and pycln hooks (#24197)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new a4d5d7c6b9 chore(pre-commit): Add pyupgrade and pycln hooks (#24197)
a4d5d7c6b9 is described below
commit a4d5d7c6b9f470c4ba0a6cc4e896a3f12beb6f43
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Thu Jun 1 12:01:10 2023 -0700
chore(pre-commit): Add pyupgrade and pycln hooks (#24197)
---
.pre-commit-config.yaml | 22 +-
RELEASING/changelog.py | 29 ++-
RELEASING/generate_email.py | 8 +-
docker/pythonpath_dev/superset_config.py | 7 +-
scripts/benchmark_migration.py | 24 +-
scripts/cancel_github_workflows.py | 23 +-
scripts/permissions_cleanup.py | 6 +-
setup.py | 6 +-
.../advanced_data_type/plugins/internet_address.py | 4 +-
.../advanced_data_type/plugins/internet_port.py | 6 +-
superset/advanced_data_type/types.py | 10 +-
superset/annotation_layers/annotations/api.py | 4 +-
.../annotations/commands/bulk_delete.py | 6 +-
.../annotations/commands/create.py | 6 +-
.../annotations/commands/update.py | 6 +-
superset/annotation_layers/annotations/dao.py | 4 +-
superset/annotation_layers/commands/bulk_delete.py | 6 +-
superset/annotation_layers/commands/create.py | 6 +-
superset/annotation_layers/commands/update.py | 6 +-
superset/annotation_layers/dao.py | 6 +-
superset/charts/commands/bulk_delete.py | 6 +-
superset/charts/commands/create.py | 6 +-
superset/charts/commands/export.py | 4 +-
superset/charts/commands/importers/dispatcher.py | 4 +-
superset/charts/commands/importers/v1/__init__.py | 15 +-
superset/charts/commands/importers/v1/utils.py | 4 +-
superset/charts/commands/update.py | 10 +-
superset/charts/dao.py | 6 +-
superset/charts/data/api.py | 18 +-
.../data/commands/create_async_job_command.py | 4 +-
superset/charts/data/commands/get_data_command.py | 4 +-
superset/charts/data/query_context_cache_loader.py | 4 +-
superset/charts/post_processing.py | 26 +-
superset/charts/schemas.py | 6 +-
superset/cli/importexport.py | 6 +-
superset/cli/main.py | 6 +-
superset/cli/native_filters.py | 9 +-
superset/cli/thumbnails.py | 4 +-
superset/commands/base.py | 6 +-
superset/commands/exceptions.py | 12 +-
superset/commands/export/assets.py | 4 +-
superset/commands/export/models.py | 14 +-
superset/commands/importers/v1/__init__.py | 26 +-
superset/commands/importers/v1/assets.py | 30 +--
superset/commands/importers/v1/examples.py | 18 +-
superset/commands/importers/v1/utils.py | 38 +--
superset/commands/utils.py | 10 +-
superset/common/chart_data.py | 3 +-
superset/common/query_actions.py | 24 +-
superset/common/query_context.py | 36 +--
superset/common/query_context_factory.py | 18 +-
superset/common/query_context_processor.py | 38 +--
superset/common/query_object.py | 106 ++++----
superset/common/query_object_factory.py | 22 +-
superset/common/tags.py | 18 +-
superset/common/utils/dataframe_utils.py | 4 +-
superset/common/utils/query_cache_manager.py | 52 ++--
superset/common/utils/time_range_utils.py | 12 +-
superset/config.py | 137 +++++-----
superset/connectors/base/models.py | 121 ++++-----
superset/connectors/sqla/models.py | 169 ++++++------
superset/connectors/sqla/utils.py | 30 +--
superset/connectors/sqla/views.py | 2 +-
superset/css_templates/commands/bulk_delete.py | 6 +-
superset/css_templates/dao.py | 4 +-
superset/dao/base.py | 16 +-
superset/dashboards/commands/bulk_delete.py | 6 +-
superset/dashboards/commands/create.py | 10 +-
superset/dashboards/commands/export.py | 9 +-
.../dashboards/commands/importers/dispatcher.py | 4 +-
superset/dashboards/commands/importers/v0.py | 14 +-
.../dashboards/commands/importers/v1/__init__.py | 20 +-
superset/dashboards/commands/importers/v1/utils.py | 20 +-
superset/dashboards/commands/update.py | 10 +-
superset/dashboards/dao.py | 20 +-
superset/dashboards/filter_sets/commands/create.py | 4 +-
superset/dashboards/filter_sets/commands/update.py | 4 +-
superset/dashboards/filter_sets/dao.py | 4 +-
superset/dashboards/filter_sets/schemas.py | 15 +-
superset/dashboards/filter_state/api.py | 9 +-
superset/dashboards/permalink/types.py | 8 +-
superset/dashboards/schemas.py | 8 +-
superset/databases/api.py | 6 +-
superset/databases/commands/create.py | 6 +-
superset/databases/commands/export.py | 7 +-
.../databases/commands/importers/dispatcher.py | 4 +-
.../databases/commands/importers/v1/__init__.py | 8 +-
superset/databases/commands/importers/v1/utils.py | 4 +-
superset/databases/commands/tables.py | 4 +-
superset/databases/commands/test_connection.py | 4 +-
superset/databases/commands/update.py | 8 +-
superset/databases/commands/validate.py | 4 +-
superset/databases/commands/validate_sql.py | 12 +-
superset/databases/dao.py | 6 +-
superset/databases/filters.py | 4 +-
superset/databases/schemas.py | 26 +-
superset/databases/ssh_tunnel/commands/create.py | 6 +-
superset/databases/ssh_tunnel/commands/update.py | 4 +-
superset/databases/ssh_tunnel/dao.py | 4 +-
superset/databases/ssh_tunnel/models.py | 4 +-
superset/databases/utils.py | 12 +-
superset/dataframe.py | 4 +-
superset/datasets/commands/bulk_delete.py | 6 +-
superset/datasets/commands/create.py | 8 +-
superset/datasets/commands/duplicate.py | 6 +-
superset/datasets/commands/export.py | 4 +-
superset/datasets/commands/importers/dispatcher.py | 4 +-
superset/datasets/commands/importers/v0.py | 8 +-
.../datasets/commands/importers/v1/__init__.py | 10 +-
superset/datasets/commands/importers/v1/utils.py | 6 +-
superset/datasets/commands/update.py | 22 +-
superset/datasets/dao.py | 28 +-
superset/datasets/models.py | 5 +-
superset/datasets/schemas.py | 8 +-
superset/datasource/dao.py | 4 +-
superset/db_engine_specs/__init__.py | 12 +-
superset/db_engine_specs/athena.py | 7 +-
superset/db_engine_specs/base.py | 256 +++++++++----------
superset/db_engine_specs/bigquery.py | 45 ++--
superset/db_engine_specs/clickhouse.py | 22 +-
superset/db_engine_specs/crate.py | 6 +-
superset/db_engine_specs/databricks.py | 22 +-
superset/db_engine_specs/dremio.py | 4 +-
superset/db_engine_specs/drill.py | 10 +-
superset/db_engine_specs/druid.py | 14 +-
superset/db_engine_specs/duckdb.py | 13 +-
superset/db_engine_specs/dynamodb.py | 4 +-
superset/db_engine_specs/elasticsearch.py | 10 +-
superset/db_engine_specs/exasol.py | 4 +-
superset/db_engine_specs/firebird.py | 4 +-
superset/db_engine_specs/firebolt.py | 4 +-
superset/db_engine_specs/gsheets.py | 19 +-
superset/db_engine_specs/hana.py | 4 +-
superset/db_engine_specs/hive.py | 84 +++---
superset/db_engine_specs/impala.py | 6 +-
superset/db_engine_specs/kusto.py | 16 +-
superset/db_engine_specs/kylin.py | 4 +-
superset/db_engine_specs/mssql.py | 9 +-
superset/db_engine_specs/mysql.py | 15 +-
superset/db_engine_specs/ocient.py | 33 ++-
superset/db_engine_specs/oracle.py | 6 +-
superset/db_engine_specs/pinot.py | 8 +-
superset/db_engine_specs/postgres.py | 31 +--
superset/db_engine_specs/presto.py | 146 +++++------
superset/db_engine_specs/redshift.py | 7 +-
superset/db_engine_specs/rockset.py | 4 +-
superset/db_engine_specs/snowflake.py | 33 +--
superset/db_engine_specs/sqlite.py | 9 +-
superset/db_engine_specs/starrocks.py | 31 +--
superset/db_engine_specs/trino.py | 32 +--
superset/embedded/dao.py | 6 +-
superset/errors.py | 6 +-
superset/examples/bart_lines.py | 2 +-
superset/examples/big_data.py | 3 +-
superset/examples/birth_names.py | 8 +-
superset/examples/countries.py | 8 +-
superset/examples/helpers.py | 10 +-
superset/examples/multiformat_time_series.py | 4 +-
superset/examples/paris.py | 2 +-
superset/examples/sf_population_polygons.py | 2 +-
superset/examples/supported_charts_dashboard.py | 3 +-
superset/examples/utils.py | 8 +-
superset/examples/world_bank.py | 3 +-
superset/exceptions.py | 18 +-
superset/explore/commands/get.py | 6 +-
superset/explore/permalink/commands/create.py | 4 +-
superset/explore/permalink/types.py | 6 +-
superset/extensions/__init__.py | 14 +-
superset/extensions/metastore_cache.py | 4 +-
superset/forms.py | 12 +-
superset/initialization/__init__.py | 6 +-
superset/jinja_context.py | 56 ++--
superset/key_value/types.py | 10 +-
superset/key_value/utils.py | 4 +-
superset/legacy.py | 4 +-
superset/migrations/env.py | 3 +-
superset/migrations/shared/migrate_viz/base.py | 8 +-
superset/migrations/shared/security_converge.py | 15 +-
superset/migrations/shared/utils.py | 5 +-
...4_12-31_db0c65b146bd_update_slice_model_json.py | 2 +-
...7c195a_rewriting_url_from_shortner_with_new_.py | 2 +-
.../versions/2017-10-03_14-37_4736ec66ce19_.py | 10 +-
...2-17_11-06_21e88bc06c02_annotation_migration.py | 2 +-
.../2018-02-13_08-07_e866bd2d4976_smaller_grid.py | 4 +-
.../versions/2018-03-20_19-47_f231d82b9b26_.py | 4 +-
...9_bf706ae5eb46_cal_heatmap_metric_to_metrics.py | 2 +-
.../2018-06-13_14-54_bddc498dd179_adhoc_filters.py | 2 -
...1c4c6_migrate_num_period_compare_and_period_.py | 18 +-
..._bebcf3fed1fe_convert_dashboard_v1_positions.py | 16 +-
...08545_migrate_time_range_for_default_filters.py | 4 +-
...127d0d1d_reconvert_legacy_filters_into_adhoc.py | 2 -
...25_10-49_b5998378c225_add_certificate_to_dbs.py | 3 +-
...978245563a02_migrate_iframe_to_dash_markdown.py | 3 +-
...654_fix_data_access_permissions_for_virtual_.py | 2 +-
...3a3a8ff221_migrate_filter_sets_to_new_format.py | 9 +-
...ed7ec95_migrate_native_filters_to_new_schema.py | 11 +-
...15da_migrate_pivot_table_v2_heatmaps_to_new_.py | 1 -
...453f4e2e_migrate_timeseries_limit_metric_to_.py | 1 -
...12_11-15_32646df09c64_update_time_grain_sqla.py | 3 +-
...14-38_a9422eeaae74_new_dataset_models_take_2.py | 8 +-
superset/models/annotations.py | 4 +-
superset/models/core.py | 57 +++--
superset/models/dashboard.py | 48 ++--
superset/models/datasource_access_request.py | 10 +-
superset/models/embedded_dashboard.py | 3 +-
superset/models/filter_set.py | 6 +-
superset/models/helpers.py | 190 +++++++-------
superset/models/slice.py | 38 +--
superset/models/sql_lab.py | 38 +--
superset/models/sql_types/presto_sql_types.py | 12 +-
superset/queries/dao.py | 6 +-
.../queries/saved_queries/commands/bulk_delete.py | 6 +-
superset/queries/saved_queries/commands/export.py | 4 +-
.../saved_queries/commands/importers/dispatcher.py | 4 +-
.../commands/importers/v1/__init__.py | 10 +-
.../saved_queries/commands/importers/v1/utils.py | 4 +-
superset/queries/saved_queries/dao.py | 4 +-
superset/queries/schemas.py | 3 +-
superset/reports/commands/alert.py | 4 +-
superset/reports/commands/base.py | 6 +-
superset/reports/commands/bulk_delete.py | 6 +-
superset/reports/commands/create.py | 10 +-
superset/reports/commands/exceptions.py | 5 +-
superset/reports/commands/execute.py | 10 +-
superset/reports/commands/update.py | 8 +-
superset/reports/dao.py | 22 +-
superset/reports/filters.py | 2 +-
superset/reports/logs/api.py | 4 +-
superset/reports/notifications/__init__.py | 1 -
superset/reports/notifications/base.py | 7 +-
superset/reports/notifications/email.py | 7 +-
superset/reports/notifications/slack.py | 4 +-
superset/reports/schemas.py | 4 +-
superset/result_set.py | 22 +-
.../row_level_security/commands/bulk_delete.py | 5 +-
superset/row_level_security/commands/create.py | 4 +-
superset/row_level_security/commands/update.py | 4 +-
superset/security/api.py | 6 +-
superset/security/guest_token.py | 8 +-
superset/security/manager.py | 61 ++---
superset/sql_lab.py | 28 +-
superset/sql_parse.py | 17 +-
superset/sql_validators/__init__.py | 4 +-
superset/sql_validators/base.py | 6 +-
superset/sql_validators/postgres.py | 6 +-
superset/sql_validators/presto_db.py | 8 +-
superset/sqllab/api.py | 4 +-
superset/sqllab/commands/estimate.py | 8 +-
superset/sqllab/commands/execute.py | 14 +-
superset/sqllab/commands/export.py | 4 +-
superset/sqllab/commands/results.py | 8 +-
superset/sqllab/exceptions.py | 26 +-
superset/sqllab/execution_context_convertor.py | 4 +-
superset/sqllab/query_render.py | 16 +-
superset/sqllab/sql_json_executer.py | 18 +-
superset/sqllab/sqllab_execution_context.py | 36 +--
superset/sqllab/utils.py | 6 +-
superset/stats_logger.py | 16 +-
superset/superset_typing.py | 35 +--
superset/tables/models.py | 13 +-
superset/tags/commands/create.py | 3 +-
superset/tags/commands/delete.py | 3 +-
superset/tags/dao.py | 12 +-
superset/tags/models.py | 40 ++-
superset/tasks/__init__.py | 1 -
superset/tasks/async_queries.py | 16 +-
superset/tasks/cache.py | 18 +-
superset/tasks/cron_util.py | 2 +-
superset/tasks/utils.py | 12 +-
superset/translations/utils.py | 8 +-
superset/utils/async_query_manager.py | 14 +-
superset/utils/cache.py | 22 +-
superset/utils/celery.py | 2 +-
superset/utils/core.py | 284 ++++++++++-----------
superset/utils/csv.py | 6 +-
.../utils/dashboard_filter_scopes_converter.py | 36 +--
superset/utils/database.py | 4 +-
superset/utils/date_parser.py | 8 +-
superset/utils/decorators.py | 9 +-
superset/utils/dict_import_export.py | 6 +-
superset/utils/encrypt.py | 26 +-
superset/utils/feature_flag_manager.py | 5 +-
superset/utils/filters.py | 4 +-
superset/utils/hashing.py | 4 +-
superset/utils/log.py | 67 ++---
superset/utils/machine_auth.py | 4 +-
superset/utils/mock_data.py | 34 ++-
superset/utils/network.py | 4 +-
superset/utils/pandas_postprocessing/aggregate.py | 4 +-
superset/utils/pandas_postprocessing/boxplot.py | 14 +-
superset/utils/pandas_postprocessing/compare.py | 6 +-
.../utils/pandas_postprocessing/contribution.py | 6 +-
superset/utils/pandas_postprocessing/cum.py | 3 +-
superset/utils/pandas_postprocessing/diff.py | 3 +-
superset/utils/pandas_postprocessing/flatten.py | 4 +-
superset/utils/pandas_postprocessing/geography.py | 4 +-
superset/utils/pandas_postprocessing/pivot.py | 8 +-
superset/utils/pandas_postprocessing/rename.py | 4 +-
superset/utils/pandas_postprocessing/rolling.py | 8 +-
superset/utils/pandas_postprocessing/select.py | 8 +-
superset/utils/pandas_postprocessing/sort.py | 6 +-
superset/utils/pandas_postprocessing/utils.py | 13 +-
superset/utils/retries.py | 9 +-
superset/utils/screenshots.py | 44 ++--
superset/utils/ssh_tunnel.py | 8 +-
superset/utils/url_map_converters.py | 4 +-
superset/utils/webdriver.py | 16 +-
superset/views/__init__.py | 2 -
superset/views/all_entities.py | 1 -
superset/views/base.py | 26 +-
superset/views/base_api.py | 48 ++--
superset/views/base_schemas.py | 13 +-
superset/views/core.py | 72 +++---
superset/views/dashboard/views.py | 10 +-
superset/views/database/forms.py | 3 +-
superset/views/database/mixins.py | 2 +-
superset/views/database/validators.py | 4 +-
superset/views/datasource/schemas.py | 4 +-
superset/views/datasource/utils.py | 6 +-
superset/views/log/dao.py | 6 +-
superset/views/tags.py | 1 -
superset/views/users/__init__.py | 1 -
superset/views/utils.py | 42 +--
superset/viz.py | 202 +++++++--------
tests/common/logger_utils.py | 12 +-
tests/common/query_context_generator.py | 12 +-
.../example_data/data_generator/base_generator.py | 5 +-
.../birth_names/birth_names_generator.py | 7 +-
.../data_loading/data_definitions/types.py | 9 +-
.../data_loading/pandas/pandas_data_loader.py | 6 +-
.../data_loading/pandas/pands_data_loading_conf.py | 4 +-
.../data_loading/pandas/table_df_convertor.py | 6 +-
tests/integration_tests/access_tests.py | 2 +-
.../advanced_data_type/api_tests.py | 4 +-
tests/integration_tests/base_tests.py | 12 +-
tests/integration_tests/cachekeys/api_tests.py | 4 +-
tests/integration_tests/charts/api_tests.py | 1 -
tests/integration_tests/charts/data/api_tests.py | 10 +-
tests/integration_tests/conftest.py | 8 +-
tests/integration_tests/core_tests.py | 30 +--
tests/integration_tests/csv_upload_tests.py | 10 +-
tests/integration_tests/dashboard_tests.py | 20 +-
tests/integration_tests/dashboard_utils.py | 6 +-
tests/integration_tests/dashboards/api_tests.py | 16 +-
tests/integration_tests/dashboards/base_case.py | 6 +-
.../dashboards/dashboard_test_utils.py | 10 +-
.../dashboards/filter_sets/conftest.py | 39 +--
.../dashboards/filter_sets/create_api_tests.py | 86 +++----
.../dashboards/filter_sets/delete_api_tests.py | 54 ++--
.../dashboards/filter_sets/get_api_tests.py | 14 +-
.../dashboards/filter_sets/update_api_tests.py | 112 ++++----
.../dashboards/filter_sets/utils.py | 24 +-
.../dashboards/permalink/api_tests.py | 3 +-
.../dashboards/security/base_case.py | 6 +-
.../dashboards/superset_factory_util.py | 26 +-
tests/integration_tests/databases/api_tests.py | 8 +-
.../integration_tests/databases/commands_tests.py | 2 +-
.../ssh_tunnel/commands/commands_tests.py | 2 +-
tests/integration_tests/datasets/api_tests.py | 20 +-
tests/integration_tests/datasets/commands_tests.py | 6 +-
.../db_engine_specs/base_tests.py | 2 -
.../db_engine_specs/bigquery_tests.py | 3 +-
.../db_engine_specs/hive_tests.py | 4 +-
.../integration_tests/dict_import_export_tests.py | 18 +-
tests/integration_tests/email_tests.py | 1 -
tests/integration_tests/event_logger_tests.py | 4 +-
.../explore/permalink/api_tests.py | 9 +-
.../explore/permalink/commands_tests.py | 1 -
.../fixtures/birth_names_dashboard.py | 4 +-
tests/integration_tests/fixtures/datasource.py | 5 +-
.../integration_tests/fixtures/energy_dashboard.py | 7 +-
tests/integration_tests/fixtures/importexport.py | 36 +--
tests/integration_tests/fixtures/query_context.py | 6 +-
.../fixtures/world_bank_dashboard.py | 10 +-
tests/integration_tests/import_export_tests.py | 26 +-
tests/integration_tests/insert_chart_mixin.py | 4 +-
.../key_value/commands/fixtures.py | 3 +-
tests/integration_tests/model_tests.py | 10 +-
tests/integration_tests/query_context_tests.py | 4 +-
tests/integration_tests/reports/alert_tests.py | 8 +-
tests/integration_tests/reports/commands_tests.py | 12 +-
tests/integration_tests/reports/scheduler_tests.py | 3 +-
tests/integration_tests/reports/utils.py | 14 +-
.../security/migrate_roles_tests.py | 1 -
.../security/row_level_security_tests.py | 18 +-
tests/integration_tests/sql_lab/api_tests.py | 4 +-
tests/integration_tests/sql_lab/commands_tests.py | 2 +-
tests/integration_tests/sqla_models_tests.py | 19 +-
tests/integration_tests/sqllab_tests.py | 16 +-
tests/integration_tests/strategy_tests.py | 2 -
tests/integration_tests/superset_test_config.py | 2 +-
...erset_test_config_sqllab_backend_persist_off.py | 2 -
.../superset_test_config_thumbnails.py | 2 +-
tests/integration_tests/tagging_tests.py | 1 -
tests/integration_tests/tags/api_tests.py | 3 -
tests/integration_tests/tags/commands_tests.py | 1 -
tests/integration_tests/tags/dao_tests.py | 3 -
tests/integration_tests/thumbnails_tests.py | 3 +-
tests/integration_tests/users/__init__.py | 1 -
tests/integration_tests/utils/csv_tests.py | 6 +-
tests/integration_tests/utils/encrypt_tests.py | 8 +-
tests/integration_tests/utils/get_dashboards.py | 3 +-
.../utils/public_interfaces_test.py | 4 +-
tests/integration_tests/utils_tests.py | 8 +-
tests/integration_tests/viz_tests.py | 5 +-
tests/unit_tests/charts/dao/dao_tests.py | 2 +-
tests/unit_tests/charts/test_post_processing.py | 1 -
.../unit_tests/common/test_query_object_factory.py | 16 +-
tests/unit_tests/config_test.py | 4 +-
tests/unit_tests/conftest.py | 3 +-
tests/unit_tests/dao/queries_test.py | 3 +-
.../dashboards/commands/importers/v1/utils_test.py | 6 +-
tests/unit_tests/dashboards/dao_tests.py | 2 +-
tests/unit_tests/databases/dao/dao_tests.py | 2 +-
.../databases/ssh_tunnel/commands/create_test.py | 1 -
.../databases/ssh_tunnel/commands/delete_test.py | 2 +-
.../databases/ssh_tunnel/commands/update_test.py | 2 +-
tests/unit_tests/databases/ssh_tunnel/dao_tests.py | 1 -
.../datasets/commands/importers/v1/import_test.py | 6 +-
tests/unit_tests/datasets/conftest.py | 8 +-
tests/unit_tests/datasets/dao/dao_tests.py | 2 +-
tests/unit_tests/datasource/dao_tests.py | 2 +-
tests/unit_tests/db_engine_specs/test_athena.py | 2 +-
tests/unit_tests/db_engine_specs/test_base.py | 6 +-
.../unit_tests/db_engine_specs/test_clickhouse.py | 6 +-
.../db_engine_specs/test_elasticsearch.py | 4 +-
tests/unit_tests/db_engine_specs/test_mssql.py | 6 +-
tests/unit_tests/db_engine_specs/test_mysql.py | 8 +-
tests/unit_tests/db_engine_specs/test_ocient.py | 6 +-
tests/unit_tests/db_engine_specs/test_postgres.py | 6 +-
tests/unit_tests/db_engine_specs/test_presto.py | 6 +-
tests/unit_tests/db_engine_specs/test_starrocks.py | 10 +-
tests/unit_tests/db_engine_specs/test_trino.py | 22 +-
tests/unit_tests/db_engine_specs/utils.py | 14 +-
tests/unit_tests/extensions/ssh_test.py | 1 -
tests/unit_tests/fixtures/assets_configs.py | 14 +-
tests/unit_tests/fixtures/datasets.py | 6 +-
tests/unit_tests/models/core_test.py | 4 +-
tests/unit_tests/pandas_postprocessing/utils.py | 8 +-
tests/unit_tests/sql_parse_tests.py | 5 +-
tests/unit_tests/tasks/test_cron_util.py | 12 +-
tests/unit_tests/tasks/test_utils.py | 16 +-
tests/unit_tests/thumbnails/test_digest.py | 20 +-
tests/unit_tests/utils/cache_test.py | 1 -
tests/unit_tests/utils/date_parser_tests.py | 6 +-
tests/unit_tests/utils/test_core.py | 5 +-
tests/unit_tests/utils/test_file.py | 1 -
tests/unit_tests/utils/urls_tests.py | 1 -
448 files changed, 3084 insertions(+), 3305 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3f524b3658..07544d66d2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,14 +15,28 @@
# limitations under the License.
#
repos:
+ - repo: https://github.com/MarcoGorelli/auto-walrus
+ rev: v0.2.2
+ hooks:
+ - id: auto-walrus
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v3.4.0
+ hooks:
+ - id: pyupgrade
+ args:
+ - --py39-plus
+ - repo: https://github.com/hadialqattan/pycln
+ rev: v2.1.2
+ hooks:
+ - id: pycln
+ args:
+ - --disable-all-dunder-policy
+ - --exclude=superset/config.py
+ - --extend-exclude=tests/integration_tests/superset_test_config.*.py
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- - repo: https://github.com/MarcoGorelli/auto-walrus
- rev: v0.2.2
- hooks:
- - id: auto-walrus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py
index 68a54e10be..d1ba06a620 100644
--- a/RELEASING/changelog.py
+++ b/RELEASING/changelog.py
@@ -17,8 +17,9 @@ import csv as lib_csv
import os
import re
import sys
+from collections.abc import Iterator
from dataclasses import dataclass
-from typing import Any, Dict, Iterator, List, Optional, Union
+from typing import Any, Optional, Union
import click
from click.core import Context
@@ -67,15 +68,15 @@ class GitChangeLog:
def __init__(
self,
version: str,
- logs: List[GitLog],
+ logs: list[GitLog],
access_token: Optional[str] = None,
risk: Optional[bool] = False,
) -> None:
self._version = version
self._logs = logs
- self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {}
- self._github_login_cache: Dict[str, Optional[str]] = {}
- self._github_prs: Dict[int, Any] = {}
+ self._pr_logs_with_details: dict[int, dict[str, Any]] = {}
+ self._github_login_cache: dict[str, Optional[str]] = {}
+ self._github_prs: dict[int, Any] = {}
self._wait = 10
github_token = access_token or os.environ.get("GITHUB_TOKEN")
self._github = Github(github_token)
@@ -126,7 +127,7 @@ class GitChangeLog:
"superset/migrations/versions/" in file.filename for file in commit.files
)
- def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]:
+ def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]:
pr_number = git_log.pr_number
if pr_number:
detail = self._pr_logs_with_details.get(pr_number)
@@ -156,7 +157,7 @@ class GitChangeLog:
return detail
- def _is_risk_pull_request(self, labels: List[Any]) -> bool:
+ def _is_risk_pull_request(self, labels: list[Any]) -> bool:
for label in labels:
risk_label = re.match(SUPERSET_RISKY_LABELS, label.name)
if risk_label is not None:
@@ -174,8 +175,8 @@ class GitChangeLog:
def _parse_change_log(
self,
- changelog: Dict[str, str],
- pr_info: Dict[str, str],
+ changelog: dict[str, str],
+ pr_info: dict[str, str],
github_login: str,
) -> None:
formatted_pr = (
@@ -227,7 +228,7 @@ class GitChangeLog:
result += f"**{key}** {changelog[key]}\n"
return result
- def __iter__(self) -> Iterator[Dict[str, Any]]:
+ def __iter__(self) -> Iterator[dict[str, Any]]:
for log in self._logs:
yield {
"pr_number": log.pr_number,
@@ -250,20 +251,20 @@ class GitLogs:
def __init__(self, git_ref: str) -> None:
self._git_ref = git_ref
- self._logs: List[GitLog] = []
+ self._logs: list[GitLog] = []
@property
def git_ref(self) -> str:
return self._git_ref
@property
- def logs(self) -> List[GitLog]:
+ def logs(self) -> list[GitLog]:
return self._logs
def fetch(self) -> None:
self._logs = list(map(self._parse_log, self._git_logs()))[::-1]
- def diff(self, git_logs: "GitLogs") -> List[GitLog]:
+ def diff(self, git_logs: "GitLogs") -> list[GitLog]:
return [log for log in git_logs.logs if log not in self._logs]
def __repr__(self) -> str:
@@ -284,7 +285,7 @@ class GitLogs:
print(f"Could not checkout {git_ref}")
sys.exit(1)
- def _git_logs(self) -> List[str]:
+ def _git_logs(self) -> list[str]:
# let's get current git ref so we can revert it back
current_git_ref = self._git_get_current_head()
self._git_checkout(self._git_ref)
diff --git a/RELEASING/generate_email.py b/RELEASING/generate_email.py
index 92536670cd..29142557c0 100755
--- a/RELEASING/generate_email.py
+++ b/RELEASING/generate_email.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Any, Dict, List
+from typing import Any
from click.core import Context
@@ -34,7 +34,7 @@ PROJECT_MODULE = "superset"
PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application"
-def string_comma_to_list(message: str) -> List[str]:
+def string_comma_to_list(message: str) -> list[str]:
if not message:
return []
return [element.strip() for element in message.split(",")]
@@ -52,7 +52,7 @@ def render_template(template_file: str, **kwargs: Any) -> str:
return template.render(kwargs)
-class BaseParameters(object):
+class BaseParameters:
def __init__(
self,
version: str,
@@ -60,7 +60,7 @@ class BaseParameters(object):
) -> None:
self.version = version
self.version_rc = version_rc
- self.template_arguments: Dict[str, Any] = {}
+ self.template_arguments: dict[str, Any] = {}
def __repr__(self) -> str:
return f"Apache Credentials: {self.version}/{self.version_rc}"
diff --git a/docker/pythonpath_dev/superset_config.py b/docker/pythonpath_dev/superset_config.py
index 6ea9abf63c..199e79f66e 100644
--- a/docker/pythonpath_dev/superset_config.py
+++ b/docker/pythonpath_dev/superset_config.py
@@ -22,7 +22,6 @@
#
import logging
import os
-from datetime import timedelta
from typing import Optional
from cachelib.file import FileSystemCache
@@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
error_msg = "The environment variable {} was missing, abort...".format(
var_name
)
- raise EnvironmentError(error_msg)
+ raise OSError(error_msg)
DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT")
@@ -53,7 +52,7 @@ DATABASE_PORT = get_env_variable("DATABASE_PORT")
DATABASE_DB = get_env_variable("DATABASE_DB")
# The SQLAlchemy connection string.
-SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
+SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format(
DATABASE_DIALECT,
DATABASE_USER,
DATABASE_PASSWORD,
@@ -80,7 +79,7 @@ CACHE_CONFIG = {
DATA_CACHE_CONFIG = CACHE_CONFIG
-class CeleryConfig(object):
+class CeleryConfig:
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
imports = ("superset.sql_lab",)
result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"
diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py
index 83c06456a1..466fab6f13 100644
--- a/scripts/benchmark_migration.py
+++ b/scripts/benchmark_migration.py
@@ -23,7 +23,7 @@ from graphlib import TopologicalSorter
from inspect import getsource
from pathlib import Path
from types import ModuleType
-from typing import Any, Dict, List, Set, Type
+from typing import Any
import click
from flask import current_app
@@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
- raise Exception(
- "No module spec found in location: `{path}`".format(path=str(filepath))
- )
+ raise Exception(f"No module spec found in location: `{str(filepath)}`")
-def extract_modified_tables(module: ModuleType) -> Set[str]:
+def extract_modified_tables(module: ModuleType) -> set[str]:
"""
Extract the tables being modified by a migration script.
@@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
actually traversing the AST.
"""
- tables: Set[str] = set()
+ tables: set[str] = set()
for function in {"upgrade", "downgrade"}:
source = getsource(getattr(module, function))
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
@@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
return tables
-def find_models(module: ModuleType) -> List[Type[Model]]:
+def find_models(module: ModuleType) -> list[type[Model]]:
"""
Find all models in a migration script.
"""
- models: List[Type[Model]] = []
+ models: list[type[Model]] = []
tables = extract_modified_tables(module)
# add models defined explicitly in the migration script
@@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
sorter: TopologicalSorter[Any] = TopologicalSorter()
for model in models:
inspector = inspect(model)
- dependent_tables: List[str] = []
+ dependent_tables: list[str] = []
for column in inspector.columns.values():
for foreign_key in column.foreign_keys:
if foreign_key.column.table.name != model.__tablename__:
@@ -174,7 +172,7 @@ def main(
print("\nIdentifying models used in the migration:")
models = find_models(module)
- model_rows: Dict[Type[Model], int] = {}
+ model_rows: dict[type[Model], int] = {}
for model in models:
rows = session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
@@ -182,7 +180,7 @@ def main(
session.close()
print("Benchmarking migration")
- results: Dict[str, float] = {}
+ results: dict[str, float] = {}
start = time.time()
upgrade(revision=revision)
duration = time.time() - start
@@ -190,14 +188,14 @@ def main(
print(f"Migration on current DB took: {duration:.2f} seconds")
min_entities = 10
- new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
+ new_models: dict[type[Model], list[Model]] = defaultdict(list)
while min_entities <= limit:
downgrade(revision=down_revision)
print(f"Running with at least {min_entities} entities of each model")
for model in models:
missing = min_entities - model_rows[model]
if missing > 0:
- entities: List[Model] = []
+ entities: list[Model] = []
print(f"- Adding {missing} entities to the {model.__name__} model")
bar = ChargingBar("Processing", max=missing)
try:
diff --git a/scripts/cancel_github_workflows.py b/scripts/cancel_github_workflows.py
index 4d30d34adf..70744c2954 100755
--- a/scripts/cancel_github_workflows.py
+++ b/scripts/cancel_github_workflows.py
@@ -33,13 +33,13 @@ Example:
./cancel_github_workflows.py 1024 --include-last
"""
import os
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
+from collections.abc import Iterable, Iterator
+from typing import Any, Literal, Optional, Union
import click
import requests
from click.exceptions import ClickException
from dateutil import parser
-from typing_extensions import Literal
github_token = os.environ.get("GITHUB_TOKEN")
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
@@ -47,7 +47,7 @@ github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
def request(
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
resp = requests.request(
method,
f"https://api.github.com/{endpoint.lstrip('/')}",
@@ -61,8 +61,8 @@ def request(
def list_runs(
repo: str,
- params: Optional[Dict[str, str]] = None,
-) -> Iterator[Dict[str, Any]]:
+ params: Optional[dict[str, str]] = None,
+) -> Iterator[dict[str, Any]]:
"""List all github workflow runs.
Returns:
An iterator that will iterate through all pages of matching runs."""
@@ -77,16 +77,15 @@ def list_runs(
params={**params, "per_page": 100, "page": page},
)
total_count = result["total_count"]
- for item in result["workflow_runs"]:
- yield item
+ yield from result["workflow_runs"]
page += 1
-def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
+def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]:
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
-def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
+def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]:
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
@@ -96,7 +95,7 @@ def get_runs(
user: Optional[str] = None,
statuses: Iterable[str] = ("queued", "in_progress"),
events: Iterable[str] = ("pull_request", "push"),
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
"""Get workflow runs associated with the given branch"""
return [
item
@@ -108,7 +107,7 @@ def get_runs(
]
-def print_commit(commit: Dict[str, Any], branch: str) -> None:
+def print_commit(commit: dict[str, Any], branch: str) -> None:
"""Print out commit message for verification"""
indented_message = " \n".join(commit["message"].split("\n"))
date_str = (
@@ -155,7 +154,7 @@ Date: {date_str}
def cancel_github_workflows(
branch_or_pull: Optional[str],
repo: str,
- event: List[str],
+ event: list[str],
include_last: bool,
include_running: bool,
) -> None:
diff --git a/scripts/permissions_cleanup.py b/scripts/permissions_cleanup.py
index 5ca75e394c..0416f55806 100644
--- a/scripts/permissions_cleanup.py
+++ b/scripts/permissions_cleanup.py
@@ -24,7 +24,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
- print("# of permission view menus is: {}".format(len(pvms)))
+ print(f"# of permission view menus is: {len(pvms)}")
pvms_dict = defaultdict(list)
for pvm in pvms:
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
@@ -43,7 +43,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
- print("Stage 1: # of permission view menus is: {}".format(len(pvms)))
+ print(f"Stage 1: # of permission view menus is: {len(pvms)}")
# 2. Clean up None permissions or view menus
pvms = security_manager.get_session.query(
@@ -57,7 +57,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
- print("Stage 2: # of permission view menus is: {}".format(len(pvms)))
+ print(f"Stage 2: # of permission view menus is: {len(pvms)}")
# 3. Delete empty permission view menus from roles
roles = security_manager.get_session.query(security_manager.role_model).all()
diff --git a/setup.py b/setup.py
index 41f7e11e38..d8adea3285 100644
--- a/setup.py
+++ b/setup.py
@@ -14,21 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import io
import json
import os
import subprocess
-import sys
from setuptools import find_packages, setup
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json")
-with open(PACKAGE_JSON, "r") as package_file:
+with open(PACKAGE_JSON) as package_file:
version_string = json.load(package_file)["version"]
-with io.open("README.md", "r", encoding="utf-8") as f:
+with open("README.md", encoding="utf-8") as f:
long_description = f.read()
diff --git a/superset/advanced_data_type/plugins/internet_address.py b/superset/advanced_data_type/plugins/internet_address.py
index 08a0925846..8ab20fe2d0 100644
--- a/superset/advanced_data_type/plugins/internet_address.py
+++ b/superset/advanced_data_type/plugins/internet_address.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import ipaddress
-from typing import Any, List
+from typing import Any
from sqlalchemy import Column
@@ -77,7 +77,7 @@ def cidr_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse:
# Make this return a single clause
def cidr_translate_filter_func(
- col: Column, operator: FilterOperator, values: List[Any]
+ col: Column, operator: FilterOperator, values: list[Any]
) -> Any:
"""
Convert a passed in column, FilterOperator and
diff --git a/superset/advanced_data_type/plugins/internet_port.py b/superset/advanced_data_type/plugins/internet_port.py
index 60a594bfd9..8983e41422 100644
--- a/superset/advanced_data_type/plugins/internet_port.py
+++ b/superset/advanced_data_type/plugins/internet_port.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import itertools
-from typing import Any, Dict, List
+from typing import Any
from sqlalchemy import Column
@@ -26,7 +26,7 @@ from superset.advanced_data_type.types import (
)
from superset.utils.core import FilterOperator, FilterStringOperators
-port_conversion_dict: Dict[str, List[int]] = {
+port_conversion_dict: dict[str, list[int]] = {
"http": [80],
"ssh": [22],
"https": [443],
@@ -100,7 +100,7 @@ def port_translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeRespo
def port_translate_filter_func(
- col: Column, operator: FilterOperator, values: List[Any]
+ col: Column, operator: FilterOperator, values: list[Any]
) -> Any:
"""
Convert a passed in column, FilterOperator
diff --git a/superset/advanced_data_type/types.py b/superset/advanced_data_type/types.py
index 316922f339..e8d5de9143 100644
--- a/superset/advanced_data_type/types.py
+++ b/superset/advanced_data_type/types.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from dataclasses import dataclass
-from typing import Any, Callable, List, Optional, TypedDict, Union
+from typing import Any, Callable, Optional, TypedDict, Union
from sqlalchemy import Column
from sqlalchemy.sql.expression import BinaryExpression
@@ -30,7 +30,7 @@ class AdvancedDataTypeRequest(TypedDict):
"""
advanced_data_type: str
- values: List[
+ values: list[
Union[FilterValues, None]
] # unparsed value (usually text when passed from text box)
@@ -41,9 +41,9 @@ class AdvancedDataTypeResponse(TypedDict, total=False):
"""
error_message: Optional[str]
- values: List[Any] # parsed value (can be any value)
+ values: list[Any] # parsed value (can be any value)
display_value: str # The string representation of the parsed values
- valid_filter_operators: List[FilterStringOperators]
+ valid_filter_operators: list[FilterStringOperators]
@dataclass
@@ -54,6 +54,6 @@ class AdvancedDataType:
verbose_name: str
description: str
- valid_data_types: List[str]
+ valid_data_types: list[str]
translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse]
translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression]
diff --git a/superset/annotation_layers/annotations/api.py b/superset/annotation_layers/annotations/api.py
index 0a6a2767f0..70e0a1ad02 100644
--- a/superset/annotation_layers/annotations/api.py
+++ b/superset/annotation_layers/annotations/api.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask import request, Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
@@ -127,7 +127,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
@staticmethod
def _apply_layered_relation_to_rison( # pylint: disable=invalid-name
- layer_id: int, rison_parameters: Dict[str, Any]
+ layer_id: int, rison_parameters: dict[str, Any]
) -> None:
if "filters" not in rison_parameters:
rison_parameters["filters"] = []
diff --git a/superset/annotation_layers/annotations/commands/bulk_delete.py b/superset/annotation_layers/annotations/commands/bulk_delete.py
index 113725050f..dd47047788 100644
--- a/superset/annotation_layers/annotations/commands/bulk_delete.py
+++ b/superset/annotation_layers/annotations/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset.annotation_layers.annotations.commands.exceptions import (
AnnotationBulkDeleteFailedError,
@@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[Annotation]] = None
+ self._models: Optional[list[Annotation]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/annotation_layers/annotations/commands/create.py b/superset/annotation_layers/annotations/commands/create.py
index 0974624561..986b564291 100644
--- a/superset/annotation_layers/annotations/commands/create.py
+++ b/superset/annotation_layers/annotations/commands/create.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -50,7 +50,7 @@ class CreateAnnotationCommand(BaseCommand):
return annotation
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
layer_id: Optional[int] = self._properties.get("layer")
start_dttm: Optional[datetime] = self._properties.get("start_dttm")
end_dttm: Optional[datetime] = self._properties.get("end_dttm")
diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py
index b644ddc362..99ab209165 100644
--- a/superset/annotation_layers/annotations/commands/update.py
+++ b/superset/annotation_layers/annotations/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Annotation] = None
@@ -54,7 +54,7 @@ class UpdateAnnotationCommand(BaseCommand):
return annotation
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
layer_id: Optional[int] = self._properties.get("layer")
short_descr: str = self._properties.get("short_descr", "")
diff --git a/superset/annotation_layers/annotations/dao.py b/superset/annotation_layers/annotations/dao.py
index 0c8a9e47c5..da69e576e5 100644
--- a/superset/annotation_layers/annotations/dao.py
+++ b/superset/annotation_layers/annotations/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -31,7 +31,7 @@ class AnnotationDAO(BaseDAO):
model_cls = Annotation
@staticmethod
- def bulk_delete(models: Optional[List[Annotation]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
try:
db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete(
diff --git a/superset/annotation_layers/commands/bulk_delete.py b/superset/annotation_layers/commands/bulk_delete.py
index b9bc17e82f..4910dc4275 100644
--- a/superset/annotation_layers/commands/bulk_delete.py
+++ b/superset/annotation_layers/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset.annotation_layers.commands.exceptions import (
AnnotationLayerBulkDeleteFailedError,
@@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationLayerCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[AnnotationLayer]] = None
+ self._models: Optional[list[AnnotationLayer]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/annotation_layers/commands/create.py b/superset/annotation_layers/commands/create.py
index 97431568a9..86b0cb3b85 100644
--- a/superset/annotation_layers/commands/create.py
+++ b/superset/annotation_layers/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List
+from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationLayerCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -46,7 +46,7 @@ class CreateAnnotationLayerCommand(BaseCommand):
return annotation_layer
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
name = self._properties.get("name", "")
diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py
index 4a9cc31be5..67d869c005 100644
--- a/superset/annotation_layers/commands/update.py
+++ b/superset/annotation_layers/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationLayerCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[AnnotationLayer] = None
@@ -50,7 +50,7 @@ class UpdateAnnotationLayerCommand(BaseCommand):
return annotation_layer
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
name = self._properties.get("name", "")
self._model = AnnotationLayerDAO.find_by_id(self._model_id)
diff --git a/superset/annotation_layers/dao.py b/superset/annotation_layers/dao.py
index d9db4b582d..67efc19f88 100644
--- a/superset/annotation_layers/dao.py
+++ b/superset/annotation_layers/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional, Union
+from typing import Optional, Union
from sqlalchemy.exc import SQLAlchemyError
@@ -32,7 +32,7 @@ class AnnotationLayerDAO(BaseDAO):
@staticmethod
def bulk_delete(
- models: Optional[List[AnnotationLayer]], commit: bool = True
+ models: Optional[list[AnnotationLayer]], commit: bool = True
) -> None:
item_ids = [model.id for model in models] if models else []
try:
@@ -46,7 +46,7 @@ class AnnotationLayerDAO(BaseDAO):
raise DAODeleteFailedError() from ex
@staticmethod
- def has_annotations(model_id: Union[int, List[int]]) -> bool:
+ def has_annotations(model_id: Union[int, list[int]]) -> bool:
if isinstance(model_id, list):
return (
db.session.query(AnnotationLayer)
diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py
index c252f0be4c..ac801b7421 100644
--- a/superset/charts/commands/bulk_delete.py
+++ b/superset/charts/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from flask_babel import lazy_gettext as _
@@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteChartCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[Slice]] = None
+ self._models: Optional[list[Slice]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py
index 38076fb9cd..78706b3a66 100644
--- a/superset/charts/commands/create.py
+++ b/superset/charts/commands/create.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateChartCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -56,7 +56,7 @@ class CreateChartCommand(CreateMixin, BaseCommand):
datasource_type = self._properties["datasource_type"]
datasource_id = self._properties["datasource_id"]
dashboard_ids = self._properties.get("dashboards", [])
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/Populate datasource
try:
diff --git a/superset/charts/commands/export.py b/superset/charts/commands/export.py
index 9d445cb54e..22310ade99 100644
--- a/superset/charts/commands/export.py
+++ b/superset/charts/commands/export.py
@@ -18,7 +18,7 @@
import json
import logging
-from typing import Iterator, Tuple
+from collections.abc import Iterator
import yaml
@@ -42,7 +42,7 @@ class ExportChartsCommand(ExportModelsCommand):
not_found = ChartNotFoundError
@staticmethod
- def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]:
+ def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]:
file_name = get_filename(model.slice_name, model.id)
file_path = f"charts/{file_name}.yaml"
diff --git a/superset/charts/commands/importers/dispatcher.py b/superset/charts/commands/importers/dispatcher.py
index afeb9c2820..fb5007a50c 100644
--- a/superset/charts/commands/importers/dispatcher.py
+++ b/superset/charts/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -40,7 +40,7 @@ class ImportChartsCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py
index ab88038aaa..132df21b08 100644
--- a/superset/charts/commands/importers/v1/__init__.py
+++ b/superset/charts/commands/importers/v1/__init__.py
@@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-import json
-from typing import Any, Dict, Set
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -40,7 +39,7 @@ class ImportChartsCommand(ImportModelsCommand):
dao = ChartDAO
model_name = "chart"
prefix = "charts/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"datasets/": ImportV1DatasetSchema(),
"databases/": ImportV1DatabaseSchema(),
@@ -49,29 +48,29 @@ class ImportChartsCommand(ImportModelsCommand):
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover datasets associated with charts
- dataset_uuids: Set[str] = set()
+ dataset_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("charts/"):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
- database_uuids: Set[str] = set()
+ database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
- datasets: Dict[str, SqlaTable] = {}
+ datasets: dict[str, SqlaTable] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
diff --git a/superset/charts/commands/importers/v1/utils.py b/superset/charts/commands/importers/v1/utils.py
index d4aeb17a1e..399e6c2243 100644
--- a/superset/charts/commands/importers/v1/utils.py
+++ b/superset/charts/commands/importers/v1/utils.py
@@ -16,7 +16,7 @@
# under the License.
import json
-from typing import Any, Dict
+from typing import Any
from flask import g
from sqlalchemy.orm import Session
@@ -28,7 +28,7 @@ from superset.models.slice import Slice
def import_chart(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Slice:
diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py
index f5fc2616a5..a4265d0835 100644
--- a/superset/charts/commands/update.py
+++ b/superset/charts/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
@@ -42,14 +42,14 @@ from superset.models.slice import Slice
logger = logging.getLogger(__name__)
-def is_query_context_update(properties: Dict[str, Any]) -> bool:
+def is_query_context_update(properties: dict[str, Any]) -> bool:
return set(properties) == {"query_context", "query_context_generation"} and bool(
properties.get("query_context_generation")
)
class UpdateChartCommand(UpdateMixin, BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Slice] = None
@@ -67,9 +67,9 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
return chart
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
dashboard_ids = self._properties.get("dashboards")
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate if datasource_id is provided datasource_type is required
datasource_id = self._properties.get("datasource_id")
diff --git a/superset/charts/dao.py b/superset/charts/dao.py
index 7102e6ad23..9c6b2c26ef 100644
--- a/superset/charts/dao.py
+++ b/superset/charts/dao.py
@@ -17,7 +17,7 @@
# pylint: disable=arguments-renamed
import logging
from datetime import datetime
-from typing import List, Optional, TYPE_CHECKING
+from typing import Optional, TYPE_CHECKING
from sqlalchemy.exc import SQLAlchemyError
@@ -39,7 +39,7 @@ class ChartDAO(BaseDAO):
base_filter = ChartFilter
@staticmethod
- def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[Slice]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
@@ -71,7 +71,7 @@ class ChartDAO(BaseDAO):
db.session.commit()
@staticmethod
- def favorited_ids(charts: List[Slice]) -> List[FavStar]:
+ def favorited_ids(charts: list[Slice]) -> list[FavStar]:
ids = [chart.id for chart in charts]
return [
star.obj_id
diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py
index 9c620dcf5d..552044ebfa 100644
--- a/superset/charts/data/api.py
+++ b/superset/charts/data/api.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import json
import logging
-from typing import Any, Dict, Optional, TYPE_CHECKING, Union
+from typing import Any, TYPE_CHECKING
import simplejson
from flask import current_app, g, make_response, request, Response
@@ -315,7 +315,7 @@ class ChartDataRestApi(ChartRestApi):
return self._get_data_response(command, True)
def _run_async(
- self, form_data: Dict[str, Any], command: ChartDataCommand
+ self, form_data: dict[str, Any], command: ChartDataCommand
) -> Response:
"""
Execute command as an async query.
@@ -344,9 +344,9 @@ class ChartDataRestApi(ChartRestApi):
def _send_chart_response(
self,
- result: Dict[Any, Any],
- form_data: Optional[Dict[str, Any]] = None,
- datasource: Optional[Union[BaseDatasource, Query]] = None,
+ result: dict[Any, Any],
+ form_data: dict[str, Any] | None = None,
+ datasource: BaseDatasource | Query | None = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
@@ -408,8 +408,8 @@ class ChartDataRestApi(ChartRestApi):
self,
command: ChartDataCommand,
force_cached: bool = False,
- form_data: Optional[Dict[str, Any]] = None,
- datasource: Optional[Union[BaseDatasource, Query]] = None,
+ form_data: dict[str, Any] | None = None,
+ datasource: BaseDatasource | Query | None = None,
) -> Response:
try:
result = command.run(force_cached=force_cached)
@@ -421,12 +421,12 @@ class ChartDataRestApi(ChartRestApi):
return self._send_chart_response(result, form_data, datasource)
# pylint: disable=invalid-name, no-self-use
- def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
+ def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
return QueryContextCacheLoader.load(cache_key)
# pylint: disable=no-self-use
def _create_query_context_from_form(
- self, form_data: Dict[str, Any]
+ self, form_data: dict[str, Any]
) -> QueryContext:
try:
return ChartDataQueryContextSchema().load(form_data)
diff --git a/superset/charts/data/commands/create_async_job_command.py b/superset/charts/data/commands/create_async_job_command.py
index c4e25f742b..fb6e3f3dbf 100644
--- a/superset/charts/data/commands/create_async_job_command.py
+++ b/superset/charts/data/commands/create_async_job_command.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask import Request
@@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]
- def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
+ def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
diff --git a/superset/charts/data/commands/get_data_command.py b/superset/charts/data/commands/get_data_command.py
index 819693607b..a84870a1dd 100644
--- a/superset/charts/data/commands/get_data_command.py
+++ b/superset/charts/data/commands/get_data_command.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_babel import lazy_gettext as _
@@ -36,7 +36,7 @@ class ChartDataCommand(BaseCommand):
def __init__(self, query_context: QueryContext):
self._query_context = query_context
- def run(self, **kwargs: Any) -> Dict[str, Any]:
+ def run(self, **kwargs: Any) -> dict[str, Any]:
# caching is handled in query_context.get_df_payload
# (also evals `force` property)
cache_query_context = kwargs.get("cache", False)
diff --git a/superset/charts/data/query_context_cache_loader.py b/superset/charts/data/query_context_cache_loader.py
index b5ff3bdae8..97fa733a3e 100644
--- a/superset/charts/data/query_context_cache_loader.py
+++ b/superset/charts/data/query_context_cache_loader.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from superset import cache
from superset.charts.commands.exceptions import ChartDataCacheLoadError
@@ -22,7 +22,7 @@ from superset.charts.commands.exceptions import ChartDataCacheLoadError
class QueryContextCacheLoader: # pylint: disable=too-few-public-methods
@staticmethod
- def load(cache_key: str) -> Dict[str, Any]:
+ def load(cache_key: str) -> dict[str, Any]:
cache_value = cache.get(cache_key)
if not cache_value:
raise ChartDataCacheLoadError("Cached data not found")
diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py
index 1165769fc8..a6b64c08d6 100644
--- a/superset/charts/post_processing.py
+++ b/superset/charts/post_processing.py
@@ -27,7 +27,7 @@ for these chart types.
"""
from io import StringIO
-from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
+from typing import Any, Optional, TYPE_CHECKING, Union
import pandas as pd
from flask_babel import gettext as __
@@ -45,14 +45,14 @@ if TYPE_CHECKING:
from superset.models.sql_lab import Query
-def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
+def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]:
"""
Sort columns when combining metrics.
MultiIndex labels have the metric name as the last element in the
tuple. We want to sort these according to the list of passed metrics.
"""
- parts: List[Any] = list(label)
+ parts: list[Any] = list(label)
metric = parts[-1]
parts[-1] = metrics.index(metric)
return tuple(parts)
@@ -60,9 +60,9 @@ def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...
def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
df: pd.DataFrame,
- rows: List[str],
- columns: List[str],
- metrics: List[str],
+ rows: list[str],
+ columns: list[str],
+ metrics: list[str],
aggfunc: str = "Sum",
transpose_pivot: bool = False,
combine_metrics: bool = False,
@@ -194,7 +194,7 @@ def list_unique_values(series: pd.Series) -> str:
"""
List unique values in a series.
"""
- return ", ".join(set(str(v) for v in pd.Series.unique(series)))
+ return ", ".join({str(v) for v in pd.Series.unique(series)})
pivot_v2_aggfunc_map = {
@@ -223,7 +223,7 @@ pivot_v2_aggfunc_map = {
def pivot_table_v2(
df: pd.DataFrame,
- form_data: Dict[str, Any],
+ form_data: dict[str, Any],
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
) -> pd.DataFrame:
"""
@@ -249,7 +249,7 @@ def pivot_table_v2(
def pivot_table(
df: pd.DataFrame,
- form_data: Dict[str, Any],
+ form_data: dict[str, Any],
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
) -> pd.DataFrame:
"""
@@ -285,7 +285,7 @@ def pivot_table(
def table(
df: pd.DataFrame,
- form_data: Dict[str, Any],
+ form_data: dict[str, Any],
datasource: Optional[ # pylint: disable=unused-argument
Union["BaseDatasource", "Query"]
] = None,
@@ -315,10 +315,10 @@ post_processors = {
def apply_post_process(
- result: Dict[Any, Any],
- form_data: Optional[Dict[str, Any]] = None,
+ result: dict[Any, Any],
+ form_data: Optional[dict[str, Any]] = None,
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
-) -> Dict[Any, Any]:
+) -> dict[Any, Any]:
form_data = form_data or {}
viz_type = form_data.get("viz_type")
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 44252ef06f..373600cd08 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -18,7 +18,7 @@
from __future__ import annotations
import inspect
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
@@ -1383,7 +1383,7 @@ class ChartDataQueryObjectSchema(Schema):
class ChartDataQueryContextSchema(Schema):
- query_context_factory: Optional[QueryContextFactory] = None
+ query_context_factory: QueryContextFactory | None = None
datasource = fields.Nested(ChartDataDatasourceSchema)
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
custom_cache_timeout = fields.Integer(
@@ -1407,7 +1407,7 @@ class ChartDataQueryContextSchema(Schema):
# pylint: disable=unused-argument
@post_load
- def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
+ def make_query_context(self, data: dict[str, Any], **kwargs: Any) -> QueryContext:
query_context = self.get_query_context_factory().create(**data)
return query_context
diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py
index c7689569c2..86f6fe9b67 100755
--- a/superset/cli/importexport.py
+++ b/superset/cli/importexport.py
@@ -18,7 +18,7 @@ import logging
import sys
from datetime import datetime
from pathlib import Path
-from typing import List, Optional
+from typing import Optional
from zipfile import is_zipfile, ZipFile
import click
@@ -309,7 +309,7 @@ else:
from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand
path_object = Path(path)
- files: List[Path] = []
+ files: list[Path] = []
if path_object.is_file():
files.append(path_object)
elif path_object.exists() and not recursive:
@@ -363,7 +363,7 @@ else:
sync_metrics = "metrics" in sync_array
path_object = Path(path)
- files: List[Path] = []
+ files: list[Path] = []
if path_object.is_file():
files.append(path_object)
elif path_object.exists() and not recursive:
diff --git a/superset/cli/main.py b/superset/cli/main.py
index 006f8eb5c9..536617cadd 100755
--- a/superset/cli/main.py
+++ b/superset/cli/main.py
@@ -18,7 +18,7 @@
import importlib
import logging
import pkgutil
-from typing import Any, Dict
+from typing import Any
import click
from colorama import Fore, Style
@@ -40,7 +40,7 @@ def superset() -> None:
"""This is a management script for the Superset application."""
@app.shell_context_processor
- def make_shell_context() -> Dict[str, Any]:
+ def make_shell_context() -> dict[str, Any]:
return dict(app=app, db=db)
@@ -79,5 +79,5 @@ def version(verbose: bool) -> None:
)
print(Fore.BLUE + "-=" * 15)
if verbose:
- print("[DB] : " + "{}".format(db.engine))
+ print("[DB] : " + f"{db.engine}")
print(Style.RESET_ALL)
diff --git a/superset/cli/native_filters.py b/superset/cli/native_filters.py
index 63cc185e8e..a25724d38d 100644
--- a/superset/cli/native_filters.py
+++ b/superset/cli/native_filters.py
@@ -17,7 +17,6 @@
import json
from copy import deepcopy
from textwrap import dedent
-from typing import Set, Tuple
import click
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
@@ -102,7 +101,7 @@ def native_filters() -> None:
)
def upgrade(
all_: bool, # pylint: disable=unused-argument
- dashboard_ids: Tuple[int, ...],
+ dashboard_ids: tuple[int, ...],
) -> None:
"""
Upgrade legacy filter-box charts to native dashboard filters.
@@ -251,7 +250,7 @@ def upgrade(
)
def downgrade(
all_: bool, # pylint: disable=unused-argument
- dashboard_ids: Tuple[int, ...],
+ dashboard_ids: tuple[int, ...],
) -> None:
"""
Downgrade native dashboard filters to legacy filter-box charts (where applicable).
@@ -347,7 +346,7 @@ def downgrade(
)
def cleanup(
all_: bool, # pylint: disable=unused-argument
- dashboard_ids: Tuple[int, ...],
+ dashboard_ids: tuple[int, ...],
) -> None:
"""
Cleanup obsolete legacy filter-box charts and interim metadata.
@@ -355,7 +354,7 @@ def cleanup(
Note this operation is irreversible.
"""
- slice_ids: Set[int] = set()
+ slice_ids: set[int] = set()
# Cleanup the dashboard which contains legacy fields used for downgrading.
for dashboard in (
diff --git a/superset/cli/thumbnails.py b/superset/cli/thumbnails.py
index 276d9981c1..325fab6853 100755
--- a/superset/cli/thumbnails.py
+++ b/superset/cli/thumbnails.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Type, Union
+from typing import Union
import click
from celery.utils.abstract import CallableTask
@@ -75,7 +75,7 @@ def compute_thumbnails(
def compute_generic_thumbnail(
friendly_type: str,
- model_cls: Union[Type[Dashboard], Type[Slice]],
+ model_cls: Union[type[Dashboard], type[Slice]],
model_id: int,
compute_func: CallableTask,
) -> None:
diff --git a/superset/commands/base.py b/superset/commands/base.py
index 42d5956312..caca50755d 100644
--- a/superset/commands/base.py
+++ b/superset/commands/base.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
-from typing import Any, List, Optional
+from typing import Any, Optional
from flask_appbuilder.security.sqla.models import User
@@ -45,7 +45,7 @@ class BaseCommand(ABC):
class CreateMixin: # pylint: disable=too-few-public-methods
@staticmethod
- def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
+ def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
"""
Populate list of owners, defaulting to the current user if `owner_ids` is
undefined or empty. If current user is missing in `owner_ids`, current user
@@ -60,7 +60,7 @@ class CreateMixin: # pylint: disable=too-few-public-methods
class UpdateMixin: # pylint: disable=too-few-public-methods
@staticmethod
- def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
+ def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
"""
Populate list of owners. If current user is missing in `owner_ids`, current user
is added unless belonging to the Admin role.
diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py
index db9d1b6c63..4398d740c5 100644
--- a/superset/commands/exceptions.py
+++ b/superset/commands/exceptions.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
@@ -59,7 +59,7 @@ class CommandInvalidError(CommandException):
def __init__(
self,
message: str = "",
- exceptions: Optional[List[ValidationError]] = None,
+ exceptions: Optional[list[ValidationError]] = None,
) -> None:
self._exceptions = exceptions or []
super().__init__(message)
@@ -67,14 +67,14 @@ class CommandInvalidError(CommandException):
def append(self, exception: ValidationError) -> None:
self._exceptions.append(exception)
- def extend(self, exceptions: List[ValidationError]) -> None:
+ def extend(self, exceptions: list[ValidationError]) -> None:
self._exceptions.extend(exceptions)
- def get_list_classnames(self) -> List[str]:
+ def get_list_classnames(self) -> list[str]:
return list(sorted({ex.__class__.__name__ for ex in self._exceptions}))
- def normalized_messages(self) -> Dict[Any, Any]:
- errors: Dict[Any, Any] = {}
+ def normalized_messages(self) -> dict[Any, Any]:
+ errors: dict[Any, Any] = {}
for exception in self._exceptions:
errors.update(exception.normalized_messages())
return errors
diff --git a/superset/commands/export/assets.py b/superset/commands/export/assets.py
index 9f088af428..1bd2cf6d61 100644
--- a/superset/commands/export/assets.py
+++ b/superset/commands/export/assets.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+from collections.abc import Iterator
from datetime import datetime, timezone
-from typing import Iterator, Tuple
import yaml
@@ -36,7 +36,7 @@ class ExportAssetsCommand(BaseCommand):
Command that exports all databases, datasets, charts, dashboards and saved queries.
"""
- def run(self) -> Iterator[Tuple[str, str]]:
+ def run(self) -> Iterator[tuple[str, str]]:
metadata = {
"version": EXPORT_VERSION,
"type": "assets",
diff --git a/superset/commands/export/models.py b/superset/commands/export/models.py
index 4edafaa746..3f21f29281 100644
--- a/superset/commands/export/models.py
+++ b/superset/commands/export/models.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+from collections.abc import Iterator
from datetime import datetime, timezone
-from typing import Iterator, List, Tuple, Type
import yaml
from flask_appbuilder import Model
@@ -30,21 +30,21 @@ METADATA_FILE_NAME = "metadata.yaml"
class ExportModelsCommand(BaseCommand):
- dao: Type[BaseDAO] = BaseDAO
- not_found: Type[CommandException] = CommandException
+ dao: type[BaseDAO] = BaseDAO
+ not_found: type[CommandException] = CommandException
- def __init__(self, model_ids: List[int], export_related: bool = True):
+ def __init__(self, model_ids: list[int], export_related: bool = True):
self.model_ids = model_ids
self.export_related = export_related
# this will be set when calling validate()
- self._models: List[Model] = []
+ self._models: list[Model] = []
@staticmethod
- def _export(model: Model, export_related: bool = True) -> Iterator[Tuple[str, str]]:
+ def _export(model: Model, export_related: bool = True) -> Iterator[tuple[str, str]]:
raise NotImplementedError("Subclasses MUST implement _export")
- def run(self) -> Iterator[Tuple[str, str]]:
+ def run(self) -> Iterator[tuple[str, str]]:
self.validate()
metadata = {
diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py
index a67828bdb2..09830bf3cf 100644
--- a/superset/commands/importers/v1/__init__.py
+++ b/superset/commands/importers/v1/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Set
+from typing import Any, Optional
from marshmallow import Schema, validate
from marshmallow.exceptions import ValidationError
@@ -40,33 +40,33 @@ class ImportModelsCommand(BaseCommand):
dao = BaseDAO
model_name = "model"
prefix = ""
- schemas: Dict[str, Schema] = {}
+ schemas: dict[str, Schema] = {}
import_error = CommandException
# pylint: disable=unused-argument
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
- self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
- self.ssh_tunnel_passwords: Dict[str, str] = (
+ self.passwords: dict[str, str] = kwargs.get("passwords") or {}
+ self.ssh_tunnel_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
- self.ssh_tunnel_private_keys: Dict[str, str] = (
+ self.ssh_tunnel_private_keys: dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
- self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
+ self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
self.overwrite: bool = kwargs.get("overwrite", False)
- self._configs: Dict[str, Any] = {}
+ self._configs: dict[str, Any] = {}
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
raise NotImplementedError("Subclasses MUST implement _import")
@classmethod
- def _get_uuids(cls) -> Set[str]:
+ def _get_uuids(cls) -> set[str]:
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
def run(self) -> None:
@@ -84,11 +84,11 @@ class ImportModelsCommand(BaseCommand):
raise self.import_error() from ex
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
# verify that the metadata file is present and valid
try:
- metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
+ metadata: Optional[dict[str, str]] = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
@@ -114,7 +114,7 @@ class ImportModelsCommand(BaseCommand):
)
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
- self, exceptions: List[ValidationError]
+ self, exceptions: list[ValidationError]
) -> None:
"""check if the object exists and shouldn't be overwritten"""
if not self.overwrite:
diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py
index ce8b46c2a0..1ab2e486cf 100644
--- a/superset/commands/importers/v1/assets.py
+++ b/superset/commands/importers/v1/assets.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from marshmallow import Schema
from marshmallow.exceptions import ValidationError
@@ -56,7 +56,7 @@ class ImportAssetsCommand(BaseCommand):
and will overwrite everything.
"""
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@@ -65,24 +65,24 @@ class ImportAssetsCommand(BaseCommand):
}
# pylint: disable=unused-argument
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
- self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
- self.ssh_tunnel_passwords: Dict[str, str] = (
+ self.passwords: dict[str, str] = kwargs.get("passwords") or {}
+ self.ssh_tunnel_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
- self.ssh_tunnel_private_keys: Dict[str, str] = (
+ self.ssh_tunnel_private_keys: dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
- self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
+ self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
- self._configs: Dict[str, Any] = {}
+ self._configs: dict[str, Any] = {}
@staticmethod
- def _import(session: Session, configs: Dict[str, Any]) -> None:
+ def _import(session: Session, configs: dict[str, Any]) -> None:
# import databases first
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=True)
@@ -95,7 +95,7 @@ class ImportAssetsCommand(BaseCommand):
import_saved_query(session, config, overwrite=True)
# import datasets
- dataset_info: Dict[str, Dict[str, Any]] = {}
+ dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
config["database_id"] = database_ids[config["database_uuid"]]
@@ -107,7 +107,7 @@ class ImportAssetsCommand(BaseCommand):
}
# import charts
- chart_ids: Dict[str, int] = {}
+ chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("charts/"):
config.update(dataset_info[config["dataset_uuid"]])
@@ -121,7 +121,7 @@ class ImportAssetsCommand(BaseCommand):
dashboard = import_dashboard(session, config, overwrite=True)
# set ref in the dashboard_slices table
- dashboard_chart_ids: List[Dict[str, int]] = []
+ dashboard_chart_ids: list[dict[str, int]] = []
for uuid in find_chart_uuids(config["position"]):
if uuid not in chart_ids:
break
@@ -151,11 +151,11 @@ class ImportAssetsCommand(BaseCommand):
raise ImportFailedError() from ex
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
# verify that the metadata file is present and valid
try:
- metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
+ metadata: Optional[dict[str, str]] = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py
index 35efdb1393..4c20e93ff7 100644
--- a/superset/commands/importers/v1/examples.py
+++ b/superset/commands/importers/v1/examples.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -52,7 +52,7 @@ class ImportExamplesCommand(ImportModelsCommand):
dao = BaseDAO
model_name = "model"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@@ -60,7 +60,7 @@ class ImportExamplesCommand(ImportModelsCommand):
}
import_error = CommandException
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
super().__init__(contents, *args, **kwargs)
self.force_data = kwargs.get("force_data", False)
@@ -81,7 +81,7 @@ class ImportExamplesCommand(ImportModelsCommand):
raise self.import_error() from ex
@classmethod
- def _get_uuids(cls) -> Set[str]:
+ def _get_uuids(cls) -> set[str]:
# pylint: disable=protected-access
return (
ImportDatabasesCommand._get_uuids()
@@ -93,12 +93,12 @@ class ImportExamplesCommand(ImportModelsCommand):
@staticmethod
def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches
session: Session,
- configs: Dict[str, Any],
+ configs: dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
) -> None:
# import databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(
@@ -114,7 +114,7 @@ class ImportExamplesCommand(ImportModelsCommand):
# database was created before its UUID was frozen, so it has a random UUID.
# We need to determine its ID so we can point the dataset to it.
examples_db = get_example_database()
- dataset_info: Dict[str, Dict[str, Any]] = {}
+ dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
# find the ID of the corresponding database
@@ -153,7 +153,7 @@ class ImportExamplesCommand(ImportModelsCommand):
}
# import charts
- chart_ids: Dict[str, int] = {}
+ chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if (
file_name.startswith("charts/")
@@ -175,7 +175,7 @@ class ImportExamplesCommand(ImportModelsCommand):
).fetchall()
# import dashboards
- dashboard_chart_ids: List[Tuple[int, int]] = []
+ dashboard_chart_ids: list[tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
try:
diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py
index c8fb97c53d..8ca008b3e2 100644
--- a/superset/commands/importers/v1/utils.py
+++ b/superset/commands/importers/v1/utils.py
@@ -15,7 +15,7 @@
import logging
from pathlib import Path, PurePosixPath
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from zipfile import ZipFile
import yaml
@@ -46,7 +46,7 @@ class MetadataSchema(Schema):
timestamp = fields.DateTime()
-def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
+def load_yaml(file_name: str, content: str) -> dict[str, Any]:
"""Try to load a YAML file"""
try:
return yaml.safe_load(content)
@@ -55,7 +55,7 @@ def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
raise ValidationError({file_name: "Not a valid YAML file"}) from ex
-def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
+def load_metadata(contents: dict[str, str]) -> dict[str, str]:
"""Apply validation and load a metadata file"""
if METADATA_FILE_NAME not in contents:
# if the contents have no METADATA_FILE_NAME this is probably
@@ -80,9 +80,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
def validate_metadata_type(
- metadata: Optional[Dict[str, str]],
+ metadata: Optional[dict[str, str]],
type_: str,
- exceptions: List[ValidationError],
+ exceptions: list[ValidationError],
) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata:
@@ -96,35 +96,35 @@ def validate_metadata_type(
# pylint: disable=too-many-locals,too-many-arguments
def load_configs(
- contents: Dict[str, str],
- schemas: Dict[str, Schema],
- passwords: Dict[str, str],
- exceptions: List[ValidationError],
- ssh_tunnel_passwords: Dict[str, str],
- ssh_tunnel_private_keys: Dict[str, str],
- ssh_tunnel_priv_key_passwords: Dict[str, str],
-) -> Dict[str, Any]:
- configs: Dict[str, Any] = {}
+ contents: dict[str, str],
+ schemas: dict[str, Schema],
+ passwords: dict[str, str],
+ exceptions: list[ValidationError],
+ ssh_tunnel_passwords: dict[str, str],
+ ssh_tunnel_private_keys: dict[str, str],
+ ssh_tunnel_priv_key_passwords: dict[str, str],
+) -> dict[str, Any]:
+ configs: dict[str, Any] = {}
# load existing databases so we can apply the password validation
- db_passwords: Dict[str, str] = {
+ db_passwords: dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(Database.uuid, Database.password).all()
}
# load existing ssh_tunnels so we can apply the password validation
- db_ssh_tunnel_passwords: Dict[str, str] = {
+ db_ssh_tunnel_passwords: dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
}
# load existing ssh_tunnels so we can apply the private_key validation
- db_ssh_tunnel_private_keys: Dict[str, str] = {
+ db_ssh_tunnel_private_keys: dict[str, str] = {
str(uuid): private_key
for uuid, private_key in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key
).all()
}
# load existing ssh_tunnels so we can apply the private_key_password validation
- db_ssh_tunnel_priv_key_passws: Dict[str, str] = {
+ db_ssh_tunnel_priv_key_passws: dict[str, str] = {
str(uuid): private_key_password
for uuid, private_key_password in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key_password
@@ -206,7 +206,7 @@ def is_valid_config(file_name: str) -> bool:
return True
-def get_contents_from_bundle(bundle: ZipFile) -> Dict[str, str]:
+def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]:
return {
remove_root(file_name): bundle.read(file_name).decode()
for file_name in bundle.namelist()
diff --git a/superset/commands/utils.py b/superset/commands/utils.py
index ad58bb4050..7bb13984f8 100644
--- a/superset/commands/utils.py
+++ b/superset/commands/utils.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import List, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING
from flask import g
from flask_appbuilder.security.sqla.models import Role, User
@@ -37,9 +37,9 @@ if TYPE_CHECKING:
def populate_owners(
- owner_ids: Optional[List[int]],
+ owner_ids: list[int] | None,
default_to_user: bool,
-) -> List[User]:
+) -> list[User]:
"""
Helper function for commands, will fetch all users from owners id's
@@ -63,13 +63,13 @@ def populate_owners(
return owners
-def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]:
+def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
"""
Helper function for commands, will fetch all roles from roles id's
:raises RolesNotFoundValidationError: If a role in the input list is not found
:param role_ids: A List of roles by id's
"""
- roles: List[Role] = []
+ roles: list[Role] = []
if role_ids:
roles = security_manager.find_roles_by_id(role_ids)
if len(roles) != len(role_ids):
diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py
index 659a640159..65c0c43c11 100644
--- a/superset/common/chart_data.py
+++ b/superset/common/chart_data.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
from enum import Enum
-from typing import Set
class ChartDataResultFormat(str, Enum):
@@ -28,7 +27,7 @@ class ChartDataResultFormat(str, Enum):
XLSX = "xlsx"
@classmethod
- def table_like(cls) -> Set["ChartDataResultFormat"]:
+ def table_like(cls) -> set["ChartDataResultFormat"]:
return {cls.CSV} | {cls.XLSX}
diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py
index f6f5a5cd62..22c778b77b 100644
--- a/superset/common/query_actions.py
+++ b/superset/common/query_actions.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import copy
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
from flask_babel import _
@@ -49,7 +49,7 @@ def _get_datasource(
def _get_columns(
query_context: QueryContext, query_obj: QueryObject, _: bool
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
@@ -65,7 +65,7 @@ def _get_columns(
def _get_timegrains(
query_context: QueryContext, query_obj: QueryObject, _: bool
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
@@ -83,7 +83,7 @@ def _get_query(
query_context: QueryContext,
query_obj: QueryObject,
_: bool,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
result = {"language": datasource.query_language}
try:
@@ -96,8 +96,8 @@ def _get_query(
def _get_full(
query_context: QueryContext,
query_obj: QueryObject,
- force_cached: Optional[bool] = False,
-) -> Dict[str, Any]:
+ force_cached: bool | None = False,
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
result_type = query_obj.result_type or query_context.result_type
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
@@ -141,7 +141,7 @@ def _get_full(
def _get_samples(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
@@ -162,7 +162,7 @@ def _get_samples(
def _get_drill_detail(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
# todo(yongjie): Remove this function,
# when determining whether samples should be applied to the time filter.
datasource = _get_datasource(query_context, query_obj)
@@ -183,13 +183,13 @@ def _get_drill_detail(
def _get_results(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
payload = _get_full(query_context, query_obj, force_cached)
return payload
-_result_type_functions: Dict[
- ChartDataResultType, Callable[[QueryContext, QueryObject, bool], Dict[str, Any]]
+_result_type_functions: dict[
+ ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]]
] = {
ChartDataResultType.COLUMNS: _get_columns,
ChartDataResultType.TIMEGRAINS: _get_timegrains,
@@ -210,7 +210,7 @@ def get_query_results(
query_context: QueryContext,
query_obj: QueryObject,
force_cached: bool,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""
Return result payload for a chart data request.
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 78eb8800c4..1a8d3c518b 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
+from typing import Any, ClassVar, TYPE_CHECKING
import pandas as pd
@@ -47,15 +47,15 @@ class QueryContext:
enforce_numerical_metrics: ClassVar[bool] = True
datasource: BaseDatasource
- slice_: Optional[Slice] = None
- queries: List[QueryObject]
- form_data: Optional[Dict[str, Any]]
+ slice_: Slice | None = None
+ queries: list[QueryObject]
+ form_data: dict[str, Any] | None
result_type: ChartDataResultType
result_format: ChartDataResultFormat
force: bool
- custom_cache_timeout: Optional[int]
+ custom_cache_timeout: int | None
- cache_values: Dict[str, Any]
+ cache_values: dict[str, Any]
_processor: QueryContextProcessor
@@ -65,14 +65,14 @@ class QueryContext:
self,
*,
datasource: BaseDatasource,
- queries: List[QueryObject],
- slice_: Optional[Slice],
- form_data: Optional[Dict[str, Any]],
+ queries: list[QueryObject],
+ slice_: Slice | None,
+ form_data: dict[str, Any] | None,
result_type: ChartDataResultType,
result_format: ChartDataResultFormat,
force: bool = False,
- custom_cache_timeout: Optional[int] = None,
- cache_values: Dict[str, Any],
+ custom_cache_timeout: int | None = None,
+ cache_values: dict[str, Any],
) -> None:
self.datasource = datasource
self.slice_ = slice_
@@ -88,18 +88,18 @@ class QueryContext:
def get_data(
self,
df: pd.DataFrame,
- ) -> Union[str, List[Dict[str, Any]]]:
+ ) -> str | list[dict[str, Any]]:
return self._processor.get_data(df)
def get_payload(
self,
- cache_query_context: Optional[bool] = False,
+ cache_query_context: bool | None = False,
force_cached: bool = False,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""
return self._processor.get_payload(cache_query_context, force_cached)
- def get_cache_timeout(self) -> Optional[int]:
+ def get_cache_timeout(self) -> int | None:
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.slice_ and self.slice_.cache_timeout is not None:
@@ -110,14 +110,14 @@ class QueryContext:
return self.datasource.database.cache_timeout
return None
- def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
return self._processor.query_cache_key(query_obj, **kwargs)
def get_df_payload(
self,
query_obj: QueryObject,
- force_cached: Optional[bool] = False,
- ) -> Dict[str, Any]:
+ force_cached: bool | None = False,
+ ) -> dict[str, Any]:
return self._processor.get_df_payload(
query_obj=query_obj,
force_cached=force_cached,
diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py
index 84c0415722..62018def8d 100644
--- a/superset/common/query_context_factory.py
+++ b/superset/common/query_context_factory.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict, List, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from superset import app, db
from superset.charts.dao import ChartDAO
@@ -48,12 +48,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
self,
*,
datasource: DatasourceDict,
- queries: List[Dict[str, Any]],
- form_data: Optional[Dict[str, Any]] = None,
- result_type: Optional[ChartDataResultType] = None,
- result_format: Optional[ChartDataResultFormat] = None,
+ queries: list[dict[str, Any]],
+ form_data: dict[str, Any] | None = None,
+ result_type: ChartDataResultType | None = None,
+ result_format: ChartDataResultFormat | None = None,
force: bool = False,
- custom_cache_timeout: Optional[int] = None,
+ custom_cache_timeout: int | None = None,
) -> QueryContext:
datasource_model_instance = None
if datasource:
@@ -101,13 +101,13 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
datasource_id=int(datasource["id"]),
)
- def _get_slice(self, slice_id: Any) -> Optional[Slice]:
+ def _get_slice(self, slice_id: Any) -> Slice | None:
return ChartDAO.find_by_id(slice_id)
def _process_query_object(
self,
datasource: BaseDatasource,
- form_data: Optional[Dict[str, Any]],
+ form_data: dict[str, Any] | None,
query_object: QueryObject,
) -> QueryObject:
self._apply_granularity(query_object, form_data, datasource)
@@ -117,7 +117,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
def _apply_granularity(
self,
query_object: QueryObject,
- form_data: Optional[Dict[str, Any]],
+ form_data: dict[str, Any] | None,
datasource: BaseDatasource,
) -> None:
temporal_columns = {
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 85a2b5d97a..ecb8db4246 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import copy
import logging
import re
-from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
+from typing import Any, ClassVar, TYPE_CHECKING
import numpy as np
import pandas as pd
@@ -77,8 +77,8 @@ logger = logging.getLogger(__name__)
class CachedTimeOffset(TypedDict):
df: pd.DataFrame
- queries: List[str]
- cache_keys: List[Optional[str]]
+ queries: list[str]
+ cache_keys: list[str | None]
class QueryContextProcessor:
@@ -102,8 +102,8 @@ class QueryContextProcessor:
enforce_numerical_metrics: ClassVar[bool] = True
def get_df_payload(
- self, query_obj: QueryObject, force_cached: Optional[bool] = False
- ) -> Dict[str, Any]:
+ self, query_obj: QueryObject, force_cached: bool | None = False
+ ) -> dict[str, Any]:
"""Handles caching around the df payload retrieval"""
cache_key = self.query_cache_key(query_obj)
timeout = self.get_cache_timeout()
@@ -181,7 +181,7 @@ class QueryContextProcessor:
"label_map": label_map,
}
- def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
"""
Returns a QueryObject cache key for objects in self.queries
"""
@@ -248,8 +248,8 @@ class QueryContextProcessor:
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
# todo: should support "python_date_format" and "get_column" in each datasource
def _get_timestamp_format(
- source: BaseDatasource, column: Optional[str]
- ) -> Optional[str]:
+ source: BaseDatasource, column: str | None
+ ) -> str | None:
column_obj = source.get_column(column)
if (
column_obj
@@ -315,9 +315,9 @@ class QueryContextProcessor:
query_context = self._query_context
# ensure query_object is immutable
query_object_clone = copy.copy(query_object)
- queries: List[str] = []
- cache_keys: List[Optional[str]] = []
- rv_dfs: List[pd.DataFrame] = [df]
+ queries: list[str] = []
+ cache_keys: list[str | None] = []
+ rv_dfs: list[pd.DataFrame] = [df]
time_offsets = query_object.time_offsets
outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
@@ -449,7 +449,7 @@ class QueryContextProcessor:
rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
- def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]:
+ def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
if self._query_context.result_format in ChartDataResultFormat.table_like():
include_index = not isinstance(df.index, pd.RangeIndex)
columns = list(df.columns)
@@ -470,9 +470,9 @@ class QueryContextProcessor:
def get_payload(
self,
- cache_query_context: Optional[bool] = False,
+ cache_query_context: bool | None = False,
force_cached: bool = False,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""
# Get all the payloads from the QueryObjects
@@ -522,13 +522,13 @@ class QueryContextProcessor:
return generate_cache_key(cache_dict, key_prefix)
- def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
+ def get_annotation_data(self, query_obj: QueryObject) -> dict[str, Any]:
"""
:param query_context:
:param query_obj:
:return:
"""
- annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
+ annotation_data: dict[str, Any] = self.get_native_annotation_data(query_obj)
for annotation_layer in [
layer
for layer in query_obj.annotation_layers
@@ -541,7 +541,7 @@ class QueryContextProcessor:
return annotation_data
@staticmethod
- def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]:
+ def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]:
annotation_data = {}
annotation_layers = [
layer
@@ -576,8 +576,8 @@ class QueryContextProcessor:
@staticmethod
def get_viz_annotation_data(
- annotation_layer: Dict[str, Any], force: bool
- ) -> Dict[str, Any]:
+ annotation_layer: dict[str, Any], force: bool
+ ) -> dict[str, Any]:
chart = ChartDAO.find_by_id(annotation_layer["value"])
if not chart:
raise QueryObjectValidationError(_("The chart does not exist"))
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 802a1eed5b..dc02b774e5 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -21,7 +21,7 @@ import json
import logging
from datetime import datetime
from pprint import pformat
-from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
+from typing import Any, NamedTuple, TYPE_CHECKING
from flask import g
from flask_babel import gettext as _
@@ -81,58 +81,58 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
and druid. The query objects are constructed on the client.
"""
- annotation_layers: List[Dict[str, Any]]
- applied_time_extras: Dict[str, str]
+ annotation_layers: list[dict[str, Any]]
+ applied_time_extras: dict[str, str]
apply_fetch_values_predicate: bool
- columns: List[Column]
- datasource: Optional[BaseDatasource]
- extras: Dict[str, Any]
- filter: List[QueryObjectFilterClause]
- from_dttm: Optional[datetime]
- granularity: Optional[str]
- inner_from_dttm: Optional[datetime]
- inner_to_dttm: Optional[datetime]
+ columns: list[Column]
+ datasource: BaseDatasource | None
+ extras: dict[str, Any]
+ filter: list[QueryObjectFilterClause]
+ from_dttm: datetime | None
+ granularity: str | None
+ inner_from_dttm: datetime | None
+ inner_to_dttm: datetime | None
is_rowcount: bool
is_timeseries: bool
- metrics: Optional[List[Metric]]
+ metrics: list[Metric] | None
order_desc: bool
- orderby: List[OrderBy]
- post_processing: List[Dict[str, Any]]
- result_type: Optional[ChartDataResultType]
- row_limit: Optional[int]
+ orderby: list[OrderBy]
+ post_processing: list[dict[str, Any]]
+ result_type: ChartDataResultType | None
+ row_limit: int | None
row_offset: int
- series_columns: List[Column]
+ series_columns: list[Column]
series_limit: int
- series_limit_metric: Optional[Metric]
- time_offsets: List[str]
- time_shift: Optional[str]
- time_range: Optional[str]
- to_dttm: Optional[datetime]
+ series_limit_metric: Metric | None
+ time_offsets: list[str]
+ time_shift: str | None
+ time_range: str | None
+ to_dttm: datetime | None
def __init__( # pylint: disable=too-many-locals
self,
*,
- annotation_layers: Optional[List[Dict[str, Any]]] = None,
- applied_time_extras: Optional[Dict[str, str]] = None,
+ annotation_layers: list[dict[str, Any]] | None = None,
+ applied_time_extras: dict[str, str] | None = None,
apply_fetch_values_predicate: bool = False,
- columns: Optional[List[Column]] = None,
- datasource: Optional[BaseDatasource] = None,
- extras: Optional[Dict[str, Any]] = None,
- filters: Optional[List[QueryObjectFilterClause]] = None,
- granularity: Optional[str] = None,
+ columns: list[Column] | None = None,
+ datasource: BaseDatasource | None = None,
+ extras: dict[str, Any] | None = None,
+ filters: list[QueryObjectFilterClause] | None = None,
+ granularity: str | None = None,
is_rowcount: bool = False,
- is_timeseries: Optional[bool] = None,
- metrics: Optional[List[Metric]] = None,
+ is_timeseries: bool | None = None,
+ metrics: list[Metric] | None = None,
order_desc: bool = True,
- orderby: Optional[List[OrderBy]] = None,
- post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
- row_limit: Optional[int],
- row_offset: Optional[int] = None,
- series_columns: Optional[List[Column]] = None,
+ orderby: list[OrderBy] | None = None,
+ post_processing: list[dict[str, Any] | None] | None = None,
+ row_limit: int | None,
+ row_offset: int | None = None,
+ series_columns: list[Column] | None = None,
series_limit: int = 0,
- series_limit_metric: Optional[Metric] = None,
- time_range: Optional[str] = None,
- time_shift: Optional[str] = None,
+ series_limit_metric: Metric | None = None,
+ time_range: str | None = None,
+ time_shift: str | None = None,
**kwargs: Any,
):
self._set_annotation_layers(annotation_layers)
@@ -166,7 +166,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self._move_deprecated_extra_fields(kwargs)
def _set_annotation_layers(
- self, annotation_layers: Optional[List[Dict[str, Any]]]
+ self, annotation_layers: list[dict[str, Any]] | None
) -> None:
self.annotation_layers = [
layer
@@ -175,14 +175,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
if layer["annotationType"] != "FORMULA"
]
- def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
+ def _set_is_timeseries(self, is_timeseries: bool | None) -> None:
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
self.is_timeseries = (
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
)
- def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None:
+ def _set_metrics(self, metrics: list[Metric] | None = None) -> None:
# Support metric reference/definition in the format of
# 1. 'metric_name' - name of predefined metric
# 2. { label: 'label_name' } - legacy format for a predefined metric
@@ -195,16 +195,16 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
]
def _set_post_processing(
- self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
+ self, post_processing: list[dict[str, Any] | None] | None
) -> None:
post_processing = post_processing or []
self.post_processing = [post_proc for post_proc in post_processing if post_proc]
def _init_series_columns(
self,
- series_columns: Optional[List[Column]],
- metrics: Optional[List[Metric]],
- is_timeseries: Optional[bool],
+ series_columns: list[Column] | None,
+ metrics: list[Metric] | None,
+ is_timeseries: bool | None,
) -> None:
if series_columns:
self.series_columns = series_columns
@@ -213,7 +213,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
else:
self.series_columns = []
- def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
+ def _rename_deprecated_fields(self, kwargs: dict[str, Any]) -> None:
# rename deprecated fields
for field in DEPRECATED_FIELDS:
if field.old_name in kwargs:
@@ -233,7 +233,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
setattr(self, field.new_name, value)
- def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None:
+ def _move_deprecated_extra_fields(self, kwargs: dict[str, Any]) -> None:
# move deprecated extras fields to extras
for field in DEPRECATED_EXTRAS_FIELDS:
if field.old_name in kwargs:
@@ -256,19 +256,19 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.extras[field.new_name] = value
@property
- def metric_names(self) -> List[str]:
+ def metric_names(self) -> list[str]:
"""Return metrics names (labels), coerce adhoc metrics to strings."""
return get_metric_names(self.metrics or [])
@property
- def column_names(self) -> List[str]:
+ def column_names(self) -> list[str]:
"""Return column names (labels). Gives priority to groupbys if both groupbys
and metrics are non-empty, otherwise returns column labels."""
return get_column_names(self.columns)
def validate(
- self, raise_exceptions: Optional[bool] = True
- ) -> Optional[QueryObjectValidationError]:
+ self, raise_exceptions: bool | None = True
+ ) -> QueryObjectValidationError | None:
"""Validate query object"""
try:
self._validate_there_are_no_missing_series()
@@ -314,7 +314,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
)
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
query_object_dict = {
"apply_fetch_values_predicate": self.apply_fetch_values_predicate,
"columns": self.columns,
diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py
index 88cc7ca1b4..5676dc9eda 100644
--- a/superset/common/query_object_factory.py
+++ b/superset/common/query_object_factory.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
@@ -31,13 +31,13 @@ if TYPE_CHECKING:
class QueryObjectFactory: # pylint: disable=too-few-public-methods
- _config: Dict[str, Any]
+ _config: dict[str, Any]
_datasource_dao: DatasourceDAO
_session_maker: sessionmaker
def __init__(
self,
- app_configurations: Dict[str, Any],
+ app_configurations: dict[str, Any],
_datasource_dao: DatasourceDAO,
session_maker: sessionmaker,
):
@@ -48,11 +48,11 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
def create( # pylint: disable=too-many-arguments
self,
parent_result_type: ChartDataResultType,
- datasource: Optional[DatasourceDict] = None,
- extras: Optional[Dict[str, Any]] = None,
- row_limit: Optional[int] = None,
- time_range: Optional[str] = None,
- time_shift: Optional[str] = None,
+ datasource: DatasourceDict | None = None,
+ extras: dict[str, Any] | None = None,
+ row_limit: int | None = None,
+ time_range: str | None = None,
+ time_shift: str | None = None,
**kwargs: Any,
) -> QueryObject:
datasource_model_instance = None
@@ -84,13 +84,13 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
def _process_extras( # pylint: disable=no-self-use
self,
- extras: Optional[Dict[str, Any]],
- ) -> Dict[str, Any]:
+ extras: dict[str, Any] | None,
+ ) -> dict[str, Any]:
extras = extras or {}
return extras
def _process_row_limit(
- self, row_limit: Optional[int], result_type: ChartDataResultType
+ self, row_limit: int | None, result_type: ChartDataResultType
) -> int:
default_row_limit = (
self._config["SAMPLES_ROW_LIMIT"]
diff --git a/superset/common/tags.py b/superset/common/tags.py
index 706192913a..6066d0eec7 100644
--- a/superset/common/tags.py
+++ b/superset/common/tags.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List
+from typing import Any
from sqlalchemy import MetaData
from sqlalchemy.exc import IntegrityError
@@ -25,7 +25,7 @@ from superset.tags.models import ObjectTypes, TagTypes
def add_types_to_charts(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
slices = metadata.tables["slices"]
@@ -57,7 +57,7 @@ def add_types_to_charts(
def add_types_to_dashboards(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
@@ -89,7 +89,7 @@ def add_types_to_dashboards(
def add_types_to_saved_queries(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
saved_query = metadata.tables["saved_query"]
@@ -121,7 +121,7 @@ def add_types_to_saved_queries(
def add_types_to_datasets(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
tables = metadata.tables["tables"]
@@ -237,7 +237,7 @@ def add_types(metadata: MetaData) -> None:
def add_owners_to_charts(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
slices = metadata.tables["slices"]
@@ -273,7 +273,7 @@ def add_owners_to_charts(
def add_owners_to_dashboards(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
@@ -309,7 +309,7 @@ def add_owners_to_dashboards(
def add_owners_to_saved_queries(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
saved_query = metadata.tables["saved_query"]
@@ -345,7 +345,7 @@ def add_owners_to_saved_queries(
def add_owners_to_datasets(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
tables = metadata.tables["tables"]
diff --git a/superset/common/utils/dataframe_utils.py b/superset/common/utils/dataframe_utils.py
index 4dd62e3b5d..a3421f6431 100644
--- a/superset/common/utils/dataframe_utils.py
+++ b/superset/common/utils/dataframe_utils.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import datetime
-from typing import Any, List, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
import numpy as np
import pandas as pd
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
def left_join_df(
left_df: pd.DataFrame,
right_df: pd.DataFrame,
- join_keys: List[str],
+ join_keys: list[str],
) -> pd.DataFrame:
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
df.reset_index(inplace=True)
diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py
index 6c1b268f46..a0fb65b20d 100644
--- a/superset/common/utils/query_cache_manager.py
+++ b/superset/common/utils/query_cache_manager.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any
from flask_caching import Cache
from pandas import DataFrame
@@ -37,7 +37,7 @@ config = app.config
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
-_cache: Dict[CacheRegion, Cache] = {
+_cache: dict[CacheRegion, Cache] = {
CacheRegion.DEFAULT: cache_manager.cache,
CacheRegion.DATA: cache_manager.data_cache,
}
@@ -53,17 +53,17 @@ class QueryCacheManager:
self,
df: DataFrame = DataFrame(),
query: str = "",
- annotation_data: Optional[Dict[str, Any]] = None,
- applied_template_filters: Optional[List[str]] = None,
- applied_filter_columns: Optional[List[Column]] = None,
- rejected_filter_columns: Optional[List[Column]] = None,
- status: Optional[str] = None,
- error_message: Optional[str] = None,
+ annotation_data: dict[str, Any] | None = None,
+ applied_template_filters: list[str] | None = None,
+ applied_filter_columns: list[Column] | None = None,
+ rejected_filter_columns: list[Column] | None = None,
+ status: str | None = None,
+ error_message: str | None = None,
is_loaded: bool = False,
- stacktrace: Optional[str] = None,
- is_cached: Optional[bool] = None,
- cache_dttm: Optional[str] = None,
- cache_value: Optional[Dict[str, Any]] = None,
+ stacktrace: str | None = None,
+ is_cached: bool | None = None,
+ cache_dttm: str | None = None,
+ cache_value: dict[str, Any] | None = None,
) -> None:
self.df = df
self.query = query
@@ -85,10 +85,10 @@ class QueryCacheManager:
self,
key: str,
query_result: QueryResult,
- annotation_data: Optional[Dict[str, Any]] = None,
- force_query: Optional[bool] = False,
- timeout: Optional[int] = None,
- datasource_uid: Optional[str] = None,
+ annotation_data: dict[str, Any] | None = None,
+ force_query: bool | None = False,
+ timeout: int | None = None,
+ datasource_uid: str | None = None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
"""
@@ -136,11 +136,11 @@ class QueryCacheManager:
@classmethod
def get(
cls,
- key: Optional[str],
+ key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
- force_query: Optional[bool] = False,
- force_cached: Optional[bool] = False,
- ) -> "QueryCacheManager":
+ force_query: bool | None = False,
+ force_cached: bool | None = False,
+ ) -> QueryCacheManager:
"""
Initialize QueryCacheManager by query-cache key
"""
@@ -190,10 +190,10 @@ class QueryCacheManager:
@staticmethod
def set(
- key: Optional[str],
- value: Dict[str, Any],
- timeout: Optional[int] = None,
- datasource_uid: Optional[str] = None,
+ key: str | None,
+ value: dict[str, Any],
+ timeout: int | None = None,
+ datasource_uid: str | None = None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
"""
@@ -204,7 +204,7 @@ class QueryCacheManager:
@staticmethod
def delete(
- key: Optional[str],
+ key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
if key:
@@ -212,7 +212,7 @@ class QueryCacheManager:
@staticmethod
def has(
- key: Optional[str],
+ key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> bool:
return bool(_cache[region].get(key)) if key else False
diff --git a/superset/common/utils/time_range_utils.py b/superset/common/utils/time_range_utils.py
index fa6a5244b2..5f9139c047 100644
--- a/superset/common/utils/time_range_utils.py
+++ b/superset/common/utils/time_range_utils.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from datetime import datetime
-from typing import Any, cast, Dict, Optional, Tuple
+from typing import Any, cast
from superset import app
from superset.common.query_object import QueryObject
@@ -26,10 +26,10 @@ from superset.utils.date_parser import get_since_until
def get_since_until_from_time_range(
- time_range: Optional[str] = None,
- time_shift: Optional[str] = None,
- extras: Optional[Dict[str, Any]] = None,
-) -> Tuple[Optional[datetime], Optional[datetime]]:
+ time_range: str | None = None,
+ time_shift: str | None = None,
+ extras: dict[str, Any] | None = None,
+) -> tuple[datetime | None, datetime | None]:
return get_since_until(
relative_start=(extras or {}).get(
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
@@ -45,7 +45,7 @@ def get_since_until_from_time_range(
# pylint: disable=invalid-name
def get_since_until_from_query_object(
query_object: QueryObject,
-) -> Tuple[Optional[datetime], Optional[datetime]]:
+) -> tuple[datetime | None, datetime | None]:
"""
this function will return since and until by tuple if
1) the time_range is in the query object.
diff --git a/superset/config.py b/superset/config.py
index 7d9359d14f..434456386d 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -33,20 +33,7 @@ import sys
from collections import OrderedDict
from datetime import timedelta
from email.mime.multipart import MIMEMultipart
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- Literal,
- Optional,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- TypedDict,
- Union,
-)
+from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
import pkg_resources
from cachelib.base import BaseCache
@@ -114,17 +101,17 @@ PACKAGE_JSON_FILE = pkg_resources.resource_filename(
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
-def _try_json_readversion(filepath: str) -> Optional[str]:
+def _try_json_readversion(filepath: str) -> str | None:
try:
- with open(filepath, "r") as f:
+ with open(filepath) as f:
return json.load(f).get("version")
except Exception: # pylint: disable=broad-except
return None
-def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
+def _try_json_readsha(filepath: str, length: int) -> str | None:
try:
- with open(filepath, "r") as f:
+ with open(filepath) as f:
return json.load(f).get("GIT_SHA")[:length]
except Exception: # pylint: disable=broad-except
return None
@@ -275,7 +262,7 @@ ENABLE_PROXY_FIX = False
PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1}
# Configuration for scheduling queries from SQL Lab.
-SCHEDULED_QUERIES: Dict[str, Any] = {}
+SCHEDULED_QUERIES: dict[str, Any] = {}
# ------------------------------
# GLOBALS FOR APP Builder
@@ -294,7 +281,7 @@ LOGO_TARGET_PATH = None
LOGO_TOOLTIP = ""
# Specify any text that should appear to the right of the logo
-LOGO_RIGHT_TEXT: Union[Callable[[], str], str] = ""
+LOGO_RIGHT_TEXT: Callable[[], str] | str = ""
# Enables SWAGGER UI for superset openapi spec
# ex: http://localhost:8080/swagger/v1
@@ -347,7 +334,7 @@ AUTH_TYPE = AUTH_DB
# Grant public role the same set of permissions as for a selected builtin role.
# This is useful if one wants to enable anonymous users to view
# dashboards. Explicit grant on specific datasets is still required.
-PUBLIC_ROLE_LIKE: Optional[str] = None
+PUBLIC_ROLE_LIKE: str | None = None
# ---------------------------------------------------
# Babel config for translations
@@ -390,8 +377,8 @@ LANGUAGES = {}
class D3Format(TypedDict, total=False):
decimal: str
thousands: str
- grouping: List[int]
- currency: List[str]
+ grouping: list[int]
+ currency: list[str]
D3_FORMAT: D3Format = {}
@@ -404,7 +391,7 @@ D3_FORMAT: D3Format = {}
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
-DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
+DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# Experimental feature introducing a client (browser) cache
"CLIENT_CACHE": False, # deprecated
"DISABLE_DATASET_SOURCE_EDIT": False, # deprecated
@@ -527,7 +514,7 @@ DEFAULT_FEATURE_FLAGS.update(
)
# This is merely a default.
-FEATURE_FLAGS: Dict[str, bool] = {}
+FEATURE_FLAGS: dict[str, bool] = {}
# A function that receives a dict of all feature flags
# (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS)
@@ -543,7 +530,7 @@ FEATURE_FLAGS: Dict[str, bool] = {}
# if hasattr(g, "user") and g.user.is_active:
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
# return feature_flags_dict
-GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None
+GET_FEATURE_FLAGS_FUNC: Callable[[dict[str, bool]], dict[str, bool]] | None = None
# A function that receives a feature flag name and an optional default value.
# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the
# evaluation of all feature flags when just evaluating a single one.
@@ -551,7 +538,7 @@ GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] =
# Note that the default `get_feature_flags` will evaluate each feature with this
# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC
# and IS_FEATURE_ENABLED_FUNC in conjunction.
-IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
+IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
# A function that expands/overrides the frontend `bootstrap_data.common` object.
# Can be used to implement custom frontend functionality,
# or dynamically change certain configs.
@@ -563,7 +550,7 @@ IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
# Takes as a parameter the common bootstrap payload before transformations.
# Returns a dict containing data that should be added or overridden to the payload.
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
- [Dict[str, Any]], Dict[str, Any]
+ [dict[str, Any]], dict[str, Any]
] = lambda data: {} # default: empty dict
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
@@ -580,7 +567,7 @@ COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
# }]
# This is merely a default
-EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
+EXTRA_CATEGORICAL_COLOR_SCHEMES: list[dict[str, Any]] = []
# THEME_OVERRIDES is used for adding custom theme to superset
# example code for "My theme" custom scheme
@@ -599,7 +586,7 @@ EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
# }
# }
-THEME_OVERRIDES: Dict[str, Any] = {}
+THEME_OVERRIDES: dict[str, Any] = {}
# EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes
# EXTRA_SEQUENTIAL_COLOR_SCHEMES = [
@@ -615,7 +602,7 @@ THEME_OVERRIDES: Dict[str, Any] = {}
# }]
# This is merely a default
-EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
+EXTRA_SEQUENTIAL_COLOR_SCHEMES: list[dict[str, Any]] = []
# ---------------------------------------------------
# Thumbnail config (behind feature flag)
@@ -626,7 +613,7 @@ EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
# `superset.tasks.types.ExecutorType` for a full list of executor options.
# To always use a fixed user account, use the following configuration:
# THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM]
-THUMBNAIL_SELENIUM_USER: Optional[str] = "admin"
+THUMBNAIL_SELENIUM_USER: str | None = "admin"
THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
# By default, thumbnail digests are calculated based on various parameters in the
@@ -639,10 +626,10 @@ THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
# `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in
# user if the executor type is equal to `ExecutorType.CURRENT_USER`)
# and return the final digest string:
-THUMBNAIL_DASHBOARD_DIGEST_FUNC: Optional[
+THUMBNAIL_DASHBOARD_DIGEST_FUNC: None | (
Callable[[Dashboard, ExecutorType, str], str]
-] = None
-THUMBNAIL_CHART_DIGEST_FUNC: Optional[Callable[[Slice, ExecutorType, str], str]] = None
+) = None
+THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str] | None = None
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
"CACHE_TYPE": "NullCache",
@@ -714,7 +701,7 @@ STORE_CACHE_KEYS_IN_METADATA_DB = False
# CORS Options
ENABLE_CORS = False
-CORS_OPTIONS: Dict[Any, Any] = {}
+CORS_OPTIONS: dict[Any, Any] = {}
# Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner.
# Disabling this option is not recommended for security reasons. If you wish to allow
@@ -736,7 +723,7 @@ HTML_SANITIZATION = True
# }
# }
# Be careful when extending the default schema to avoid XSS attacks.
-HTML_SANITIZATION_SCHEMA_EXTENSIONS: Dict[str, Any] = {}
+HTML_SANITIZATION_SCHEMA_EXTENSIONS: dict[str, Any] = {}
# Chrome allows up to 6 open connections per domain at a time. When there are more
# than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for
@@ -768,13 +755,13 @@ EXCEL_EXPORT = {"encoding": "utf-8"}
# time grains in superset/db_engine_specs/base.py).
# For example: to disable 1 second time grain:
# TIME_GRAIN_DENYLIST = ['PT1S']
-TIME_GRAIN_DENYLIST: List[str] = []
+TIME_GRAIN_DENYLIST: list[str] = []
# Additional time grains to be supported using similar definitions as in
# superset/db_engine_specs/base.py.
# For example: To add a new 2 second time grain:
# TIME_GRAIN_ADDONS = {'PT2S': '2 second'}
-TIME_GRAIN_ADDONS: Dict[str, str] = {}
+TIME_GRAIN_ADDONS: dict[str, str] = {}
# Implementation of additional time grains per engine.
# The column to be truncated is denoted `{col}` in the expression.
@@ -784,7 +771,7 @@ TIME_GRAIN_ADDONS: Dict[str, str] = {}
# 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)'
# }
# }
-TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
+TIME_GRAIN_ADDON_EXPRESSIONS: dict[str, dict[str, str]] = {}
# ---------------------------------------------------
# List of viz_types not allowed in your environment
@@ -792,7 +779,7 @@ TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
# VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap']
# ---------------------------------------------------
-VIZ_TYPE_DENYLIST: List[str] = []
+VIZ_TYPE_DENYLIST: list[str] = []
# --------------------------------------------------
# Modules, datasources and middleware to be registered
@@ -802,8 +789,8 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
("superset.connectors.sqla.models", ["SqlaTable"]),
]
)
-ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
-ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
+ADDITIONAL_MODULE_DS_MAP: dict[str, list[str]] = {}
+ADDITIONAL_MIDDLEWARE: list[Callable[..., Any]] = []
# 1) https://docs.python-guide.org/writing/logging/
# 2) https://docs.python.org/2/library/logging.config.html
@@ -925,9 +912,9 @@ CELERY_CONFIG = CeleryConfig # pylint: disable=invalid-name
# within the app
# OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will
# override anything set within the app
-DEFAULT_HTTP_HEADERS: Dict[str, Any] = {}
-OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {}
-HTTP_HEADERS: Dict[str, Any] = {}
+DEFAULT_HTTP_HEADERS: dict[str, Any] = {}
+OVERRIDE_HTTP_HEADERS: dict[str, Any] = {}
+HTTP_HEADERS: dict[str, Any] = {}
# The db id here results in selecting this one as a default in SQL Lab
DEFAULT_DB_ID = None
@@ -974,8 +961,8 @@ SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
# return out
#
# QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter}
-QUERY_COST_FORMATTERS_BY_ENGINE: Dict[
- str, Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]
+QUERY_COST_FORMATTERS_BY_ENGINE: dict[
+ str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]]
] = {}
# Flag that controls if limit should be enforced on the CTA (create table as queries).
@@ -1000,13 +987,13 @@ SQLLAB_CTAS_NO_LIMIT = False
# else:
# return f'tmp_{schema}'
# Function accepts database object, user object, schema name and sql that will be run.
-SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
+SQLLAB_CTAS_SCHEMA_NAME_FUNC: None | (
Callable[[Database, models.User, str, str], str]
-] = None
+) = None
# If enabled, it can be used to store the results of long-running queries
# in SQL Lab by using the "Run Async" button/feature
-RESULTS_BACKEND: Optional[BaseCache] = None
+RESULTS_BACKEND: BaseCache | None = None
# Use PyArrow and MessagePack for async query results serialization,
# rather than JSON. This feature requires additional testing from the
@@ -1028,7 +1015,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
database: Database,
user: models.User, # pylint: disable=unused-argument
- schema: Optional[str],
+ schema: str | None,
) -> str:
# Note the final empty path enforces a trailing slash.
return os.path.join(
@@ -1038,14 +1025,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# The namespace within hive where the tables created from
# uploading CSVs will be stored.
-UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
+UPLOADED_CSV_HIVE_NAMESPACE: str | None = None
# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_file_upload
# db configuration and a result of this function.
# mypy doesn't catch that if case ensures list content being always str
-ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], List[str]] = (
+ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = (
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
if UPLOADED_CSV_HIVE_NAMESPACE
else []
@@ -1062,7 +1049,7 @@ CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
# It's important to make sure that the objects exposed (as well as objects attached
# to those objets) are harmless. We recommend only exposing simple/pure functions that
# return native types.
-JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
+JINJA_CONTEXT_ADDONS: dict[str, Callable[..., Any]] = {}
# A dictionary of macro template processors (by engine) that gets merged into global
# template processors. The existing template processors get updated with this
@@ -1070,7 +1057,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
# dictionary. The customized addons don't necessarily need to use Jinja templating
# language. This allows you to define custom logic to process templates on a per-engine
# basis. Example value = `{"presto": CustomPrestoTemplateProcessor}`
-CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
+CUSTOM_TEMPLATE_PROCESSORS: dict[str, type[BaseTemplateProcessor]] = {}
# Roles that are controlled by the API / Superset and should not be changes
# by humans.
@@ -1125,7 +1112,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app
-BLUEPRINTS: List[Blueprint] = []
+BLUEPRINTS: list[Blueprint] = []
# Provide a callable that receives a tracking_url and returns another
# URL. This is used to translate internal Hadoop job tracker URL
@@ -1142,7 +1129,7 @@ TRACKING_URL_TRANSFORMER = lambda url: url
# customize the polling time of each engine
-DB_POLL_INTERVAL_SECONDS: Dict[str, int] = {}
+DB_POLL_INTERVAL_SECONDS: dict[str, int] = {}
# Interval between consecutive polls when using Presto Engine
# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression
@@ -1159,7 +1146,7 @@ PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds())
# "another_auth_method": auth_method,
# },
# }
-ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {}
+ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {}
# The id of a template dashboard that should be copied to every new user
DASHBOARD_TEMPLATE_ID = None
@@ -1224,14 +1211,14 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# Owners, filters for created_by, etc.
# The users can also be excluded by overriding the get_exclude_users_from_lists method
# in security manager
-EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None
+EXCLUDE_USERS_FROM_LISTS: list[str] | None = None
# For database connections, this dictionary will remove engines from the available
# list/dropdown if you do not want these dbs to show as available.
# The available list is generated by driver installed, and some engines have multiple
# drivers.
# e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}}
-DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {}
+DBS_AVAILABLE_DENYLIST: dict[str, set[str]] = {}
# This auth provider is used by background (offline) tasks that need to access
# protected resources. Can be overridden by end users in order to support
@@ -1261,7 +1248,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True
# ExecutorType.OWNER,
# ExecutorType.SELENIUM,
# ]
-ALERT_REPORTS_EXECUTE_AS: List[ExecutorType] = [ExecutorType.OWNER]
+ALERT_REPORTS_EXECUTE_AS: list[ExecutorType] = [ExecutorType.OWNER]
# if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout
# Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG
ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds())
@@ -1286,7 +1273,7 @@ EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] "
EMAIL_REPORTS_CTA = "Explore in Superset"
# Slack API token for the superset reports, either string or callable
-SLACK_API_TOKEN: Optional[Union[Callable[[], str], str]] = None
+SLACK_API_TOKEN: Callable[[], str] | str | None = None
SLACK_PROXY = None
# The webdriver to use for generating reports. Use one of the following
@@ -1310,7 +1297,7 @@ WEBDRIVER_WINDOW = {
WEBDRIVER_AUTH_FUNC = None
# Any config options to be passed as-is to the webdriver
-WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {"service_log_path": "/dev/null"}
+WEBDRIVER_CONFIGURATION: dict[Any, Any] = {"service_log_path": "/dev/null"}
# Additional args to be passed as arguments to the config object
# Note: If using Chrome, you'll want to add the "--marionette" arg.
@@ -1353,7 +1340,7 @@ SQL_VALIDATORS_BY_ENGINE = {
# displayed prominently in the "Add Database" dialog. You should
# use the "engine_name" attribute of the corresponding DB engine spec
# in `superset/db_engine_specs/`.
-PREFERRED_DATABASES: List[str] = [
+PREFERRED_DATABASES: list[str] = [
"PostgreSQL",
"Presto",
"MySQL",
@@ -1386,7 +1373,7 @@ TALISMAN_CONFIG = {
#
SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS?
SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls?
-SESSION_COOKIE_SAMESITE: Optional[Literal["None", "Lax", "Strict"]] = "Lax"
+SESSION_COOKIE_SAMESITE: Literal["None", "Lax", "Strict"] | None = "Lax"
# Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection
SESSION_PROTECTION = "strong"
@@ -1418,7 +1405,7 @@ DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]
# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
-SSL_CERT_PATH: Optional[str] = None
+SSL_CERT_PATH: str | None = None
# SQLA table mutator, every time we fetch the metadata for a certain table
# (superset.connectors.sqla.models.SqlaTable), we call this hook
@@ -1443,9 +1430,9 @@ GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
-GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: Optional[
+GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (
Literal["None", "Lax", "Strict"]
-] = None
+) = None
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
@@ -1461,7 +1448,7 @@ GUEST_TOKEN_JWT_ALGO = "HS256"
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
# Guest token audience for the embedded superset, either string or callable
-GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
+GUEST_TOKEN_JWT_AUDIENCE: Callable[[], str] | str | None = None
# A SQL dataset health check. Note if enabled it is strongly advised that the callable
# be memoized to aid with performance, i.e.,
@@ -1492,7 +1479,7 @@ GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
# cache_manager.cache.delete_memoized(func)
# cache_manager.cache.set(name, code, timeout=0)
#
-DATASET_HEALTH_CHECK: Optional[Callable[["SqlaTable"], str]] = None
+DATASET_HEALTH_CHECK: Callable[[SqlaTable], str] | None = None
# Do not show user info or profile in the menu
MENU_HIDE_USER_INFO = False
@@ -1502,7 +1489,7 @@ MENU_HIDE_USER_INFO = False
ENABLE_BROAD_ACTIVITY_ACCESS = True
# the advanced data type key should correspond to that set in the column metadata
-ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
+ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = {
"internet_address": internet_address,
"port": internet_port,
}
@@ -1514,9 +1501,9 @@ ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
# "Xyz",
# [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}],
# )
-WELCOME_PAGE_LAST_TAB: Union[
- Literal["examples", "all"], Tuple[str, List[Dict[str, Any]]]
-] = "all"
+WELCOME_PAGE_LAST_TAB: (
+ Literal["examples", "all"] | tuple[str, list[dict[str, Any]]]
+) = "all"
# Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag.
# 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base)
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 2cb0d54c51..d43d078639 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -16,21 +16,12 @@
# under the License.
from __future__ import annotations
+import builtins
import json
+from collections.abc import Hashable
from datetime import datetime
from enum import Enum
-from typing import (
- Any,
- Dict,
- Hashable,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, TYPE_CHECKING
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __
@@ -89,23 +80,23 @@ class BaseDatasource(
# ---------------------------------------------------------------
# class attributes to define when deriving BaseDatasource
# ---------------------------------------------------------------
- __tablename__: Optional[str] = None # {connector_name}_datasource
- baselink: Optional[str] = None # url portion pointing to ModelView endpoint
+ __tablename__: str | None = None # {connector_name}_datasource
+ baselink: str | None = None # url portion pointing to ModelView endpoint
@property
- def column_class(self) -> Type["BaseColumn"]:
+ def column_class(self) -> type[BaseColumn]:
# link to derivative of BaseColumn
raise NotImplementedError()
@property
- def metric_class(self) -> Type["BaseMetric"]:
+ def metric_class(self) -> type[BaseMetric]:
# link to derivative of BaseMetric
raise NotImplementedError()
- owner_class: Optional[User] = None
+ owner_class: User | None = None
# Used to do code highlighting when displaying the query in the UI
- query_language: Optional[str] = None
+ query_language: str | None = None
# Only some datasources support Row Level Security
is_rls_supported: bool = False
@@ -131,9 +122,9 @@ class BaseDatasource(
is_managed_externally = Column(Boolean, nullable=False, default=False)
external_url = Column(Text, nullable=True)
- sql: Optional[str] = None
- owners: List[User]
- update_from_object_fields: List[str]
+ sql: str | None = None
+ owners: list[User]
+ update_from_object_fields: list[str]
extra_import_fields = ["is_managed_externally", "external_url"]
@@ -142,7 +133,7 @@ class BaseDatasource(
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
@property
- def owners_data(self) -> List[Dict[str, Any]]:
+ def owners_data(self) -> list[dict[str, Any]]:
return [
{
"first_name": o.first_name,
@@ -167,8 +158,8 @@ class BaseDatasource(
),
)
- columns: List["BaseColumn"] = []
- metrics: List["BaseMetric"] = []
+ columns: list[BaseColumn] = []
+ metrics: list[BaseMetric] = []
@property
def type(self) -> str:
@@ -180,11 +171,11 @@ class BaseDatasource(
return f"{self.id}__{self.type}"
@property
- def column_names(self) -> List[str]:
+ def column_names(self) -> list[str]:
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property
- def columns_types(self) -> Dict[str, str]:
+ def columns_types(self) -> dict[str, str]:
return {c.column_name: c.type for c in self.columns}
@property
@@ -196,26 +187,26 @@ class BaseDatasource(
raise NotImplementedError()
@property
- def connection(self) -> Optional[str]:
+ def connection(self) -> str | None:
"""String representing the context of the Datasource"""
return None
@property
- def schema(self) -> Optional[str]:
+ def schema(self) -> str | None:
"""String representing the schema of the Datasource (if it applies)"""
return None
@property
- def filterable_column_names(self) -> List[str]:
+ def filterable_column_names(self) -> list[str]:
return sorted([c.column_name for c in self.columns if c.filterable])
@property
- def dttm_cols(self) -> List[str]:
+ def dttm_cols(self) -> list[str]:
return []
@property
def url(self) -> str:
- return "/{}/edit/{}".format(self.baselink, self.id)
+ return f"/{self.baselink}/edit/{self.id}"
@property
def explore_url(self) -> str:
@@ -224,10 +215,10 @@ class BaseDatasource(
return f"/explore/?datasource_type={self.type}&datasource_id={self.id}"
@property
- def column_formats(self) -> Dict[str, Optional[str]]:
+ def column_formats(self) -> dict[str, str | None]:
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
- def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
+ def add_missing_metrics(self, metrics: list[BaseMetric]) -> None:
existing_metrics = {m.metric_name for m in self.metrics}
for metric in metrics:
if metric.metric_name not in existing_metrics:
@@ -235,7 +226,7 @@ class BaseDatasource(
self.metrics.append(metric)
@property
- def short_data(self) -> Dict[str, Any]:
+ def short_data(self) -> dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
"edit_url": self.url,
@@ -249,11 +240,11 @@ class BaseDatasource(
}
@property
- def select_star(self) -> Optional[str]:
+ def select_star(self) -> str | None:
pass
@property
- def order_by_choices(self) -> List[Tuple[str, str]]:
+ def order_by_choices(self) -> list[tuple[str, str]]:
choices = []
# self.column_names return sorted column_names
for column_name in self.column_names:
@@ -267,7 +258,7 @@ class BaseDatasource(
return choices
@property
- def verbose_map(self) -> Dict[str, str]:
+ def verbose_map(self) -> dict[str, str]:
verb_map = {"__timestamp": "Time"}
verb_map.update(
{o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
@@ -278,7 +269,7 @@ class BaseDatasource(
return verb_map
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
# simple fields
@@ -313,8 +304,8 @@ class BaseDatasource(
}
def data_for_slices( # pylint: disable=too-many-locals
- self, slices: List[Slice]
- ) -> Dict[str, Any]:
+ self, slices: list[Slice]
+ ) -> dict[str, Any]:
"""
The representation of the datasource containing only the required data
to render the provided slices.
@@ -381,8 +372,8 @@ class BaseDatasource(
if metric["metric_name"] in metric_names
]
- filtered_columns: List[Column] = []
- column_types: Set[GenericDataType] = set()
+ filtered_columns: list[Column] = []
+ column_types: set[GenericDataType] = set()
for column in data["columns"]:
generic_type = column.get("type_generic")
if generic_type is not None:
@@ -413,18 +404,18 @@ class BaseDatasource(
@staticmethod
def filter_values_handler( # pylint: disable=too-many-arguments
- values: Optional[FilterValues],
+ values: FilterValues | None,
operator: str,
target_generic_type: GenericDataType,
- target_native_type: Optional[str] = None,
+ target_native_type: str | None = None,
is_list_target: bool = False,
- db_engine_spec: Optional[Type[BaseEngineSpec]] = None,
- db_extra: Optional[Dict[str, Any]] = None,
- ) -> Optional[FilterValues]:
+ db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
+ db_extra: dict[str, Any] | None = None,
+ ) -> FilterValues | None:
if values is None:
return None
- def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
+ def handle_single_value(value: FilterValue | None) -> FilterValue | None:
if operator == utils.FilterOperator.TEMPORAL_RANGE:
return value
if (
@@ -464,7 +455,7 @@ class BaseDatasource(
values = values[0] if values else None
return values
- def external_metadata(self) -> List[Dict[str, str]]:
+ def external_metadata(self) -> list[dict[str, str]]:
"""Returns column information from the external system"""
raise NotImplementedError()
@@ -483,7 +474,7 @@ class BaseDatasource(
"""
raise NotImplementedError()
- def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
@@ -494,7 +485,7 @@ class BaseDatasource(
def default_query(qry: Query) -> Query:
return qry
- def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
+ def get_column(self, column_name: str | None) -> BaseColumn | None:
if not column_name:
return None
for col in self.columns:
@@ -504,11 +495,11 @@ class BaseDatasource(
@staticmethod
def get_fk_many_from_list(
- object_list: List[Any],
- fkmany: List[Column],
- fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
+ object_list: list[Any],
+ fkmany: list[Column],
+ fkmany_class: builtins.type[BaseColumn | BaseMetric],
key_attr: str,
- ) -> List[Column]:
+ ) -> list[Column]:
"""Update ORM one-to-many list from object list
Used for syncing metrics and columns using the same code"""
@@ -541,7 +532,7 @@ class BaseDatasource(
fkmany += new_fks
return fkmany
- def update_from_object(self, obj: Dict[str, Any]) -> None:
+ def update_from_object(self, obj: dict[str, Any]) -> None:
"""Update datasource from a data structure
The UI's table editor crafts a complex data structure that
@@ -578,7 +569,7 @@ class BaseDatasource(
def get_extra_cache_keys( # pylint: disable=no-self-use
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
- ) -> List[Hashable]:
+ ) -> list[Hashable]:
"""If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
@@ -607,14 +598,14 @@ class BaseDatasource(
@classmethod
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
- ) -> Optional["BaseDatasource"]:
+ ) -> BaseDatasource | None:
raise NotImplementedError()
class BaseColumn(AuditMixinNullable, ImportExportMixin):
"""Interface for column"""
- __tablename__: Optional[str] = None # {connector_name}_column
+ __tablename__: str | None = None # {connector_name}_column
id = Column(Integer, primary_key=True)
column_name = Column(String(255), nullable=False)
@@ -628,7 +619,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
is_dttm = None
# [optional] Set this to support import/export functionality
- export_fields: List[Any] = []
+ export_fields: list[Any] = []
def __repr__(self) -> str:
return str(self.column_name)
@@ -666,7 +657,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types))
@property
- def type_generic(self) -> Optional[utils.GenericDataType]:
+ def type_generic(self) -> utils.GenericDataType | None:
if self.is_string:
return utils.GenericDataType.STRING
if self.is_boolean:
@@ -686,7 +677,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
raise NotImplementedError()
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"id",
"column_name",
@@ -705,7 +696,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
class BaseMetric(AuditMixinNullable, ImportExportMixin):
"""Interface for Metrics"""
- __tablename__: Optional[str] = None # {connector_name}_metric
+ __tablename__: str | None = None # {connector_name}_metric
id = Column(Integer, primary_key=True)
metric_name = Column(String(255), nullable=False)
@@ -730,7 +721,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
"""
@property
- def perm(self) -> Optional[str]:
+ def perm(self) -> str | None:
raise NotImplementedError()
@property
@@ -738,7 +729,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
raise NotImplementedError()
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"id",
"metric_name",
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 8833d6f6cb..41a9c89757 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -22,21 +22,10 @@ import json
import logging
import re
from collections import defaultdict
+from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- Hashable,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
-)
+from typing import Any, Callable, cast
import dateutil.parser
import numpy as np
@@ -136,9 +125,9 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
@dataclass
class MetadataResult:
- added: List[str] = field(default_factory=list)
- removed: List[str] = field(default_factory=list)
- modified: List[str] = field(default_factory=list)
+ added: list[str] = field(default_factory=list)
+ removed: list[str] = field(default_factory=list)
+ modified: list[str] = field(default_factory=list)
class AnnotationDatasource(BaseDatasource):
@@ -190,7 +179,7 @@ class AnnotationDatasource(BaseDatasource):
def get_query_str(self, query_obj: QueryObjectDict) -> str:
raise NotImplementedError()
- def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
raise NotImplementedError()
@@ -201,7 +190,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
- table: Mapped["SqlaTable"] = relationship(
+ table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="columns",
)
@@ -263,15 +252,15 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.type_generic == GenericDataType.TEMPORAL
@property
- def db_engine_spec(self) -> Type[BaseEngineSpec]:
+ def db_engine_spec(self) -> type[BaseEngineSpec]:
return self.table.db_engine_spec
@property
- def db_extra(self) -> Dict[str, Any]:
+ def db_extra(self) -> dict[str, Any]:
return self.table.database.get_extra()
@property
- def type_generic(self) -> Optional[utils.GenericDataType]:
+ def type_generic(self) -> utils.GenericDataType | None:
if self.is_dttm:
return GenericDataType.TEMPORAL
@@ -310,8 +299,8 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def get_sqla_col(
self,
- label: Optional[str] = None,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ label: str | None = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.column_name
db_engine_spec = self.db_engine_spec
@@ -332,10 +321,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def get_timestamp_expression(
self,
- time_grain: Optional[str],
- label: Optional[str] = None,
- template_processor: Optional[BaseTemplateProcessor] = None,
- ) -> Union[TimestampExpression, Label]:
+ time_grain: str | None,
+ label: str | None = None,
+ template_processor: BaseTemplateProcessor | None = None,
+ ) -> TimestampExpression | Label:
"""
Return a SQLAlchemy Core element representation of self to be used in a query.
@@ -365,7 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.table.make_sqla_column_compatible(time_expr, label)
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"id",
"column_name",
@@ -399,7 +388,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
- table: Mapped["SqlaTable"] = relationship(
+ table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="metrics",
)
@@ -425,8 +414,8 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
def get_sqla_col(
self,
- label: Optional[str] = None,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ label: str | None = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.metric_name
expression = self.expression
@@ -437,7 +426,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
return self.table.make_sqla_column_compatible(sqla_col, label)
@property
- def perm(self) -> Optional[str]:
+ def perm(self) -> str | None:
return (
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
obj=self, parent_name=self.table.full_name
@@ -446,11 +435,11 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
else None
)
- def get_perm(self) -> Optional[str]:
+ def get_perm(self) -> str | None:
return self.perm
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"is_certified",
"certified_by",
@@ -473,11 +462,11 @@ sqlatable_user = Table(
def _process_sql_expression(
- expression: Optional[str],
+ expression: str | None,
database_id: int,
schema: str,
- template_processor: Optional[BaseTemplateProcessor] = None,
-) -> Optional[str]:
+ template_processor: BaseTemplateProcessor | None = None,
+) -> str | None:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
@@ -501,12 +490,12 @@ class SqlaTable(
type = "table"
query_language = "sql"
is_rls_supported = True
- columns: Mapped[List[TableColumn]] = relationship(
+ columns: Mapped[list[TableColumn]] = relationship(
TableColumn,
back_populates="table",
cascade="all, delete-orphan",
)
- metrics: Mapped[List[SqlMetric]] = relationship(
+ metrics: Mapped[list[SqlMetric]] = relationship(
SqlMetric,
back_populates="table",
cascade="all, delete-orphan",
@@ -577,11 +566,11 @@ class SqlaTable(
return self.name
@property
- def db_extra(self) -> Dict[str, Any]:
+ def db_extra(self) -> dict[str, Any]:
return self.database.get_extra()
@staticmethod
- def _apply_cte(sql: str, cte: Optional[str]) -> str:
+ def _apply_cte(sql: str, cte: str | None) -> str:
"""
Append a CTE before the SELECT statement if defined
@@ -594,7 +583,7 @@ class SqlaTable(
return sql
@property
- def db_engine_spec(self) -> Type[BaseEngineSpec]:
+ def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
return self.database.db_engine_spec
@property
@@ -637,9 +626,9 @@ class SqlaTable(
cls,
session: Session,
datasource_name: str,
- schema: Optional[str],
+ schema: str | None,
database_name: str,
- ) -> Optional[SqlaTable]:
+ ) -> SqlaTable | None:
schema = schema or None
query = (
session.query(cls)
@@ -660,7 +649,7 @@ class SqlaTable(
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
return Markup(anchor)
- def get_schema_perm(self) -> Optional[str]:
+ def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema)
@@ -685,18 +674,18 @@ class SqlaTable(
)
@property
- def dttm_cols(self) -> List[str]:
+ def dttm_cols(self) -> list[str]:
l = [c.column_name for c in self.columns if c.is_dttm]
if self.main_dttm_col and self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
- def num_cols(self) -> List[str]:
+ def num_cols(self) -> list[str]:
return [c.column_name for c in self.columns if c.is_numeric]
@property
- def any_dttm_col(self) -> Optional[str]:
+ def any_dttm_col(self) -> str | None:
cols = self.dttm_cols
return cols[0] if cols else None
@@ -713,7 +702,7 @@ class SqlaTable(
def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)
- def external_metadata(self) -> List[Dict[str, str]]:
+ def external_metadata(self) -> list[dict[str, str]]:
# todo(yongjie): create a physical table column type in a separate PR
if self.sql:
return get_virtual_table_metadata(dataset=self) # type: ignore
@@ -724,14 +713,14 @@ class SqlaTable(
)
@property
- def time_column_grains(self) -> Dict[str, Any]:
+ def time_column_grains(self) -> dict[str, Any]:
return {
"time_columns": self.dttm_cols,
"time_grains": [grain.name for grain in self.database.grains()],
}
@property
- def select_star(self) -> Optional[str]:
+ def select_star(self) -> str | None:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
@@ -739,20 +728,20 @@ class SqlaTable(
)
@property
- def health_check_message(self) -> Optional[str]:
+ def health_check_message(self) -> str | None:
check = config["DATASET_HEALTH_CHECK"]
return check(self) if check else None
@property
- def granularity_sqla(self) -> List[Tuple[Any, Any]]:
+ def granularity_sqla(self) -> list[tuple[Any, Any]]:
return utils.choicify(self.dttm_cols)
@property
- def time_grain_sqla(self) -> List[Tuple[Any, Any]]:
+ def time_grain_sqla(self) -> list[tuple[Any, Any]]:
return [(g.duration, g.name) for g in self.database.grains() or []]
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
data_ = super().data
if self.type == "table":
data_["granularity_sqla"] = self.granularity_sqla
@@ -767,7 +756,7 @@ class SqlaTable(
return data_
@property
- def extra_dict(self) -> Dict[str, Any]:
+ def extra_dict(self) -> dict[str, Any]:
try:
return json.loads(self.extra)
except (TypeError, json.JSONDecodeError):
@@ -775,7 +764,7 @@ class SqlaTable(
def get_fetch_values_predicate(
self,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> TextClause:
fetch_values_predicate = self.fetch_values_predicate
if template_processor:
@@ -792,7 +781,7 @@ class SqlaTable(
)
) from ex
- def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@@ -869,8 +858,8 @@ class SqlaTable(
return tbl
def get_from_clause(
- self, template_processor: Optional[BaseTemplateProcessor] = None
- ) -> Tuple[Union[TableClause, Alias], Optional[str]]:
+ self, template_processor: BaseTemplateProcessor | None = None
+ ) -> tuple[TableClause | Alias, str | None]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
@@ -899,7 +888,7 @@ class SqlaTable(
return from_clause, cte
def get_rendered_sql(
- self, template_processor: Optional[BaseTemplateProcessor] = None
+ self, template_processor: BaseTemplateProcessor | None = None
) -> str:
"""
Render sql with template engine (Jinja).
@@ -928,8 +917,8 @@ class SqlaTable(
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
- columns_by_name: Dict[str, TableColumn],
- template_processor: Optional[BaseTemplateProcessor] = None,
+ columns_by_name: dict[str, TableColumn],
+ template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.
@@ -946,7 +935,7 @@ class SqlaTable(
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
metric_column = metric.get("column") or {}
column_name = cast(str, metric_column.get("column_name"))
- table_column: Optional[TableColumn] = columns_by_name.get(column_name)
+ table_column: TableColumn | None = columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col(
template_processor=template_processor
@@ -971,7 +960,7 @@ class SqlaTable(
self,
col: AdhocColumn,
force_type_check: bool = False,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc column into a sqlalchemy column.
@@ -1021,7 +1010,7 @@ class SqlaTable(
return self.make_sqla_column_compatible(sqla_column, label)
def make_sqla_column_compatible(
- self, sqla_col: ColumnElement, label: Optional[str] = None
+ self, sqla_col: ColumnElement, label: str | None = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
@@ -1038,7 +1027,7 @@ class SqlaTable(
return sqla_col
def make_orderby_compatible(
- self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
+ self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
@@ -1069,7 +1058,7 @@ class SqlaTable(
def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
- ) -> List[TextClause]:
+ ) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
@@ -1078,8 +1067,8 @@ class SqlaTable(
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
- all_filters: List[TextClause] = []
- filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
+ all_filters: list[TextClause] = []
+ filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
@@ -1114,9 +1103,9 @@ class SqlaTable(
def _get_series_orderby(
self,
series_limit_metric: Metric,
- metrics_by_name: Dict[str, SqlMetric],
- columns_by_name: Dict[str, TableColumn],
- template_processor: Optional[BaseTemplateProcessor] = None,
+ metrics_by_name: dict[str, SqlMetric],
+ columns_by_name: dict[str, TableColumn],
+ template_processor: BaseTemplateProcessor | None = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
@@ -1138,8 +1127,8 @@ class SqlaTable(
self,
row: pd.Series,
dimension: str,
- columns_by_name: Dict[str, TableColumn],
- ) -> Union[str, int, float, bool, Text]:
+ columns_by_name: dict[str, TableColumn],
+ ) -> str | int | float | bool | Text:
"""
Convert a prequery result type to its equivalent Python type.
@@ -1159,7 +1148,7 @@ class SqlaTable(
value = value.item()
column_ = columns_by_name[dimension]
- db_extra: Dict[str, Any] = self.database.get_extra()
+ db_extra: dict[str, Any] = self.database.get_extra()
if column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
@@ -1174,9 +1163,9 @@ class SqlaTable(
def _get_top_groups(
self,
df: pd.DataFrame,
- dimensions: List[str],
- groupby_exprs: Dict[str, Any],
- columns_by_name: Dict[str, TableColumn],
+ dimensions: list[str],
+ groupby_exprs: dict[str, Any],
+ columns_by_name: dict[str, TableColumn],
) -> ColumnElement:
groups = []
for _unused, row in df.iterrows():
@@ -1201,7 +1190,7 @@ class SqlaTable(
errors = None
error_message = None
- def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
+ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
"""
Some engines change the case or generate bespoke column names, either by
default or due to lack of support for aliasing. This function ensures that
@@ -1283,7 +1272,7 @@ class SqlaTable(
else self.columns
)
- old_columns_by_name: Dict[str, TableColumn] = {
+ old_columns_by_name: dict[str, TableColumn] = {
col.column_name: col for col in old_columns
}
results = MetadataResult(
@@ -1341,8 +1330,8 @@ class SqlaTable(
session: Session,
database: Database,
datasource_name: str,
- schema: Optional[str] = None,
- ) -> List[SqlaTable]:
+ schema: str | None = None,
+ ) -> list[SqlaTable]:
query = (
session.query(cls)
.filter_by(database_id=database.id)
@@ -1357,9 +1346,9 @@ class SqlaTable(
cls,
session: Session,
database: Database,
- permissions: Set[str],
- schema_perms: Set[str],
- ) -> List[SqlaTable]:
+ permissions: set[str],
+ schema_perms: set[str],
+ ) -> list[SqlaTable]:
# TODO(hughhhh): add unit test
return (
session.query(cls)
@@ -1389,7 +1378,7 @@ class SqlaTable(
)
@classmethod
- def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
+ def get_all_datasources(cls, session: Session) -> list[SqlaTable]:
qry = session.query(cls)
qry = cls.default_query(qry)
return qry.all()
@@ -1409,7 +1398,7 @@ class SqlaTable(
:param query_obj: query object to analyze
:return: True if there are call(s) to an `ExtraCache` method, False otherwise
"""
- templatable_statements: List[str] = []
+ templatable_statements: list[str] = []
if self.sql:
templatable_statements.append(self.sql)
if self.fetch_values_predicate:
@@ -1428,7 +1417,7 @@ class SqlaTable(
return True
return False
- def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]:
+ def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
"""
The cache key of a SqlaTable needs to consider any keys added by the parent
class and any keys added via `ExtraCache`.
@@ -1489,7 +1478,7 @@ class SqlaTable(
@staticmethod
def update_column( # pylint: disable=unused-argument
- mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
+ mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn
) -> None:
"""
:param mapper: Unused.
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 698311dab6..d41c0555d3 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -17,19 +17,9 @@
from __future__ import annotations
import logging
+from collections.abc import Iterable, Iterator
from functools import lru_cache
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Type,
- TYPE_CHECKING,
- TypeVar,
-)
+from typing import Any, Callable, TYPE_CHECKING, TypeVar
from uuid import UUID
from flask_babel import lazy_gettext as _
@@ -58,8 +48,8 @@ if TYPE_CHECKING:
def get_physical_table_metadata(
database: Database,
table_name: str,
- schema_name: Optional[str] = None,
-) -> List[Dict[str, Any]]:
+ schema_name: str | None = None,
+) -> list[dict[str, Any]]:
"""Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec
db_dialect = database.get_dialect()
@@ -103,7 +93,7 @@ def get_physical_table_metadata(
return cols
-def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
+def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
"""Use SQLparser to get virtual dataset metadata"""
if not dataset.sql:
raise SupersetGenericDBErrorException(
@@ -150,7 +140,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
def get_columns_description(
database: Database,
query: str,
-) -> List[ResultSetColumnType]:
+) -> list[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
with database.get_raw_connection() as conn:
@@ -171,7 +161,7 @@ def get_dialect_name(drivername: str) -> str:
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
-def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
+def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]:
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote
@@ -181,9 +171,9 @@ logger = logging.getLogger(__name__)
def find_cached_objects_in_session(
session: Session,
- cls: Type[DeclarativeModel],
- ids: Optional[Iterable[int]] = None,
- uuids: Optional[Iterable[UUID]] = None,
+ cls: type[DeclarativeModel],
+ ids: Iterable[int] | None = None,
+ uuids: Iterable[UUID] | None = None,
) -> Iterator[DeclarativeModel]:
"""Find known ORM instances in cached SQLA session states.
diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py
index 0989a545fd..9116b9636e 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -447,7 +447,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
resp = super().edit(pk)
if isinstance(resp, str):
return resp
- return redirect("/explore/?datasource_type=table&datasource_id={}".format(pk))
+ return redirect(f"/explore/?datasource_type=table&datasource_id={pk}")
@expose("/list/")
@has_access
diff --git a/superset/css_templates/commands/bulk_delete.py b/superset/css_templates/commands/bulk_delete.py
index 93564208c4..57612d9048 100644
--- a/superset/css_templates/commands/bulk_delete.py
+++ b/superset/css_templates/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset.commands.base import BaseCommand
from superset.css_templates.commands.exceptions import (
@@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteCssTemplateCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[CssTemplate]] = None
+ self._models: Optional[list[CssTemplate]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/css_templates/dao.py b/superset/css_templates/dao.py
index 1862fb7aaf..bc1a796269 100644
--- a/superset/css_templates/dao.py
+++ b/superset/css_templates/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -31,7 +31,7 @@ class CssTemplateDAO(BaseDAO):
model_cls = CssTemplate
@staticmethod
- def bulk_delete(models: Optional[List[CssTemplate]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
try:
db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete(
diff --git a/superset/dao/base.py b/superset/dao/base.py
index d3675a0e17..539dbab2d5 100644
--- a/superset/dao/base.py
+++ b/superset/dao/base.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=isinstance-second-argument-not-valid-type
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import Any, Optional, Union
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
@@ -37,7 +37,7 @@ class BaseDAO:
Base DAO, implement base CRUD sqlalchemy operations
"""
- model_cls: Optional[Type[Model]] = None
+ model_cls: Optional[type[Model]] = None
"""
Child classes need to state the Model class so they don't need to implement basic
create, update and delete methods
@@ -75,10 +75,10 @@ class BaseDAO:
@classmethod
def find_by_ids(
cls,
- model_ids: Union[List[str], List[int]],
+ model_ids: Union[list[str], list[int]],
session: Session = None,
skip_base_filter: bool = False,
- ) -> List[Model]:
+ ) -> list[Model]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
@@ -95,7 +95,7 @@ class BaseDAO:
return query.all()
@classmethod
- def find_all(cls) -> List[Model]:
+ def find_all(cls) -> list[Model]:
"""
Get all that fit the `base_filter`
"""
@@ -121,7 +121,7 @@ class BaseDAO:
return query.filter_by(**filter_by).one_or_none()
@classmethod
- def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
+ def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
"""
Generic for creating models
:raises: DAOCreateFailedError
@@ -163,7 +163,7 @@ class BaseDAO:
@classmethod
def update(
- cls, model: Model, properties: Dict[str, Any], commit: bool = True
+ cls, model: Model, properties: dict[str, Any], commit: bool = True
) -> Model:
"""
Generic update a model
@@ -196,7 +196,7 @@ class BaseDAO:
return model
@classmethod
- def bulk_delete(cls, models: List[Model], commit: bool = True) -> None:
+ def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
try:
for model in models:
cls.delete(model, False)
diff --git a/superset/dashboards/commands/bulk_delete.py b/superset/dashboards/commands/bulk_delete.py
index 13541cd946..385f1fbc6d 100644
--- a/superset/dashboards/commands/bulk_delete.py
+++ b/superset/dashboards/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from flask_babel import lazy_gettext as _
@@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteDashboardCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[Dashboard]] = None
+ self._models: Optional[list[Dashboard]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py
index 0ad8ddee7c..58acc379ba 100644
--- a/superset/dashboards/commands/create.py
+++ b/superset/dashboards/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class CreateDashboardCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -48,9 +48,9 @@ class CreateDashboardCommand(CreateMixin, BaseCommand):
return dashboard
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owner_ids: Optional[List[int]] = self._properties.get("owners")
- role_ids: Optional[List[int]] = self._properties.get("roles")
+ exceptions: list[ValidationError] = []
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
+ role_ids: Optional[list[int]] = self._properties.get("roles")
slug: str = self._properties.get("slug", "")
# Validate slug uniqueness
diff --git a/superset/dashboards/commands/export.py b/superset/dashboards/commands/export.py
index 886b84ffa6..2e70e29bb0 100644
--- a/superset/dashboards/commands/export.py
+++ b/superset/dashboards/commands/export.py
@@ -20,7 +20,8 @@ import json
import logging
import random
import string
-from typing import Any, Dict, Iterator, Optional, Set, Tuple
+from typing import Any, Optional
+from collections.abc import Iterator
import yaml
@@ -52,7 +53,7 @@ def suffix(length: int = 8) -> str:
)
-def get_default_position(title: str) -> Dict[str, Any]:
+def get_default_position(title: str) -> dict[str, Any]:
return {
"DASHBOARD_VERSION_KEY": "v2",
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
@@ -66,7 +67,7 @@ def get_default_position(title: str) -> Dict[str, Any]:
}
-def append_charts(position: Dict[str, Any], charts: Set[Slice]) -> Dict[str, Any]:
+def append_charts(position: dict[str, Any], charts: set[Slice]) -> dict[str, Any]:
chart_hashes = [f"CHART-{suffix()}" for _ in charts]
# if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid
@@ -109,7 +110,7 @@ class ExportDashboardsCommand(ExportModelsCommand):
@staticmethod
def _export(
model: Dashboard, export_related: bool = True
- ) -> Iterator[Tuple[str, str]]:
+ ) -> Iterator[tuple[str, str]]:
file_name = get_filename(model.dashboard_title, model.id)
file_path = f"dashboards/{file_name}.yaml"
diff --git a/superset/dashboards/commands/importers/dispatcher.py b/superset/dashboards/commands/importers/dispatcher.py
index dd0121f3e3..d5323b4fe4 100644
--- a/superset/dashboards/commands/importers/dispatcher.py
+++ b/superset/dashboards/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -43,7 +43,7 @@ class ImportDashboardsCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py
index e49c931896..012dbbc5c9 100644
--- a/superset/dashboards/commands/importers/v0.py
+++ b/superset/dashboards/commands/importers/v0.py
@@ -19,7 +19,7 @@ import logging
import time
from copy import copy
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session
@@ -83,7 +83,7 @@ def import_chart(
def import_dashboard(
# pylint: disable=too-many-locals,too-many-statements
dashboard_to_import: Dashboard,
- dataset_id_mapping: Optional[Dict[int, int]] = None,
+ dataset_id_mapping: Optional[dict[int, int]] = None,
import_time: Optional[int] = None,
) -> int:
"""Imports the dashboard from the object to the database.
@@ -97,7 +97,7 @@ def import_dashboard(
"""
def alter_positions(
- dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
+ dashboard: Dashboard, old_to_new_slc_id_dict: dict[int, int]
) -> None:
"""Updates slice_ids in the position json.
@@ -166,7 +166,7 @@ def import_dashboard(
dashboard_to_import.slug = None
old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}")
- old_to_new_slc_id_dict: Dict[int, int] = {}
+ old_to_new_slc_id_dict: dict[int, int] = {}
new_timed_refresh_immune_slices = []
new_expanded_slices = {}
new_filter_scopes = {}
@@ -268,7 +268,7 @@ def import_dashboard(
return dashboard_to_import.id # type: ignore
-def decode_dashboards(o: Dict[str, Any]) -> Any:
+def decode_dashboards(o: dict[str, Any]) -> Any:
"""
Function to be passed into json.loads obj_hook parameter
Recreates the dashboard object from a json representation.
@@ -302,7 +302,7 @@ def import_dashboards(
data = json.loads(content, object_hook=decode_dashboards)
if not data:
raise DashboardImportException(_("No data in file"))
- dataset_id_mapping: Dict[int, int] = {}
+ dataset_id_mapping: dict[int, int] = {}
for table in data["datasources"]:
new_dataset_id = import_dataset(table, database_id, import_time=import_time)
params = json.loads(table.params)
@@ -324,7 +324,7 @@ class ImportDashboardsCommand(BaseCommand):
# pylint: disable=unused-argument
def __init__(
- self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any
+ self, contents: dict[str, str], database_id: Optional[int] = None, **kwargs: Any
):
self.contents = contents
self.database_id = database_id
diff --git a/superset/dashboards/commands/importers/v1/__init__.py b/superset/dashboards/commands/importers/v1/__init__.py
index 5d83a580bd..597adba6d9 100644
--- a/superset/dashboards/commands/importers/v1/__init__.py
+++ b/superset/dashboards/commands/importers/v1/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -47,7 +47,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
dao = DashboardDAO
model_name = "dashboard"
prefix = "dashboards/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@@ -59,11 +59,11 @@ class ImportDashboardsCommand(ImportModelsCommand):
# pylint: disable=too-many-branches, too-many-locals
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover charts and datasets associated with dashboards
- chart_uuids: Set[str] = set()
- dataset_uuids: Set[str] = set()
+ chart_uuids: set[str] = set()
+ dataset_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
chart_uuids.update(find_chart_uuids(config["position"]))
@@ -77,20 +77,20 @@ class ImportDashboardsCommand(ImportModelsCommand):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
- database_uuids: Set[str] = set()
+ database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
- dataset_info: Dict[str, Dict[str, Any]] = {}
+ dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
@@ -105,7 +105,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
}
# import charts with the correct parent ref
- chart_ids: Dict[str, int] = {}
+ chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if (
file_name.startswith("charts/")
@@ -129,7 +129,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
).fetchall()
# import dashboards
- dashboard_chart_ids: List[Tuple[int, int]] = []
+ dashboard_chart_ids: list[tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)
diff --git a/superset/dashboards/commands/importers/v1/utils.py b/superset/dashboards/commands/importers/v1/utils.py
index 9f0ffc36a1..1deb44949a 100644
--- a/superset/dashboards/commands/importers/v1/utils.py
+++ b/superset/dashboards/commands/importers/v1/utils.py
@@ -17,7 +17,7 @@
import json
import logging
-from typing import Any, Dict, Set
+from typing import Any
from flask import g
from sqlalchemy.orm import Session
@@ -32,12 +32,12 @@ logger = logging.getLogger(__name__)
JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"}
-def find_chart_uuids(position: Dict[str, Any]) -> Set[str]:
+def find_chart_uuids(position: dict[str, Any]) -> set[str]:
return set(build_uuid_to_id_map(position))
-def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
- uuids: Set[str] = set()
+def find_native_filter_datasets(metadata: dict[str, Any]) -> set[str]:
+ uuids: set[str] = set()
for native_filter in metadata.get("native_filter_configuration", []):
targets = native_filter.get("targets", [])
for target in targets:
@@ -47,7 +47,7 @@ def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
return uuids
-def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
+def build_uuid_to_id_map(position: dict[str, Any]) -> dict[str, int]:
return {
child["meta"]["uuid"]: child["meta"]["chartId"]
for child in position.values()
@@ -60,10 +60,10 @@ def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
def update_id_refs( # pylint: disable=too-many-locals
- config: Dict[str, Any],
- chart_ids: Dict[str, int],
- dataset_info: Dict[str, Dict[str, Any]],
-) -> Dict[str, Any]:
+ config: dict[str, Any],
+ chart_ids: dict[str, int],
+ dataset_info: dict[str, dict[str, Any]],
+) -> dict[str, Any]:
"""Update dashboard metadata to use new IDs"""
fixed = config.copy()
@@ -147,7 +147,7 @@ def update_id_refs( # pylint: disable=too-many-locals
def import_dashboard(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Dashboard:
diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py
index 11833a64be..fefa65e3f6 100644
--- a/superset/dashboards/commands/update.py
+++ b/superset/dashboards/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import json
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class UpdateDashboardCommand(UpdateMixin, BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Dashboard] = None
@@ -64,9 +64,9 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
return dashboard
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owners_ids: Optional[List[int]] = self._properties.get("owners")
- roles_ids: Optional[List[int]] = self._properties.get("roles")
+ exceptions: list[ValidationError] = []
+ owners_ids: Optional[list[int]] = self._properties.get("owners")
+ roles_ids: Optional[list[int]] = self._properties.get("roles")
slug: Optional[str] = self._properties.get("slug")
# Validate/populate model exists
diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py
index 5355d602be..d88fb431b7 100644
--- a/superset/dashboards/dao.py
+++ b/superset/dashboards/dao.py
@@ -17,7 +17,7 @@
import json
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Optional, Union
from flask import g
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -68,12 +68,12 @@ class DashboardDAO(BaseDAO):
return dashboard
@staticmethod
- def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]:
+ def get_datasets_for_dashboard(id_or_slug: str) -> list[Any]:
dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug)
return dashboard.datasets_trimmed_for_slices()
@staticmethod
- def get_charts_for_dashboard(id_or_slug: str) -> List[Slice]:
+ def get_charts_for_dashboard(id_or_slug: str) -> list[Slice]:
return DashboardDAO.get_by_id_or_slug(id_or_slug).slices
@staticmethod
@@ -173,7 +173,7 @@ class DashboardDAO(BaseDAO):
return model
@staticmethod
- def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[Dashboard]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
@@ -196,8 +196,8 @@ class DashboardDAO(BaseDAO):
@staticmethod
def set_dash_metadata( # pylint: disable=too-many-locals
dashboard: Dashboard,
- data: Dict[Any, Any],
- old_to_new_slice_ids: Optional[Dict[int, int]] = None,
+ data: dict[Any, Any],
+ old_to_new_slice_ids: Optional[dict[int, int]] = None,
commit: bool = False,
) -> Dashboard:
new_filter_scopes = {}
@@ -235,7 +235,7 @@ class DashboardDAO(BaseDAO):
if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore
- slc_id_dict: Dict[int, int] = {}
+ slc_id_dict: dict[int, int] = {}
if old_to_new_slice_ids:
slc_id_dict = {
old: new
@@ -288,7 +288,7 @@ class DashboardDAO(BaseDAO):
return dashboard
@staticmethod
- def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
+ def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]:
ids = [dash.id for dash in dashboards]
return [
star.obj_id
@@ -303,7 +303,7 @@ class DashboardDAO(BaseDAO):
@classmethod
def copy_dashboard(
- cls, original_dash: Dashboard, data: Dict[str, Any]
+ cls, original_dash: Dashboard, data: dict[str, Any]
) -> Dashboard:
dash = Dashboard()
dash.owners = [g.user] if g.user else []
@@ -311,7 +311,7 @@ class DashboardDAO(BaseDAO):
dash.css = data.get("css")
metadata = json.loads(data["json_metadata"])
- old_to_new_slice_ids: Dict[int, int] = {}
+ old_to_new_slice_ids: dict[int, int] = {}
if data.get("duplicate_slices"):
# Duplicating slices as well, mapping old ids to new ones
for slc in original_dash.slices:
diff --git a/superset/dashboards/filter_sets/commands/create.py b/superset/dashboards/filter_sets/commands/create.py
index de1d70daf7..63c4534786 100644
--- a/superset/dashboards/filter_sets/commands/create.py
+++ b/superset/dashboards/filter_sets/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_appbuilder.models.sqla import Model
@@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class CreateFilterSetCommand(BaseFilterSetCommand):
# pylint: disable=C0103
- def __init__(self, dashboard_id: int, data: Dict[str, Any]):
+ def __init__(self, dashboard_id: int, data: dict[str, Any]):
super().__init__(dashboard_id)
self._properties = data.copy()
diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py
index 07d59f93ae..722672d668 100644
--- a/superset/dashboards/filter_sets/commands/update.py
+++ b/superset/dashboards/filter_sets/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_appbuilder.models.sqla import Model
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class UpdateFilterSetCommand(BaseFilterSetCommand):
- def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]):
+ def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]):
super().__init__(dashboard_id)
self._filter_set_id = filter_set_id
self._properties = data.copy()
diff --git a/superset/dashboards/filter_sets/dao.py b/superset/dashboards/filter_sets/dao.py
index 949aa6d3fd..5f2b0ba418 100644
--- a/superset/dashboards/filter_sets/dao.py
+++ b/superset/dashboards/filter_sets/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_appbuilder.models.sqla import Model
from sqlalchemy.exc import SQLAlchemyError
@@ -40,7 +40,7 @@ class FilterSetDAO(BaseDAO):
model_cls = FilterSet
@classmethod
- def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
+ def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
raise DAOConfigError()
model = FilterSet()
diff --git a/superset/dashboards/filter_sets/schemas.py b/superset/dashboards/filter_sets/schemas.py
index c1a13b424e..2309eea99f 100644
--- a/superset/dashboards/filter_sets/schemas.py
+++ b/superset/dashboards/filter_sets/schemas.py
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, cast, Dict, Mapping
+from collections.abc import Mapping
+from typing import Any, cast
from marshmallow import fields, post_load, Schema, ValidationError
from marshmallow.validate import Length, OneOf
@@ -64,11 +65,11 @@ class FilterSetPostSchema(FilterSetSchema):
@post_load
def validate(
self, data: Mapping[Any, Any], *, many: Any, partial: Any
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data:
raise ValidationError("owner_id is mandatory when owner_type is User")
- return cast(Dict[str, Any], data)
+ return cast(dict[str, Any], data)
class FilterSetPutSchema(FilterSetSchema):
@@ -84,14 +85,14 @@ class FilterSetPutSchema(FilterSetSchema):
@post_load
def validate( # pylint: disable=unused-argument
self, data: Mapping[Any, Any], *, many: Any, partial: Any
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
if JSON_METADATA_FIELD in data:
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
- return cast(Dict[str, Any], data)
+ return cast(dict[str, Any], data)
-def validate_pair(first_field: str, second_field: str, data: Dict[str, Any]) -> None:
+def validate_pair(first_field: str, second_field: str, data: dict[str, Any]) -> None:
if first_field in data and second_field not in data:
raise ValidationError(
- "{} must be included alongside {}".format(first_field, second_field)
+ f"{first_field} must be included alongside {second_field}"
)
diff --git a/superset/dashboards/filter_state/api.py b/superset/dashboards/filter_state/api.py
index 7a771d6b54..a1b855ca9e 100644
--- a/superset/dashboards/filter_state/api.py
+++ b/superset/dashboards/filter_state/api.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Type
from flask import Response
from flask_appbuilder.api import expose, protect, safe
@@ -35,16 +34,16 @@ class DashboardFilterStateRestApi(TemporaryCacheRestApi):
resource_name = "dashboard"
openapi_spec_tag = "Dashboard Filter State"
- def get_create_command(self) -> Type[CreateFilterStateCommand]:
+ def get_create_command(self) -> type[CreateFilterStateCommand]:
return CreateFilterStateCommand
- def get_update_command(self) -> Type[UpdateFilterStateCommand]:
+ def get_update_command(self) -> type[UpdateFilterStateCommand]:
return UpdateFilterStateCommand
- def get_get_command(self) -> Type[GetFilterStateCommand]:
+ def get_get_command(self) -> type[GetFilterStateCommand]:
return GetFilterStateCommand
- def get_delete_command(self) -> Type[DeleteFilterStateCommand]:
+ def get_delete_command(self) -> type[DeleteFilterStateCommand]:
return DeleteFilterStateCommand
@expose("/<int:pk>/filter_state", methods=("POST",))
diff --git a/superset/dashboards/permalink/types.py b/superset/dashboards/permalink/types.py
index 91c5a9620c..4961d2a17b 100644
--- a/superset/dashboards/permalink/types.py
+++ b/superset/dashboards/permalink/types.py
@@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Tuple, TypedDict
+from typing import Any, Optional, TypedDict
class DashboardPermalinkState(TypedDict):
- dataMask: Optional[Dict[str, Any]]
- activeTabs: Optional[List[str]]
+ dataMask: Optional[dict[str, Any]]
+ activeTabs: Optional[list[str]]
anchor: Optional[str]
- urlParams: Optional[List[Tuple[str, str]]]
+ urlParams: Optional[list[tuple[str, str]]]
class DashboardPermalinkValue(TypedDict):
diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py
index ab93e4130f..846ed39e82 100644
--- a/superset/dashboards/schemas.py
+++ b/superset/dashboards/schemas.py
@@ -16,7 +16,7 @@
# under the License.
import json
import re
-from typing import Any, Dict, Union
+from typing import Any, Union
from marshmallow import fields, post_load, pre_load, Schema
from marshmallow.validate import Length, ValidationError
@@ -144,9 +144,9 @@ class DashboardJSONMetadataSchema(Schema):
@pre_load
def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use
self,
- data: Dict[str, Any],
+ data: dict[str, Any],
**kwargs: Any,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
Remove ``show_native_filters`` from the JSON metadata.
@@ -254,7 +254,7 @@ class DashboardDatasetSchema(Schema):
class BaseDashboardSchema(Schema):
# pylint: disable=no-self-use,unused-argument
@post_load
- def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def post_load(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
if data.get("slug"):
data["slug"] = data["slug"].strip()
data["slug"] = data["slug"].replace(" ", "-")
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 77f9596182..c214065a27 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -19,7 +19,7 @@ import json
import logging
from datetime import datetime
from io import BytesIO
-from typing import Any, cast, Dict, List, Optional
+from typing import Any, cast, Optional
from zipfile import is_zipfile, ZipFile
from flask import request, Response, send_file
@@ -1328,13 +1328,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
- preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", [])
+ preferred_databases: list[str] = app.config.get("PREFERRED_DATABASES", [])
available_databases = []
for engine_spec, drivers in get_available_engine_specs().items():
if not drivers:
continue
- payload: Dict[str, Any] = {
+ payload: dict[str, Any] = {
"name": engine_spec.engine_name,
"engine": engine_spec.engine,
"available_drivers": sorted(drivers),
diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py
index 16d27835b3..e3fd667130 100644
--- a/superset/databases/commands/create.py
+++ b/superset/databases/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
@@ -47,7 +47,7 @@ stats_logger = current_app.config["STATS_LOGGER"]
class CreateDatabaseCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -128,7 +128,7 @@ class CreateDatabaseCommand(BaseCommand):
return database
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
database_name: Optional[str] = self._properties.get("database_name")
if not sqlalchemy_uri:
diff --git a/superset/databases/commands/export.py b/superset/databases/commands/export.py
index e1f8fc2a25..889cb86c8f 100644
--- a/superset/databases/commands/export.py
+++ b/superset/databases/commands/export.py
@@ -18,7 +18,8 @@
import json
import logging
-from typing import Any, Dict, Iterator, Tuple
+from typing import Any
+from collections.abc import Iterator
import yaml
@@ -33,7 +34,7 @@ from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__)
-def parse_extra(extra_payload: str) -> Dict[str, Any]:
+def parse_extra(extra_payload: str) -> dict[str, Any]:
try:
extra = json.loads(extra_payload)
except json.decoder.JSONDecodeError:
@@ -57,7 +58,7 @@ class ExportDatabasesCommand(ExportModelsCommand):
@staticmethod
def _export(
model: Database, export_related: bool = True
- ) -> Iterator[Tuple[str, str]]:
+ ) -> Iterator[tuple[str, str]]:
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
file_path = f"databases/{db_file_name}.yaml"
diff --git a/superset/databases/commands/importers/dispatcher.py b/superset/databases/commands/importers/dispatcher.py
index 88d38bf13b..70031b09e4 100644
--- a/superset/databases/commands/importers/dispatcher.py
+++ b/superset/databases/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -38,7 +38,7 @@ class ImportDatabasesCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/databases/commands/importers/v1/__init__.py b/superset/databases/commands/importers/v1/__init__.py
index 239bd0977f..ba119beaaa 100644
--- a/superset/databases/commands/importers/v1/__init__.py
+++ b/superset/databases/commands/importers/v1/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -36,7 +36,7 @@ class ImportDatabasesCommand(ImportModelsCommand):
dao = DatabaseDAO
model_name = "database"
prefix = "databases/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
@@ -44,10 +44,10 @@ class ImportDatabasesCommand(ImportModelsCommand):
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# first import databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=overwrite)
diff --git a/superset/databases/commands/importers/v1/utils.py b/superset/databases/commands/importers/v1/utils.py
index c0c0ee60d9..8881f78a9c 100644
--- a/superset/databases/commands/importers/v1/utils.py
+++ b/superset/databases/commands/importers/v1/utils.py
@@ -16,7 +16,7 @@
# under the License.
import json
-from typing import Any, Dict
+from typing import Any
from sqlalchemy.orm import Session
@@ -28,7 +28,7 @@ from superset.models.core import Database
def import_database(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Database:
diff --git a/superset/databases/commands/tables.py b/superset/databases/commands/tables.py
index 48e9227dea..b7dbb4d461 100644
--- a/superset/databases/commands/tables.py
+++ b/superset/databases/commands/tables.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, cast, Dict
+from typing import Any, cast
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable
@@ -40,7 +40,7 @@ class TablesDatabaseCommand(BaseCommand):
self._schema_name = schema_name
self._force = force
- def run(self) -> Dict[str, Any]:
+ def run(self) -> dict[str, Any]:
self.validate()
try:
tables = security_manager.get_datasources_accessible_by_user(
diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py
index 9809641d5c..2680c5e8c1 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -17,7 +17,7 @@
import logging
import sqlite3
from contextlib import closing
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask import current_app as app
from flask_babel import gettext as _
@@ -64,7 +64,7 @@ def get_log_connection_action(
class TestConnectionDatabaseCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
self._model: Optional[Database] = None
diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py
index 746f7a8152..f12706fa1d 100644
--- a/superset/databases/commands/update.py
+++ b/superset/databases/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
@@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand):
raise DatabaseConnectionFailedError() from ex
# Update database schema permissions
- new_schemas: List[str] = []
+ new_schemas: list[str] = []
for schema in schemas:
old_view_menu_name = security_manager.get_schema_perm(
@@ -164,7 +164,7 @@ class UpdateDatabaseCommand(BaseCommand):
chart.schema_perm = new_view_menu_name
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py
index 2a624e32c7..d97ad33af9 100644
--- a/superset/databases/commands/validate.py
+++ b/superset/databases/commands/validate.py
@@ -16,7 +16,7 @@
# under the License.
import json
from contextlib import closing
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask_babel import gettext as __
@@ -38,7 +38,7 @@ BYPASS_VALIDATION_ENGINES = {"bigquery"}
class ValidateDatabaseParametersCommand(BaseCommand):
- def __init__(self, properties: Dict[str, Any]):
+ def __init__(self, properties: dict[str, Any]):
self._properties = properties.copy()
self._model: Optional[Database] = None
diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py
index 346d684a0d..40d88af745 100644
--- a/superset/databases/commands/validate_sql.py
+++ b/superset/databases/commands/validate_sql.py
@@ -16,7 +16,7 @@
# under the License.
import logging
import re
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Optional
from flask import current_app
from flask_babel import gettext as __
@@ -41,13 +41,13 @@ logger = logging.getLogger(__name__)
class ValidateSQLCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
- self._validator: Optional[Type[BaseSQLValidator]] = None
+ self._validator: Optional[type[BaseSQLValidator]] = None
- def run(self) -> List[Dict[str, Any]]:
+ def run(self) -> list[dict[str, Any]]:
"""
Validates a SQL statement
@@ -97,9 +97,7 @@ class ValidateSQLCommand(BaseCommand):
if not validators_by_engine or spec.engine not in validators_by_engine:
raise NoValidatorConfigFoundError(
SupersetError(
- message=__(
- "no SQL validator is configured for {}".format(spec.engine)
- ),
+ message=__(f"no SQL validator is configured for {spec.engine}"),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
),
diff --git a/superset/databases/dao.py b/superset/databases/dao.py
index c82f0db574..9ce3b5e73e 100644
--- a/superset/databases/dao.py
+++ b/superset/databases/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
@@ -38,7 +38,7 @@ class DatabaseDAO(BaseDAO):
def update(
cls,
model: Database,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> Database:
"""
@@ -93,7 +93,7 @@ class DatabaseDAO(BaseDAO):
)
@classmethod
- def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
+ def get_related_objects(cls, database_id: int) -> dict[str, Any]:
database: Any = cls.find_by_id(database_id)
datasets = database.tables
dataset_ids = [dataset.id for dataset in datasets]
diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 86564e8f15..2ca77b77d1 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Set
+from typing import Any
from flask import g
from flask_babel import lazy_gettext as _
@@ -30,7 +30,7 @@ from superset.views.base import BaseFilter
def can_access_databases(
view_menu_name: str,
-) -> Set[str]:
+) -> set[str]:
return {
security_manager.unpack_database_and_schema(vm).database
for vm in security_manager.user_view_menu_names(view_menu_name)
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 00e8c3ca53..01a00e8b80 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -19,7 +19,7 @@
import inspect
import json
-from typing import Any, Dict, List
+from typing import Any
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -263,8 +263,8 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
@pre_load
def build_sqlalchemy_uri(
- self, data: Dict[str, Any], **kwargs: Any
- ) -> Dict[str, Any]:
+ self, data: dict[str, Any], **kwargs: Any
+ ) -> dict[str, Any]:
"""
Build SQLAlchemy URI from separate parameters.
@@ -325,9 +325,9 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
def rename_encrypted_extra(
self: Schema,
- data: Dict[str, Any],
+ data: dict[str, Any],
**kwargs: Any,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""
Rename ``encrypted_extra`` to ``masked_encrypted_extra``.
@@ -707,8 +707,8 @@ class DatabaseFunctionNamesResponse(Schema):
class ImportV1DatabaseExtraSchema(Schema):
@pre_load
def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name
- self, data: Dict[str, Any], **kwargs: Any
- ) -> Dict[str, Any]:
+ self, data: dict[str, Any], **kwargs: Any
+ ) -> dict[str, Any]:
"""
Fixes for ``schemas_allowed_for_csv_upload``.
"""
@@ -744,8 +744,8 @@ class ImportV1DatabaseExtraSchema(Schema):
class ImportV1DatabaseSchema(Schema):
@pre_load
def fix_allow_csv_upload(
- self, data: Dict[str, Any], **kwargs: Any
- ) -> Dict[str, Any]:
+ self, data: dict[str, Any], **kwargs: Any
+ ) -> dict[str, Any]:
"""
Fix for ``allow_csv_upload`` .
"""
@@ -775,7 +775,7 @@ class ImportV1DatabaseSchema(Schema):
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
@validates_schema
- def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
+ def validate_password(self, data: dict[str, Any], **kwargs: Any) -> None:
"""If sqlalchemy_uri has a masked password, password is required"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
@@ -789,7 +789,7 @@ class ImportV1DatabaseSchema(Schema):
@validates_schema
def validate_ssh_tunnel_credentials(
- self, data: Dict[str, Any], **kwargs: Any
+ self, data: dict[str, Any], **kwargs: Any
) -> None:
"""If ssh_tunnel has a masked credentials, credentials are required"""
uuid = data["uuid"]
@@ -829,7 +829,7 @@ class ImportV1DatabaseSchema(Schema):
# or there're times where it's masked.
# If both are masked, we need to return a list of errors
# so the UI ask for both fields at the same time if needed
- exception_messages: List[str] = []
+ exception_messages: list[str] = []
if private_key is None or private_key == PASSWORD_MASK:
# If we get here we need to ask for the private key
exception_messages.append(
@@ -864,7 +864,7 @@ class EncryptedDict(EncryptedField, fields.Dict):
pass
-def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore
+def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
ret = {}
if isinstance(field, EncryptedField):
if self.openapi_version.major > 2:
diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py
index 45e5af5f44..9c41b83392 100644
--- a/superset/databases/ssh_tunnel/commands/create.py
+++ b/superset/databases/ssh_tunnel/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
- def __init__(self, database_id: int, data: Dict[str, Any]):
+ def __init__(self, database_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
@@ -61,7 +61,7 @@ class CreateSSHTunnelCommand(BaseCommand):
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py
index 42925d1caa..37fd4a94b9 100644
--- a/superset/databases/ssh_tunnel/commands/update.py
+++ b/superset/databases/ssh_tunnel/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class UpdateSSHTunnelCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
diff --git a/superset/databases/ssh_tunnel/dao.py b/superset/databases/ssh_tunnel/dao.py
index 89562fc05d..731f9183b3 100644
--- a/superset/databases/ssh_tunnel/dao.py
+++ b/superset/databases/ssh_tunnel/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from superset.dao.base import BaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
@@ -31,7 +31,7 @@ class SSHTunnelDAO(BaseDAO):
def update(
cls,
model: SSHTunnel,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> SSHTunnel:
"""
diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py
index 3384679cb7..d9462a63db 100644
--- a/superset/databases/ssh_tunnel/models.py
+++ b/superset/databases/ssh_tunnel/models.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
import sqlalchemy as sa
from flask import current_app
@@ -82,7 +82,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
]
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
output = {
"id": self.id,
"server_address": self.server_address,
diff --git a/superset/databases/utils.py b/superset/databases/utils.py
index 9229bb8cba..74943f4747 100644
--- a/superset/databases/utils.py
+++ b/superset/databases/utils.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Optional, Union
from sqlalchemy.engine.url import make_url, URL
@@ -25,7 +25,7 @@ def get_foreign_keys_metadata(
database: Any,
table_name: str,
schema_name: Optional[str],
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
foreign_keys = database.get_foreign_keys(table_name, schema_name)
for fk in foreign_keys:
fk["column_names"] = fk.pop("constrained_columns")
@@ -35,14 +35,14 @@ def get_foreign_keys_metadata(
def get_indexes_metadata(
database: Any, table_name: str, schema_name: Optional[str]
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
indexes = database.get_indexes(table_name, schema_name)
for idx in indexes:
idx["type"] = "index"
return indexes
-def get_col_type(col: Dict[Any, Any]) -> str:
+def get_col_type(col: dict[Any, Any]) -> str:
try:
dtype = f"{col['type']}"
except Exception: # pylint: disable=broad-except
@@ -53,7 +53,7 @@ def get_col_type(col: Dict[Any, Any]) -> str:
def get_table_metadata(
database: Any, table_name: str, schema_name: Optional[str]
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
@@ -73,7 +73,7 @@ def get_table_metadata(
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
indexes = get_indexes_metadata(database, table_name, schema_name)
keys += foreign_keys + indexes
- payload_columns: List[Dict[str, Any]] = []
+ payload_columns: list[dict[str, Any]] = []
table_comment = database.get_table_comment(table_name, schema_name)
for col in columns:
dtype = get_col_type(col)
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 8abeedc095..8083993294 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -17,7 +17,7 @@
""" Superset utilities for pandas.DataFrame.
"""
import logging
-from typing import Any, Dict, List
+from typing import Any
import pandas as pd
@@ -37,7 +37,7 @@ def _convert_big_integers(val: Any) -> Any:
return str(val) if isinstance(val, int) and abs(val) > JS_MAX_INTEGER else val
-def df_to_records(dframe: pd.DataFrame) -> List[Dict[str, Any]]:
+def df_to_records(dframe: pd.DataFrame) -> list[dict[str, Any]]:
"""
Convert a DataFrame to a set of records.
diff --git a/superset/datasets/commands/bulk_delete.py b/superset/datasets/commands/bulk_delete.py
index 643ac784ec..fd13351809 100644
--- a/superset/datasets/commands/bulk_delete.py
+++ b/superset/datasets/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset import security_manager
from superset.commands.base import BaseCommand
@@ -34,9 +34,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteDatasetCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[SqlaTable]] = None
+ self._models: Optional[list[SqlaTable]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py
index 04f54339d0..1c864ad196 100644
--- a/superset/datasets/commands/create.py
+++ b/superset/datasets/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateDatasetCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -55,12 +55,12 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
return dataset
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
database_id = self._properties["database"]
table_name = self._properties["table_name"]
schema = self._properties.get("schema", None)
sql = self._properties.get("sql", None)
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, schema, table_name):
diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py
index 5fc642cbe3..5a4a84fdf9 100644
--- a/superset/datasets/commands/duplicate.py
+++ b/superset/datasets/commands/duplicate.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List
+from typing import Any
from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as __
@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
class DuplicateDatasetCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]) -> None:
+ def __init__(self, data: dict[str, Any]) -> None:
self._base_model: SqlaTable = SqlaTable()
self._properties = data.copy()
@@ -105,7 +105,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
return table
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
base_model_id = self._properties["base_model_id"]
duplicate_name = self._properties["table_name"]
diff --git a/superset/datasets/commands/export.py b/superset/datasets/commands/export.py
index c6fe43c89d..8c02a23f29 100644
--- a/superset/datasets/commands/export.py
+++ b/superset/datasets/commands/export.py
@@ -18,7 +18,7 @@
import json
import logging
-from typing import Iterator, Tuple
+from collections.abc import Iterator
import yaml
@@ -43,7 +43,7 @@ class ExportDatasetsCommand(ExportModelsCommand):
@staticmethod
def _export(
model: SqlaTable, export_related: bool = True
- ) -> Iterator[Tuple[str, str]]:
+ ) -> Iterator[tuple[str, str]]:
db_file_name = get_filename(
model.database.database_name, model.database.id, skip_id=True
)
diff --git a/superset/datasets/commands/importers/dispatcher.py b/superset/datasets/commands/importers/dispatcher.py
index 74f1129d23..6be8635da2 100644
--- a/superset/datasets/commands/importers/dispatcher.py
+++ b/superset/datasets/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -43,7 +43,7 @@ class ImportDatasetsCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py
index f706ecf38b..c530be3c14 100644
--- a/superset/datasets/commands/importers/v0.py
+++ b/superset/datasets/commands/importers/v0.py
@@ -16,7 +16,7 @@
# under the License.
import json
import logging
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Optional
import yaml
from flask_appbuilder import Model
@@ -213,7 +213,7 @@ def import_simple_obj(
def import_from_dict(
- session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
+ session: Session, data: dict[str, Any], sync: Optional[list[str]] = None
) -> None:
"""Imports databases from dictionary"""
if not sync:
@@ -238,12 +238,12 @@ class ImportDatasetsCommand(BaseCommand):
# pylint: disable=unused-argument
def __init__(
self,
- contents: Dict[str, str],
+ contents: dict[str, str],
*args: Any,
**kwargs: Any,
):
self.contents = contents
- self._configs: Dict[str, Any] = {}
+ self._configs: dict[str, Any] = {}
self.sync = []
if kwargs.get("sync_columns"):
diff --git a/superset/datasets/commands/importers/v1/__init__.py b/superset/datasets/commands/importers/v1/__init__.py
index e73213319d..e753138ab8 100644
--- a/superset/datasets/commands/importers/v1/__init__.py
+++ b/superset/datasets/commands/importers/v1/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Set
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -36,7 +36,7 @@ class ImportDatasetsCommand(ImportModelsCommand):
dao = DatasetDAO
model_name = "dataset"
prefix = "datasets/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
@@ -44,16 +44,16 @@ class ImportDatasetsCommand(ImportModelsCommand):
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover databases associated with datasets
- database_uuids: Set[str] = set()
+ database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
database_uuids.add(config["database_uuid"])
# import related databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py
index 52f46829b5..ae47fc411a 100644
--- a/superset/datasets/commands/importers/v1/utils.py
+++ b/superset/datasets/commands/importers/v1/utils.py
@@ -18,7 +18,7 @@ import gzip
import json
import logging
import re
-from typing import Any, Dict
+from typing import Any
from urllib import request
import pandas as pd
@@ -69,7 +69,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
raise Exception(f"Unknown type: {native_type}")
-def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
+def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]:
return {
column.column_name: get_sqla_type(column.type)
for column in dataset.columns
@@ -101,7 +101,7 @@ def validate_data_uri(data_uri: str) -> None:
def import_dataset(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
ignore_permissions: bool = False,
diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py
index cc9f480a41..be9625709f 100644
--- a/superset/datasets/commands/update.py
+++ b/superset/datasets/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from collections import Counter
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
@@ -52,7 +52,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
def __init__(
self,
model_id: int,
- data: Dict[str, Any],
+ data: dict[str, Any],
override_columns: Optional[bool] = False,
):
self._model_id = model_id
@@ -76,8 +76,8 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
raise DatasetUpdateFailedError()
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ exceptions: list[ValidationError] = []
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
@@ -125,14 +125,14 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
raise DatasetInvalidError(exceptions=exceptions)
def _validate_columns(
- self, columns: List[Dict[str, Any]], exceptions: List[ValidationError]
+ self, columns: list[dict[str, Any]], exceptions: list[ValidationError]
) -> None:
# Validate duplicates on data
if self._get_duplicates(columns, "column_name"):
exceptions.append(DatasetColumnsDuplicateValidationError())
else:
# validate invalid id's
- columns_ids: List[int] = [
+ columns_ids: list[int] = [
column["id"] for column in columns if "id" in column
]
if not DatasetDAO.validate_columns_exist(self._model_id, columns_ids):
@@ -140,7 +140,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
# validate new column names uniqueness
if not self.override_columns:
- columns_names: List[str] = [
+ columns_names: list[str] = [
column["column_name"] for column in columns if "id" not in column
]
if not DatasetDAO.validate_columns_uniqueness(
@@ -149,26 +149,26 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
exceptions.append(DatasetColumnsExistsValidationError())
def _validate_metrics(
- self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError]
+ self, metrics: list[dict[str, Any]], exceptions: list[ValidationError]
) -> None:
if self._get_duplicates(metrics, "metric_name"):
exceptions.append(DatasetMetricsDuplicateValidationError())
else:
# validate invalid id's
- metrics_ids: List[int] = [
+ metrics_ids: list[int] = [
metric["id"] for metric in metrics if "id" in metric
]
if not DatasetDAO.validate_metrics_exist(self._model_id, metrics_ids):
exceptions.append(DatasetMetricsNotFoundValidationError())
# validate new metric names uniqueness
- metric_names: List[str] = [
+ metric_names: list[str] = [
metric["metric_name"] for metric in metrics if "id" not in metric
]
if not DatasetDAO.validate_metrics_uniqueness(self._model_id, metric_names):
exceptions.append(DatasetMetricsExistsValidationError())
@staticmethod
- def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]:
+ def _get_duplicates(data: list[dict[str, Any]], key: str) -> list[str]:
duplicates = [
name
for name, count in Counter([item[key] for item in data]).items()
diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py
index b158fce1fe..f4d46be109 100644
--- a/superset/datasets/dao.py
+++ b/superset/datasets/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -44,7 +44,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return None
@staticmethod
- def get_related_objects(database_id: int) -> Dict[str, Any]:
+ def get_related_objects(database_id: int) -> dict[str, Any]:
charts = (
db.session.query(Slice)
.filter(
@@ -108,7 +108,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return not db.session.query(dataset_query.exists()).scalar()
@staticmethod
- def validate_columns_exist(dataset_id: int, columns_ids: List[int]) -> bool:
+ def validate_columns_exist(dataset_id: int, columns_ids: list[int]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id, TableColumn.id.in_(columns_ids)
@@ -117,7 +117,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return len(columns_ids) == len(dataset_query)
@staticmethod
- def validate_columns_uniqueness(dataset_id: int, columns_names: List[str]) -> bool:
+ def validate_columns_uniqueness(dataset_id: int, columns_names: list[str]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id,
@@ -127,7 +127,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return len(dataset_query) == 0
@staticmethod
- def validate_metrics_exist(dataset_id: int, metrics_ids: List[int]) -> bool:
+ def validate_metrics_exist(dataset_id: int, metrics_ids: list[int]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id, SqlMetric.id.in_(metrics_ids)
@@ -136,7 +136,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return len(metrics_ids) == len(dataset_query)
@staticmethod
- def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bool:
+ def validate_metrics_uniqueness(dataset_id: int, metrics_names: list[str]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id,
@@ -149,7 +149,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update(
cls,
model: SqlaTable,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> Optional[SqlaTable]:
"""
@@ -173,7 +173,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_columns(
cls,
model: SqlaTable,
- property_columns: List[Dict[str, Any]],
+ property_columns: list[dict[str, Any]],
commit: bool = True,
override_columns: bool = False,
) -> None:
@@ -239,7 +239,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_metrics(
cls,
model: SqlaTable,
- property_metrics: List[Dict[str, Any]],
+ property_metrics: list[dict[str, Any]],
commit: bool = True,
) -> None:
"""
@@ -304,14 +304,14 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_column(
cls,
model: TableColumn,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> TableColumn:
return DatasetColumnDAO.update(model, properties, commit=commit)
@classmethod
def create_column(
- cls, properties: Dict[str, Any], commit: bool = True
+ cls, properties: dict[str, Any], commit: bool = True
) -> TableColumn:
"""
Creates a Dataset model on the metadata DB
@@ -346,7 +346,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_metric(
cls,
model: SqlMetric,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> SqlMetric:
return DatasetMetricDAO.update(model, properties, commit=commit)
@@ -354,7 +354,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
@classmethod
def create_metric(
cls,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> SqlMetric:
"""
@@ -363,7 +363,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return DatasetMetricDAO.create(properties, commit=commit)
@staticmethod
- def bulk_delete(models: Optional[List[SqlaTable]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[SqlaTable]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
diff --git a/superset/datasets/models.py b/superset/datasets/models.py
index b433709f2c..50aeea7b51 100644
--- a/superset/datasets/models.py
+++ b/superset/datasets/models.py
@@ -24,7 +24,6 @@ dataset, new models for columns, metrics, and tables were also introduced.
These models are not fully implemented, and shouldn't be used yet.
"""
-from typing import List
import sqlalchemy as sa
from flask_appbuilder import Model
@@ -87,7 +86,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
# The relationship between datasets and columns is 1:n, but we use a
# many-to-many association table to avoid adding two mutually exclusive
# columns(dataset_id and table_id) to Column
- columns: List[Column] = relationship(
+ columns: list[Column] = relationship(
"Column",
secondary=dataset_column_association_table,
cascade="all, delete-orphan",
@@ -97,7 +96,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
owners = relationship(
security_manager.user_model, secondary=dataset_user_association_table
)
- tables: List[Table] = relationship(
+ tables: list[Table] = relationship(
"Table", secondary=dataset_table_association_table, backref="datasets"
)
diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py
index f248fc70ff..9a2af98066 100644
--- a/superset/datasets/schemas.py
+++ b/superset/datasets/schemas.py
@@ -16,7 +16,7 @@
# under the License.
import json
import re
-from typing import Any, Dict
+from typing import Any
from flask_babel import lazy_gettext as _
from marshmallow import fields, pre_load, Schema, ValidationError
@@ -150,7 +150,7 @@ class DatasetRelatedObjectsResponse(Schema):
class ImportV1ColumnSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
- def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
"""
Fix for extra initially being exported as a string.
"""
@@ -176,7 +176,7 @@ class ImportV1ColumnSchema(Schema):
class ImportV1MetricSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
- def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
"""
Fix for extra initially being exported as a string.
"""
@@ -198,7 +198,7 @@ class ImportV1MetricSchema(Schema):
class ImportV1DatasetSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
- def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
"""
Fix for extra initially being exported as a string.
"""
diff --git a/superset/datasource/dao.py b/superset/datasource/dao.py
index 158a32c7fd..4682f070e2 100644
--- a/superset/datasource/dao.py
+++ b/superset/datasource/dao.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Dict, Type, Union
+from typing import Union
from sqlalchemy.orm import Session
@@ -34,7 +34,7 @@ Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
class DatasourceDAO(BaseDAO):
- sources: Dict[Union[DatasourceType, str], Type[Datasource]] = {
+ sources: dict[Union[DatasourceType, str], type[Datasource]] = {
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery,
diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py
index f19dffd4a3..20cdfcc51f 100644
--- a/superset/db_engine_specs/__init__.py
+++ b/superset/db_engine_specs/__init__.py
@@ -33,7 +33,7 @@ import pkgutil
from collections import defaultdict
from importlib import import_module
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Type
+from typing import Any, Optional
import sqlalchemy.databases
import sqlalchemy.dialects
@@ -58,11 +58,11 @@ def is_engine_spec(obj: Any) -> bool:
)
-def load_engine_specs() -> List[Type[BaseEngineSpec]]:
+def load_engine_specs() -> list[type[BaseEngineSpec]]:
"""
Load all engine specs, native and 3rd party.
"""
- engine_specs: List[Type[BaseEngineSpec]] = []
+ engine_specs: list[type[BaseEngineSpec]] = []
# load standard engines
db_engine_spec_dir = str(Path(__file__).parent)
@@ -85,7 +85,7 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
return engine_specs
-def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
+def get_engine_spec(backend: str, driver: Optional[str] = None) -> type[BaseEngineSpec]:
"""
Return the DB engine spec associated with a given SQLAlchemy URL.
@@ -120,11 +120,11 @@ backend_replacements = {
}
-def get_available_engine_specs() -> Dict[Type[BaseEngineSpec], Set[str]]:
+def get_available_engine_specs() -> dict[type[BaseEngineSpec], set[str]]:
"""
Return available engine specs and installed drivers for them.
"""
- drivers: Dict[str, Set[str]] = defaultdict(set)
+ drivers: dict[str, set[str]] = defaultdict(set)
# native SQLAlchemy dialects
for attr in sqlalchemy.databases.__all__:
diff --git a/superset/db_engine_specs/athena.py b/superset/db_engine_specs/athena.py
index 047952402d..ad6bed113d 100644
--- a/superset/db_engine_specs/athena.py
+++ b/superset/db_engine_specs/athena.py
@@ -16,7 +16,8 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
from flask_babel import gettext as __
from sqlalchemy import types
@@ -51,7 +52,7 @@ class AthenaEngineSpec(BaseEngineSpec):
date_add('day', 1, CAST({col} AS TIMESTAMP))))",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
SYNTAX_ERROR_REGEX: (
__(
"Please check your query for syntax errors at or "
@@ -64,7 +65,7 @@ class AthenaEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index a7ff862272..ef922a5e63 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -22,22 +22,8 @@ import json
import logging
import re
from datetime import datetime
-from typing import (
- Any,
- Callable,
- ContextManager,
- Dict,
- List,
- Match,
- NamedTuple,
- Optional,
- Pattern,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
-)
+from re import Match, Pattern
+from typing import Any, Callable, ContextManager, NamedTuple, TYPE_CHECKING, Union
import pandas as pd
import sqlparse
@@ -77,7 +63,7 @@ if TYPE_CHECKING:
from superset.models.core import Database
from superset.models.sql_lab import Query
-ColumnTypeMapping = Tuple[
+ColumnTypeMapping = tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
@@ -90,10 +76,10 @@ class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
function: str
- duration: Optional[str]
+ duration: str | None
-builtin_time_grains: Dict[Optional[str], str] = {
+builtin_time_grains: dict[str | None, str] = {
"PT1S": __("Second"),
"PT5S": __("5 second"),
"PT30S": __("30 second"),
@@ -160,12 +146,12 @@ class MetricType(TypedDict, total=False):
metric_name: str
expression: str
- verbose_name: Optional[str]
- metric_type: Optional[str]
- description: Optional[str]
- d3format: Optional[str]
- warning_text: Optional[str]
- extra: Optional[str]
+ verbose_name: str | None
+ metric_type: str | None
+ description: str | None
+ d3format: str | None
+ warning_text: str | None
+ extra: str | None
class BaseEngineSpec: # pylint: disable=too-many-public-methods
@@ -182,19 +168,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
having to add the same aggregation in SELECT.
"""
- engine_name: Optional[str] = None # for user messages, overridden in child classes
+ engine_name: str | None = None # for user messages, overridden in child classes
# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
# see the ``supports_url`` and ``supports_backend`` methods below.
engine = "base" # str as defined in sqlalchemy.engine.engine
- engine_aliases: Set[str] = set()
- drivers: Dict[str, str] = {}
- default_driver: Optional[str] = None
+ engine_aliases: set[str] = set()
+ drivers: dict[str, str] = {}
+ default_driver: str | None = None
disable_ssh_tunneling = False
- _date_trunc_functions: Dict[str, str] = {}
- _time_grain_expressions: Dict[Optional[str], str] = {}
- _default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
+ _date_trunc_functions: dict[str, str] = {}
+ _time_grain_expressions: dict[str | None, str] = {}
+ _default_column_type_mappings: tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
@@ -312,7 +298,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
),
)
# engine-specific type mappings to check prior to the defaults
- column_type_mappings: Tuple[ColumnTypeMapping, ...] = ()
+ column_type_mappings: tuple[ColumnTypeMapping, ...] = ()
# Does database support join-free timeslot grouping
time_groupby_inline = False
@@ -351,23 +337,23 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
allow_limit_clause = True
# This set will give keywords for select statements
# to consider for the engines with TOP SQL parsing
- select_keywords: Set[str] = {"SELECT"}
+ select_keywords: set[str] = {"SELECT"}
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
- top_keywords: Set[str] = {"TOP"}
+ top_keywords: set[str] = {"TOP"}
# A set of disallowed connection query parameters by driver name
- disallow_uri_query_params: Dict[str, Set[str]] = {}
+ disallow_uri_query_params: dict[str, set[str]] = {}
# A Dict of query parameters that will always be used on every connection
# by driver name
- enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}
+ enforce_uri_query_params: dict[str, dict[str, Any]] = {}
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
- custom_errors: Dict[
- Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
+ custom_errors: dict[
+ Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
] = {}
# Whether the engine supports file uploads
@@ -422,7 +408,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return cls.supports_backend(backend, driver)
@classmethod
- def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
+ def supports_backend(cls, backend: str, driver: str | None = None) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
"""
@@ -439,7 +425,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return driver in cls.drivers
@classmethod
- def get_default_schema(cls, database: Database) -> Optional[str]:
+ def get_default_schema(cls, database: Database) -> str | None:
"""
Return the default schema in a given database.
"""
@@ -450,8 +436,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_schema_from_engine_params( # pylint: disable=unused-argument
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
- ) -> Optional[str]:
+ connect_args: dict[str, Any],
+ ) -> str | None:
"""
Return the schema configured in a SQLALchemy URI and connection argments, if any.
"""
@@ -462,7 +448,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
query: Query,
- ) -> Optional[str]:
+ ) -> str | None:
"""
Return the default schema for a given query.
@@ -501,7 +487,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return cls.get_default_schema(database)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
"""
Each engine can implement and converge its own specific exceptions into
Superset DBAPI exceptions
@@ -541,7 +527,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_allow_cost_estimate( # pylint: disable=unused-argument
cls,
- extra: Dict[str, Any],
+ extra: dict[str, Any],
) -> bool:
return False
@@ -561,8 +547,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_engine(
cls,
database: Database,
- schema: Optional[str] = None,
- source: Optional[utils.QuerySource] = None,
+ schema: str | None = None,
+ source: utils.QuerySource | None = None,
) -> ContextManager[Engine]:
"""
Return an engine context manager.
@@ -578,8 +564,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_timestamp_expr(
cls,
col: ColumnClause,
- pdf: Optional[str],
- time_grain: Optional[str],
+ pdf: str | None,
+ time_grain: str | None,
) -> TimestampExpression:
"""
Construct a TimestampExpression to be used in a SQLAlchemy query.
@@ -616,7 +602,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return TimestampExpression(time_expr, col, type_=col.type)
@classmethod
- def get_time_grains(cls) -> Tuple[TimeGrain, ...]:
+ def get_time_grains(cls) -> tuple[TimeGrain, ...]:
"""
Generate a tuple of supported time grains.
@@ -634,8 +620,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def _sort_time_grains(
- cls, val: Tuple[Optional[str], str], index: int
- ) -> Union[float, int, str]:
+ cls, val: tuple[str | None, str], index: int
+ ) -> float | int | str:
"""
Return an ordered time-based value of a portion of a time grain
for sorting
@@ -695,7 +681,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return plist.get(index, 0)
@classmethod
- def get_time_grain_expressions(cls) -> Dict[Optional[str], str]:
+ def get_time_grain_expressions(cls) -> dict[str | None, str]:
"""
Return a dict of all supported time grains including any potential added grains
but excluding any potentially disabled grains in the config file.
@@ -706,7 +692,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
time_grain_expressions = cls._time_grain_expressions.copy()
grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
- denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"]
+ denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"]
for key in denylist:
time_grain_expressions.pop(key, None)
@@ -723,9 +709,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
)
@classmethod
- def fetch_data(
- cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
"""
:param cursor: Cursor instance
@@ -743,9 +727,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def expand_data(
- cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]]
- ) -> Tuple[
- List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType]
+ cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]]
+ ) -> tuple[
+ list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType]
]:
"""
Some engines support expanding nested fields. See implementation in Presto
@@ -759,7 +743,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return columns, data, []
@classmethod
- def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
+ def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
"""Allow altering default column attributes when first detected/added
For instance special column like `__time` for Druid can be
@@ -789,7 +773,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return cls.epoch_to_dttm().replace("{col}", "({col}/1000)")
@classmethod
- def get_datatype(cls, type_code: Any) -> Optional[str]:
+ def get_datatype(cls, type_code: Any) -> str | None:
"""
Change column type code from cursor description to string representation.
@@ -802,7 +786,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
@deprecated(deprecated_in="3.0")
- def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
@@ -818,8 +802,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
table_name: str,
- schema_name: Optional[str],
- ) -> Dict[str, Any]:
+ schema_name: str | None,
+ ) -> dict[str, Any]:
"""
Returns engine-specific table metadata
@@ -872,7 +856,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
sql_remainder = None
sql = sql.strip(" \t\n;")
sql_statement = sqlparse.format(sql, strip_comments=True)
- query_limit: Optional[int] = sql_parse.extract_top_from_query(
+ query_limit: int | None = sql_parse.extract_top_from_query(
sql_statement, cls.top_keywords
)
if not limit:
@@ -928,7 +912,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return True
@classmethod
- def get_limit_from_sql(cls, sql: str) -> Optional[int]:
+ def get_limit_from_sql(cls, sql: str) -> int | None:
"""
Extract limit from SQL query
@@ -951,7 +935,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return parsed_query.set_or_update_query_limit(limit)
@classmethod
- def get_cte_query(cls, sql: str) -> Optional[str]:
+ def get_cte_query(cls, sql: str) -> str | None:
"""
Convert the input CTE based SQL to the SQL for virtual table conversion
@@ -981,7 +965,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
@@ -1012,8 +996,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def convert_dttm( # pylint: disable=unused-argument
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
"""
Convert a Python `datetime` object to a SQL expression.
@@ -1044,8 +1028,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def extract_errors(
- cls, ex: Exception, context: Optional[Dict[str, Any]] = None
- ) -> List[SupersetError]:
+ cls, ex: Exception, context: dict[str, Any] | None = None
+ ) -> list[SupersetError]:
raw_message = cls._extract_error_message(ex)
context = context or {}
@@ -1076,10 +1060,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def adjust_engine_params( # pylint: disable=unused-argument
cls,
uri: URL,
- connect_args: Dict[str, Any],
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ connect_args: dict[str, Any],
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> tuple[URL, dict[str, Any]]:
"""
Return a new URL and ``connect_args`` for a specific catalog/schema.
@@ -1116,7 +1100,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get all catalogs from database.
@@ -1126,7 +1110,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return []
@classmethod
- def get_schema_names(cls, inspector: Inspector) -> List[str]:
+ def get_schema_names(cls, inspector: Inspector) -> list[str]:
"""
Get all schemas from database
@@ -1140,8 +1124,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the real table names within the specified schema.
@@ -1168,8 +1152,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the view names within the specified schema.
@@ -1197,8 +1181,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database, # pylint: disable=unused-argument
inspector: Inspector,
table_name: str,
- schema: Optional[str],
- ) -> List[Dict[str, Any]]:
+ schema: str | None,
+ ) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
@@ -1213,8 +1197,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_table_comment(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> Optional[str]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> str | None:
"""
Get comment of table from a given schema and table
@@ -1237,8 +1221,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
"""
Get all columns from a given schema and table
@@ -1255,8 +1239,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
inspector: Inspector,
table_name: str,
- schema: Optional[str],
- ) -> List[MetricType]:
+ schema: str | None,
+ ) -> list[MetricType]:
"""
Get all metrics from a given schema and table.
"""
@@ -1273,11 +1257,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument
cls,
table_name: str,
- schema: Optional[str],
+ schema: str | None,
database: Database,
query: Select,
- columns: Optional[List[Dict[str, Any]]] = None,
- ) -> Optional[Select]:
+ columns: list[dict[str, Any]] | None = None,
+ ) -> Select | None:
"""
Add a where clause to a query to reference only the most recent partition
@@ -1293,7 +1277,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
return [column(c["name"]) for c in cols]
@classmethod
@@ -1302,12 +1286,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
table_name: str,
engine: Engine,
- schema: Optional[str] = None,
+ schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: list[dict[str, Any]] | None = None,
) -> str:
"""
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.
@@ -1326,7 +1310,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return: SQL query
"""
# pylint: disable=redefined-outer-name
- fields: Union[str, List[Any]] = "*"
+ fields: str | list[Any] = "*"
cols = cols or []
if (show_cols or latest_partition) and not cols:
cols = database.get_columns(table_name, schema)
@@ -1355,7 +1339,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return sql
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
@@ -1367,8 +1351,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
"""
Format cost estimate.
@@ -1405,8 +1389,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
schema: str,
sql: str,
- source: Optional[utils.QuerySource] = None,
- ) -> List[Dict[str, Any]]:
+ source: utils.QuerySource | None = None,
+ ) -> list[dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
@@ -1433,7 +1417,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: Optional[str]
+ cls, url: URL, impersonate_user: bool, username: str | None
) -> URL:
"""
Return a modified URL with the username set.
@@ -1450,9 +1434,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def update_impersonation_config(
cls,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
uri: str,
- username: Optional[str],
+ username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -1490,7 +1474,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
raise cls.get_dbapi_mapped_exception(ex) from ex
@classmethod
- def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
+ def make_label_compatible(cls, label: str) -> str | quoted_name:
"""
Conditionally mutate and/or quote a sqlalchemy expression label. If
force_column_alias_quotes is set to True, return the label as a
@@ -1515,8 +1499,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_column_types(
cls,
- column_type: Optional[str],
- ) -> Optional[Tuple[TypeEngine, GenericDataType]]:
+ column_type: str | None,
+ ) -> tuple[TypeEngine, GenericDataType] | None:
"""
Return a sqlalchemy native column type and generic data type that corresponds
to the column type defined in the data source (return None to use default type
@@ -1598,7 +1582,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_function_names( # pylint: disable=unused-argument
cls,
database: Database,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -1609,7 +1593,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return []
@staticmethod
- def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
+ def pyodbc_rows_to_tuples(data: list[Any]) -> list[tuple[Any, ...]]:
"""
Convert pyodbc.Row objects from `fetch_data` to tuples.
@@ -1634,7 +1618,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@staticmethod
- def get_extra_params(database: Database) -> Dict[str, Any]:
+ def get_extra_params(database: Database) -> dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
@@ -1642,7 +1626,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
- extra: Dict[str, Any] = {}
+ extra: dict[str, Any] = {}
if database.extra:
try:
extra = json.loads(database.extra)
@@ -1653,7 +1637,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@staticmethod
def update_params_from_encrypted_extra( # pylint: disable=invalid-name
- database: Database, params: Dict[str, Any]
+ database: Database, params: dict[str, Any]
) -> None:
"""
Some databases require some sensitive information which do not conform to
@@ -1691,10 +1675,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_column_spec( # pylint: disable=unused-argument
cls,
- native_type: Optional[str],
- db_extra: Optional[Dict[str, Any]] = None,
+ native_type: str | None,
+ db_extra: dict[str, Any] | None = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
- ) -> Optional[ColumnSpec]:
+ ) -> ColumnSpec | None:
"""
Get generic type related specs regarding a native column type.
@@ -1714,10 +1698,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_sqla_column_type(
cls,
- native_type: Optional[str],
- db_extra: Optional[Dict[str, Any]] = None,
+ native_type: str | None,
+ db_extra: dict[str, Any] | None = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
- ) -> Optional[TypeEngine]:
+ ) -> TypeEngine | None:
"""
Converts native database type to sqlalchemy column type.
@@ -1761,7 +1745,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
cursor: Any,
query: Query,
- ) -> Optional[str]:
+ ) -> str | None:
"""
Select identifiers from the database engine that uniquely identifies the
queries to cancel. The identifier is typically a session id, process id
@@ -1794,11 +1778,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return False
@classmethod
- def parse_sql(cls, sql: str) -> List[str]:
+ def parse_sql(cls, sql: str) -> list[str]:
return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
@classmethod
- def get_impersonation_key(cls, user: Optional[User]) -> Any:
+ def get_impersonation_key(cls, user: User | None) -> Any:
"""
Construct an impersonation key, by default it's the given username.
@@ -1809,7 +1793,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return user.username if user else None
@classmethod
- def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]:
+ def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None:
"""
Mask ``encrypted_extra``.
@@ -1822,9 +1806,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# pylint: disable=unused-argument
@classmethod
- def unmask_encrypted_extra(
- cls, old: Optional[str], new: Optional[str]
- ) -> Optional[str]:
+ def unmask_encrypted_extra(cls, old: str | None, new: str | None) -> str | None:
"""
Remove masks from ``encrypted_extra``.
@@ -1835,7 +1817,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return new
@classmethod
- def get_public_information(cls) -> Dict[str, Any]:
+ def get_public_information(cls) -> dict[str, Any]:
"""
Construct a Dict with properties we want to expose.
@@ -1891,12 +1873,12 @@ class BasicParametersSchema(Schema):
class BasicParametersType(TypedDict, total=False):
- username: Optional[str]
- password: Optional[str]
+ username: str | None
+ password: str | None
host: str
port: int
database: str
- query: Dict[str, Any]
+ query: dict[str, Any]
encryption: bool
@@ -1929,13 +1911,13 @@ class BasicParametersMixin:
# query parameter to enable encryption in the database connection
# for Postgres this would be `{"sslmode": "verify-ca"}`, eg.
- encryption_parameters: Dict[str, str] = {}
+ encryption_parameters: dict[str, str] = {}
@classmethod
def build_sqlalchemy_uri( # pylint: disable=unused-argument
cls,
parameters: BasicParametersType,
- encrypted_extra: Optional[Dict[str, str]] = None,
+ encrypted_extra: dict[str, str] | None = None,
) -> str:
# make a copy so that we don't update the original
query = parameters.get("query", {}).copy()
@@ -1958,7 +1940,7 @@ class BasicParametersMixin:
@classmethod
def get_parameters_from_uri( # pylint: disable=unused-argument
- cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
+ cls, uri: str, encrypted_extra: dict[str, Any] | None = None
) -> BasicParametersType:
url = make_url_safe(uri)
query = {
@@ -1982,14 +1964,14 @@ class BasicParametersMixin:
@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
- ) -> List[SupersetError]:
+ ) -> list[SupersetError]:
"""
Validates any number of parameters, for progressive validation.
If only the hostname is present it will check if the name is resolvable. As more
parameters are present in the request, more validation is done.
"""
- errors: List[SupersetError] = []
+ errors: list[SupersetError] = []
required = {"host", "port", "username", "database"}
parameters = properties.get("parameters", {})
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 1f5068ad04..3b62f4bbb8 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -18,7 +18,8 @@ import json
import re
import urllib
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Tuple, Type, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
import pandas as pd
from apispec import APISpec
@@ -99,8 +100,8 @@ class BigQueryParametersSchema(Schema):
class BigQueryParametersType(TypedDict):
- credentials_info: Dict[str, Any]
- query: Dict[str, Any]
+ credentials_info: dict[str, Any]
+ query: dict[str, Any]
class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods
@@ -173,7 +174,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
"P1Y": "{func}({col}, YEAR)",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_DATABASE_PERMISSIONS_REGEX: (
__(
"Unable to connect. Verify that the following roles are set "
@@ -219,7 +220,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -235,7 +236,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Support type BigQuery Row, introduced here PR #4071
# google.cloud.bigquery.table.Row
@@ -280,7 +281,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
@deprecated(deprecated_in="3.0")
- def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
@@ -305,7 +306,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
inspector: Inspector,
table_name: str,
schema: Optional[str],
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
@@ -321,7 +322,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def extra_table_metadata(
cls, database: "Database", table_name: str, schema_name: Optional[str]
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
indexes = database.get_indexes(table_name, schema_name)
if not indexes:
return {}
@@ -354,7 +355,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
database: "Database",
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
@@ -421,7 +422,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
schema: str,
sql: str,
source: Optional[utils.QuerySource] = None,
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
@@ -448,7 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
cls,
database: "Database",
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get all catalogs.
@@ -462,11 +463,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return sorted(project.project_id for project in projects)
@classmethod
- def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
+ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
with cls.get_engine(cursor) as engine:
client = cls._get_client(engine)
job_config = bigquery.QueryJobConfig(dry_run=True)
@@ -503,15 +504,15 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
return [{k: str(v) for k, v in row.items()} for row in raw_cost]
@classmethod
def build_sqlalchemy_uri(
cls,
parameters: BigQueryParametersType,
- encrypted_extra: Optional[Dict[str, Any]] = None,
+ encrypted_extra: Optional[dict[str, Any]] = None,
) -> str:
query = parameters.get("query", {})
query_params = urllib.parse.urlencode(query)
@@ -533,7 +534,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
def get_parameters_from_uri(
cls,
uri: str,
- encrypted_extra: Optional[Dict[str, Any]] = None,
+ encrypted_extra: Optional[dict[str, Any]] = None,
) -> Any:
value = make_url_safe(uri)
@@ -592,7 +593,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return json.dumps(new_config)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel
from google.auth.exceptions import DefaultCredentialsError
@@ -602,7 +603,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
def validate_parameters(
cls,
properties: BasicPropertiesType, # pylint: disable=unused-argument
- ) -> List[SupersetError]:
+ ) -> list[SupersetError]:
return []
@classmethod
@@ -636,7 +637,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: Optional[list[dict[str, Any]]] = None,
) -> str:
"""
Remove array structures from `SELECT *`.
@@ -699,7 +700,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
)
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
"""
Label columns using their fully qualified name.
diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py
index a62087bc6a..af38c15e0b 100644
--- a/superset/db_engine_specs/clickhouse.py
+++ b/superset/db_engine_specs/clickhouse.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
import re
from datetime import datetime
-from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING
+from typing import Any, cast, TYPE_CHECKING
from flask import current_app
from flask_babel import gettext as __
@@ -124,8 +124,8 @@ class ClickHouseBaseEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -145,7 +145,7 @@ class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
supports_file_upload = False
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
return {NewConnectionError: SupersetDBAPIDatabaseError}
@classmethod
@@ -159,7 +159,7 @@ class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
@classmethod
@cache_manager.cache.memoize()
- def get_function_names(cls, database: Database) -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -256,7 +256,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
engine_name = "ClickHouse Connect (Superset)"
default_driver = "connect"
- _function_names: List[str] = []
+ _function_names: list[str] = []
sqlalchemy_uri_placeholder = (
"clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]"
@@ -265,7 +265,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
encryption_parameters = {"secure": "true"}
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
return {}
@classmethod
@@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
return new_exception(str(exception))
@classmethod
- def get_function_names(cls, database: Database) -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver.exceptions import ClickHouseError
@@ -304,7 +304,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
def build_sqlalchemy_uri(
cls,
parameters: BasicParametersType,
- encrypted_extra: Optional[Dict[str, str]] = None,
+ encrypted_extra: dict[str, str] | None = None,
) -> str:
url_params = parameters.copy()
if url_params.get("encryption"):
@@ -318,7 +318,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
@classmethod
def get_parameters_from_uri(
- cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
+ cls, uri: str, encrypted_extra: dict[str, Any] | None = None
) -> BasicParametersType:
url = make_url_safe(uri)
query = url.query
@@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
- ) -> List[SupersetError]:
+ ) -> list[SupersetError]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver import default_port
diff --git a/superset/db_engine_specs/crate.py b/superset/db_engine_specs/crate.py
index 6eafae829e..d8d91c6796 100644
--- a/superset/db_engine_specs/crate.py
+++ b/superset/db_engine_specs/crate.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from datetime import datetime
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from sqlalchemy import types
@@ -53,8 +53,8 @@ class CrateEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.TIMESTAMP):
diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py
index 5f12f3174d..5df24be65d 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -17,7 +17,7 @@
import json
from datetime import datetime
-from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
+from typing import Any, Optional, TYPE_CHECKING
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
@@ -135,7 +135,7 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
return HiveEngineSpec.convert_dttm(target_type, dttm, db_extra=db_extra)
@@ -160,14 +160,14 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
encryption_parameters = {"ssl": "1"}
@staticmethod
- def get_extra_params(database: "Database") -> Dict[str, Any]:
+ def get_extra_params(database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
Trim whitespace from connect_args to avoid databricks driver errors
"""
- extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
- engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
- connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
+ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+ engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
+ connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)])
connect_args.setdefault("_user_agent_entry", USER_AGENT)
@@ -184,7 +184,7 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
database: "Database",
inspector: Inspector,
schema: Optional[str],
- ) -> Set[str]:
+ ) -> set[str]:
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
@@ -213,8 +213,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
@classmethod
def extract_errors(
- cls, ex: Exception, context: Optional[Dict[str, Any]] = None
- ) -> List[SupersetError]:
+ cls, ex: Exception, context: Optional[dict[str, Any]] = None
+ ) -> list[SupersetError]:
raw_message = cls._extract_error_message(ex)
context = context or {}
@@ -271,8 +271,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
def validate_parameters( # type: ignore
cls,
properties: DatabricksPropertiesType,
- ) -> List[SupersetError]:
- errors: List[SupersetError] = []
+ ) -> list[SupersetError]:
+ errors: list[SupersetError] = []
required = {"access_token", "host", "port", "database", "extra"}
extra = json.loads(properties.get("extra", "{}"))
engine_params = extra.get("engine_params", {})
diff --git a/superset/db_engine_specs/dremio.py b/superset/db_engine_specs/dremio.py
index 7fae3014d6..7b4c0458cd 100644
--- a/superset/db_engine_specs/dremio.py
+++ b/superset/db_engine_specs/dremio.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -46,7 +46,7 @@ class DremioEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 16ac89212a..946544863d 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Optional
from urllib import parse
from sqlalchemy import types
@@ -59,7 +59,7 @@ class DrillEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -74,10 +74,10 @@ class DrillEngineSpec(BaseEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
@@ -87,7 +87,7 @@ class DrillEngineSpec(BaseEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py
index 83829ec22a..43ce310a40 100644
--- a/superset/db_engine_specs/druid.py
+++ b/superset/db_engine_specs/druid.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import json
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
@@ -79,7 +79,7 @@ class DruidEngineSpec(BaseEngineSpec):
orm_col.is_dttm = True
@staticmethod
- def get_extra_params(database: Database) -> Dict[str, Any]:
+ def get_extra_params(database: Database) -> dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.
@@ -104,8 +104,8 @@ class DruidEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -130,15 +130,15 @@ class DruidEngineSpec(BaseEngineSpec):
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
"""
Update the Druid type map.
"""
return super().get_columns(inspector, table_name, schema)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel
from requests import exceptions as requests_exceptions
diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py
index 1248287b84..3bbf9ecc38 100644
--- a/superset/db_engine_specs/duckdb.py
+++ b/superset/db_engine_specs/duckdb.py
@@ -18,7 +18,8 @@ from __future__ import annotations
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy import types
@@ -51,7 +52,7 @@ class DuckDBEngineSpec(BaseEngineSpec):
"P1Y": "DATE_TRUNC('year', {col})",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__('We can\'t seem to resolve the column "%(column_name)s"'),
SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
@@ -65,8 +66,8 @@ class DuckDBEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, (types.String, types.DateTime)):
@@ -75,6 +76,6 @@ class DuckDBEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
- cls, database: Database, inspector: Inspector, schema: Optional[str]
- ) -> Set[str]:
+ cls, database: Database, inspector: Inspector, schema: str | None
+ ) -> set[str]:
return set(inspector.get_table_names(schema))
diff --git a/superset/db_engine_specs/dynamodb.py b/superset/db_engine_specs/dynamodb.py
index c398a9c1df..5f7a9e2b71 100644
--- a/superset/db_engine_specs/dynamodb.py
+++ b/superset/db_engine_specs/dynamodb.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -55,7 +55,7 @@ class DynamoDBEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/elasticsearch.py b/superset/db_engine_specs/elasticsearch.py
index 934aa0bb03..d717c52bf5 100644
--- a/superset/db_engine_specs/elasticsearch.py
+++ b/superset/db_engine_specs/elasticsearch.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
from packaging.version import Version
from sqlalchemy import types
@@ -50,10 +50,10 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
"P1Y": "HISTOGRAM({col}, INTERVAL 1 YEAR)",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-error,import-outside-toplevel
import es.exceptions as es_exceptions
@@ -65,7 +65,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
db_extra = db_extra or {}
@@ -117,7 +117,7 @@ class OpenDistroEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py
index c06fbd826d..6da56e2fee 100644
--- a/superset/db_engine_specs/exasol.py
+++ b/superset/db_engine_specs/exasol.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List, Optional, Tuple
+from typing import Any, Optional
from superset.db_engine_specs.base import BaseEngineSpec
@@ -42,7 +42,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
diff --git a/superset/db_engine_specs/firebird.py b/superset/db_engine_specs/firebird.py
index 306a642dc3..4448074157 100644
--- a/superset/db_engine_specs/firebird.py
+++ b/superset/db_engine_specs/firebird.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -72,7 +72,7 @@ class FirebirdEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/firebolt.py b/superset/db_engine_specs/firebolt.py
index 65cd714352..ace3d6b3b2 100644
--- a/superset/db_engine_specs/firebolt.py
+++ b/superset/db_engine_specs/firebolt.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -43,7 +43,7 @@ class FireboltEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 73a66c464f..abf5bac48f 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -16,7 +16,8 @@
# under the License.
import json
import re
-from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
@@ -56,12 +57,12 @@ class GSheetsParametersSchema(Schema):
class GSheetsParametersType(TypedDict):
service_account_info: str
- catalog: Optional[Dict[str, str]]
+ catalog: Optional[dict[str, str]]
class GSheetsPropertiesType(TypedDict):
parameters: GSheetsParametersType
- catalog: Dict[str, str]
+ catalog: dict[str, str]
class GSheetsEngineSpec(SqliteEngineSpec):
@@ -77,7 +78,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
default_driver = "apsw"
sqlalchemy_uri_placeholder = "gsheets://"
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
SYNTAX_ERROR_REGEX: (
__(
'Please check your query for syntax errors near "%(server_error)s". '
@@ -110,7 +111,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
database: "Database",
table_name: str,
schema_name: Optional[str],
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
with database.get_raw_connection(schema=schema_name) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
@@ -127,7 +128,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
cls,
_: GSheetsParametersType,
encrypted_extra: Optional[ # pylint: disable=unused-argument
- Dict[str, Any]
+ dict[str, Any]
] = None,
) -> str:
return "gsheets://"
@@ -136,7 +137,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
def get_parameters_from_uri(
cls,
uri: str, # pylint: disable=unused-argument
- encrypted_extra: Optional[Dict[str, Any]] = None,
+ encrypted_extra: Optional[dict[str, Any]] = None,
) -> Any:
# Building parameters from encrypted_extra and uri
if encrypted_extra:
@@ -214,8 +215,8 @@ class GSheetsEngineSpec(SqliteEngineSpec):
def validate_parameters(
cls,
properties: GSheetsPropertiesType,
- ) -> List[SupersetError]:
- errors: List[SupersetError] = []
+ ) -> list[SupersetError]:
+ errors: list[SupersetError] = []
# backwards compatible just incase people are send data
# via parameters for validation
diff --git a/superset/db_engine_specs/hana.py b/superset/db_engine_specs/hana.py
index e579550b2e..108838f9d2 100644
--- a/superset/db_engine_specs/hana.py
+++ b/superset/db_engine_specs/hana.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -45,7 +45,7 @@ class HanaEngineSpec(PostgresBaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 6d8986c1c7..7601ebb2cd 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -22,7 +22,7 @@ import re
import tempfile
import time
from datetime import datetime
-from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from urllib import parse
import numpy as np
@@ -150,9 +150,7 @@ class HiveEngineSpec(PrestoEngineSpec):
hive.Cursor.fetch_logs = fetch_logs
@classmethod
- def fetch_data(
- cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
# pylint: disable=import-outside-toplevel
import pyhive
from TCLIService import ttypes
@@ -168,10 +166,10 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def df_to_sql(
cls,
- database: "Database",
+ database: Database,
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
@@ -248,8 +246,8 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -263,10 +261,10 @@ class HiveEngineSpec(PrestoEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ connect_args: dict[str, Any],
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> tuple[URL, dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))
@@ -276,8 +274,8 @@ class HiveEngineSpec(PrestoEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
- ) -> Optional[str]:
+ connect_args: dict[str, Any],
+ ) -> str | None:
"""
Return the configured schema.
"""
@@ -292,10 +290,10 @@ class HiveEngineSpec(PrestoEngineSpec):
return msg
@classmethod
- def progress(cls, log_lines: List[str]) -> int:
+ def progress(cls, log_lines: list[str]) -> int:
total_jobs = 1 # assuming there's at least 1 job
current_job = 1
- stages: Dict[int, float] = {}
+ stages: dict[int, float] = {}
for line in log_lines:
match = cls.jobs_stats_r.match(line)
if match:
@@ -323,7 +321,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return int(progress)
@classmethod
- def get_tracking_url_from_logs(cls, log_lines: List[str]) -> Optional[str]:
+ def get_tracking_url_from_logs(cls, log_lines: list[str]) -> str | None:
lkp = "Tracking URL = "
for line in log_lines:
if lkp in line:
@@ -407,19 +405,19 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
return inspector.get_columns(table_name, schema)
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
table_name: str,
- schema: Optional[str],
- database: "Database",
+ schema: str | None,
+ database: Database,
query: Select,
- columns: Optional[List[Dict[str, Any]]] = None,
- ) -> Optional[Select]:
+ columns: list[dict[str, Any]] | None = None,
+ ) -> Select | None:
try:
col_names, values = cls.latest_partition(
table_name, schema, database, show_first=True
@@ -437,18 +435,18 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod
def latest_sub_partition( # type: ignore
- cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
+ cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
) -> str:
# TODO(bogdan): implement`
pass
@classmethod
- def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
+ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
"""Hive partitions look like ds={partition name}/ds={partition name}"""
if not df.empty:
return [
@@ -461,12 +459,12 @@ class HiveEngineSpec(PrestoEngineSpec):
def _partition_query( # pylint: disable=too-many-arguments
cls,
table_name: str,
- schema: Optional[str],
- indexes: List[Dict[str, Any]],
- database: "Database",
+ schema: str | None,
+ indexes: list[dict[str, Any]],
+ database: Database,
limit: int = 0,
- order_by: Optional[List[Tuple[str, bool]]] = None,
- filters: Optional[Dict[Any, Any]] = None,
+ order_by: list[tuple[str, bool]] | None = None,
+ filters: dict[Any, Any] | None = None,
) -> str:
full_table_name = f"{schema}.{table_name}" if schema else table_name
return f"SHOW PARTITIONS {full_table_name}"
@@ -474,15 +472,15 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
- database: "Database",
+ database: Database,
table_name: str,
engine: Engine,
- schema: Optional[str] = None,
+ schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: list[dict[str, Any]] | None = None,
) -> str:
return super( # pylint: disable=bad-super-call
PrestoEngineSpec, cls
@@ -500,7 +498,7 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: Optional[str]
+ cls, url: URL, impersonate_user: bool, username: str | None
) -> URL:
"""
Return a modified URL with the username set.
@@ -516,9 +514,9 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def update_impersonation_config(
cls,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
uri: str,
- username: Optional[str],
+ username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -549,7 +547,7 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
@cache_manager.cache.memoize()
- def get_function_names(cls, database: "Database") -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -600,10 +598,10 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_view_names(
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the view names within the specified schema.
@@ -635,9 +633,9 @@ class HiveEngineSpec(PrestoEngineSpec):
# TODO: contribute back to pyhive.
def fetch_logs( # pylint: disable=protected-access
- self: "Cursor",
+ self: Cursor,
_max_rows: int = 1024,
- orientation: Optional["TFetchOrientation"] = None,
+ orientation: TFetchOrientation | None = None,
) -> str:
"""Mocked. Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after
diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py
index e59c2b74fb..cd1c9e4732 100644
--- a/superset/db_engine_specs/impala.py
+++ b/superset/db_engine_specs/impala.py
@@ -18,7 +18,7 @@ import logging
import re
import time
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import current_app
from sqlalchemy import types
@@ -57,7 +57,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -68,7 +68,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
return None
@classmethod
- def get_schema_names(cls, inspector: Inspector) -> List[str]:
+ def get_schema_names(cls, inspector: Inspector) -> list[str]:
schemas = [
row[0]
for row in inspector.engine.execute("SHOW SCHEMAS")
diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py
index 9fddb23d26..17147d5cc0 100644
--- a/superset/db_engine_specs/kusto.py
+++ b/superset/db_engine_specs/kusto.py
@@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Optional
from sqlalchemy import types
from sqlalchemy.dialects.mssql.base import SMALLDATETIME
@@ -61,7 +61,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
" DATEDIFF(week, 0, DATEADD(day, -1, {col})), 0)",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
column_type_mappings = (
(
@@ -72,7 +72,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel,import-error
import sqlalchemy_kusto.errors as kusto_exceptions
@@ -84,7 +84,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -128,10 +128,10 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"P1Y": "datetime_diff('year',CreateDate, datetime(0001-01-01 00:00:00))+1",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel,import-error
import sqlalchemy_kusto.errors as kusto_exceptions
@@ -143,7 +143,7 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -168,7 +168,7 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
return not parsed_query.sql.startswith(".")
@classmethod
- def parse_sql(cls, sql: str) -> List[str]:
+ def parse_sql(cls, sql: str) -> list[str]:
"""
Kusto supports a single query statement, but it could include sub queries
and variables declared via let keyword.
diff --git a/superset/db_engine_specs/kylin.py b/superset/db_engine_specs/kylin.py
index e340daea51..f522602a48 100644
--- a/superset/db_engine_specs/kylin.py
+++ b/superset/db_engine_specs/kylin.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -42,7 +42,7 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index 8b38ec7421..3e0879b904 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -17,7 +17,8 @@
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
from flask_babel import gettext as __
from sqlalchemy import types
@@ -80,7 +81,7 @@ class MssqlEngineSpec(BaseEngineSpec):
),
)
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__(
'Either the username "%(username)s", password, '
@@ -115,7 +116,7 @@ class MssqlEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -132,7 +133,7 @@ class MssqlEngineSpec(BaseEngineSpec):
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 6258f6b21a..9f853d577c 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -16,7 +16,8 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
from urllib import parse
from flask_babel import gettext as __
@@ -143,9 +144,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
"INTERVAL 1 DAY)) - 1 DAY))",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__('Either the username "%(username)s" or the password is incorrect.'),
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
@@ -186,7 +187,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -201,10 +202,10 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
uri, new_connect_args = super().adjust_engine_params(
uri,
connect_args,
@@ -221,7 +222,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py
index 4b8a59117e..59fa52a656 100644
--- a/superset/db_engine_specs/ocient.py
+++ b/superset/db_engine_specs/ocient.py
@@ -17,7 +17,8 @@
import re
import threading
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Pattern, Set, Tuple
+from re import Pattern
+from typing import Any, Callable, List, NamedTuple, Optional
from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
@@ -178,15 +179,13 @@ def _polygon_to_geo_json(
# Sanitization function for column values
SanitizeFunc = Callable[[Any], Any]
+
# Represents a pair of a column index and the sanitization function
# to apply to its values.
-PlacedSanitizeFunc = NamedTuple(
- "PlacedSanitizeFunc",
- [
- ("column_index", int),
- ("sanitize_func", SanitizeFunc),
- ],
-)
+class PlacedSanitizeFunc(NamedTuple):
+ column_index: int
+ sanitize_func: SanitizeFunc
+
# This map contains functions used to sanitize values for column types
# that cannot be processed natively by Superset.
@@ -199,7 +198,7 @@ PlacedSanitizeFunc = NamedTuple(
try:
from pyocient import TypeCodes
- _sanitized_ocient_type_codes: Dict[int, SanitizeFunc] = {
+ _sanitized_ocient_type_codes: dict[int, SanitizeFunc] = {
TypeCodes.BINARY: _to_hex,
TypeCodes.ST_POINT: _point_to_geo_json,
TypeCodes.IP: str,
@@ -211,7 +210,7 @@ except ImportError as e:
_sanitized_ocient_type_codes = {}
-def _find_columns_to_sanitize(cursor: Any) -> List[PlacedSanitizeFunc]:
+def _find_columns_to_sanitize(cursor: Any) -> list[PlacedSanitizeFunc]:
"""
Cleans the column value for consumption by Superset.
@@ -238,10 +237,10 @@ class OcientEngineSpec(BaseEngineSpec):
# Store mapping of superset Query id -> Ocient ID
# These are inserted into the cache when executing the query
# They are then removed, either upon cancellation or query completion
- query_id_mapping: Dict[str, str] = dict()
+ query_id_mapping: dict[str, str] = dict()
query_id_mapping_lock = threading.Lock()
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_INVALID_USERNAME_REGEX: (
__('The username "%(username)s" does not exist.'),
SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR,
@@ -309,15 +308,15 @@ class OcientEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
- ) -> Set[str]:
+ ) -> set[str]:
return inspector.get_table_names(schema)
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
try:
- rows: List[Tuple[Any, ...]] = super().fetch_data(cursor, limit)
+ rows: list[tuple[Any, ...]] = super().fetch_data(cursor, limit)
except Exception as exception:
with OcientEngineSpec.query_id_mapping_lock:
del OcientEngineSpec.query_id_mapping[
@@ -329,7 +328,7 @@ class OcientEngineSpec(BaseEngineSpec):
if len(rows) > 0 and type(rows[0]).__name__ == "Row":
# Peek at the schema to determine which column values, if any,
# require sanitization.
- columns_to_sanitize: List[PlacedSanitizeFunc] = _find_columns_to_sanitize(
+ columns_to_sanitize: list[PlacedSanitizeFunc] = _find_columns_to_sanitize(
cursor
)
@@ -341,7 +340,7 @@ class OcientEngineSpec(BaseEngineSpec):
# Use the identity function if the column type doesn't need to be
# sanitized.
- sanitization_functions: List[SanitizeFunc] = [
+ sanitization_functions: list[SanitizeFunc] = [
identity for _ in range(len(cursor.description))
]
for info in columns_to_sanitize:
diff --git a/superset/db_engine_specs/oracle.py b/superset/db_engine_specs/oracle.py
index 4a219919bb..1199b74406 100644
--- a/superset/db_engine_specs/oracle.py
+++ b/superset/db_engine_specs/oracle.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Optional
from sqlalchemy import types
@@ -43,7 +43,7 @@ class OracleEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -68,7 +68,7 @@ class OracleEngineSpec(BaseEngineSpec):
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
"""
:param cursor: Cursor instance
:param limit: Maximum number of rows to be returned by the cursor
diff --git a/superset/db_engine_specs/pinot.py b/superset/db_engine_specs/pinot.py
index cebdd693a4..bfec8b2947 100644
--- a/superset/db_engine_specs/pinot.py
+++ b/superset/db_engine_specs/pinot.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, Optional
+from typing import Optional
from sqlalchemy.sql.expression import ColumnClause
@@ -30,7 +30,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
allows_alias_in_orderby = False
# Pinot does its own conversion below
- _time_grain_expressions: Dict[Optional[str], str] = {
+ _time_grain_expressions: dict[Optional[str], str] = {
"PT1S": "1:SECONDS",
"PT1M": "1:MINUTES",
"PT5M": "5:MINUTES",
@@ -45,7 +45,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"P1Y": "year",
}
- _python_to_java_time_patterns: Dict[str, str] = {
+ _python_to_java_time_patterns: dict[str, str] = {
"%Y": "yyyy",
"%m": "MM",
"%d": "dd",
@@ -54,7 +54,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"%S": "ss",
}
- _use_date_trunc_function: Dict[str, bool] = {
+ _use_date_trunc_function: dict[str, bool] = {
"PT1S": False,
"PT1M": False,
"PT5M": False,
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index e809187af6..2088782f83 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -18,7 +18,8 @@ import json
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
@@ -73,7 +74,7 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P<syntax_error>.*?)"')
-def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]:
+def parse_options(connect_args: dict[str, Any]) -> dict[str, str]:
"""
Parse ``options`` from ``connect_args`` into a dictionary.
"""
@@ -109,7 +110,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
"P1Y": "DATE_TRUNC('year', {col})",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_INVALID_USERNAME_REGEX: (
__('The username "%(username)s" does not exist.'),
SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR,
@@ -169,7 +170,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
if not cursor.description:
return []
return super().fetch_data(cursor, limit)
@@ -221,7 +222,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
@@ -253,10 +254,10 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
if not schema:
return uri, connect_args
@@ -269,11 +270,11 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
return uri, connect_args
@classmethod
- def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
+ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
sql = f"EXPLAIN {statement}"
cursor.execute(sql)
@@ -289,8 +290,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
return [{k: str(v) for k, v in row.items()} for row in raw_cost]
@classmethod
@@ -298,7 +299,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
cls,
database: "Database",
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Return all catalogs.
@@ -317,7 +318,7 @@ WHERE datistemplate = false;
@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
- ) -> Set[str]:
+ ) -> set[str]:
"""Need to consider foreign tables for PostgreSQL"""
return set(inspector.get_table_names(schema)) | set(
inspector.get_foreign_table_names(schema)
@@ -325,7 +326,7 @@ WHERE datistemplate = false;
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -337,7 +338,7 @@ WHERE datistemplate = false;
return None
@staticmethod
- def get_extra_params(database: "Database") -> Dict[str, Any]:
+ def get_extra_params(database: "Database") -> dict[str, Any]:
"""
For Postgres, the path to a SSL certificate is placed in `connect_args`.
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 82b05e53e3..d5a2ab7605 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -23,19 +23,9 @@ import time
from abc import ABCMeta
from collections import defaultdict, deque
from datetime import datetime
+from re import Pattern
from textwrap import dedent
-from typing import (
- Any,
- cast,
- Dict,
- List,
- Optional,
- Pattern,
- Set,
- Tuple,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, cast, Optional, TYPE_CHECKING
from urllib import parse
import pandas as pd
@@ -78,7 +68,7 @@ if TYPE_CHECKING:
# need try/catch because pyhive may not be installed
try:
- from pyhive.presto import Cursor # pylint: disable=unused-import
+ from pyhive.presto import Cursor
except ImportError:
pass
@@ -107,7 +97,7 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
logger = logging.getLogger(__name__)
-def get_children(column: ResultSetColumnType) -> List[ResultSetColumnType]:
+def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
"""
Get the children of a complex Presto type (row or array).
@@ -276,8 +266,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
"""
Convert a Python `datetime` object to a SQL expression.
:param target_type: The target type of expression
@@ -304,10 +294,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ connect_args: dict[str, Any],
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> tuple[URL, dict[str, Any]]:
database = uri.database
if schema and database:
schema = parse.quote(schema, safe="")
@@ -323,8 +313,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
- ) -> Optional[str]:
+ connect_args: dict[str, Any],
+ ) -> str | None:
"""
Return the configured schema.
@@ -341,7 +331,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return parse.unquote(database.split("/")[1])
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param statement: A single SQL statement
@@ -369,8 +359,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
"""
Format cost estimate.
:param raw_cost: JSON estimate from Trino
@@ -401,7 +391,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
("networkCost", "Network cost", ""),
]
for row in raw_cost:
- estimate: Dict[str, float] = row.get("estimate", {})
+ estimate: dict[str, float] = row.get("estimate", {})
statement_cost = {}
for key, label, suffix in columns:
if key in estimate:
@@ -412,7 +402,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
@cache_manager.data_cache.memoize()
- def get_function_names(cls, database: Database) -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -426,12 +416,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument
cls,
table_name: str,
- schema: Optional[str],
- indexes: List[Dict[str, Any]],
+ schema: str | None,
+ indexes: list[dict[str, Any]],
database: Database,
limit: int = 0,
- order_by: Optional[List[Tuple[str, bool]]] = None,
- filters: Optional[Dict[Any, Any]] = None,
+ order_by: list[tuple[str, bool]] | None = None,
+ filters: dict[Any, Any] | None = None,
) -> str:
"""
Return a partition query.
@@ -449,7 +439,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
order
:param filters: dict of field name and filter value combinations
"""
- limit_clause = "LIMIT {}".format(limit) if limit else ""
+ limit_clause = f"LIMIT {limit}" if limit else ""
order_by_clause = ""
if order_by:
l = []
@@ -492,11 +482,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
table_name: str,
- schema: Optional[str],
+ schema: str | None,
database: Database,
query: Select,
- columns: Optional[List[Dict[str, Any]]] = None,
- ) -> Optional[Select]:
+ columns: list[dict[str, Any]] | None = None,
+ ) -> Select | None:
try:
col_names, values = cls.latest_partition(
table_name, schema, database, show_first=True
@@ -525,7 +515,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return query
@classmethod
- def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
+ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
if not df.empty:
return df.to_records(index=False)[0].item()
return None
@@ -535,10 +525,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def latest_partition(
cls,
table_name: str,
- schema: Optional[str],
+ schema: str | None,
database: Database,
show_first: bool = False,
- ) -> Tuple[List[str], Optional[List[str]]]:
+ ) -> tuple[list[str], list[str] | None]:
"""Returns col name and the latest (max) partition value for a table
:param table_name: the name of the table
@@ -589,7 +579,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def latest_sub_partition(
- cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
+ cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
) -> Any:
"""Returns the latest (max) partition value for a table
@@ -652,7 +642,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
engine_name = "Presto"
allows_alias_to_source_column = False
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__(
'We can\'t seem to resolve the column "%(column_name)s" at '
@@ -708,16 +698,16 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
}
@classmethod
- def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
+ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
version = extra.get("version")
return version is not None and Version(version) >= Version("0.319")
@classmethod
def update_impersonation_config(
cls,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
uri: str,
- username: Optional[str],
+ username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -741,8 +731,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the real table names within the specified schema.
@@ -769,8 +759,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the view names within the specified schema.
@@ -817,7 +807,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get all catalogs.
"""
@@ -826,7 +816,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
Create column info object
:param name: column name
@@ -836,7 +826,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return {"name": name, "type": f"{data_type}"}
@classmethod
- def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
+ def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
"""
Get the full column name
:param names: list of all individual column names
@@ -860,7 +850,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
)
@classmethod
- def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
+ def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
"""
Split data type based on given delimiter. Do not split the string if the
delimiter is enclosed in quotes
@@ -869,16 +859,14 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
comma, whitespace)
:return: list of strings after breaking it by the delimiter
"""
- return re.split(
- r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type
- )
+ return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type)
@classmethod
def _parse_structural_column( # pylint: disable=too-many-locals
cls,
parent_column_name: str,
parent_data_type: str,
- result: List[Dict[str, Any]],
+ result: list[dict[str, Any]],
) -> None:
"""
Parse a row or array column
@@ -893,7 +881,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# split on open parenthesis ( to get the structural
# data type and its component types
data_types = cls._split_data_type(full_data_type, r"\(")
- stack: List[Tuple[str, str]] = []
+ stack: list[tuple[str, str]] = []
for data_type in data_types:
# split on closed parenthesis ) to track which component
# types belong to what structural data type
@@ -962,8 +950,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _show_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[ResultRow]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[ResultRow]:
"""
Show presto column names
:param inspector: object that performs database schema inspection
@@ -974,13 +962,13 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
- full_table = "{}.{}".format(quote(schema), full_table)
+ full_table = f"{quote(schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
@@ -991,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
(i.e. column name and data type)
"""
columns = cls._show_columns(inspector, table_name, schema)
- result: List[Dict[str, Any]] = []
+ result: list[dict[str, Any]] = []
for column in columns:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
@@ -1031,7 +1019,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"')
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
"""
Format column clauses where names are in quotes and labels are specified
:param cols: columns
@@ -1053,7 +1041,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# quote each column name if it is not already quoted
for index, col_name in enumerate(col_names):
if not cls._is_column_name_quoted(col_name):
- col_names[index] = '"{}"'.format(col_name)
+ col_names[index] = f'"{col_name}"'
quoted_col_name = ".".join(
col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
for col_name in col_names
@@ -1069,12 +1057,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
database: Database,
table_name: str,
engine: Engine,
- schema: Optional[str] = None,
+ schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: list[dict[str, Any]] | None = None,
) -> str:
"""
Include selecting properties of row objects. We cannot easily break arrays into
@@ -1102,9 +1090,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def expand_data( # pylint: disable=too-many-locals
- cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]]
- ) -> Tuple[
- List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType]
+ cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]]
+ ) -> tuple[
+ list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType]
]:
"""
We do not immediately display rows and arrays clearly in the data grid. This
@@ -1133,7 +1121,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# process each column, unnesting ARRAY types and
# expanding ROW types into new columns
to_process = deque((column, 0) for column in columns)
- all_columns: List[ResultSetColumnType] = []
+ all_columns: list[ResultSetColumnType] = []
expanded_columns = []
current_array_level = None
while to_process:
@@ -1147,11 +1135,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# added by the first. every time we change a level in the nested arrays
# we reinitialize this.
if level != current_array_level:
- unnested_rows: Dict[int, int] = defaultdict(int)
+ unnested_rows: dict[int, int] = defaultdict(int)
current_array_level = level
name = column["name"]
- values: Optional[Union[str, List[Any]]]
+ values: str | list[Any] | None
if column["type"] and column["type"].startswith("ARRAY("):
# keep processing array children; we append to the right so that
@@ -1198,7 +1186,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
for row in data:
values = row.get(name) or []
if isinstance(values, str):
- values = cast(Optional[List[Any]], destringify(values))
+ values = cast(Optional[list[Any]], destringify(values))
row[name] = values
for value, col in zip(values or [], expanded):
row[col["name"]] = value
@@ -1211,8 +1199,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def extra_table_metadata(
- cls, database: Database, table_name: str, schema_name: Optional[str]
- ) -> Dict[str, Any]:
+ cls, database: Database, table_name: str, schema_name: str | None
+ ) -> dict[str, Any]:
metadata = {}
if indexes := database.get_indexes(table_name, schema_name):
@@ -1243,8 +1231,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_create_view(
- cls, database: Database, schema: Optional[str], table: str
- ) -> Optional[str]:
+ cls, database: Database, schema: str | None, table: str
+ ) -> str | None:
"""
Return a CREATE VIEW statement, or `None` if not a view.
@@ -1267,7 +1255,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return rows[0][0]
@classmethod
- def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
+ def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
if cursor.last_query_id:
# pylint: disable=protected-access, line-too-long
@@ -1277,7 +1265,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return None
@classmethod
- def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
+ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
"""Updates progress information"""
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py
index 27b749e418..2e746a6349 100644
--- a/superset/db_engine_specs/redshift.py
+++ b/superset/db_engine_specs/redshift.py
@@ -16,7 +16,8 @@
# under the License.
import logging
import re
-from typing import Any, Dict, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
import pandas as pd
from flask_babel import gettext as __
@@ -66,7 +67,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
encryption_parameters = {"sslmode": "verify-ca"}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__('Either the username "%(username)s" or the password is incorrect.'),
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
@@ -106,7 +107,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
database: Database,
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
diff --git a/superset/db_engine_specs/rockset.py b/superset/db_engine_specs/rockset.py
index cc215054be..71adca0b10 100644
--- a/superset/db_engine_specs/rockset.py
+++ b/superset/db_engine_specs/rockset.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, Optional, TYPE_CHECKING
from sqlalchemy import types
@@ -51,7 +51,7 @@ class RocksetEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index 69ccf55931..32ade649b0 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -18,7 +18,8 @@ import json
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from urllib import parse
from apispec import APISpec
@@ -107,7 +108,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
"P1Y": "DATE_TRUNC('YEAR', {col})",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
OBJECT_DOES_NOT_EXIST_REGEX: (
__("%(object)s does not exist in this database."),
SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR,
@@ -124,13 +125,13 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
}
@staticmethod
- def get_extra_params(database: "Database") -> Dict[str, Any]:
+ def get_extra_params(database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
- extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
- engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
- connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
+ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+ engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
+ connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("application", USER_AGENT)
@@ -140,10 +141,10 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
database = uri.database
if "/" in database:
database = database.split("/")[0]
@@ -157,7 +158,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
@@ -174,7 +175,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
database: "Database",
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Return all catalogs.
@@ -197,7 +198,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -261,7 +262,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
parameters: SnowflakeParametersType,
encrypted_extra: Optional[ # pylint: disable=unused-argument
- Dict[str, Any]
+ dict[str, Any]
] = None,
) -> str:
return str(
@@ -283,7 +284,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
uri: str,
encrypted_extra: Optional[ # pylint: disable=unused-argument
- Dict[str, str]
+ dict[str, str]
] = None,
) -> Any:
url = make_url_safe(uri)
@@ -300,8 +301,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
- ) -> List[SupersetError]:
- errors: List[SupersetError] = []
+ ) -> list[SupersetError]:
+ errors: list[SupersetError] = []
required = {
"warehouse",
"username",
@@ -346,7 +347,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@staticmethod
def update_params_from_encrypted_extra(
database: "Database",
- params: Dict[str, Any],
+ params: dict[str, Any],
) -> None:
if not database.encrypted_extra:
return
diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py
index a414143296..767d0a20ad 100644
--- a/superset/db_engine_specs/sqlite.py
+++ b/superset/db_engine_specs/sqlite.py
@@ -16,7 +16,8 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy import types
@@ -60,7 +61,7 @@ class SqliteEngineSpec(BaseEngineSpec):
),
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__('We can\'t seem to resolve the column "%(column_name)s"'),
SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
@@ -74,7 +75,7 @@ class SqliteEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, (types.String, types.DateTime)):
@@ -84,6 +85,6 @@ class SqliteEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
- ) -> Set[str]:
+ ) -> set[str]:
"""Need to disregard the schema for Sqlite"""
return set(inspector.get_table_names())
diff --git a/superset/db_engine_specs/starrocks.py b/superset/db_engine_specs/starrocks.py
index f687fdbdb3..63269439af 100644
--- a/superset/db_engine_specs/starrocks.py
+++ b/superset/db_engine_specs/starrocks.py
@@ -17,7 +17,8 @@
import logging
import re
-from typing import Any, Dict, List, Optional, Pattern, Tuple, Type
+from re import Pattern
+from typing import Any, Optional
from urllib import parse
from flask_babel import gettext as __
@@ -40,11 +41,11 @@ CONNECTION_UNKNOWN_DATABASE_REGEX = re.compile("Unknown database '(?P<database>.
logger = logging.getLogger(__name__)
-class TINYINT(Integer): # pylint: disable=no-init
+class TINYINT(Integer):
__visit_name__ = "TINYINT"
-class DOUBLE(Numeric): # pylint: disable=no-init
+class DOUBLE(Numeric):
__visit_name__ = "DOUBLE"
@@ -52,7 +53,7 @@ class ARRAY(TypeEngine): # pylint: disable=no-init
__visit_name__ = "ARRAY"
@property
- def python_type(self) -> Optional[Type[List[Any]]]:
+ def python_type(self) -> Optional[type[list[Any]]]:
return list
@@ -60,7 +61,7 @@ class MAP(TypeEngine): # pylint: disable=no-init
__visit_name__ = "MAP"
@property
- def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
+ def python_type(self) -> Optional[type[dict[Any, Any]]]:
return dict
@@ -68,7 +69,7 @@ class STRUCT(TypeEngine): # pylint: disable=no-init
__visit_name__ = "STRUCT"
@property
- def python_type(self) -> Optional[Type[Any]]:
+ def python_type(self) -> Optional[type[Any]]:
return None
@@ -117,7 +118,7 @@ class StarRocksEngineSpec(MySQLEngineSpec):
(re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING),
)
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__('Either the username "%(username)s" or the password is incorrect.'),
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
@@ -134,10 +135,10 @@ class StarRocksEngineSpec(MySQLEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
database = uri.database
if schema and database:
schema = parse.quote(schema, safe="")
@@ -152,9 +153,9 @@ class StarRocksEngineSpec(MySQLEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
columns = cls._show_columns(inspector, table_name, schema)
- result: List[Dict[str, Any]] = []
+ result: list[dict[str, Any]] = []
for column in columns:
column_spec = cls.get_column_spec(column.Type)
column_type = column_spec.sqla_type if column_spec else None
@@ -174,7 +175,7 @@ class StarRocksEngineSpec(MySQLEngineSpec):
@classmethod
def _show_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[ResultRow]:
+ ) -> list[ResultRow]:
"""
Show starrocks column names
:param inspector: object that performs database schema inspection
@@ -185,13 +186,13 @@ class StarRocksEngineSpec(MySQLEngineSpec):
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
- full_table = "{}.{}".format(quote(schema), full_table)
+ full_table = f"{quote(schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
Create column info object
:param name: column name
@@ -204,7 +205,7 @@ class StarRocksEngineSpec(MySQLEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 0fa4d05cbc..f05bd67ec3 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, Dict, Optional, Type, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
import simplejson as json
from flask import current_app
@@ -53,8 +53,8 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
... 14475 lines suppressed ...