You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2022/06/07 12:51:47 UTC

[spark] branch master updated: [SPARK-39359][SQL] Restrict DEFAULT columns to allowlist of supported data source types

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 778f650a388 [SPARK-39359][SQL] Restrict DEFAULT columns to allowlist of supported data source types
778f650a388 is described below

commit 778f650a38850fe64f2608f44026137205b6b15e
Author: Daniel Tenedorio <da...@databricks.com>
AuthorDate: Tue Jun 7 05:51:19 2022 -0700

    [SPARK-39359][SQL] Restrict DEFAULT columns to allowlist of supported data source types
    
    ### What changes were proposed in this pull request?
    
    Restrict DEFAULT columns to allowlist of supported data source types.
    
    Example:
    
    ```
    > create table t(a string) using avro
    > alter table t add column(b int default 42)
    AnalysisException: Failed to execute command because target data source type avro does not support assigning DEFAULT column values
    ```
    
    ### Why are the changes needed?
    
    This change is necessary for correctness because each data source allowing DEFAULT column values must include support for returning these values when missing in the corresponding storage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    This PR adds unit test coverage.
    
    Closes #36745 from dtenedor/restrict-datasources.
    
    Authored-by: Daniel Tenedorio <da...@databricks.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../catalyst/util/ResolveDefaultColumnsUtil.scala  | 10 +++++
 .../spark/sql/errors/QueryCompilationErrors.scala  |  6 +++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 11 ++++++
 .../sql/catalyst/catalog/SessionCatalogSuite.scala |  7 +++-
 .../spark/sql/execution/command/tables.scala       |  8 ++--
 .../execution/datasources/DataSourceStrategy.scala |  8 ++--
 .../datasources/v2/DataSourceV2Strategy.scala      |  4 +-
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 26 +++++++------
 .../org/apache/spark/sql/sources/InsertSuite.scala | 43 ++++++++++++++++++++--
 9 files changed, 96 insertions(+), 27 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
