You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/07/28 06:05:35 UTC

[spark] branch branch-3.2 updated: [SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 2f4f793  [SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too
2f4f793 is described below

commit 2f4f7936fdc06a84abbb264d4f7899b9084e606c
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Wed Jul 28 14:04:24 2021 +0800

    [SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too
    
    ### What changes were proposed in this pull request?
    Unify schema check code of FileFormat and check avro schema filed name when CREATE TABLE DDL too
    
    ### Why are the changes needed?
    Refactor code
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Not need
    
    Closes #33441 from AngersZhuuuu/SPARK-36202.
    
    Authored-by: Angerszhuuuu <an...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 86f44578e5204487930f334aecdd97255681a3fc)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../org/apache/spark/sql/avro/AvroFileFormat.scala | 12 ++++++++
 .../org/apache/spark/sql/avro/AvroSuite.scala      | 30 +++++++++++++++++++
 .../apache/spark/sql/execution/command/ddl.scala   | 35 +++++++++++++++++-----
 .../execution/datasources/DataSourceUtils.scala    | 16 ++++++++++
 .../sql/execution/datasources/FileFormat.scala     |  6 ++++
 .../execution/datasources/orc/OrcFileFormat.scala  | 28 ++++++-----------
 .../datasources/parquet/ParquetFileFormat.scala    |  4 +++
 .../parquet/ParquetSchemaConverter.scala           | 10 -------
 8 files changed, 104 insertions(+), 37 deletions(-)

diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index c2cea41..398cb02 100755
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -153,6 +153,18 @@ private[sql] class AvroFileFormat extends FileFormat
   }
 
   override def supportDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType)
+
+  override def supportFieldName(name: String): Boolean = {
+    if (name.length == 0) {
+      false
+    } else {
+      name.zipWithIndex.forall {
+        case (c, 0) if !Character.isLetter(c) && c != '_' => false
+        case (c, _) if !Character.isLetterOrDigit(c) && c != '_' => false
+        case _ => true
+      }
+    }
+  }
 }
 
 private[avro] object AvroFileFormat {
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index ffad851..f93c61a 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -2158,6 +2158,36 @@ abstract class AvroSuite
       }
     }
   }
