You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/12/07 13:42:39 UTC
[spark] branch master updated: [SPARK-33641][SQL] Invalidate new
char/varchar types in public APIs that produce incorrect results
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new da72b87 [SPARK-33641][SQL] Invalidate new char/varchar types in public APIs that produce incorrect results
da72b87 is described below
commit da72b87374a7be5416b99ed016dc2fc9da0ed88a
Author: Kent Yao <ya...@hotmail.com>
AuthorDate: Mon Dec 7 13:40:15 2020 +0000
[SPARK-33641][SQL] Invalidate new char/varchar types in public APIs that produce incorrect results
### What changes were proposed in this pull request?
In this PR, we suppose to narrow the use cases of the char/varchar data types, of which are invalid now or later
### Why are the changes needed?
1. udf
```scala
scala> spark.udf.register("abcd", () => "12345", org.apache.spark.sql.types.VarcharType(2))
scala> spark.sql("select abcd()").show
scala.MatchError: CharType(2) (of class org.apache.spark.sql.types.VarcharType)
at org.apache.spark.sql.catalyst.encoders.RowEncoder$.externalDataTypeFor(RowEncoder.scala:215)
at org.apache.spark.sql.catalyst.encoders.RowEncoder$.externalDataTypeForInput(RowEncoder.scala:212)
at org.apache.spark.sql.catalyst.expressions.objects.ValidateExternalType.<init>(objects.scala:1741)
at org.apache.spark.sql.catalyst.encoders.RowEncoder$.$anonfun$serializerFor$3(RowEncoder.scala:175)
at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:245)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.flatMap(TraversableLike.scala:245)
at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:242)
at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
at org.apache.spark.sql.catalyst.encoders.RowEncoder$.serializerFor(RowEncoder.scala:171)
at org.apache.spark.sql.catalyst.encoders.RowEncoder$.apply(RowEncoder.scala:66)
at org.apache.spark.sql.Dataset$.$anonfun$ofRows$2(Dataset.scala:99)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:768)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:96)
at org.apache.spark.sql.SparkSession.$anonfun$sql$1(SparkSession.scala:611)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:768)
at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:606)
... 47 elided
```
2. spark.createDataframe
```
scala> spark.createDataFrame(spark.read.text("README.md").rdd, new org.apache.spark.sql.types.StructType().add("c", "char(1)")).show
+--------------------+
| c|
+--------------------+
| # Apache Spark|
| |
|Spark is a unifie...|
|high-level APIs i...|
|supports general ...|
|rich set of highe...|
|MLlib for machine...|
|and Structured St...|
| |
|<https://spark.ap...|
| |
|[![Jenkins Build]...|
|[![AppVeyor Build...|
|[![PySpark Covera...|
| |
| |
```
3. reader.schema
```
scala> spark.read.schema("a varchar(2)").text("./README.md").show(100)
+--------------------+
| a|
+--------------------+
| # Apache Spark|
| |
|Spark is a unifie...|
|high-level APIs i...|
|supports general ...|
```
4. etc
### Does this PR introduce _any_ user-facing change?
NO, we intend to avoid protentical breaking change
### How was this patch tested?
new tests
Closes #30586 from yaooqinn/SPARK-33641.
Authored-by: Kent Yao <ya...@hotmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/expressions/ExprUtils.scala | 6 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 19 +---
.../spark/sql/catalyst/parser/ParseDriver.scala | 5 -
.../sql/catalyst/parser/ParserInterface.scala | 6 --
.../spark/sql/catalyst/util/CharVarcharUtils.scala | 38 ++++++-
.../org/apache/spark/sql/internal/SQLConf.scala | 13 +++
.../org/apache/spark/sql/types/VarcharType.scala | 2 +-
.../sql/catalyst/parser/DataTypeParserSuite.scala | 14 +--
.../catalyst/parser/TableSchemaParserSuite.scala | 4 +-
.../org/apache/spark/sql/types/DataTypeSuite.scala | 10 ++
.../main/scala/org/apache/spark/sql/Column.scala | 2 +-
.../org/apache/spark/sql/DataFrameReader.scala | 7 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 10 +-
.../org/apache/spark/sql/UDFRegistration.scala | 73 ++++++++-----
.../sql/execution/datasources/jdbc/JdbcUtils.scala | 7 +-
.../scala/org/apache/spark/sql/functions.scala | 12 ++-
.../apache/spark/sql/CharVarcharTestSuite.scala | 114 ++++++++++++++-------
.../spark/sql/SparkSessionExtensionSuite.scala | 3 -
.../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 5 +-
.../spark/sql/hive/client/HiveClientImpl.scala | 2 +-
20 files changed, 226 insertions(+), 126 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index 56bd3d7..b45bbe4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -21,7 +21,7 @@ import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
import java.util.Locale
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
@@ -30,7 +30,9 @@ object ExprUtils {
def evalTypeExpr(exp: Expression): DataType = {
if (exp.foldable) {
exp.eval() match {
- case s: UTF8String if s != null => DataType.fromDDL(s.toString)
+ case s: UTF8String if s != null =>
+ val dataType = DataType.fromDDL(s.toString)
+ CharVarcharUtils.failIfHasCharVarchar(dataType)
case _ => throw new AnalysisException(
s"The expression '${exp.sql}' is not a valid schema string.")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 12c5e0d..a22383c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -95,19 +95,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}
override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
- visitSparkDataType(ctx.dataType)
+ typedVisit[DataType](ctx.dataType)
}
override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
- val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(
- StructType(visitColTypeList(ctx.colTypeList)))
+ val schema = StructType(visitColTypeList(ctx.colTypeList))
withOrigin(ctx)(schema)
}
- def parseRawDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
- typedVisit[DataType](ctx.dataType())
- }
-
/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
@@ -1550,7 +1545,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
* Create a [[Cast]] expression.
*/
override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
- Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType))
+ val rawDataType = typedVisit[DataType](ctx.dataType())
+ val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
+ Cast(expression(ctx.expression), dataType)
}
/**
@@ -2229,12 +2226,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
/* ********************************************************************************************
* DataType parsing
* ******************************************************************************************** */
- /**
- * Create a Spark DataType.
- */
- private def visitSparkDataType(ctx: DataTypeContext): DataType = {
- CharVarcharUtils.replaceCharVarcharWithString(typedVisit(ctx))
- }
/**
* Resolve/create a primitive type.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index ac3fbbf..d08be46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -39,11 +39,6 @@ abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper with
astBuilder.visitSingleDataType(parser.singleDataType())
}
- /** Similar to `parseDataType`, but without CHAR/VARCHAR replacement. */
- override def parseRawDataType(sqlText: String): DataType = parse(sqlText) { parser =>
- astBuilder.parseRawDataType(parser.singleDataType())
- }
-
/** Creates Expression for a given SQL string. */
override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
astBuilder.visitSingleExpression(parser.singleExpression())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
index d724933..77e357a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
@@ -70,10 +70,4 @@ trait ParserInterface {
*/
@throws[ParseException]("Text cannot be parsed to a DataType")
def parseDataType(sqlText: String): DataType
-
- /**
- * Parse a string to a raw [[DataType]] without CHAR/VARCHAR replacement.
- */
- @throws[ParseException]("Text cannot be parsed to a DataType")
- def parseRawDataType(sqlText: String): DataType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
index 0cbe5ab..b551d96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.catalyst.util
import scala.collection.mutable
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-object CharVarcharUtils {
+object CharVarcharUtils extends Logging {
private val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING"
@@ -53,6 +56,19 @@ object CharVarcharUtils {
}
/**
+ * Validate the given [[DataType]] to fail if it is char or varchar types or contains nested ones
+ */
+ def failIfHasCharVarchar(dt: DataType): DataType = {
+ if (!SQLConf.get.charVarcharAsString && hasCharVarchar(dt)) {
+ throw new AnalysisException("char/varchar type can only be used in the table schema. " +
+ s"You can set ${SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key} to true, so that Spark" +
+ s" treat them as string type as same as Spark 3.0 and earlier")
+ } else {
+ replaceCharVarcharWithString(dt)
+ }
+ }
+
+ /**
* Replaces CharType/VarcharType with StringType recursively in the given data type.
*/
def replaceCharVarcharWithString(dt: DataType): DataType = dt match {
@@ -70,6 +86,24 @@ object CharVarcharUtils {
}
/**
+ * Replaces CharType/VarcharType with StringType recursively in the given data type, with a
+ * warning message if it has char or varchar types
+ */
+ def replaceCharVarcharWithStringForCast(dt: DataType): DataType = {
+ if (SQLConf.get.charVarcharAsString) {
+ replaceCharVarcharWithString(dt)
+ } else if (hasCharVarchar(dt)) {
+ logWarning("The Spark cast operator does not support char/varchar type and simply treats" +
+ " them as string type. Please use string type directly to avoid confusion. Otherwise," +
+ s" you can set ${SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key} to true, so that Spark treat" +
+ s" them as string type as same as Spark 3.0 and earlier")
+ replaceCharVarcharWithString(dt)
+ } else {
+ dt
+ }
+ }
+
+ /**
* Removes the metadata entry that contains the original type string of CharType/VarcharType from
* the given attribute's metadata.
*/
@@ -85,7 +119,7 @@ object CharVarcharUtils {
*/
def getRawType(metadata: Metadata): Option[DataType] = {
if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) {
- Some(CatalystSqlParser.parseRawDataType(
+ Some(CatalystSqlParser.parseDataType(
metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)))
} else {
None
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ea30832..69f04e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2962,6 +2962,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val LEGACY_CHAR_VARCHAR_AS_STRING =
+ buildConf("spark.sql.legacy.charVarcharAsString")
+ .internal()
+ .doc("When true, Spark will not fail if user uses char and varchar type directly in those" +
+ " APIs that accept or parse data types as parameters, e.g." +
+ " `SparkSession.read.schema(...)`, `SparkSession.udf.register(...)` but treat them as" +
+ " string type as Spark 3.0 and earlier.")
+ .version("3.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -3612,6 +3623,8 @@ class SQLConf extends Serializable with Logging {
def disabledJdbcConnectionProviders: String = getConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST)
+ def charVarcharAsString: Boolean = getConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala
index 8d78640..2e30820 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala
@@ -32,6 +32,6 @@ case class VarcharType(length: Int) extends AtomicType {
override def defaultSize: Int = length
override def typeName: String = s"varchar($length)"
- override def toString: String = s"CharType($length)"
+ override def toString: String = s"VarcharType($length)"
private[spark] override def asNullable: VarcharType = this
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 655b1d2..b9f9840 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -56,10 +56,10 @@ class DataTypeParserSuite extends SparkFunSuite {
checkDataType("DATE", DateType)
checkDataType("timestamp", TimestampType)
checkDataType("string", StringType)
- checkDataType("ChaR(5)", StringType)
- checkDataType("ChaRacter(5)", StringType)
- checkDataType("varchAr(20)", StringType)
- checkDataType("cHaR(27)", StringType)
+ checkDataType("ChaR(5)", CharType(5))
+ checkDataType("ChaRacter(5)", CharType(5))
+ checkDataType("varchAr(20)", VarcharType(20))
+ checkDataType("cHaR(27)", CharType(27))
checkDataType("BINARY", BinaryType)
checkDataType("void", NullType)
checkDataType("interval", CalendarIntervalType)
@@ -103,9 +103,9 @@ class DataTypeParserSuite extends SparkFunSuite {
StructType(
StructField("deciMal", DecimalType.USER_DEFAULT, true) ::
StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) ::
- StructField("MAP", MapType(TimestampType, StringType), true) ::
+ StructField("MAP", MapType(TimestampType, VarcharType(10)), true) ::
StructField("arrAy", ArrayType(DoubleType, true), true) ::
- StructField("anotherArray", ArrayType(StringType, true), true) :: Nil)
+ StructField("anotherArray", ArrayType(CharType(9), true), true) :: Nil)
)
// Use backticks to quote column names having special characters.
checkDataType(
@@ -113,7 +113,7 @@ class DataTypeParserSuite extends SparkFunSuite {
StructType(
StructField("x+y", IntegerType, true) ::
StructField("!@#$%^&*()", StringType, true) ::
- StructField("1_2.345<>:\"", StringType, true) :: Nil)
+ StructField("1_2.345<>:\"", VarcharType(20), true) :: Nil)
)
// Empty struct.
checkDataType("strUCt<>", StructType(Nil))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala
index 95851d4..5519f01 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.types._
class TableSchemaParserSuite extends SparkFunSuite {
@@ -69,8 +68,7 @@ class TableSchemaParserSuite extends SparkFunSuite {
StructField("arrAy", ArrayType(DoubleType)) ::
StructField("anotherArray", ArrayType(CharType(9))) :: Nil)) :: Nil)
- assert(parse(tableSchemaString) ===
- CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedDataType))
+ assert(parse(tableSchemaString) === expectedDataType)
}
// Negative cases
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 9442a3e..8c2e5db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -249,6 +249,12 @@ class DataTypeSuite extends SparkFunSuite {
checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false))
checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false))
+ checkDataTypeFromJson(CharType(1))
+ checkDataTypeFromDDL(CharType(1))
+
+ checkDataTypeFromJson(VarcharType(10))
+ checkDataTypeFromDDL(VarcharType(11))
+
val metadata = new MetadataBuilder()
.putString("name", "age")
.build()
@@ -310,6 +316,10 @@ class DataTypeSuite extends SparkFunSuite {
checkDefaultSize(MapType(IntegerType, StringType, true), 24)
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 12)
checkDefaultSize(structType, 20)
+ checkDefaultSize(CharType(5), 5)
+ checkDefaultSize(CharType(100), 100)
+ checkDefaultSize(VarcharType(5), 5)
+ checkDefaultSize(VarcharType(10), 10)
def checkEqualsIgnoreCompatibleNullability(
from: DataType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 86ba813..4ef23d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -1185,7 +1185,7 @@ class Column(val expr: Expression) extends Logging {
* @since 1.3.0
*/
def cast(to: DataType): Column = withExpr {
- Cast(expr, CharVarcharUtils.replaceCharVarcharWithString(to))
+ Cast(expr, CharVarcharUtils.replaceCharVarcharWithStringForCast(to))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 007df18..b94c42a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -73,7 +73,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def schema(schema: StructType): DataFrameReader = {
- this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema))
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+ this.userSpecifiedSchema = Option(replaced)
this
}
@@ -89,7 +90,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 2.3.0
*/
def schema(schemaString: String): DataFrameReader = {
- this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString))
+ val rawSchema = StructType.fromDDL(schemaString)
+ val schema = CharVarcharUtils.failIfHasCharVarchar(rawSchema).asInstanceOf[StructType]
+ this.userSpecifiedSchema = Option(schema)
this
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 3a9b069..a2c9406 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ExternalCommandExecutor
@@ -347,9 +348,10 @@ class SparkSession private(
*/
@DeveloperApi
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(replaced)
val toRow = encoder.createSerializer()
val catalystRows = rowRDD.map(toRow)
internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema)
@@ -365,7 +367,8 @@ class SparkSession private(
*/
@DeveloperApi
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
- createDataFrame(rowRDD.rdd, schema)
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+ createDataFrame(rowRDD.rdd, replaced)
}
/**
@@ -378,7 +381,8 @@ class SparkSession private(
*/
@DeveloperApi
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive {
- Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala.toSeq))
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+ Dataset.ofRows(self, LocalRelation.fromExternalRows(replaced.toAttributes, rows.asScala.toSeq))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index cceb385..237cfe1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
@@ -162,9 +163,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
| * @since $version
| */
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
+ | val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
| val func = $funcCall
| def builder(e: Seq[Expression]) = if (e.length == $i) {
- | ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ | ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
| } else {
| throw new AnalysisException("Invalid number of arguments for function " + name +
| ". Expected: $i; Found: " + e.length)
@@ -753,9 +755,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 2.3.0
*/
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = () => f.asInstanceOf[UDF0[Any]].call()
def builder(e: Seq[Expression]) = if (e.length == 0) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 0; Found: " + e.length)
@@ -768,9 +771,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
def builder(e: Seq[Expression]) = if (e.length == 1) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 1; Found: " + e.length)
@@ -783,9 +787,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 2) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 2; Found: " + e.length)
@@ -798,9 +803,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 3) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 3; Found: " + e.length)
@@ -813,9 +819,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 4) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 4; Found: " + e.length)
@@ -828,9 +835,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 5) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 5; Found: " + e.length)
@@ -843,9 +851,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 6) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 6; Found: " + e.length)
@@ -858,9 +867,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 7) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 7; Found: " + e.length)
@@ -873,9 +883,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 8) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 8; Found: " + e.length)
@@ -888,9 +899,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 9) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 9; Found: " + e.length)
@@ -903,9 +915,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 10) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 10; Found: " + e.length)
@@ -918,9 +931,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 11) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 11; Found: " + e.length)
@@ -933,9 +947,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 12) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 12; Found: " + e.length)
@@ -948,9 +963,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 13) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 13; Found: " + e.length)
@@ -963,9 +979,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 14) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 14; Found: " + e.length)
@@ -978,9 +995,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 15) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 15; Found: " + e.length)
@@ -993,9 +1011,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 16) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 16; Found: " + e.length)
@@ -1008,9 +1027,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 17) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 17; Found: " + e.length)
@@ -1023,9 +1043,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 18) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 18; Found: " + e.length)
@@ -1038,9 +1059,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 19) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 19; Found: " + e.length)
@@ -1053,9 +1075,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 20) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 20; Found: " + e.length)
@@ -1068,9 +1091,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 21) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 21; Found: " + e.length)
@@ -1083,9 +1107,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType)
val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 22) {
- ScalaUDF(func, returnType, e, Nil, udfName = Some(name))
+ ScalaUDF(func, replaced, e, Nil, udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 22; Found: " + e.length)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 216fb02..f997e57b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
@@ -761,10 +761,7 @@ object JdbcUtils extends Logging {
schema: StructType,
caseSensitive: Boolean,
createTableColumnTypes: String): Map[String, String] = {
- val parsedSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes)
- val userSchema = StructType(parsedSchema.map { field =>
- field.copy(dataType = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType))
- })
+ val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes)
val nameEquality = if (caseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 9861d21..5b1ee2d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint}
-import org.apache.spark.sql.catalyst.util.TimestampFormatter
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TimestampFormatter}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.SQLConf
@@ -4009,7 +4009,7 @@ object functions {
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
- JsonToStructs(schema, options, e.expr)
+ JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), options, e.expr)
}
/**
@@ -4040,8 +4040,9 @@ object functions {
* @group collection_funcs
* @since 2.2.0
*/
- def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column =
- from_json(e, schema, options.asScala.toMap)
+ def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column = {
+ from_json(e, CharVarcharUtils.failIfHasCharVarchar(schema), options.asScala.toMap)
+ }
/**
* Parses a column containing a JSON string into a `StructType` with the specified schema.
@@ -4393,7 +4394,8 @@ object functions {
* @since 3.0.0
*/
def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr {
- CsvToStructs(schema, options, e.expr)
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+ CsvToStructs(replaced, options, e.expr)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index abb1327..fcd334b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.SimpleInsertSource
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
-import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
// The base trait for char/varchar tests that need to be run with different table implementations.
trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
@@ -435,55 +435,91 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession {
assert(df.schema.map(_.dataType) == Seq(StringType))
}
- assertNoCharType(spark.range(1).select($"id".cast("char(5)")))
- assertNoCharType(spark.range(1).select($"id".cast(CharType(5))))
- assertNoCharType(spark.range(1).selectExpr("CAST(id AS CHAR(5))"))
- assertNoCharType(sql("SELECT CAST(id AS CHAR(5)) FROM range(1)"))
+ val logAppender = new LogAppender("The Spark cast operator does not support char/varchar" +
+ " type and simply treats them as string type. Please use string type directly to avoid" +
+ " confusion.")
+ withLogAppender(logAppender) {
+ assertNoCharType(spark.range(1).select($"id".cast("char(5)")))
+ assertNoCharType(spark.range(1).select($"id".cast(CharType(5))))
+ assertNoCharType(spark.range(1).selectExpr("CAST(id AS CHAR(5))"))
+ assertNoCharType(sql("SELECT CAST(id AS CHAR(5)) FROM range(1)"))
+ }
}
- test("user-specified schema in functions") {
- val df = sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""")
- checkAnswer(df, Row(Row("str")))
- val schema = df.schema.head.dataType.asInstanceOf[StructType]
- assert(schema.map(_.dataType) == Seq(StringType))
+ def failWithInvalidCharUsage[T](fn: => T): Unit = {
+ val e = intercept[AnalysisException](fn)
+ assert(e.getMessage contains "char/varchar type can only be used in the table schema")
}
- test("user-specified schema in DataFrameReader: file source from Dataset") {
- val ds = spark.range(10).map(_.toString)
- val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds)
- assert(df1.schema.map(_.dataType) == Seq(StringType))
- val df2 = spark.read.schema("id char(5)").csv(ds)
- assert(df2.schema.map(_.dataType) == Seq(StringType))
+ test("invalidate char/varchar in functions") {
+ failWithInvalidCharUsage(sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')"""))
+ withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) {
+ val df = sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""")
+ checkAnswer(df, Row(Row("str")))
+ val schema = df.schema.head.dataType.asInstanceOf[StructType]
+ assert(schema.map(_.dataType) == Seq(StringType))
+ }
}
- test("user-specified schema in DataFrameReader: DSV1") {
- def checkSchema(df: DataFrame): Unit = {
- val relations = df.queryExecution.analyzed.collect {
- case l: LogicalRelation => l.relation
- }
- assert(relations.length == 1)
- assert(relations.head.schema.map(_.dataType) == Seq(StringType))
+ test("invalidate char/varchar in SparkSession createDataframe") {
+ val df = spark.range(10).map(_.toString).toDF()
+ val schema = new StructType().add("id", CharType(5))
+ failWithInvalidCharUsage(spark.createDataFrame(df.collectAsList(), schema))
+ failWithInvalidCharUsage(spark.createDataFrame(df.rdd, schema))
+ failWithInvalidCharUsage(spark.createDataFrame(df.toJavaRDD, schema))
+ withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) {
+ val df1 = spark.createDataFrame(df.collectAsList(), schema)
+ checkAnswer(df1, df)
+ assert(df1.schema.head.dataType === StringType)
}
-
- checkSchema(spark.read.schema(new StructType().add("id", CharType(5)))
- .format(classOf[SimpleInsertSource].getName).load())
- checkSchema(spark.read.schema("id char(5)")
- .format(classOf[SimpleInsertSource].getName).load())
}
- test("user-specified schema in DataFrameReader: DSV2") {
- def checkSchema(df: DataFrame): Unit = {
- val tables = df.queryExecution.analyzed.collect {
- case d: DataSourceV2Relation => d.table
+ test("invalidate char/varchar in spark.read.schema") {
+ failWithInvalidCharUsage(spark.read.schema(new StructType().add("id", CharType(5))))
+ failWithInvalidCharUsage(spark.read.schema("id char(5)"))
+ withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) {
+ val ds = spark.range(10).map(_.toString)
+ val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds)
+ assert(df1.schema.map(_.dataType) == Seq(StringType))
+ val df2 = spark.read.schema("id char(5)").csv(ds)
+ assert(df2.schema.map(_.dataType) == Seq(StringType))
+
+ def checkSchema(df: DataFrame): Unit = {
+ val schemas = df.queryExecution.analyzed.collect {
+ case l: LogicalRelation => l.relation.schema
+ case d: DataSourceV2Relation => d.table.schema()
+ }
+ assert(schemas.length == 1)
+ assert(schemas.head.map(_.dataType) == Seq(StringType))
}
- assert(tables.length == 1)
- assert(tables.head.schema.map(_.dataType) == Seq(StringType))
- }
- checkSchema(spark.read.schema(new StructType().add("id", CharType(5)))
- .format(classOf[SchemaRequiredDataSource].getName).load())
- checkSchema(spark.read.schema("id char(5)")
- .format(classOf[SchemaRequiredDataSource].getName).load())
+ // user-specified schema in DataFrameReader: DSV1
+ checkSchema(spark.read.schema(new StructType().add("id", CharType(5)))
+ .format(classOf[SimpleInsertSource].getName).load())
+ checkSchema(spark.read.schema("id char(5)")
+ .format(classOf[SimpleInsertSource].getName).load())
+
+ // user-specified schema in DataFrameReader: DSV2
+ checkSchema(spark.read.schema(new StructType().add("id", CharType(5)))
+ .format(classOf[SchemaRequiredDataSource].getName).load())
+ checkSchema(spark.read.schema("id char(5)")
+ .format(classOf[SchemaRequiredDataSource].getName).load())
+ }
+ }
+
+ test("invalidate char/varchar in udf's result type") {
+ failWithInvalidCharUsage(spark.udf.register("testchar", () => "B", VarcharType(1)))
+ failWithInvalidCharUsage(spark.udf.register("testchar2", (x: String) => x, VarcharType(1)))
+ withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) {
+ spark.udf.register("testchar", () => "B", VarcharType(1))
+ spark.udf.register("testchar2", (x: String) => x, VarcharType(1))
+ val df1 = spark.sql("select testchar()")
+ checkAnswer(df1, Row("B"))
+ assert(df1.schema.head.dataType === StringType)
+ val df2 = spark.sql("select testchar2('abc')")
+ checkAnswer(df2, Row("abc"))
+ assert(df2.schema.head.dataType === StringType)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index f02d204..ea276bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -384,9 +384,6 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars
override def parseDataType(sqlText: String): DataType =
delegate.parseDataType(sqlText)
-
- override def parseRawDataType(sqlText: String): DataType =
- delegate.parseRawDataType(sqlText)
}
object MyExtensions {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index fb46c2f..1a28523 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -390,14 +390,13 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
.foldLeft(new StructType())((schema, colType) => schema.add(colType._1, colType._2))
val createTableColTypes =
colTypes.map { case (col, dataType) => s"$col $dataType" }.mkString(", ")
- val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row.empty)), schema)
val expectedSchemaStr =
colTypes.map { case (col, dataType) => s""""$col" $dataType """ }.mkString(", ")
assert(JdbcUtils.schemaString(
- df.schema,
- df.sqlContext.conf.caseSensitiveAnalysis,
+ schema,
+ spark.sqlContext.conf.caseSensitiveAnalysis,
url1,
Option(createTableColTypes)) == expectedSchemaStr)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index bada131..34befb8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -985,7 +985,7 @@ private[hive] object HiveClientImpl extends Logging {
/** Get the Spark SQL native DataType from Hive's FieldSchema. */
private def getSparkSQLDataType(hc: FieldSchema): DataType = {
try {
- CatalystSqlParser.parseRawDataType(hc.getType)
+ CatalystSqlParser.parseDataType(hc.getType)
} catch {
case e: ParseException =>
throw new SparkException(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org