You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2017/03/05 11:53:25 UTC

spark git commit: [SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit

Repository: spark
Updated Branches:
  refs/heads/master f48461ab2 -> 14bb398fa


[SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit

## What changes were proposed in this pull request?
This pr is to support Seq, Map, and Struct in functions.lit; it adds a new IF named `lit2` with `TypeTag` for avoiding type erasure.

## How was this patch tested?
Added tests in `LiteralExpressionSuite`

Author: Takeshi Yamamuro <ya...@apache.org>
Author: Takeshi YAMAMURO <li...@gmail.com>

Closes #16610 from maropu/SPARK-19254.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/14bb398f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/14bb398f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/14bb398f

Branch: refs/heads/master
Commit: 14bb398fae974137c3e38162cefc088e12838258
Parents: f48461a
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Sun Mar 5 03:53:19 2017 -0800
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Sun Mar 5 03:53:19 2017 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/literals.scala     | 12 ++-
 .../expressions/LiteralExpressionSuite.scala    | 90 +++++++++++++++++---
 .../scala/org/apache/spark/sql/functions.scala  | 25 ++++--
 .../spark/sql/ColumnExpressionSuite.scala       | 14 +++
 4 files changed, 121 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/14bb398f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e66fb89..eaeaf08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -32,11 +32,13 @@ import java.util.Objects
 import javax.xml.bind.DatatypeConverter
 
 import scala.math.{BigDecimal, BigInt}
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.Try
 
 import org.json4s.JsonAST._
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
@@ -153,6 +155,14 @@ object Literal {
     Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
   }
 
+  def create[T : TypeTag](v: T): Literal = Try {
+    val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
+    val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
+    Literal(convert(v), dataType)
+  }.getOrElse {
+    Literal(v)
+  }
+
   /**
    * Create a literal with default value for given DataType
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/14bb398f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 15e8e6c..a9e0eb0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.nio.charset.StandardCharsets
 
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
 import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
@@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("boolean literals") {
     checkEvaluation(Literal(true), true)
     checkEvaluation(Literal(false), false)
+
+    checkEvaluation(Literal.create(true), true)
+    checkEvaluation(Literal.create(false), false)
   }
 
   test("int literals") {
@@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
       checkEvaluation(Literal(d.toLong), d.toLong)
       checkEvaluation(Literal(d.toShort), d.toShort)
       checkEvaluation(Literal(d.toByte), d.toByte)
+
+      checkEvaluation(Literal.create(d), d)
+      checkEvaluation(Literal.create(d.toLong), d.toLong)
+      checkEvaluation(Literal.create(d.toShort), d.toShort)
+      checkEvaluation(Literal.create(d.toByte), d.toByte)
     }
     checkEvaluation(Literal(Long.MinValue), Long.MinValue)
     checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
+
+    checkEvaluation(Literal.create(Long.MinValue), Long.MinValue)
+    checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue)
   }
 
   test("double literals") {
     List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
       checkEvaluation(Literal(d), d)
       checkEvaluation(Literal(d.toFloat), d.toFloat)
+
+      checkEvaluation(Literal.create(d), d)
+      checkEvaluation(Literal.create(d.toFloat), d.toFloat)
     }
     checkEvaluation(Literal(Double.MinValue), Double.MinValue)
     checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
     checkEvaluation(Literal(Float.MinValue), Float.MinValue)
     checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)
 
+    checkEvaluation(Literal.create(Double.MinValue), Double.MinValue)
+    checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue)
+    checkEvaluation(Literal.create(Float.MinValue), Float.MinValue)
+    checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue)
+
   }
 
   test("string literals") {
     checkEvaluation(Literal(""), "")
     checkEvaluation(Literal("test"), "test")
     checkEvaluation(Literal("\u0000"), "\u0000")
+
+    checkEvaluation(Literal.create(""), "")
+    checkEvaluation(Literal.create("test"), "test")
+    checkEvaluation(Literal.create("\u0000"), "\u0000")
   }
 
   test("sum two literals") {
     checkEvaluation(Add(Literal(1), Literal(1)), 2)
+    checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2)
   }
 
   test("binary literals") {
     checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
     checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
+
+    checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0))
+    checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2))
   }
 
   test("decimal") {
@@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
         Decimal((d * 1000L).toLong, 10, 3))
       checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d))
       checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d))
+
+      checkEvaluation(Literal.create(Decimal(d)), Decimal(d))
+      checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt))
+      checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong))
+      checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)),
+        Decimal((d * 1000L).toLong, 10, 3))
+      checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d))
+      checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d))
+
     }
   }
 
+  private def toCatalyst[T: TypeTag](value: T): Any = {
+    val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
+    CatalystTypeConverters.createToCatalystConverter(dataType)(value)
+  }
+
   test("array") {
-    def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = {
-      val toCatalyst = (a: Array[_], elementType: DataType) => {
-        CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a)
-      }
-      checkEvaluation(Literal(a), toCatalyst(a, elementType))
+    def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = {
+      checkEvaluation(Literal(a), toCatalyst(a))
+      checkEvaluation(Literal.create(a), toCatalyst(a))
+    }
+    checkArrayLiteral(Array(1, 2, 3))
+    checkArrayLiteral(Array("a", "b", "c"))
+    checkArrayLiteral(Array(1.0, 4.0))
+    checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR))
+  }
+
+  test("seq") {
+    def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = {
+      checkEvaluation(Literal.create(a), toCatalyst(a))
     }
-    checkArrayLiteral(Array(1, 2, 3), IntegerType)
-    checkArrayLiteral(Array("a", "b", "c"), StringType)
-    checkArrayLiteral(Array(1.0, 4.0), DoubleType)
-    checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
+    checkSeqLiteral(Seq(1, 2, 3), IntegerType)
+    checkSeqLiteral(Seq("a", "b", "c"), StringType)
+    checkSeqLiteral(Seq(1.0, 4.0), DoubleType)
+    checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
       CalendarIntervalType)
   }
 
-  test("unsupported types (map and struct) in literals") {
+  test("map") {
+    def checkMapLiteral[T: TypeTag](m: T): Unit = {
+      checkEvaluation(Literal.create(m), toCatalyst(m))
+    }
+    checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3))
+    checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0))
+  }
+
+  test("struct") {
+    def checkStructLiteral[T: TypeTag](s: T): Unit = {
+      checkEvaluation(Literal.create(s), toCatalyst(s))
+    }
+    checkStructLiteral((1, 3.0, "abcde"))
+    checkStructLiteral(("de", 1, 2.0f))
+    checkStructLiteral((1, ("fgh", 3.0)))
+  }
+
+  test("unsupported types (map and struct) in Literal.apply") {
     def checkUnsupportedTypeInLiteral(v: Any): Unit = {
       val errMsgMap = intercept[RuntimeException] {
         Literal(v)

http://git-wip-us.apache.org/repos/asf/spark/blob/14bb398f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 24ed906..2247010 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -91,15 +91,24 @@ object functions {
    * @group normal_funcs
    * @since 1.3.0
    */
-  def lit(literal: Any): Column = {
-    literal match {
-      case c: Column => return c
-      case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
-      case _ =>  // continue
-    }
+  def lit(literal: Any): Column = typedLit(literal)
 
-    val literalExpr = Literal(literal)
-    Column(literalExpr)
+  /**
+   * Creates a [[Column]] of literal value.
+   *
+   * The passed in object is returned directly if it is already a [[Column]].
+   * If the object is a Scala Symbol, it is converted into a [[Column]] also.
+   * Otherwise, a new [[Column]] is created to represent the literal value.
+   * The difference between this function and [[lit]] is that this function
+   * can handle parameterized scala types e.g.: List, Seq and Map.
+   *
+   * @group normal_funcs
+   * @since 2.2.0
+   */
+  def typedLit[T : TypeTag](literal: T): Column = literal match {
+    case c: Column => c
+    case s: Symbol => new ColumnName(s.name)
+    case _ => Column(Literal.create(literal))
   }
 
   //////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/14bb398f/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index ee280a3..b0f398d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
       testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
       testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
   }
+
+  test("typedLit") {
+    val df = Seq(Tuple1(0)).toDF("a")
+    // Only check the types `lit` cannot handle
+    checkAnswer(
+      df.select(typedLit(Seq(1, 2, 3))),
+      Row(Seq(1, 2, 3)) :: Nil)
+    checkAnswer(
+      df.select(typedLit(Map("a" -> 1, "b" -> 2))),
+      Row(Map("a" -> 1, "b" -> 2)) :: Nil)
+    checkAnswer(
+      df.select(typedLit(("a", 2, 1.0))),
+      Row(Row("a", 2, 1.0)) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org