You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/05/11 01:01:48 UTC
spark git commit: [SPARK-22938][SQL][FOLLOWUP] Assert that
SQLConf.get is accessed only on the driver
Repository: spark
Updated Branches:
refs/heads/master d3c426a5b -> a4206d58e
[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver
## What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/20136 . #20136 didn't really work because in the test, we are using local backend, which shares the driver side `SparkEnv`, so `SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER` doesn't work.
This PR changes the check to `TaskContext.get != null`, and move the check to `SQLConf.get`, and fix all the places that violate this check:
* `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. https://github.com/apache/spark/pull/21223 merged
* `DataType#sameType` may be executed in executor side, for things like json schema inference, so we can't call `conf.caseSensitiveAnalysis` there. This contributes to most of the code changes, as we need to add `caseSensitive` parameter to a lot of methods.
* `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't can't call `conf.parquetFilterPushDownDate` there. https://github.com/apache/spark/pull/21224 merged
* `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. https://github.com/apache/spark/pull/21225 merged
* `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. https://github.com/apache/spark/pull/21226 merged
## How was this patch tested?
existing test
Author: Wenchen Fan <we...@databricks.com>
Closes #21190 from cloud-fan/minor.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a4206d58
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a4206d58
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a4206d58
Branch: refs/heads/master
Commit: a4206d58e05ab9ed6f01fee57e18dee65cbc4efc
Parents: d3c426a
Author: Wenchen Fan <we...@databricks.com>
Authored: Fri May 11 09:01:40 2018 +0800
Committer: hyukjinkwon <gu...@apache.org>
Committed: Fri May 11 09:01:40 2018 +0800
----------------------------------------------------------------------
.../sql/catalyst/analysis/CheckAnalysis.scala | 5 +-
.../catalyst/analysis/ResolveInlineTables.scala | 4 +-
.../sql/catalyst/analysis/TypeCoercion.scala | 156 +++++++++++--------
.../org/apache/spark/sql/internal/SQLConf.scala | 16 +-
.../org/apache/spark/sql/types/DataType.scala | 8 +-
.../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++-----
.../org/apache/spark/sql/SparkSession.scala | 21 ++-
.../datasources/PartitioningUtils.scala | 5 +-
.../datasources/json/JsonInferSchema.scala | 39 +++--
.../execution/datasources/json/JsonSuite.scala | 4 +-
10 files changed, 188 insertions(+), 140 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 90bda2a..94b0561 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
/**
@@ -260,7 +261,9 @@ trait CheckAnalysis extends PredicateHelper {
// Check if the data types match.
dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
// SPARK-18058: we shall not care about the nullability of columns
- if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) {
+ val widerType = TypeCoercion.findWiderTypeForTwo(
+ dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis)
+ if (widerType.isEmpty) {
failAnalysis(
s"""
|${operator.nodeName} can only be performed on tables with the compatible
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index f2df3e1..4eb6e64 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -83,7 +83,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
// For each column, traverse all the values and find a common data type and nullability.
val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
val inputTypes = column.map(_.dataType)
- val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
+ val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion(
+ inputTypes, conf.caseSensitiveAnalysis)
+ val tpe = wideType.getOrElse {
table.failAnalysis(s"incompatible types found in column $name for inline table")
}
StructField(name, tpe, nullable = column.exists(_.nullable))
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index b2817b0..a7ba201 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -48,18 +48,18 @@ object TypeCoercion {
def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
InConversion(conf) ::
- WidenSetOperationTypes ::
+ WidenSetOperationTypes(conf) ::
PromoteStrings(conf) ::
DecimalPrecision ::
BooleanEquality ::
- FunctionArgumentConversion ::
+ FunctionArgumentConversion(conf) ::
ConcatCoercion(conf) ::
EltCoercion(conf) ::
- CaseWhenCoercion ::
- IfCoercion ::
+ CaseWhenCoercion(conf) ::
+ IfCoercion(conf) ::
StackCoercion ::
Division ::
- new ImplicitTypeCasts(conf) ::
+ ImplicitTypeCasts(conf) ::
DateTimeOperations ::
WindowFrameCoercion ::
Nil
@@ -83,7 +83,10 @@ object TypeCoercion {
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[DecimalPrecision]].
*/
- val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
+ def findTightestCommonType(
+ left: DataType,
+ right: DataType,
+ caseSensitive: Boolean): Option[DataType] = (left, right) match {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
@@ -102,22 +105,32 @@ object TypeCoercion {
case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
Some(TimestampType)
- case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) =>
- Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
- // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType
- // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
- // - Different names: use f1.name
- // - Different nullabilities: `nullable` is true iff one of them is nullable.
- val dataType = findTightestCommonType(f1.dataType, f2.dataType).get
- StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
- }))
+ case (t1 @ StructType(fields1), t2 @ StructType(fields2)) =>
+ val isSameType = if (caseSensitive) {
+ DataType.equalsIgnoreNullability(t1, t2)
+ } else {
+ DataType.equalsIgnoreCaseAndNullability(t1, t2)
+ }
+
+ if (isSameType) {
+ Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
+ // Since t1 is same type of t2, two StructTypes have the same DataType
+ // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
+ // - Different names: use f1.name
+ // - Different nullabilities: `nullable` is true iff one of them is nullable.
+ val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get
+ StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
+ }))
+ } else {
+ None
+ }
case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) =>
- findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
+ findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2))
case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) =>
- val keyType = findTightestCommonType(kt1, kt2)
- val valueType = findTightestCommonType(vt1, vt2)
+ val keyType = findTightestCommonType(kt1, kt2, caseSensitive)
+ val valueType = findTightestCommonType(vt1, vt2, caseSensitive)
Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
case _ => None
@@ -172,13 +185,14 @@ object TypeCoercion {
* i.e. the main difference with [[findTightestCommonType]] is that here we allow some
* loss of precision when widening decimal and double, and promotion to string.
*/
- def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
- findTightestCommonType(t1, t2)
+ def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = {
+ findTightestCommonType(t1, t2, caseSensitive)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse(stringPromotion(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
- findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
+ findWiderTypeForTwo(et1, et2, caseSensitive)
+ .map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}
@@ -193,7 +207,8 @@ object TypeCoercion {
case _ => false
}
- private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
+ private def findWiderCommonType(
+ types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = {
// findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal
// to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType.
// Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance,
@@ -201,7 +216,7 @@ object TypeCoercion {
val (stringTypes, nonStringTypes) = types.partition(hasStringType(_))
(stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) =>
r match {
- case Some(d) => findWiderTypeForTwo(d, c)
+ case Some(d) => findWiderTypeForTwo(d, c, caseSensitive)
case _ => None
})
}
@@ -213,20 +228,22 @@ object TypeCoercion {
*/
private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
t1: DataType,
- t2: DataType): Option[DataType] = {
- findTightestCommonType(t1, t2)
+ t2: DataType,
+ caseSensitive: Boolean): Option[DataType] = {
+ findTightestCommonType(t1, t2, caseSensitive)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
- findWiderTypeWithoutStringPromotionForTwo(et1, et2)
+ findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive)
.map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}
- def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
+ def findWiderTypeWithoutStringPromotion(
+ types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
- case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
+ case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive)
case None => None
})
}
@@ -279,29 +296,32 @@ object TypeCoercion {
*
* This rule is only applied to Union/Except/Intersect
*/
- object WidenSetOperationTypes extends Rule[LogicalPlan] {
+ case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ SetOperation(left, right) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
- val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
+ val newChildren: Seq[LogicalPlan] =
+ buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis)
assert(newChildren.length == 2)
s.makeCopy(Array(newChildren.head, newChildren.last))
case s: Union if s.childrenResolved &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
- val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
+ val newChildren: Seq[LogicalPlan] =
+ buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis)
s.makeCopy(Array(newChildren))
}
/** Build new children with the widest types for each attribute among all the children */
- private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
+ private def buildNewChildrenWithWiderTypes(
+ children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = {
require(children.forall(_.output.length == children.head.output.length))
// Get a sequence of data types, each of which is the widest type of this specific attribute
// in all the children
val targetTypes: Seq[DataType] =
- getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]())
+ getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive)
if (targetTypes.nonEmpty) {
// Add an extra Project if the targetTypes are different from the original types.
@@ -316,18 +336,19 @@ object TypeCoercion {
@tailrec private def getWidestTypes(
children: Seq[LogicalPlan],
attrIndex: Int,
- castedTypes: mutable.Queue[DataType]): Seq[DataType] = {
+ castedTypes: mutable.Queue[DataType],
+ caseSensitive: Boolean): Seq[DataType] = {
// Return the result after the widen data types have been found for all the children
if (attrIndex >= children.head.output.length) return castedTypes.toSeq
// For the attrIndex-th attribute, find the widest type
- findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
+ findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match {
// If unable to find an appropriate widen type for this column, return an empty Seq
case None => Seq.empty[DataType]
// Otherwise, record the result in the queue and find the type for the next column
case Some(widenType) =>
castedTypes.enqueue(widenType)
- getWidestTypes(children, attrIndex + 1, castedTypes)
+ getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive)
}
}
@@ -432,7 +453,7 @@ object TypeCoercion {
val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf)
- .orElse(findTightestCommonType(l.dataType, r.dataType))
+ .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis))
}
// The number of columns/expressions must match between LHS and RHS of an
@@ -461,7 +482,7 @@ object TypeCoercion {
}
case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
- findWiderCommonType(i.children.map(_.dataType)) match {
+ findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
}
@@ -515,7 +536,7 @@ object TypeCoercion {
/**
* This ensure that the types for various functions are as expected.
*/
- object FunctionArgumentConversion extends TypeCoercionRule {
+ case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
@@ -523,7 +544,7 @@ object TypeCoercion {
case a @ CreateArray(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderCommonType(types) match {
+ findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}
@@ -531,7 +552,7 @@ object TypeCoercion {
case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderCommonType(types) match {
+ findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case None => c
}
@@ -542,7 +563,7 @@ object TypeCoercion {
m.keys
} else {
val types = m.keys.map(_.dataType)
- findWiderCommonType(types) match {
+ findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
case None => m.keys
}
@@ -552,7 +573,7 @@ object TypeCoercion {
m.values
} else {
val types = m.values.map(_.dataType)
- findWiderCommonType(types) match {
+ findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
case None => m.values
}
@@ -580,7 +601,7 @@ object TypeCoercion {
// compatible with every child column.
case c @ Coalesce(es) if !haveSameType(es) =>
val types = es.map(_.dataType)
- findWiderCommonType(types) match {
+ findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}
@@ -590,14 +611,14 @@ object TypeCoercion {
// string.g
case g @ Greatest(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderTypeWithoutStringPromotion(types) match {
+ findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case None => g
}
case l @ Least(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderTypeWithoutStringPromotion(types) match {
+ findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
case None => l
}
@@ -637,11 +658,11 @@ object TypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
- object CaseWhenCoercion extends TypeCoercionRule {
+ case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
- val maybeCommonType = findWiderCommonType(c.valueTypes)
+ val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis)
maybeCommonType.map { commonType =>
var changed = false
val newBranches = c.branches.map { case (condition, value) =>
@@ -668,16 +689,17 @@ object TypeCoercion {
/**
* Coerces the type of different branches of If statement to a common type.
*/
- object IfCoercion extends TypeCoercionRule {
+ case class IfCoercion(conf: SQLConf) extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType =>
- findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
- val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
- val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
- If(pred, newLeft, newRight)
+ findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map {
+ widestType =>
+ val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
+ val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
+ If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
case If(Literal(null, NullType), left, right) =>
If(Literal.create(null, BooleanType), left, right)
@@ -776,12 +798,11 @@ object TypeCoercion {
/**
* Casts types according to the expected input types for [[Expression]]s.
*/
- class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule {
+ case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule {
private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING)
- override protected def coerceTypes(
- plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -804,17 +825,18 @@ object TypeCoercion {
}
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
- findTightestCommonType(left.dataType, right.dataType).map { commonType =>
- if (b.inputType.acceptsType(commonType)) {
- // If the expression accepts the tightest common type, cast to that.
- val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
- val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
- b.withNewChildren(Seq(newLeft, newRight))
- } else {
- // Otherwise, don't do anything with the expression.
- b
- }
- }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
+ findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map {
+ commonType =>
+ if (b.inputType.acceptsType(commonType)) {
+ // If the expression accepts the tightest common type, cast to that.
+ val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
+ val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
+ b.withNewChildren(Seq(newLeft, newRight))
+ } else {
+ // Otherwise, don't do anything with the expression.
+ b
+ }
+ }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b00edca..0b1965c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -27,7 +27,7 @@ import scala.util.matching.Regex
import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkContext, SparkEnv}
+import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
@@ -107,7 +107,13 @@ object SQLConf {
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
* run unit tests (that does not involve SparkSession) in serial order.
*/
- def get: SQLConf = confGetter.get()()
+ def get: SQLConf = {
+ if (Utils.isTesting && TaskContext.get != null) {
+ // we're accessing it during task execution, fail.
+ throw new IllegalStateException("SQLConf should only be created and accessed on the driver.")
+ }
+ confGetter.get()()
+ }
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
.internal()
@@ -1274,12 +1280,6 @@ object SQLConf {
class SQLConf extends Serializable with Logging {
import SQLConf._
- if (Utils.isTesting && SparkEnv.get != null) {
- // assert that we're only accessing it on the driver.
- assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
- "SQLConf should only be created and accessed on the driver.")
- }
-
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 0bef116..4ee12db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -81,11 +81,7 @@ abstract class DataType extends AbstractDataType {
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*/
private[spark] def sameType(other: DataType): Boolean =
- if (SQLConf.get.caseSensitiveAnalysis) {
- DataType.equalsIgnoreNullability(this, other)
- } else {
- DataType.equalsIgnoreCaseAndNullability(this, other)
- }
+ DataType.equalsIgnoreNullability(this, other)
/**
* Returns the same data type but set all nullability fields are true
@@ -218,7 +214,7 @@ object DataType {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
- private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
+ private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 0acd3b4..f73e045 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest {
}
private def checkWidenType(
- widenFunc: (DataType, DataType) => Option[DataType],
+ widenFunc: (DataType, DataType, Boolean) => Option[DataType],
t1: DataType,
t2: DataType,
expected: Option[DataType],
isSymmetric: Boolean = true): Unit = {
- var found = widenFunc(t1, t2)
+ var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis)
assert(found == expected,
s"Expected $expected as wider common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
if (isSymmetric) {
- found = widenFunc(t2, t1)
+ found = widenFunc(t2, t1, conf.caseSensitiveAnalysis)
assert(found == expected,
s"Expected $expected as wider common type for $t2 and $t1, found $found")
}
@@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for expressions that implement ExpectsInputTypes") {
import TypeCoercionSuite._
- ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeUnaryExpression(Literal.create(null, NullType)),
AnyTypeUnaryExpression(Literal.create(null, NullType)))
- ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeUnaryExpression(Literal.create(null, NullType)),
NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
}
@@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for binary operators") {
import TypeCoercionSuite._
- ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
- ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
}
test("coalesce casts") {
- val rule = TypeCoercion.FunctionArgumentConversion
+ val rule = TypeCoercion.FunctionArgumentConversion(conf)
val intLit = Literal(1)
val longLit = Literal.create(1L)
@@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("CreateArray casts") {
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
@@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal("a")
@@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal("a"), StringType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal(1)
:: Nil),
@@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal(1).cast(DecimalType(13, 3))
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal.create(null, DecimalType(22, 10))
:: Literal.create(null, DecimalType(38, 38))
@@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest {
test("CreateMap casts") {
// type coercion for map keys
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal("a")
:: Literal.create(2.0, FloatType)
@@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(2.0, FloatType), FloatType)
:: Literal("b")
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal.create(null, DecimalType(5, 3))
:: Literal("a")
:: Literal.create(2.0, FloatType)
@@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal("b")
:: Nil))
// type coercion for map values
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2)
@@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal(2)
:: Cast(Literal(3.0), StringType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0))
:: Literal(2)
@@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
:: Nil))
// type coercion for both map keys and values
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2.0)
@@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest {
test("greatest/least cast") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
@@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
@@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DecimalType(22, 0))
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal(1.0)
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
@@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
:: Literal(1).cast(DoubleType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal.create(null, DecimalType(15, 0))
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
@@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
:: Literal(1).cast(DecimalType(20, 5))
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal.create(2L, LongType)
:: Literal(1)
:: Literal.create(null, DecimalType(10, 5))
@@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("nanvl casts") {
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion,
+ ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
}
test("type coercion for If") {
- val rule = TypeCoercion.IfCoercion
+ val rule = TypeCoercion.IfCoercion(conf)
val intLit = Literal(1)
val doubleLit = Literal(1.0)
val trueLit = Literal.create(true, BooleanType)
@@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("type coercion for CaseKeyWhen") {
- ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)
- ruleTest(TypeCoercion.CaseWhenCoercion,
+ ruleTest(TypeCoercion.CaseWhenCoercion(conf),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
- ruleTest(TypeCoercion.CaseWhenCoercion,
+ ruleTest(TypeCoercion.CaseWhenCoercion(conf),
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
)
- ruleTest(TypeCoercion.CaseWhenCoercion,
+ ruleTest(TypeCoercion.CaseWhenCoercion(conf),
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
@@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest {
private val timeZoneResolver = ResolveTimeZone(new SQLConf)
private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
- timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
+ timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan))
}
test("WidenSetOperationTypes for except and intersect") {
@@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest {
test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
"in aggregation function like sum") {
- val rules = Seq(FunctionArgumentConversion, Division)
+ val rules = Seq(FunctionArgumentConversion(conf), Division)
// Casts Integer to Double
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
@@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("SPARK-17117 null type coercion in divide") {
- val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf))
+ val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf))
val nullLit = Literal.create(null, NullType)
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index c502e58..e2a1a57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
-import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
+import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
@@ -898,6 +898,7 @@ object SparkSession extends Logging {
* @since 2.0.0
*/
def getOrCreate(): SparkSession = synchronized {
+ assertOnDriver()
// Get the session from current thread's active session.
var session = activeThreadSession.get()
if ((session ne null) && !session.sparkContext.isStopped) {
@@ -1022,14 +1023,20 @@ object SparkSession extends Logging {
*
* @since 2.2.0
*/
- def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
+ def getActiveSession: Option[SparkSession] = {
+ assertOnDriver()
+ Option(activeThreadSession.get)
+ }
/**
* Returns the default SparkSession that is returned by the builder.
*
* @since 2.2.0
*/
- def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
+ def getDefaultSession: Option[SparkSession] = {
+ assertOnDriver()
+ Option(defaultSession.get)
+ }
/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1062,6 +1069,14 @@ object SparkSession extends Logging {
}
}
+ private def assertOnDriver(): Unit = {
+ if (Utils.isTesting && TaskContext.get != null) {
+ // we're accessing it during task execution, fail.
+ throw new IllegalStateException(
+ "SparkSession should only be created and accessed on the driver.")
+ }
+ }
+
/**
* Helper method to create an instance of `SessionState` based on `className` from conf.
* The result is either `SessionState` or a Hive based `SessionState`.
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index f9a2480..1edf276 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
@@ -521,6 +522,8 @@ object PartitioningUtils {
private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = {
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType
case (DoubleType, LongType) | (LongType, DoubleType) => StringType
- case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType)
+ case (t1, t2) =>
+ TypeCoercion.findWiderTypeForTwo(
+ t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index a270a64..e0424b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -44,6 +45,7 @@ private[sql] object JsonInferSchema {
createParser: (JsonFactory, T) => JsonParser): StructType = {
val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
+ val caseSensitive = SQLConf.get.caseSensitiveAnalysis
// perform schema inference on each row and merge afterwards
val rootType = json.mapPartitions { iter =>
@@ -53,7 +55,7 @@ private[sql] object JsonInferSchema {
try {
Utils.tryWithResource(createParser(factory, row)) { parser =>
parser.nextToken()
- Some(inferField(parser, configOptions))
+ Some(inferField(parser, configOptions, caseSensitive))
}
} catch {
case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
@@ -68,7 +70,7 @@ private[sql] object JsonInferSchema {
}
}
}.fold(StructType(Nil))(
- compatibleRootType(columnNameOfCorruptRecord, parseMode))
+ compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive))
canonicalizeType(rootType) match {
case Some(st: StructType) => st
@@ -98,14 +100,15 @@ private[sql] object JsonInferSchema {
/**
* Infer the type of a json document from the parser's token stream
*/
- private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
+ private def inferField(
+ parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType
case FIELD_NAME =>
parser.nextToken()
- inferField(parser, configOptions)
+ inferField(parser, configOptions, caseSensitive)
case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
@@ -122,7 +125,7 @@ private[sql] object JsonInferSchema {
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
- inferField(parser, configOptions),
+ inferField(parser, configOptions, caseSensitive),
nullable = true)
}
val fields: Array[StructField] = builder.result()
@@ -137,7 +140,7 @@ private[sql] object JsonInferSchema {
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(
- elementType, inferField(parser, configOptions))
+ elementType, inferField(parser, configOptions, caseSensitive), caseSensitive)
}
ArrayType(elementType)
@@ -243,13 +246,14 @@ private[sql] object JsonInferSchema {
*/
private def compatibleRootType(
columnNameOfCorruptRecords: String,
- parseMode: ParseMode): (DataType, DataType) => DataType = {
+ parseMode: ParseMode,
+ caseSensitive: Boolean): (DataType, DataType) => DataType = {
// Since we support array of json objects at the top level,
// we need to check the element type and find the root level data type.
case (ArrayType(ty1, _), ty2) =>
- compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
+ compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2)
case (ty1, ArrayType(ty2, _)) =>
- compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
+ compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2)
// Discard null/empty documents
case (struct: StructType, NullType) => struct
case (NullType, struct: StructType) => struct
@@ -259,7 +263,7 @@ private[sql] object JsonInferSchema {
withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
// If we get anything else, we call compatibleType.
// Usually, when we reach here, ty1 and ty2 are two StructTypes.
- case (ty1, ty2) => compatibleType(ty1, ty2)
+ case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive)
}
private[this] val emptyStructFieldArray = Array.empty[StructField]
@@ -267,8 +271,8 @@ private[sql] object JsonInferSchema {
/**
* Returns the most general data type for two given data types.
*/
- def compatibleType(t1: DataType, t2: DataType): DataType = {
- TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+ def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = {
+ TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
@@ -303,7 +307,8 @@ private[sql] object JsonInferSchema {
val f2Name = fields2(f2Idx).name
val comp = f1Name.compareTo(f2Name)
if (comp == 0) {
- val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
+ val dataType = compatibleType(
+ fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive)
newFields.add(StructField(f1Name, dataType, nullable = true))
f1Idx += 1
f2Idx += 1
@@ -326,15 +331,17 @@ private[sql] object JsonInferSchema {
StructType(newFields.toArray(emptyStructFieldArray))
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
- ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+ ArrayType(
+ compatibleType(elementType1, elementType2, caseSensitive),
+ containsNull1 || containsNull2)
// The case that given `DecimalType` is capable of given `IntegralType` is handled in
// `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
// the given `DecimalType` is not capable of the given `IntegralType`.
case (t1: IntegralType, t2: DecimalType) =>
- compatibleType(DecimalType.forType(t1), t2)
+ compatibleType(DecimalType.forType(t1), t2, caseSensitive)
case (t1: DecimalType, t2: IntegralType) =>
- compatibleType(t1, DecimalType.forType(t2))
+ compatibleType(t1, DecimalType.forType(t2), caseSensitive)
// strings and every string is a Json object.
case (_, _) => StringType
http://git-wip-us.apache.org/repos/asf/spark/blob/a4206d58/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 4b3921c..34d23ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("Get compatible type") {
def checkDataType(t1: DataType, t2: DataType, expected: DataType) {
- var actual = compatibleType(t1, t2)
+ var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis)
assert(actual == expected,
s"Expected $expected as the most general data type for $t1 and $t2, found $actual")
- actual = compatibleType(t2, t1)
+ actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis)
assert(actual == expected,
s"Expected $expected as the most general data type for $t1 and $t2, found $actual")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org