You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/17 18:38:19 UTC

spark git commit: [SPARK-8945][SQL] Add add and subtract expressions for IntervalType

Repository: spark
Updated Branches:
  refs/heads/master 305e77cd8 -> eba6a1af4


[SPARK-8945][SQL] Add add and subtract expressions for IntervalType

JIRA: https://issues.apache.org/jira/browse/SPARK-8945

Add add and subtract expressions for IntervalType.

Author: Liang-Chi Hsieh <vi...@appier.com>

This patch had conflicts when merged, resolved by
Committer: Reynold Xin <rx...@databricks.com>

Closes #7398 from viirya/interval_add_subtract and squashes the following commits:

acd1f1e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract
5abae28 [Liang-Chi Hsieh] For comments.
6f5b72e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract
dbe3906 [Liang-Chi Hsieh] For comments.
13a2fc5 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract
83ec129 [Liang-Chi Hsieh] Remove intervalMethod.
acfe1ab [Liang-Chi Hsieh] Fix scala style.
d3e9d0e [Liang-Chi Hsieh] Add add and subtract expressions for IntervalType.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eba6a1af
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eba6a1af
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eba6a1af

Branch: refs/heads/master
Commit: eba6a1af4c8ffb21934a59a61a419d625f37cceb
Parents: 305e77c
Author: Liang-Chi Hsieh <vi...@appier.com>
Authored: Fri Jul 17 09:38:08 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Jul 17 09:38:08 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 60 +++++++++++++++++---
 .../expressions/codegen/CodeGenerator.scala     |  4 +-
 .../sql/catalyst/expressions/literals.scala     |  3 +-
 .../spark/sql/types/AbstractDataType.scala      |  6 ++
 .../analysis/ExpressionTypeCheckingSuite.scala  |  6 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 17 ++++++
 .../org/apache/spark/unsafe/types/Interval.java | 16 ++++++
 .../spark/unsafe/types/IntervalSuite.java       | 38 +++++++++++++
 8 files changed, 136 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 382cbe3..1616d1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.Interval
 
 
 case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
 
   override def dataType: DataType = child.dataType
 
@@ -36,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
     case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
     case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
+    case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
   }
 
-  protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
+  protected override def nullSafeEval(input: Any): Any = {
+    if (dataType.isInstanceOf[IntervalType]) {
+      input.asInstanceOf[Interval].negate()
+    } else {
+      numeric.negate(input)
+    }
+  }
 }
 
 case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
   override def prettyName: String = "positive"
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
 
   override def dataType: DataType = child.dataType
 
