You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/11 16:55:26 UTC

(superset) 01/01: fix: pass valid SQL to SM

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch jinja2-sm
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 1f8c47af1d76172f145e703bc3cf816946fd0fd7
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Mon Mar 11 12:53:54 2024 -0400

    fix: pass valid SQL to SM
---
 superset/commands/dataset/create.py              |  9 +++-
 tests/unit_tests/commands/dataset/__init__.py    | 16 +++++++
 tests/unit_tests/commands/dataset/create_test.py | 59 ++++++++++++++++++++++++
 3 files changed, 82 insertions(+), 2 deletions(-)

diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py
index 16b87a567a..295f951cce 100644
--- a/superset/commands/dataset/create.py
+++ b/superset/commands/dataset/create.py
@@ -15,12 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 import logging
-from typing import Any, Optional
+from typing import Any, cast, Optional
 
 from flask_appbuilder.models.sqla import Model
 from marshmallow import ValidationError
 from sqlalchemy.exc import SQLAlchemyError
 
+from superset import jinja_context
 from superset.commands.base import BaseCommand, CreateMixin
 from superset.commands.dataset.exceptions import (
     DatabaseNotFoundValidationError,
@@ -34,6 +35,7 @@ from superset.daos.dataset import DatasetDAO
 from superset.daos.exceptions import DAOCreateFailedError
 from superset.exceptions import SupersetSecurityException
 from superset.extensions import db, security_manager
+from superset.models.core import Database
 
 logger = logging.getLogger(__name__)
 
@@ -73,6 +75,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
         database = DatasetDAO.get_database_by_id(database_id)
         if not database:
             exceptions.append(DatabaseNotFoundValidationError())
+        database = cast(Database, database)
         self._properties["database"] = database
 
         # Validate table exists on dataset if sql is not provided
@@ -85,10 +88,12 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
             exceptions.append(TableNotFoundValidationError(table_name))
 
         if sql:
+            processor = jinja_context.get_template_processor(database=database)
+            rendered_sql = processor.process_template(sql)
             try:
                 security_manager.raise_for_access(
                     database=database,
-                    sql=sql,
+                    sql=rendered_sql,
                     schema=schema,
                 )
             except SupersetSecurityException as ex:
diff --git a/tests/unit_tests/commands/dataset/__init__.py b/tests/unit_tests/commands/dataset/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/unit_tests/commands/dataset/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/commands/dataset/create_test.py b/tests/unit_tests/commands/dataset/create_test.py
new file mode 100644
index 0000000000..a77c078a8e
--- /dev/null
+++ b/tests/unit_tests/commands/dataset/create_test.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name
+
+from pytest_mock import MockerFixture
+
+from superset.commands.dataset.create import CreateDatasetCommand
+
+
+def test_jionja_in_sql(mocker: MockerFixture) -> None:
+    """
+    Test that we pass valid SQL to the security manager.
+
+    See discussion in https://github.com/apache/superset/pull/26476. Before, we were
+    passing templated SQL to the security manager as if it was SQL, which could result
+    in it failing to be parsed (since it was not valid SQL).
+
+    This has been fixed so that the template is processed before being passed to the
+    security manager.
+    """
+    DatasetDAO = mocker.patch("superset.commands.dataset.create.DatasetDAO")
+    DatasetDAO.validate_uniqueness.return_value = True
+    database = mocker.MagicMock()
+    DatasetDAO.get_database_by_id.return_value = database
+    jinja_context = mocker.patch("superset.commands.dataset.create.jinja_context")
+    jinja_context.get_template_processor().process_template.return_value = "SELECT '42'"
+    mocker.patch.object(CreateDatasetCommand, "populate_owners")
+    security_manager = mocker.patch("superset.commands.dataset.create.security_manager")
+
+    data = {
+        "database": 1,
+        "table_name": "tmp_table",
+        "schema": "main",
+        "sql": "SELECT '{{ answer }}'",
+        "owners": [1],
+    }
+    command = CreateDatasetCommand(data)
+    command.validate()
+
+    security_manager.raise_for_access.assert_called_with(
+        database=database,
+        sql="SELECT '42'",
+        schema="main",
+    )