You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/31 04:22:41 UTC

spark git commit: [SPARK-8176] [SPARK-8197] [SQL] function to_date/ trunc

Repository: spark
Updated Branches:
  refs/heads/master 9307f5653 -> 83670fc9e


[SPARK-8176] [SPARK-8197] [SQL] function to_date/ trunc

This PR is based on #6988 , thanks to adrian-wang .

This brings two SQL functions: to_date() and trunc().

Closes #6988

Author: Daoyuan Wang <da...@intel.com>
Author: Davies Liu <da...@databricks.com>

Closes #7805 from davies/to_date and squashes the following commits:

2c7beba [Davies Liu] Merge branch 'master' of github.com:apache/spark into to_date
310dd55 [Daoyuan Wang] remove dup test in rebase
980b092 [Daoyuan Wang] resolve rebase conflict
a476c5a [Daoyuan Wang] address comments from davies
d44ea5f [Daoyuan Wang] function to_date, trunc


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/83670fc9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/83670fc9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/83670fc9

Branch: refs/heads/master
Commit: 83670fc9e6fc9c7a6ae68dfdd3f9335ea72f4ab0
Parents: 9307f56
Author: Daoyuan Wang <da...@intel.com>
Authored: Thu Jul 30 19:22:38 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jul 30 19:22:38 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 30 +++++++
 .../catalyst/analysis/FunctionRegistry.scala    |  2 +
 .../expressions/datetimeFunctions.scala         | 88 +++++++++++++++++++-
 .../spark/sql/catalyst/util/DateTimeUtils.scala | 34 ++++++++
 .../expressions/DateExpressionsSuite.scala      | 29 ++++++-
 .../expressions/NonFoldableLiteral.scala        |  4 +
 .../scala/org/apache/spark/sql/functions.scala  | 16 ++++
 .../apache/spark/sql/DateFunctionsSuite.scala   | 44 ++++++++++
 8 files changed, 245 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index a7295e2..8024a8d 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -889,6 +889,36 @@ def months_between(date1, date2):
 
 
 @since(1.5)