@@ -95,32 +103,66 @@ private[sql] object BinaryArithmetic {
 
 case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
 
-  override def inputType: AbstractDataType = NumericType
+  override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
   override def symbol: String = "+"
-  override def decimalMethod: String = "$plus"
 
   override lazy val resolved =
     childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
 
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
-  protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    if (dataType.isInstanceOf[IntervalType]) {
+      input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
+    } else {
+      numeric.plus(input1, input2)
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
+    case dt: DecimalType =>
+      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
+    case ByteType | ShortType =>
+      defineCodeGen(ctx, ev,
+        (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+    case IntervalType =>
+      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
+    case _ =>
+      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
+  }
 }
 
 case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
 
-  override def inputType: AbstractDataType = NumericType
+  override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
   override def symbol: String = "-"
-  override def decimalMethod: String = "$minus"
 
   override lazy val resolved =
     childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
 
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
-  protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    if (dataType.isInstanceOf[IntervalType]) {
+      input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
+    } else {
+      numeric.minus(input1, input2)
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
+    case dt: DecimalType =>
+      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
+    case ByteType | ShortType =>
+      defineCodeGen(ctx, ev,
+        (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+    case IntervalType =>
+      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
+    case _ =>
+      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
+  }
 }
 
 case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {

http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 45dc146..7c388bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -27,7 +27,7 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types._
 
 
 // These classes are here to avoid issues with serialization and integration with quasiquotes.
@@ -69,6 +69,7 @@ class CodeGenContext {
     mutableStates += ((javaType, variableName, initialValue))
   }
 
+  final val intervalType: String = classOf[Interval].getName
   final val JAVA_BOOLEAN = "boolean"
   final val JAVA_BYTE = "byte"
   final val JAVA_SHORT = "short"
@@ -137,6 +138,7 @@ class CodeGenContext {
     case dt: DecimalType => "Decimal"
     case BinaryType => "byte[]"
     case StringType => "UTF8String"
+    case IntervalType => intervalType
     case _: StructType => "InternalRow"
     case _: ArrayType => s"scala.collection.Seq"
     case _: MapType => s"scala.collection.Map"

http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 3a7a7ae..e1fdb29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types._
 
 object Literal {
   def apply(v: Any): Literal = v match {
@@ -42,6 +42,7 @@ object Literal {
     case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
     case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
     case a: Array[Byte] => Literal(a, BinaryType)
+    case i: Interval => Literal(i, IntervalType)
     case null => Literal(null, NullType)
     case _ =>
       throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)

http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 076d7b5..40bf4b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -91,6 +91,12 @@ private[sql] object TypeCollection {
     TimestampType, DateType,
     StringType, BinaryType)
 
+  /**
+   * Types that include numeric types and interval type. They are only used in unary_minus,
+   * unary_positive, add and subtract operations.
+   */
+  val NumericAndInterval = TypeCollection(NumericType, IntervalType)
+
   def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
 
   def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {

http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index ed0d20e..ad15136 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
   }
 
   test("check types for unary arithmetic") {
-    assertError(UnaryMinus('stringField), "expected to be of type numeric")
+    assertError(UnaryMinus('stringField), "type (numeric or interval)")
     assertError(Abs('stringField), "expected to be of type numeric")
     assertError(BitwiseNot('stringField), "expected to be of type integral")
   }
@@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
     assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
 
-    assertError(Add('booleanField, 'booleanField), "accepts numeric type")
-    assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
+    assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type")
+    assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type")
     assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
     assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
     assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")

http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 2314408..5b8b70e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1492,4 +1492,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
     // Currently we don't yet support nanosecond
     checkIntervalParseError("select interval 23 nanosecond")
   }
+
+  test("SPARK-8945: add and subtract expressions for interval type") {
+    import org.apache.spark.unsafe.types.Interval
+
+    val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
+    checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))
+
+    checkAnswer(df.select(df("i") + new Interval(2, 123)),
+      Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))
+
+    checkAnswer(df.select(df("i") - new Interval(2, 123)),
+      Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
+
+    // unary minus
+    checkAnswer(df.select(-df("i")),
+      Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
index 905ea0b..71b1a85 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
@@ -87,6 +87,22 @@ public final class Interval implements Serializable {
     this.microseconds = microseconds;
   }
 
+  public Interval add(Interval that) {
+    int months = this.months + that.months;
+    long microseconds = this.microseconds + that.microseconds;
+    return new Interval(months, microseconds);
+  }
+
+  public Interval subtract(Interval that) {
+    int months = this.months - that.months;
+    long microseconds = this.microseconds - that.microseconds;
+    return new Interval(months, microseconds);
+  }
+
+  public Interval negate() {
+    return new Interval(-this.months, -this.microseconds);
+  }
+
   @Override
   public boolean equals(Object other) {
     if (this == other) return true;

http://git-wip-us.apache.org/repos/asf/spark/blob/eba6a1af/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
index 1832d0b..d29517c 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
@@ -101,6 +101,44 @@ public class IntervalSuite {
     assertEquals(Interval.fromString(input), null);
   }
 
+  @Test
+  public void addTest() {
+    String input = "interval 3 month 1 hour";
+    String input2 = "interval 2 month 100 hour";
+
+    Interval interval = Interval.fromString(input);
+    Interval interval2 = Interval.fromString(input2);
+
+    assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR));
+
+    input = "interval -10 month -81 hour";
+    input2 = "interval 75 month 200 hour";
+
+    interval = Interval.fromString(input);
+    interval2 = Interval.fromString(input2);
+
+    assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR));
+  }
+
+  @Test
+  public void subtractTest() {
+    String input = "interval 3 month 1 hour";
+    String input2 = "interval 2 month 100 hour";
+
+    Interval interval = Interval.fromString(input);
+    Interval interval2 = Interval.fromString(input2);
+
+    assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR));
+
+    input = "interval -10 month -81 hour";
+    input2 = "interval 75 month 200 hour";
+
+    interval = Interval.fromString(input);
+    interval2 = Interval.fromString(input2);
+
+    assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR));
+  }
+
   private void testSingleUnit(String unit, int number, int months, long microseconds) {
     String input1 = "interval " + number + " " + unit;
     String input2 = "interval " + number + " " + unit + "s";


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org