index d44a4782576..319095e541c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
@@ -83,16 +83,26 @@ object ResolveDefaultColumns {
    *
    * @param analyzer      used for analyzing the result of parsing the expression stored as text.
    * @param tableSchema   represents the names and types of the columns of the statement to process.
+   * @param tableProvider provider of the target table to store default values for, if any.
    * @param statementType name of the statement being processed, such as INSERT; useful for errors.
    * @return a copy of `tableSchema` with field metadata updated with the constant-folded values.
    */
   def constantFoldCurrentDefaultsToExistDefaults(
       analyzer: Analyzer,
       tableSchema: StructType,
+      tableProvider: Option[String],
       statementType: String): StructType = {
     if (SQLConf.get.enableDefaultColumns) {
+      val allowedTableProviders: Array[String] =
+        SQLConf.get.getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)
+          .toLowerCase().split(",").map(_.trim)
+      val givenTableProvider: String = tableProvider.getOrElse("").toLowerCase()
       val newFields: Seq[StructField] = tableSchema.fields.map { field =>
         if (field.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) {
+          // Make sure that the target table has a provider that supports default column values.
+          if (!allowedTableProviders.contains(givenTableProvider)) {
+            throw QueryCompilationErrors.defaultReferencesNotAllowedInDataSource(givenTableProvider)
+          }
           val analyzed: Expression = analyze(analyzer, field, statementType)
           val newMetadata: Metadata = new MetadataBuilder().withMetadata(field.metadata)
             .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, analyzed.sql).build()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 27ca00b489d..afd5cb2a073 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2445,4 +2445,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
       s"Invalid DEFAULT value for column $fieldName: $defaultValue fails to parse as a valid " +
         "literal value")
   }
+
+  def defaultReferencesNotAllowedInDataSource(dataSource: String): Throwable = {
+    new AnalysisException(
+      s"Failed to execute command because DEFAULT values are not supported for target data " +
+        "source with table provider: \"" + dataSource + "\"")
+  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6dfc46f6a3f..8c7702efd47 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2881,6 +2881,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val DEFAULT_COLUMN_ALLOWED_PROVIDERS =
+    buildConf("spark.sql.defaultColumn.allowedProviders")
+      .internal()
+      .doc("List of table providers wherein SQL commands are permitted to assign DEFAULT column " +
+        "values. Comma-separated list, whitespace ignored, case-insensitive.")
+      .version("3.4.0")
+      .stringConf
+      .createWithDefault("csv,json,orc,parquet")
+
   val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES =
     buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues")
       .internal()
@@ -4454,6 +4463,8 @@ class SQLConf extends Serializable with Logging {
 
   def enableDefaultColumns: Boolean = getConf(SQLConf.ENABLE_DEFAULT_COLUMNS)
 
+  def defaultColumnAllowedProviders: String = getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)
+
   def useNullsForMissingDefaultColumnValues: Boolean =
     getConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 2d9a17716d4..bf9ce8791b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -122,7 +122,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
   }
 
   test("create table with default columns") {
-    withBasicCatalog { catalog =>
+    def test: Unit = withBasicCatalog { catalog =>
       assert(catalog.externalCatalog.listTables("db1").isEmpty)
       assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
       catalog.createTable(newTable(
@@ -174,12 +174,15 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
       // disabled.
       withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "false") {
         val result: StructType = ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          analyzer, db1tbl3.schema, "CREATE TABLE")
+          analyzer, db1tbl3.schema, db1tbl3.provider, "CREATE TABLE")
         val columnEWithFeatureDisabled: StructField = findField("e", result)
         // No constant-folding has taken place to the EXISTS_DEFAULT metadata.
         assert(!columnEWithFeatureDisabled.metadata.contains("EXISTS_DEFAULT"))
       }
     }
+    withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> "csv,hive,json,orc,parquet") {
+      test
+    }
   }
 
   test("create databases using invalid names") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 536baed9419..07246033265 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -230,7 +230,8 @@ case class AlterTableAddColumnsCommand(
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
     val catalogTable = verifyAlterTableAddColumn(sparkSession.sessionState.conf, catalog, table)
-    val colsWithProcessedDefaults = constantFoldCurrentDefaultsToExistDefaults(sparkSession)
+    val colsWithProcessedDefaults =
+      constantFoldCurrentDefaultsToExistDefaults(sparkSession, catalogTable.provider)
 
     CommandUtils.uncacheTableOrView(sparkSession, table.quotedString)
     catalog.refreshTable(table)
@@ -285,11 +286,12 @@ case class AlterTableAddColumnsCommand(
    * in a separate column metadata entry, then returns the updated column definitions.
    */
   private def constantFoldCurrentDefaultsToExistDefaults(
-      sparkSession: SparkSession): Seq[StructField] = {
+      sparkSession: SparkSession, tableProvider: Option[String]): Seq[StructField] = {
     colsToAdd.map { col: StructField =>
       if (col.metadata.contains(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY)) {
         val foldedStructType = ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          sparkSession.sessionState.analyzer, StructType(Seq(col)), "ALTER TABLE ADD COLUMNS")
+          sparkSession.sessionState.analyzer, StructType(Seq(col)), tableProvider,
+          "ALTER TABLE ADD COLUMNS")
         foldedStructType.fields(0)
       } else {
         col
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 429b7072cae..a82c222ea1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -147,10 +147,10 @@ case class DataSourceAnalysis(analyzer: Analyzer) extends Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
     case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) =>
-      val newTableDesc: CatalogTable =
-        tableDesc.copy(schema =
-          ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-            analyzer, tableDesc.schema, "CREATE TABLE"))
+      val newSchema: StructType =
+        ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
+          analyzer, tableDesc.schema, tableDesc.provider, "CREATE TABLE")
+      val newTableDesc = tableDesc.copy(schema = newSchema)
       CreateDataSourceTableCommand(newTableDesc, ignoreIfExists = mode == SaveMode.Ignore)
 
     case CreateTable(tableDesc, mode, Some(query))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 707cdb2ec3e..401427ac41e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -173,7 +173,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
         tableSpec, ifNotExists) =>
       val newSchema: StructType =
         ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          session.sessionState.analyzer, schema, "CREATE TABLE")
+          session.sessionState.analyzer, schema, tableSpec.provider, "CREATE TABLE")
       CreateTableExec(catalog.asTableCatalog, ident.asIdentifier, newSchema,
         partitioning, qualifyLocInTableSpec(tableSpec), ifNotExists) :: Nil
 
@@ -195,7 +195,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
     case ReplaceTable(ResolvedDBObjectName(catalog, ident), schema, parts, tableSpec, orCreate) =>
       val newSchema: StructType =
         ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          session.sessionState.analyzer, schema, "CREATE TABLE")