+def to_date(col):
+    """
+    Converts the column of StringType or TimestampType into DateType.
+
+    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+    >>> df.select(to_date(df.t).alias('date')).collect()
+    [Row(date=datetime.date(1997, 2, 28))]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.to_date(_to_java_column(col)))
+
+
+@since(1.5)
+def trunc(date, format):
+    """
+    Returns date truncated to the unit specified by the format.
+
+    :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
+
+    >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d'])
+    >>> df.select(trunc(df.d, 'year').alias('year')).collect()
+    [Row(year=datetime.date(1997, 1, 1))]
+    >>> df.select(trunc(df.d, 'mon').alias('month')).collect()
+    [Row(month=datetime.date(1997, 2, 1))]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
+
+
+@since(1.5)
 def size(col):
     """
     Collection function: returns the length of the array or map stored in the column.

http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6c7c481..1bf7204 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -223,6 +223,8 @@ object FunctionRegistry {
     expression[NextDay]("next_day"),
     expression[Quarter]("quarter"),
     expression[Second]("second"),
+    expression[ToDate]("to_date"),
+    expression[TruncDate]("trunc"),
     expression[UnixTimestamp]("unix_timestamp"),
     expression[WeekOfYear]("weekofyear"),
     expression[Year]("year"),

http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
index 9795673..6e76133 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
@@ -507,7 +507,6 @@ case class FromUnixTime(sec: Expression, format: Expression)
       })
     }
   }
-
 }
 
 /**
@@ -696,3 +695,90 @@ case class MonthsBetween(date1: Expression, date2: Expression)
     })
   }
 }
+
+/**
+ * Returns the date part of a timestamp or string.
+ */
+case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+  // Implicit casting of spark will accept string in both date and timestamp format, as
+  // well as TimestampType.
+  override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
+
+  override def dataType: DataType = DateType
+
+  override def eval(input: InternalRow): Any = child.eval(input)
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    defineCodeGen(ctx, ev, d => d)
+  }
+}
+
+/*
+ * Returns date truncated to the unit specified by the format.
+ */
+case class TruncDate(date: Expression, format: Expression)
+  extends BinaryExpression with ImplicitCastInputTypes {
+  override def left: Expression = date
+  override def right: Expression = format
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
+  override def dataType: DataType = DateType
+  override def prettyName: String = "trunc"
+
+  lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
+
+  override def eval(input: InternalRow): Any = {
+    val minItem = if (format.foldable) {
+      minItemConst
+    } else {
+      DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
+    }
+    if (minItem == -1) {
+      // unknown format
+      null
+    } else {
+      val d = date.eval(input)
+      if (d == null) {
+        null
+      } else {
+        DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem)
+      }
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+
+    if (format.foldable) {
+      if (minItemConst == -1) {
+        s"""
+          boolean ${ev.isNull} = true;
+          ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+        """
+      } else {
+        val d = date.gen(ctx)
+        s"""
+          ${d.code}
+          boolean ${ev.isNull} = ${d.isNull};
+          ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+          if (!${ev.isNull}) {
+            ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst);
+          }
+        """
+      }
+    } else {
+      nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
+        val form = ctx.freshName("form")
+        s"""
+          int $form = $dtu.parseTruncLevel($fmt);
+          if ($form == -1) {
+            ${ev.isNull} = true;
+          } else {
+            ${ev.primitive} = $dtu.truncDate($dateVal, $form);
+          }
+        """
+      })
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 53abdf6..5a7c25b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -779,4 +779,38 @@ object DateTimeUtils {
     }
     date + (lastDayOfMonthInYear - dayInYear)
   }
+
+  private val TRUNC_TO_YEAR = 1
+  private val TRUNC_TO_MONTH = 2
+  private val TRUNC_INVALID = -1
+
+  /**
+   * Returns the trunc date from original date and trunc level.
+   * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
+   */
+  def truncDate(d: Int, level: Int): Int = {
+    if (level == TRUNC_TO_YEAR) {
+      d - DateTimeUtils.getDayInYear(d) + 1
+    } else if (level == TRUNC_TO_MONTH) {
+      d - DateTimeUtils.getDayOfMonth(d) + 1
+    } else {
+      throw new Exception(s"Invalid trunc level: $level")
+    }
+  }
+
+  /**
+   * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
+   * TRUNC_INVALID means unsupported truncate level.
+   */
+  def parseTruncLevel(format: UTF8String): Int = {
+    if (format == null) {
+      TRUNC_INVALID
+    } else {
+      format.toString.toUpperCase match {
+        case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
+        case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
+        case _ => TRUNC_INVALID
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 887e436..6c15c05 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -351,6 +351,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
   }
 
+  test("function to_date") {
+    checkEvaluation(
+      ToDate(Literal(Date.valueOf("2015-07-22"))),
+      DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22")))
+    checkEvaluation(ToDate(Literal.create(null, DateType)), null)
+  }
+
+  test("function trunc") {
+    def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
+      checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
+        expected)
+      checkEvaluation(
+        TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
+        expected)
+    }
+    val date = Date.valueOf("2015-07-22")
+    Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt =>
+      testTrunc(date, fmt, Date.valueOf("2015-01-01"))
+    }
+    Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
+      testTrunc(date, fmt, Date.valueOf("2015-07-01"))
+    }
+    testTrunc(date, "DD", null)
+    testTrunc(date, null, null)
+    testTrunc(null, "MON", null)
+    testTrunc(null, null, null)
+  }
+
   test("from_unixtime") {
     val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
     val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
@@ -405,5 +433,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(
       UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null)
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
index 0559fb8..31ecf4a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
@@ -47,4 +47,8 @@ object NonFoldableLiteral {
     val lit = Literal(value)
     NonFoldableLiteral(lit.value, lit.dataType)
   }
+  def create(value: Any, dataType: DataType): NonFoldableLiteral = {
+    val lit = Literal.create(value, dataType)
+    NonFoldableLiteral(lit.value, lit.dataType)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 168894d..46dc460 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2181,6 +2181,22 @@ object functions {
    */
   def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p))
 
+  /*
+   * Converts the column into DateType.
+   *
+   * @group datetime_funcs
+   * @since 1.5.0
+   */
+  def to_date(e: Column): Column = ToDate(e.expr)
+
+  /**
+   * Returns date truncated to the unit specified by the format.
+   *
+   * @group datetime_funcs
+   * @since 1.5.0
+   */
+  def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format))
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Collection functions
   //////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/83670fc9/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
index b7267c4..8c596fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
@@ -345,6 +345,50 @@ class DateFunctionsSuite extends QueryTest {
       Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30"))))
   }
 
+  test("function to_date") {
+    val d1 = Date.valueOf("2015-07-22")
+    val d2 = Date.valueOf("2015-07-01")
+    val t1 = Timestamp.valueOf("2015-07-22 10:00:00")
+    val t2 = Timestamp.valueOf("2014-12-31 23:59:59")
+    val s1 = "2015-07-22 10:00:00"
+    val s2 = "2014-12-31"
+    val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s")
+
+    checkAnswer(
+      df.select(to_date(col("t"))),
+      Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+    checkAnswer(
+      df.select(to_date(col("d"))),
+      Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
+    checkAnswer(
+      df.select(to_date(col("s"))),
+      Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+
+    checkAnswer(
+      df.selectExpr("to_date(t)"),
+      Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+    checkAnswer(
+      df.selectExpr("to_date(d)"),
+      Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
+    checkAnswer(
+      df.selectExpr("to_date(s)"),
+      Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+  }
+
+  test("function trunc") {
+    val df = Seq(
+      (1, Timestamp.valueOf("2015-07-22 10:00:00")),
+      (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t")
+
+    checkAnswer(
+      df.select(trunc(col("t"), "YY")),
+      Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01"))))
+
+    checkAnswer(
+      df.selectExpr("trunc(t, 'Month')"),
+      Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01"))))
+  }
+
   test("from_unixtime") {
     val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
     val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org