+
+  test("SPARK-33865: CREATE TABLE DDL with avro should check col name") {
+    withTable("test_ddl") {
+      withView("v") {
+        spark.range(1).createTempView("v")
+        withTempDir { dir =>
+          val e = intercept[AnalysisException] {
+            sql(
+              s"""
+                 |CREATE TABLE test_ddl USING AVRO
+                 |LOCATION '${dir}'
+                 |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin)
+          }.getMessage
+          assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
+        }
+
+        withTempDir { dir =>
+          spark.sql(
+            s"""
+               |CREATE TABLE test_ddl USING AVRO
+               |LOCATION '${dir}'
+               |AS SELECT ID, IF(ID=1,ID,0) AS A, ABS(ID) AS B
+               |FROM v""".stripMargin)
+          val expectedSchema = StructType(Seq(StructField("ID", LongType, true),
+            StructField("A", LongType, true), StructField("B", LongType, true)))
+          assert(spark.table("test_ddl").schema == expectedSchema)
+        }
+      }
+    }
+  }
 }
 
 class AvroV1Suite extends AvroSuite {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 140f9d7..ea1b656 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs._
 import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
@@ -40,9 +41,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog}
 import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
-import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
-import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
+import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.PartitioningUtils
@@ -860,7 +860,7 @@ case class AlterTableSetLocationCommand(
 }
 
 
-object DDLUtils {
+object DDLUtils extends Logging {
   val HIVE_PROVIDER = "hive"
 
   def isHiveTable(table: CatalogTable): Boolean = {
@@ -933,19 +933,38 @@ object DDLUtils {
         case HIVE_PROVIDER =>
           val serde = table.storage.serde
           if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
-            OrcFileFormat.checkFieldNames(schema)
+            checkDataColNames("orc", schema)
           } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde ||
             serde == Some("parquet.hive.serde.ParquetHiveSerDe") ||
             serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) {
-            ParquetSchemaConverter.checkFieldNames(schema)
+            checkDataColNames("parquet", schema)
+          } else if (serde == HiveSerDe.sourceToSerDe("avro").get.serde) {
+            checkDataColNames("avro", schema)
           }
-        case "parquet" => ParquetSchemaConverter.checkFieldNames(schema)
-        case "orc" => OrcFileFormat.checkFieldNames(schema)
+        case "parquet" => checkDataColNames("parquet", schema)
+        case "orc" => checkDataColNames("orc", schema)
+        case "avro" => checkDataColNames("avro", schema)
         case _ =>
       }
     }
   }
 
+  def checkDataColNames(provider: String, schema: StructType): Unit = {
+    val source = try {
+      DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance()
+    } catch {
+      case e: Throwable =>
+        logError(s"Failed to find data source: $provider when check data column names.", e)
+        return
+    }
+    source match {
+      case f: FileFormat => DataSourceUtils.checkFieldNames(f, schema)
+      case f: FileDataSourceV2 =>
+        DataSourceUtils.checkFieldNames(f.fallbackFileFormat.newInstance(), schema)
+      case _ =>
+    }
+  }
+
   /**
    * Throws exception if outputPath tries to overwrite inputpath.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
index 2b10e4e..b562d44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
@@ -59,6 +59,22 @@ object DataSourceUtils {
   }
 
   /**
+   * Verify if the field name is supported in datasource. This verification should be done
+   * in a driver side.
+   */
+  def checkFieldNames(format: FileFormat, schema: StructType): Unit = {
+    schema.foreach { field =>
+      if (!format.supportFieldName(field.name)) {
+        throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(field.name)
+      }
+      field.dataType match {
+        case s: StructType => checkFieldNames(format, s)
+        case _ =>
+      }
+    }
+  }
+
+  /**
    * Verify if the schema is supported in datasource. This verification should be done
    * in a driver side.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
index 7fd48ca..beb1f4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
@@ -163,6 +163,12 @@ trait FileFormat {
    * By default all data types are supported.
    */
   def supportDataType(dataType: DataType): Boolean = true
+
+  /**
+   * Returns whether this format supports the given filed name in read/write path.
+   * By default all field name is supported.
+   */
+  def supportFieldName(name: String): Boolean = true
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 9024c78..85c0ff0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -36,31 +36,12 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.{SerializableConfiguration, Utils}
 
 private[sql] object OrcFileFormat {
-  private def checkFieldName(name: String): Unit = {
-    try {
-      TypeDescription.fromString(s"struct<`$name`:int>")
-    } catch {
-      case _: IllegalArgumentException =>
-        throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name)
-    }
-  }
-
-  def checkFieldNames(schema: StructType): Unit = {
-    schema.foreach { field =>
-      checkFieldName(field.name)
-      field.dataType match {
-        case s: StructType => checkFieldNames(s)
-        case _ =>
-      }
-    }
-  }
 
   def getQuotedSchemaString(dataType: DataType): String = dataType match {
     case _: AtomicType => dataType.catalogString
@@ -279,4 +260,13 @@ class OrcFileFormat
 
     case _ => false
   }
+
+  override def supportFieldName(name: String): Boolean = {
+    try {
+      TypeDescription.fromString(s"struct<`$name`:int>")
+      true
+    } catch {
+      case _: IllegalArgumentException => false
+    }
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index ee229a3..586952a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -386,6 +386,10 @@ class ParquetFileFormat
 
     case _ => false
   }
+
+  override def supportFieldName(name: String): Boolean = {
+    !name.matches(".*[ ,;{}()\n\t=].*")
+  }
 }
 
 object ParquetFileFormat extends Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index f3bfd99..217c020 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -591,16 +591,6 @@ private[sql] object ParquetSchemaConverter {
     }
   }
 
-  def checkFieldNames(schema: StructType): Unit = {
-    schema.foreach { field =>
-      checkFieldName(field.name)
-      field.dataType match {
-        case s: StructType => checkFieldNames(s)
-        case _ =>
-      }
-    }
-  }
-
   def checkConversionRequirement(f: => Boolean, message: String): Unit = {
     if (!f) {
       throw new AnalysisException(message)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org