+          session.sessionState.analyzer, schema, tableSpec.provider, "CREATE TABLE")
       catalog match {
         case staging: StagingTableCatalog =>
           AtomicReplaceTableExec(staging, ident.asIdentifier, newSchema, parts,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 6302dbb4f0c..45c721100f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -614,21 +614,23 @@ class DataSourceV2SQLSuite
     val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
     assert(table.asInstanceOf[InMemoryTable].rows.nonEmpty)
 
-    spark.sql("REPLACE TABLE testcat.table_name (id bigint NOT NULL DEFAULT 41 + 1) USING foo")
-    val replaced = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+    withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> "foo") {
+      spark.sql("REPLACE TABLE testcat.table_name (id bigint NOT NULL DEFAULT 41 + 1) USING foo")
+      val replaced = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
 
-    assert(replaced.asInstanceOf[InMemoryTable].rows.isEmpty,
+      assert(replaced.asInstanceOf[InMemoryTable].rows.isEmpty,
         "Replaced table should have no rows after committing.")
-    assert(replaced.schema().fields.length === 1,
+      assert(replaced.schema().fields.length === 1,
         "Replaced table should have new schema.")
-    val actual = replaced.schema().fields(0)
-    val expected = StructField("id", LongType, nullable = false,
-      new MetadataBuilder().putString(
-        ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "41 + 1")
-        .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "CAST(42 AS BIGINT)")
-        .build())
-    assert(actual === expected,
-      "Replaced table should have new schema with DEFAULT column metadata.")
+      val actual = replaced.schema().fields(0)
+      val expected = StructField("id", LongType, nullable = false,
+        new MetadataBuilder().putString(
+          ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "41 + 1")
+          .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "CAST(42 AS BIGINT)")
+          .build())
+      assert(actual === expected,
+        "Replaced table should have new schema with DEFAULT column metadata.")
+    }
   }
 
   test("ReplaceTableAsSelect: CREATE OR REPLACE new table has same behavior as CTAS.") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 938d9b87d7c..a1d00361dfc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -1578,10 +1578,30 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
           "y timestamp default timestamp'0000', " +
           "z timestamp_ntz default cast(timestamp'0000' as timestamp_ntz), " +
           "a1 timestamp_ltz default cast(timestamp'0000' as timestamp_ltz), " +
-          "a2 decimal(5, 2) default 123.45)")
-        checkAnswer(sql("select s, t, u, v, w, x is not null, " +
-          "y is not null, z is not null, a1 is not null, a2 is not null from t"),
-          Row(true, null, 42, 0.0f, 0.0d, true, true, true, true, true))
+          "a2 decimal(5, 2) default 123.45," +
+          "a3 bigint default 43," +
+          "a4 smallint default cast(5 as smallint)," +
+          "a5 tinyint default cast(6 as tinyint))")
+        // Manually inspect the result row values rather than using the 'checkAnswer' helper method
+        // in order to ensure the values' correctness while avoiding minor type incompatibilities.
+        val result: Array[Row] =
+          sql("select s, t, u, v, w, x, y, z, a1, a2, a3, a4, a5 from t").collect()
+        assert(result.length == 1)
+        val row: Row = result(0)
+        assert(row.length == 13)
+        assert(row(0) == true)
+        assert(row(1) == null)
+        assert(row(2) == 42)
+        assert(row(3) == 0.0f)
+        assert(row(4) == 0.0d)
+        assert(row(5).toString == "0001-01-01")
+        assert(row(6).toString == "0001-01-01 00:00:00.0")
+        assert(row(7).toString == "0000-01-01T00:00")
+        assert(row(8).toString == "0001-01-01 00:00:00.0")
+        assert(row(9).toString == "123.45")
+        assert(row(10) == 43L)
+        assert(row(11) == 5)
+        assert(row(12) == 6)
       }
     }
 
@@ -1637,6 +1657,21 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
     }
   }
 
+  test("SPARK-39359 Restrict DEFAULT columns to allowlist of supported data source types") {
+    withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> "csv,json,orc") {
+      val unsupported = "DEFAULT values are not supported for target data source"
+      assert(intercept[AnalysisException] {
+        sql(s"create table t(a string default 'abc') using parquet")
+      }.getMessage.contains(unsupported))
+      withTable("t") {
+        sql(s"create table t(a string, b int) using parquet")
+        assert(intercept[AnalysisException] {
+          sql("alter table t add column s bigint default 42")
+        }.getMessage.contains(unsupported))
+      }
+    }
+  }
+
   test("Stop task set if FileAlreadyExistsException was thrown") {
     Seq(true, false).foreach { fastFail =>
       withSQLConf("fs.file.impl" -> classOf[FileExistingTestFileSystem].getName,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org