You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/07/28 06:05:35 UTC
[spark] branch branch-3.2 updated: [SPARK-33865][SPARK-36202][SQL]
When HiveDDL, we need check avro schema too
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 2f4f793 [SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too
2f4f793 is described below
commit 2f4f7936fdc06a84abbb264d4f7899b9084e606c
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Wed Jul 28 14:04:24 2021 +0800
[SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too
### What changes were proposed in this pull request?
Unify schema check code of FileFormat and check avro schema filed name when CREATE TABLE DDL too
### Why are the changes needed?
Refactor code
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Not need
Closes #33441 from AngersZhuuuu/SPARK-36202.
Authored-by: Angerszhuuuu <an...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 86f44578e5204487930f334aecdd97255681a3fc)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../org/apache/spark/sql/avro/AvroFileFormat.scala | 12 ++++++++
.../org/apache/spark/sql/avro/AvroSuite.scala | 30 +++++++++++++++++++
.../apache/spark/sql/execution/command/ddl.scala | 35 +++++++++++++++++-----
.../execution/datasources/DataSourceUtils.scala | 16 ++++++++++
.../sql/execution/datasources/FileFormat.scala | 6 ++++
.../execution/datasources/orc/OrcFileFormat.scala | 28 ++++++-----------
.../datasources/parquet/ParquetFileFormat.scala | 4 +++
.../parquet/ParquetSchemaConverter.scala | 10 -------
8 files changed, 104 insertions(+), 37 deletions(-)
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index c2cea41..398cb02 100755
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -153,6 +153,18 @@ private[sql] class AvroFileFormat extends FileFormat
}
override def supportDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType)
+
+ override def supportFieldName(name: String): Boolean = {
+ if (name.length == 0) {
+ false
+ } else {
+ name.zipWithIndex.forall {
+ case (c, 0) if !Character.isLetter(c) && c != '_' => false
+ case (c, _) if !Character.isLetterOrDigit(c) && c != '_' => false
+ case _ => true
+ }
+ }
+ }
}
private[avro] object AvroFileFormat {
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index ffad851..f93c61a 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -2158,6 +2158,36 @@ abstract class AvroSuite
}
}
}
+
+ test("SPARK-33865: CREATE TABLE DDL with avro should check col name") {
+ withTable("test_ddl") {
+ withView("v") {
+ spark.range(1).createTempView("v")
+ withTempDir { dir =>
+ val e = intercept[AnalysisException] {
+ sql(
+ s"""
+ |CREATE TABLE test_ddl USING AVRO
+ |LOCATION '${dir}'
+ |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin)
+ }.getMessage
+ assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
+ }
+
+ withTempDir { dir =>
+ spark.sql(
+ s"""
+ |CREATE TABLE test_ddl USING AVRO
+ |LOCATION '${dir}'
+ |AS SELECT ID, IF(ID=1,ID,0) AS A, ABS(ID) AS B
+ |FROM v""".stripMargin)
+ val expectedSchema = StructType(Seq(StructField("ID", LongType, true),
+ StructField("A", LongType, true), StructField("B", LongType, true)))
+ assert(spark.table("test_ddl").schema == expectedSchema)
+ }
+ }
+ }
+ }
}
class AvroV1Suite extends AvroSuite {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 140f9d7..ea1b656 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
+import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
@@ -40,9 +41,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog}
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
-import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
-import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
+import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.PartitioningUtils
@@ -860,7 +860,7 @@ case class AlterTableSetLocationCommand(
}
-object DDLUtils {
+object DDLUtils extends Logging {
val HIVE_PROVIDER = "hive"
def isHiveTable(table: CatalogTable): Boolean = {
@@ -933,19 +933,38 @@ object DDLUtils {
case HIVE_PROVIDER =>
val serde = table.storage.serde
if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
- OrcFileFormat.checkFieldNames(schema)
+ checkDataColNames("orc", schema)
} else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde ||
serde == Some("parquet.hive.serde.ParquetHiveSerDe") ||
serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) {
- ParquetSchemaConverter.checkFieldNames(schema)
+ checkDataColNames("parquet", schema)
+ } else if (serde == HiveSerDe.sourceToSerDe("avro").get.serde) {
+ checkDataColNames("avro", schema)
}
- case "parquet" => ParquetSchemaConverter.checkFieldNames(schema)
- case "orc" => OrcFileFormat.checkFieldNames(schema)
+ case "parquet" => checkDataColNames("parquet", schema)
+ case "orc" => checkDataColNames("orc", schema)
+ case "avro" => checkDataColNames("avro", schema)
case _ =>
}
}
}
+ def checkDataColNames(provider: String, schema: StructType): Unit = {
+ val source = try {
+ DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance()
+ } catch {
+ case e: Throwable =>
+ logError(s"Failed to find data source: $provider when check data column names.", e)
+ return
+ }
+ source match {
+ case f: FileFormat => DataSourceUtils.checkFieldNames(f, schema)
+ case f: FileDataSourceV2 =>
+ DataSourceUtils.checkFieldNames(f.fallbackFileFormat.newInstance(), schema)
+ case _ =>
+ }
+ }
+
/**
* Throws exception if outputPath tries to overwrite inputpath.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
index 2b10e4e..b562d44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
@@ -59,6 +59,22 @@ object DataSourceUtils {
}
/**
+ * Verify if the field name is supported in datasource. This verification should be done
+ * in a driver side.
+ */
+ def checkFieldNames(format: FileFormat, schema: StructType): Unit = {
+ schema.foreach { field =>
+ if (!format.supportFieldName(field.name)) {
+ throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(field.name)
+ }
+ field.dataType match {
+ case s: StructType => checkFieldNames(format, s)
+ case _ =>
+ }
+ }
+ }
+
+ /**
* Verify if the schema is supported in datasource. This verification should be done
* in a driver side.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
index 7fd48ca..beb1f4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
@@ -163,6 +163,12 @@ trait FileFormat {
* By default all data types are supported.
*/
def supportDataType(dataType: DataType): Boolean = true
+
+ /**
+ * Returns whether this format supports the given filed name in read/write path.
+ * By default all field name is supported.
+ */
+ def supportFieldName(name: String): Boolean = true
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 9024c78..85c0ff0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -36,31 +36,12 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.{SerializableConfiguration, Utils}
private[sql] object OrcFileFormat {
- private def checkFieldName(name: String): Unit = {
- try {
- TypeDescription.fromString(s"struct<`$name`:int>")
- } catch {
- case _: IllegalArgumentException =>
- throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name)
- }
- }
-
- def checkFieldNames(schema: StructType): Unit = {
- schema.foreach { field =>
- checkFieldName(field.name)
- field.dataType match {
- case s: StructType => checkFieldNames(s)
- case _ =>
- }
- }
- }
def getQuotedSchemaString(dataType: DataType): String = dataType match {
case _: AtomicType => dataType.catalogString
@@ -279,4 +260,13 @@ class OrcFileFormat
case _ => false
}
+
+ override def supportFieldName(name: String): Boolean = {
+ try {
+ TypeDescription.fromString(s"struct<`$name`:int>")
+ true
+ } catch {
+ case _: IllegalArgumentException => false
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index ee229a3..586952a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -386,6 +386,10 @@ class ParquetFileFormat
case _ => false
}
+
+ override def supportFieldName(name: String): Boolean = {
+ !name.matches(".*[ ,;{}()\n\t=].*")
+ }
}
object ParquetFileFormat extends Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index f3bfd99..217c020 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -591,16 +591,6 @@ private[sql] object ParquetSchemaConverter {
}
}
- def checkFieldNames(schema: StructType): Unit = {
- schema.foreach { field =>
- checkFieldName(field.name)
- field.dataType match {
- case s: StructType => checkFieldNames(s)
- case _ =>
- }
- }
- }
-
def checkConversionRequirement(f: => Boolean, message: String): Unit = {
if (!f) {
throw new AnalysisException(message)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org