You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/11/02 03:29:54 UTC
[2/3] [SPARK-3930] [SPARK-3933] Support fixed-precision decimal in
SQL, and some optimizations
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 5657bc5..6bfa0db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.scalatest.FunSuite
import org.scalatest.Matchers._
import org.scalactic.TripleEqualsSupport.Spread
@@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
- actual.asInstanceOf[Double] shouldBe expected
+ actual.asInstanceOf[Double] shouldBe expected
}
test("IN") {
@@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(InSet(three, nS, three +: nullS), false)
checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
}
-
+
test("MaxOf") {
checkEvaluation(MaxOf(1, 2), 2)
checkEvaluation(MaxOf(2, 1), 2)
@@ -265,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite {
val ts = Timestamp.valueOf(nts)
checkEvaluation("abdef" cast StringType, "abdef")
- checkEvaluation("abdef" cast DecimalType, null)
+ checkEvaluation("abdef" cast DecimalType.Unlimited, null)
checkEvaluation("abdef" cast TimestampType, null)
- checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65))
+ checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65))
checkEvaluation(Literal(1) cast LongType, 1)
checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong)
@@ -289,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Cast(Cast(Cast(Cast(
Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0)
checkEvaluation(Literal(true) cast IntegerType, 1)
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
@@ -302,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("23" cast DoubleType, 23d)
checkEvaluation("23" cast IntegerType, 23)
checkEvaluation("23" cast FloatType, 23f)
- checkEvaluation("23" cast DecimalType, 23: BigDecimal)
+ checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23))
checkEvaluation("23" cast ByteType, 23.toByte)
checkEvaluation("23" cast ShortType, 23.toShort)
checkEvaluation("2012-12-11" cast DoubleType, null)
@@ -311,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d)
checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f)
- checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal)
+ checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24))
checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte)
checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort)
@@ -325,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(("abcdef" cast IntegerType).nullable === true)
assert(("abcdef" cast ShortType).nullable === true)
assert(("abcdef" cast ByteType).nullable === true)
- assert(("abcdef" cast DecimalType).nullable === true)
+ assert(("abcdef" cast DecimalType.Unlimited).nullable === true)
+ assert(("abcdef" cast DecimalType(4, 2)).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)
@@ -338,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(d1) < Literal(d2), true)
}
+ test("casting to fixed-precision decimals") {
+ // Overflow and rounding for casting to fixed-precision decimals:
+ // - Values should round with HALF_UP mode by default when you lower scale
+ // - Values that would overflow the target precision should turn into null
+ // - Because of this, casts to fixed-precision decimals should be nullable
+
+ assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false)
+
+ assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true)
+
+ checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123))
+ checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123))
+ checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null)
+ checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null)
+
+ checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null)
+ checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null)
+
+ checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null)
+ checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null)
+
+ checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null)
+
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null)
+ }
+
test("timestamp") {
val ts1 = new Timestamp(12)
val ts2 = new Timestamp(123)
@@ -374,7 +434,7 @@ class ExpressionEvaluationSuite extends FunSuite {
millis.toFloat / 1000)
checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType),
millis.toDouble / 1000)
- checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
+ checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1))
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
@@ -673,7 +733,7 @@ class ExpressionEvaluationSuite extends FunSuite {
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
val d = 'a.double.at(0)
-
+
for ((row, expected) <- rowSequence zip expectedResults) {
checkEvaluation(Sqrt(d), expected, row)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala
new file mode 100644
index 0000000..5aa2634
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.types.decimal
+
+import org.scalatest.{PrivateMethodTester, FunSuite}
+
+import scala.language.postfixOps
+
+class DecimalSuite extends FunSuite with PrivateMethodTester {
+ test("creating decimals") {
+ /** Check that a Decimal has the given string representation, precision and scale */
+ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
+ assert(d.toString === string)
+ assert(d.precision === precision)
+ assert(d.scale === scale)
+ }
+
+ checkDecimal(new Decimal(), "0", 1, 0)
+ checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
+ checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
+ checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1)
+ checkDecimal(Decimal("10.030"), "10.030", 5, 3)
+ checkDecimal(Decimal(10.03), "10.03", 4, 2)
+ checkDecimal(Decimal(17L), "17", 20, 0)
+ checkDecimal(Decimal(17), "17", 10, 0)
+ checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1)
+ checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2)
+ checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1)
+ checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0)
+ checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
+ checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
+ intercept[IllegalArgumentException](Decimal(170L, 2, 1))
+ intercept[IllegalArgumentException](Decimal(170L, 2, 0))
+ intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
+ intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
+ intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
+ }
+
+ test("double and long values") {
+ /** Check that a Decimal converts to the given double and long values */
+ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {
+ assert(d.toDouble === doubleValue)
+ assert(d.toLong === longValue)
+ }
+
+ checkValues(new Decimal(), 0.0, 0L)
+ checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L)
+ checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L)
+ checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L)
+ checkValues(Decimal(10.03), 10.03, 10L)
+ checkValues(Decimal(17L), 17.0, 17L)
+ checkValues(Decimal(17), 17.0, 17L)
+ checkValues(Decimal(17L, 2, 1), 1.7, 1L)
+ checkValues(Decimal(170L, 4, 2), 1.7, 1L)
+ checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong)
+ checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong)
+ checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong)
+ checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong)
+ checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue)
+ checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue)
+ checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L)
+ checkValues(Decimal(Double.MinValue), Double.MinValue, 0L)
+ }
+
+ // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs
+ private val decimalVal = PrivateMethod[BigDecimal]('decimalVal)
+
+ /** Check whether a decimal is represented compactly (passing whether we expect it to be) */
+ private def checkCompact(d: Decimal, expected: Boolean): Unit = {
+ val isCompact = d.invokePrivate(decimalVal()).eq(null)
+ assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact")
+ }
+
+ test("small decimals represented as unscaled long") {
+ checkCompact(new Decimal(), true)
+ checkCompact(Decimal(BigDecimal(10.03)), false)
+ checkCompact(Decimal(BigDecimal(1e20)), false)
+ checkCompact(Decimal(17L), true)
+ checkCompact(Decimal(17), true)
+ checkCompact(Decimal(17L, 2, 1), true)
+ checkCompact(Decimal(170L, 4, 2), true)
+ checkCompact(Decimal(17L, 24, 1), true)
+ checkCompact(Decimal(1e16.toLong), true)
+ checkCompact(Decimal(1e17.toLong), true)
+ checkCompact(Decimal(1e18.toLong - 1), true)
+ checkCompact(Decimal(- 1e18.toLong + 1), true)
+ checkCompact(Decimal(1e18.toLong - 1, 30, 10), true)
+ checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true)
+ checkCompact(Decimal(1e18.toLong), false)
+ checkCompact(Decimal(-1e18.toLong), false)
+ checkCompact(Decimal(1e18.toLong, 30, 10), false)
+ checkCompact(Decimal(-1e18.toLong, 30, 10), false)
+ checkCompact(Decimal(Long.MaxValue), false)
+ checkCompact(Decimal(Long.MinValue), false)
+ }
+
+ test("hash code") {
+ assert(Decimal(123).hashCode() === (123).##)
+ assert(Decimal(-123).hashCode() === (-123).##)
+ assert(Decimal(123.312).hashCode() === (123.312).##)
+ assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##)
+ assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##)
+ assert(Decimal(BigDecimal(123)).hashCode() === (123).##)
+
+ val reallyBig = BigDecimal("123182312312313232112312312123.1231231231")
+ assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode)
+ }
+
+ test("equals") {
+ // The decimals on the left are stored compactly, while the ones on the right aren't
+ checkCompact(Decimal(123), true)
+ checkCompact(Decimal(BigDecimal(123)), false)
+ checkCompact(Decimal("123"), false)
+ assert(Decimal(123) === Decimal(BigDecimal(123)))
+ assert(Decimal(123) === Decimal(BigDecimal("123.00")))
+ assert(Decimal(-123) === Decimal(BigDecimal(-123)))
+ assert(Decimal(-123) === Decimal(BigDecimal("-123.00")))
+ }
+
+ test("isZero") {
+ assert(Decimal(0).isZero)
+ assert(Decimal(0, 4, 2).isZero)
+ assert(Decimal("0").isZero)
+ assert(Decimal("0.000").isZero)
+ assert(!Decimal(1).isZero)
+ assert(!Decimal(1, 4, 2).isZero)
+ assert(!Decimal("1").isZero)
+ assert(!Decimal("0.001").isZero)
+ }
+
+ test("arithmetic") {
+ assert(Decimal(100) + Decimal(-100) === Decimal(0))
+ assert(Decimal(100) + Decimal(-100) === Decimal(0))
+ assert(Decimal(100) * Decimal(-100) === Decimal(-10000))
+ assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26))
+ assert(Decimal(100) / Decimal(-100) === Decimal(-1))
+ assert(Decimal(100) / Decimal(0) === null)
+ assert(Decimal(100) % Decimal(-100) === Decimal(0))
+ assert(Decimal(100) % Decimal(3) === Decimal(1))
+ assert(Decimal(-100) % Decimal(3) === Decimal(-1))
+ assert(Decimal(100) % Decimal(0) === null)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
index 0c85cdc..c383540 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
@@ -53,11 +53,6 @@ public abstract class DataType {
public static final TimestampType TimestampType = new TimestampType();
/**
- * Gets the DecimalType object.
- */
- public static final DecimalType DecimalType = new DecimalType();
-
- /**
* Gets the DoubleType object.
*/
public static final DoubleType DoubleType = new DoubleType();
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
index bc54c07..6075245 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
@@ -19,9 +19,61 @@ package org.apache.spark.sql.api.java;
/**
* The data type representing java.math.BigDecimal values.
- *
- * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}.
*/
public class DecimalType extends DataType {
- protected DecimalType() {}
+ private boolean hasPrecisionInfo;
+ private int precision;
+ private int scale;
+
+ public DecimalType(int precision, int scale) {
+ this.hasPrecisionInfo = true;
+ this.precision = precision;
+ this.scale = scale;
+ }
+
+ public DecimalType() {
+ this.hasPrecisionInfo = false;
+ this.precision = -1;
+ this.scale = -1;
+ }
+
+ public boolean isUnlimited() {
+ return !hasPrecisionInfo;
+ }
+
+ public boolean isFixed() {
+ return hasPrecisionInfo;
+ }
+
+ /** Return the precision, or -1 if no precision is set */
+ public int getPrecision() {
+ return precision;
+ }
+
+ /** Return the scale, or -1 if no precision is set */
+ public int getScale() {
+ return scale;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ DecimalType that = (DecimalType) o;
+
+ if (hasPrecisionInfo != that.hasPrecisionInfo) return false;
+ if (precision != that.precision) return false;
+ if (scale != that.scale) return false;
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = (hasPrecisionInfo ? 1 : 0);
+ result = 31 * result + precision;
+ result = 31 * result + scale;
+ return result;
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8b96df1..018a18c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.util.{Map => JMap, List => JList}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.storage.StorageLevel
import scala.collection.JavaConversions._
@@ -113,7 +114,7 @@ class SchemaRDD(
// =========================================================================================
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
- firstParent[Row].compute(split, context).map(_.copy())
+ firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala)
override def getPartitions: Array[Partition] = firstParent[Row].partitions
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 082ae03..876b1c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -230,7 +230,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
case c: Class[_] if c == classOf[java.lang.Boolean] =>
(org.apache.spark.sql.BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] =>
- (org.apache.spark.sql.DecimalType, true)
+ (org.apache.spark.sql.DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] =>
(org.apache.spark.sql.DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] =>
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index df01411..401798e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.api.java
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.annotation.varargs
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
import scala.collection.JavaConversions
@@ -106,6 +108,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable {
}
override def hashCode(): Int = row.hashCode()
+
+ override def toString: String = row.toString
}
object Row {
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b3edd50..087b0ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -70,16 +70,29 @@ case class GeneratedAggregate(
val computeFunctions = aggregatesToCompute.map {
case c @ Count(expr) =>
+ // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+ // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+ val toCount = expr match {
+ case UnscaledValue(e) => e
+ case _ => expr
+ }
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
val initialValue = Literal(0L)
- val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val result = currentCount
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case Sum(expr) =>
- val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
- val initialValue = Cast(Literal(0L), expr.dataType)
+ val resultType = expr.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale)
+ case _ =>
+ expr.dataType
+ }
+
+ val currentSum = AttributeReference("currentSum", resultType, nullable = false)()
+ val initialValue = Cast(Literal(0L), resultType)
// Coalasce avoids double calculation...
// but really, common sub expression elimination would be better....
@@ -93,10 +106,26 @@ case class GeneratedAggregate(
val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
val initialCount = Literal(0L)
val initialSum = Cast(Literal(0L), expr.dataType)
- val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+
+ // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+ // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+ val toCount = expr match {
+ case UnscaledValue(e) => e
+ case _ => expr
+ }
+
+ val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
- val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType))
+ val resultType = expr.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4)
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ DoubleType
+ }
+ val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType))
AggregateEvaluation(
currentCount :: currentSum :: Nil,
@@ -142,7 +171,7 @@ case class GeneratedAggregate(
val computationSchema = computeFunctions.flatMap(_.schema)
- val resultMap: Map[TreeNodeRef, Expression] =
+ val resultMap: Map[TreeNodeRef, Expression] =
aggregatesToCompute.zip(computeFunctions).map {
case (agg, func) => new TreeNodeRef(agg) -> func.result
}.toMap
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b1a7948..aafcce0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.{ScalaReflection, trees}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -82,7 +82,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
- def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
+ def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect()
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 077e6eb..84d96e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -29,6 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.{SerializerInstance, KryoSerializer}
import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils
@@ -51,6 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
new OpenHashSetSerializer)
+ kryo.register(classOf[Decimal])
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 977f3c9..e6cd1a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan)
partsScanned += numPartsToTry
}
- buf.toArray
+ buf.toArray.map(ScalaReflection.convertRowToScala)
}
override def execute() = {
@@ -176,10 +176,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
override def output = child.output
override def outputPartitioning = SinglePartition
- val ordering = new RowOrdering(sortOrder, child.output)
+ val ord = new RowOrdering(sortOrder, child.output)
// TODO: Is this copying for no reason?
- override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
+ override def executeCollect() =
+ child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 8fd3588..5cf2a78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -49,7 +49,8 @@ case class BroadcastHashJoin(
@transient
private val broadcastFuture = future {
- val input: Array[Row] = buildPlan.executeCollect()
+ // Note that we use .execute().collect() because we don't want to convert data to Scala types
+ val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
sparkContext.broadcast(hashed)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index a1961bb..9976690 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
import java.util.{List => JList, Map => JMap}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -116,7 +118,7 @@ object EvaluatePython {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (row: Row, struct: StructType) =>
+ case (row: Seq[Any], struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
row.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
@@ -133,6 +135,8 @@ object EvaluatePython {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
+ case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal
+
// Pyrolite can handle Timestamp
case (other, _) => other
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index eabe312..5bb6f6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.json
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
@@ -175,9 +177,9 @@ private[sql] object JsonRDD extends Logging {
ScalaReflection.typeOfObject orElse {
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
- case value: java.math.BigInteger => DecimalType
+ case value: java.math.BigInteger => DecimalType.Unlimited
// DecimalType's JVMType is scala BigDecimal.
- case value: java.math.BigDecimal => DecimalType
+ case value: java.math.BigDecimal => DecimalType.Unlimited
// Unexpected data type.
case _ => StringType
}
@@ -319,13 +321,13 @@ private[sql] object JsonRDD extends Logging {
}
}
- private def toDecimal(value: Any): BigDecimal = {
+ private def toDecimal(value: Any): Decimal = {
value match {
- case value: java.lang.Integer => BigDecimal(value)
- case value: java.lang.Long => BigDecimal(value)
- case value: java.math.BigInteger => BigDecimal(value)
- case value: java.lang.Double => BigDecimal(value)
- case value: java.math.BigDecimal => BigDecimal(value)
+ case value: java.lang.Integer => Decimal(value)
+ case value: java.lang.Long => Decimal(value)
+ case value: java.math.BigInteger => Decimal(BigDecimal(value))
+ case value: java.lang.Double => Decimal(value)
+ case value: java.math.BigDecimal => Decimal(BigDecimal(value))
}
}
@@ -391,7 +393,7 @@ private[sql] object JsonRDD extends Logging {
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
case LongType => toLong(value)
case DoubleType => toDouble(value)
- case DecimalType => toDecimal(value)
+ case DecimalType() => toDecimal(value)
case BooleanType => value.asInstanceOf[BooleanType.JvmType]
case NullType => null
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index f0e57e2..05926a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -183,6 +183,20 @@ package object sql {
*
* The data type representing `scala.math.BigDecimal` values.
*
+ * TODO(matei): explain precision and scale
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ type DecimalType = catalyst.types.DecimalType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `scala.math.BigDecimal` values.
+ *
+ * TODO(matei): explain precision and scale
+ *
* @group dataType
*/
@DeveloperApi
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 2fc7e1c..08feced 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.parquet
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap}
import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
@@ -117,6 +119,12 @@ private[sql] object CatalystConverter {
parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType])
}
}
+ case d: DecimalType => {
+ new CatalystPrimitiveConverter(parent, fieldIndex) {
+ override def addBinary(value: Binary): Unit =
+ parent.updateDecimal(fieldIndex, value, d)
+ }
+ }
// All other primitive types use the default converter
case ctype: PrimitiveType => { // note: need the type tag here!
new CatalystPrimitiveConverter(parent, fieldIndex)
@@ -191,6 +199,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, value.toStringUsingUTF8)
+ protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = {
+ updateField(fieldIndex, readDecimal(new Decimal(), value, ctype))
+ }
+
protected[parquet] def isRootConverter: Boolean = parent == null
protected[parquet] def clearBuffer(): Unit
@@ -201,6 +213,27 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
* @return
*/
def getCurrentRecord: Row = throw new UnsupportedOperationException
+
+ /**
+ * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in
+ * a long (i.e. precision <= 18)
+ */
+ protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = {
+ val precision = ctype.precisionInfo.get.precision
+ val scale = ctype.precisionInfo.get.scale
+ val bytes = value.getBytes
+ require(bytes.length <= 16, "Decimal field too large to read")
+ var unscaled = 0L
+ var i = 0
+ while (i < bytes.length) {
+ unscaled = (unscaled << 8) | (bytes(i) & 0xFF)
+ i += 1
+ }
+ // Make sure unscaled has the right sign, by sign-extending the first bit
+ val numBits = 8 * bytes.length
+ unscaled = (unscaled << (64 - numBits)) >> (64 - numBits)
+ dest.set(unscaled, precision, scale)
+ }
}
/**
@@ -352,6 +385,16 @@ private[parquet] class CatalystPrimitiveRowConverter(
override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit =
current.setString(fieldIndex, value.toStringUsingUTF8)
+
+ override protected[parquet] def updateDecimal(
+ fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = {
+ var decimal = current(fieldIndex).asInstanceOf[Decimal]
+ if (decimal == null) {
+ decimal = new Decimal
+ current(fieldIndex) = decimal
+ }
+ readDecimal(decimal, value, ctype)
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index bdf0240..2a5f23b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet
import java.util.{HashMap => JHashMap}
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import parquet.column.ParquetProperties
import parquet.hadoop.ParquetOutputFormat
import parquet.hadoop.api.ReadSupport.ReadContext
@@ -204,6 +205,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
+ case d: DecimalType =>
+ if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ sys.error(s"Unsupported datatype $d, cannot write to consumer")
+ }
+ writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision)
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}
@@ -283,6 +289,23 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
}
writer.endGroup()
}
+
+ // Scratch array used to write decimals as fixed-length binary
+ private val scratchBytes = new Array[Byte](8)
+
+ private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
+ val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision)
+ val unscaledLong = decimal.toUnscaledLong
+ var i = 0
+ var shift = 8 * (numBytes - 1)
+ while (i < numBytes) {
+ scratchBytes(i) = (unscaledLong >> shift).toByte
+ i += 1
+ shift -= 8
+ }
+ writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
+ }
+
}
// Optimized for non-nested rows
@@ -326,6 +349,11 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
case DoubleType => writer.addDouble(record.getDouble(index))
case FloatType => writer.addFloat(record.getFloat(index))
case BooleanType => writer.addBoolean(record.getBoolean(index))
+ case d: DecimalType =>
+ if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ sys.error(s"Unsupported datatype $d, cannot write to consumer")
+ }
+ writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index e6389cf..e5077de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -29,8 +29,8 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter}
import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData}
import parquet.hadoop.util.ContextUtil
-import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType}
-import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns}
+import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType}
+import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata}
import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
import parquet.schema.Type.Repetition
@@ -41,17 +41,25 @@ import org.apache.spark.sql.catalyst.types._
// Implicits
import scala.collection.JavaConversions._
+/** A class representing Parquet info fields we care about, for passing back to Parquet */
+private[parquet] case class ParquetTypeInfo(
+ primitiveType: ParquetPrimitiveTypeName,
+ originalType: Option[ParquetOriginalType] = None,
+ decimalMetadata: Option[DecimalMetadata] = None,
+ length: Option[Int] = None)
+
private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean =
classOf[PrimitiveType] isAssignableFrom ctype.getClass
def toPrimitiveDataType(
parquetType: ParquetPrimitiveType,
- binayAsString: Boolean): DataType =
+ binaryAsString: Boolean): DataType = {
+ val originalType = parquetType.getOriginalType
+ val decimalInfo = parquetType.getDecimalMetadata
parquetType.getPrimitiveTypeName match {
case ParquetPrimitiveTypeName.BINARY
- if (parquetType.getOriginalType == ParquetOriginalType.UTF8 ||
- binayAsString) => StringType
+ if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType
case ParquetPrimitiveTypeName.BINARY => BinaryType
case ParquetPrimitiveTypeName.BOOLEAN => BooleanType
case ParquetPrimitiveTypeName.DOUBLE => DoubleType
@@ -61,9 +69,14 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetPrimitiveTypeName.INT96 =>
// TODO: add BigInteger type? TODO(andre) use DecimalType instead????
sys.error("Potential loss of precision: cannot convert INT96")
+ case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
+ if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) =>
+ // TODO: for now, our reader only supports decimals that fit in a Long
+ DecimalType(decimalInfo.getPrecision, decimalInfo.getScale)
case _ => sys.error(
s"Unsupported parquet datatype $parquetType")
}
+ }
/**
* Converts a given Parquet `Type` into the corresponding
@@ -183,24 +196,41 @@ private[parquet] object ParquetTypesConverter extends Logging {
* is not primitive.
*
* @param ctype The type to convert
- * @return The name of the corresponding Parquet primitive type
+ * @return The name of the corresponding Parquet type properties
*/
- def fromPrimitiveDataType(ctype: DataType):
- Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match {
- case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))
- case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None)
- case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None)
- case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None)
- case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None)
- case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None)
+ def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match {
+ case StringType => Some(ParquetTypeInfo(
+ ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)))
+ case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY))
+ case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN))
+ case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE))
+ case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT))
+ case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
// There is no type for Byte or Short so we promote them to INT32.
- case ShortType => Some(ParquetPrimitiveTypeName.INT32, None)
- case ByteType => Some(ParquetPrimitiveTypeName.INT32, None)
- case LongType => Some(ParquetPrimitiveTypeName.INT64, None)
+ case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
+ case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
+ case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64))
+ case DecimalType.Fixed(precision, scale) if precision <= 18 =>
+ // TODO: for now, our writer only supports decimals that fit in a Long
+ Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY,
+ Some(ParquetOriginalType.DECIMAL),
+ Some(new DecimalMetadata(precision, scale)),
+ Some(BYTES_FOR_PRECISION(precision))))
case _ => None
}
/**
+ * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision.
+ */
+ private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision =>
+ var length = 1
+ while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) {
+ length += 1
+ }
+ length
+ }
+
+ /**
* Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into
* the corresponding Parquet `Type`.
*
@@ -247,10 +277,17 @@ private[parquet] object ParquetTypesConverter extends Logging {
} else {
if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED
}
- val primitiveType = fromPrimitiveDataType(ctype)
- primitiveType.map {
- case (primitiveType, originalType) =>
- new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull)
+ val typeInfo = fromPrimitiveDataType(ctype)
+ typeInfo.map {
+ case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) =>
+ val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull)
+ for (len <- length) {
+ builder.length(len)
+ }
+ for (metadata <- decimalMetadata) {
+ builder.precision(metadata.getPrecision).scale(metadata.getScale)
+ }
+ builder.named(name)
}.getOrElse {
ctype match {
case ArrayType(elementType, false) => {
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 142598c..7564bf3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.types.util
import org.apache.spark.sql._
import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder}
+import org.apache.spark.sql.api.java.{DecimalType => JDecimalType}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConverters._
@@ -44,7 +46,8 @@ protected[sql] object DataTypeConversions {
case BooleanType => JDataType.BooleanType
case DateType => JDataType.DateType
case TimestampType => JDataType.TimestampType
- case DecimalType => JDataType.DecimalType
+ case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale)
+ case DecimalType.Unlimited => new JDecimalType()
case DoubleType => JDataType.DoubleType
case FloatType => JDataType.FloatType
case ByteType => JDataType.ByteType
@@ -88,7 +91,11 @@ protected[sql] object DataTypeConversions {
case timestampType: org.apache.spark.sql.api.java.TimestampType =>
TimestampType
case decimalType: org.apache.spark.sql.api.java.DecimalType =>
- DecimalType
+ if (decimalType.isFixed) {
+ DecimalType(decimalType.getPrecision, decimalType.getScale)
+ } else {
+ DecimalType.Unlimited
+ }
case doubleType: org.apache.spark.sql.api.java.DoubleType =>
DoubleType
case floatType: org.apache.spark.sql.api.java.FloatType =>
@@ -115,7 +122,7 @@ protected[sql] object DataTypeConversions {
/** Converts Java objects to catalyst rows / types */
def convertJavaToCatalyst(a: Any): Any = a match {
- case d: java.math.BigDecimal => BigDecimal(d)
+ case d: java.math.BigDecimal => Decimal(BigDecimal(d))
case other => other
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
index 9435a88..a04b806 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -118,7 +118,7 @@ public class JavaApplySchemaSuite implements Serializable {
"\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
"\"boolean\":false, \"null\":null}"));
List<StructField> fields = new ArrayList<StructField>(7);
- fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true));
+ fields.add(DataType.createStructField("bigInteger", new DecimalType(), true));
fields.add(DataType.createStructField("boolean", DataType.BooleanType, true));
fields.add(DataType.createStructField("double", DataType.DoubleType, true));
fields.add(DataType.createStructField("integer", DataType.IntegerType, true));
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
index d04396a..8396a29 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
@@ -41,7 +41,8 @@ public class JavaSideDataTypeConversionSuite {
checkDataType(DataType.BooleanType);
checkDataType(DataType.DateType);
checkDataType(DataType.TimestampType);
- checkDataType(DataType.DecimalType);
+ checkDataType(new DecimalType());
+ checkDataType(new DecimalType(10, 4));
checkDataType(DataType.DoubleType);
checkDataType(DataType.FloatType);
checkDataType(DataType.ByteType);
@@ -59,7 +60,7 @@ public class JavaSideDataTypeConversionSuite {
// Simple StructType.
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false));
@@ -128,7 +129,7 @@ public class JavaSideDataTypeConversionSuite {
// StructType
try {
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
simpleFields.add(null);
@@ -138,7 +139,7 @@ public class JavaSideDataTypeConversionSuite {
}
try {
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
DataType.createStructType(simpleFields);
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 6c9db63..e9740d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -69,7 +69,7 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(LongType)
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
- checkDataTypeJsonRepr(DecimalType)
+ checkDataTypeJsonRepr(DecimalType.Unlimited)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
checkDataTypeJsonRepr(BinaryType)
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index bfa9ea4..cf3a59e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions._
@@ -81,7 +82,9 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectData")
- assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
+ assert(sql("SELECT * FROM reflectData").collect().head ===
+ Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)))
}
test("query case class RDD with nulls") {
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
index d83f3e2..c9012c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.api.java
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.beans.BeanProperty
import org.scalatest.FunSuite
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
index e0e0ff9..62fe59d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
@@ -38,7 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(org.apache.spark.sql.BooleanType)
checkDataType(org.apache.spark.sql.DateType)
checkDataType(org.apache.spark.sql.TimestampType)
- checkDataType(org.apache.spark.sql.DecimalType)
+ checkDataType(org.apache.spark.sql.DecimalType.Unlimited)
checkDataType(org.apache.spark.sql.DoubleType)
checkDataType(org.apache.spark.sql.FloatType)
checkDataType(org.apache.spark.sql.ByteType)
@@ -58,7 +58,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
// Simple StructType.
val simpleScalaStructType = SStructType(
- SStructField("a", org.apache.spark.sql.DecimalType, false) ::
+ SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) ::
SStructField("b", org.apache.spark.sql.BooleanType, true) ::
SStructField("c", org.apache.spark.sql.LongType, true) ::
SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil)
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index ce6184f..1cb6c23 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.json
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
@@ -44,19 +45,22 @@ class JsonSuite extends QueryTest {
checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType))
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType))
- checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType))
+ checkTypePromotion(
+ Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited))
val longNumber: Long = 9223372036854775807L
checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType))
checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType))
- checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType))
+ checkTypePromotion(
+ Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited))
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
- checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType))
-
+ checkTypePromotion(
+ Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
+
checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
- checkTypePromotion(new Timestamp(intNumber.toLong),
+ checkTypePromotion(new Timestamp(intNumber.toLong),
enforceCorrectType(intNumber.toLong, TimestampType))
val strTime = "2014-09-30 12:34:56"
checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType))
@@ -80,7 +84,7 @@ class JsonSuite extends QueryTest {
checkDataType(NullType, IntegerType, IntegerType)
checkDataType(NullType, LongType, LongType)
checkDataType(NullType, DoubleType, DoubleType)
- checkDataType(NullType, DecimalType, DecimalType)
+ checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(NullType, StringType, StringType)
checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType))
checkDataType(NullType, StructType(Nil), StructType(Nil))
@@ -91,7 +95,7 @@ class JsonSuite extends QueryTest {
checkDataType(BooleanType, IntegerType, StringType)
checkDataType(BooleanType, LongType, StringType)
checkDataType(BooleanType, DoubleType, StringType)
- checkDataType(BooleanType, DecimalType, StringType)
+ checkDataType(BooleanType, DecimalType.Unlimited, StringType)
checkDataType(BooleanType, StringType, StringType)
checkDataType(BooleanType, ArrayType(IntegerType), StringType)
checkDataType(BooleanType, StructType(Nil), StringType)
@@ -100,7 +104,7 @@ class JsonSuite extends QueryTest {
checkDataType(IntegerType, IntegerType, IntegerType)
checkDataType(IntegerType, LongType, LongType)
checkDataType(IntegerType, DoubleType, DoubleType)
- checkDataType(IntegerType, DecimalType, DecimalType)
+ checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(IntegerType, StringType, StringType)
checkDataType(IntegerType, ArrayType(IntegerType), StringType)
checkDataType(IntegerType, StructType(Nil), StringType)
@@ -108,23 +112,23 @@ class JsonSuite extends QueryTest {
// LongType
checkDataType(LongType, LongType, LongType)
checkDataType(LongType, DoubleType, DoubleType)
- checkDataType(LongType, DecimalType, DecimalType)
+ checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(LongType, StringType, StringType)
checkDataType(LongType, ArrayType(IntegerType), StringType)
checkDataType(LongType, StructType(Nil), StringType)
// DoubleType
checkDataType(DoubleType, DoubleType, DoubleType)
- checkDataType(DoubleType, DecimalType, DecimalType)
+ checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(DoubleType, StringType, StringType)
checkDataType(DoubleType, ArrayType(IntegerType), StringType)
checkDataType(DoubleType, StructType(Nil), StringType)
// DoubleType
- checkDataType(DecimalType, DecimalType, DecimalType)
- checkDataType(DecimalType, StringType, StringType)
- checkDataType(DecimalType, ArrayType(IntegerType), StringType)
- checkDataType(DecimalType, StructType(Nil), StringType)
+ checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited)
+ checkDataType(DecimalType.Unlimited, StringType, StringType)
+ checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType)
+ checkDataType(DecimalType.Unlimited, StructType(Nil), StringType)
// StringType
checkDataType(StringType, StringType, StringType)
@@ -178,7 +182,7 @@ class JsonSuite extends QueryTest {
checkDataType(
StructType(
StructField("f1", IntegerType, true) :: Nil),
- DecimalType,
+ DecimalType.Unlimited,
StringType)
}
@@ -186,7 +190,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(primitiveFieldAndType)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
@@ -216,7 +220,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) ::
- StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, false), true) ::
StructField("arrayOfInteger", ArrayType(IntegerType, false), true) ::
@@ -230,7 +234,7 @@ class JsonSuite extends QueryTest {
StructField("field3", StringType, true) :: Nil), false), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
- StructField("field2", DecimalType, true) :: Nil), true) ::
+ StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(IntegerType, false), true) ::
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
@@ -331,7 +335,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
StructField("num_num_1", LongType, true) ::
- StructField("num_num_2", DecimalType, true) ::
+ StructField("num_num_2", DecimalType.Unlimited, true) ::
StructField("num_num_3", DoubleType, true) ::
StructField("num_str", StringType, true) ::
StructField("str_bool", StringType, true) :: Nil)
@@ -521,7 +525,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonFile(path)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
@@ -551,7 +555,7 @@ class JsonSuite extends QueryTest {
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val schema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 9979ab4..08d9da2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -77,6 +77,8 @@ case class AllDataTypesWithNonPrimitiveType(
case class BinaryData(binaryData: Array[Byte])
+case class NumericData(i: Int, d: Double)
+
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
TestData // Load test data tables.
@@ -560,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(stringResult.size === 1)
assert(stringResult(0).getString(2) == "100", "stringvalue incorrect")
assert(stringResult(0).getInt(1) === 100)
-
+
val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40")
assert(
query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
@@ -869,4 +871,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(a.dataType === b.dataType)
}
}
+
+ test("read/write fixed-length decimals") {
+ for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType(precision, scale))
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+
+ // Decimals with precision above 18 are not yet supported
+ intercept[RuntimeException] {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType(19, 10))
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+
+ // Unlimited-length decimals are not yet supported
+ intercept[RuntimeException] {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType.Unlimited)
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
----------------------------------------------------------------------
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index 2a4f241..99c4f46 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -47,7 +47,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)(
hiveContext, sessionToActivePool)
- handleToOperation.put(operation.getHandle, operation)
- operation
+ handleToOperation.put(operation.getHandle, operation)
+ operation
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
----------------------------------------------------------------------
diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
index bbd727c..8077d0e 100644
--- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
+++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
@@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation(
to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal)))
case FloatType =>
to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
- case DecimalType =>
+ case DecimalType() =>
val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
case LongType =>
@@ -156,7 +156,7 @@ private[hive] class SparkExecuteStatementOperation(
to.addColumnValue(ColumnValue.doubleValue(null))
case FloatType =>
to.addColumnValue(ColumnValue.floatValue(null))
- case DecimalType =>
+ case DecimalType() =>
to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal))
case LongType =>
to.addColumnValue(ColumnValue.longValue(null))
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
----------------------------------------------------------------------
diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
index e59681b..2c1983d 100644
--- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
+++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
@@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation(
to += from.getDouble(ordinal)
case FloatType =>
to += from.getFloat(ordinal)
- case DecimalType =>
+ case DecimalType() =>
to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
case LongType =>
to += from.getLong(ordinal)
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index ff8fa44..2e27817 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -21,6 +21,10 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.sql.{Date, Timestamp}
import java.util.{ArrayList => JArrayList}
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.spark.sql.catalyst.types.DecimalType
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.JavaConversions._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
@@ -370,7 +374,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
- ShortType, DecimalType, DateType, TimestampType, BinaryType)
+ ShortType, DateType, TimestampType, BinaryType)
protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
@@ -388,6 +392,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
case (d: Date, DateType) => new DateWritable(d).toString
case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString
case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8")
+ case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString
+ HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
@@ -406,6 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "null"
case (s: String, StringType) => "\"" + s + "\""
+ case (decimal, DecimalType()) => decimal.toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 0439ab9..1e2bf5c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -28,6 +28,7 @@ import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -38,7 +39,7 @@ private[hive] trait HiveInspectors {
// writable
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
- case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
+ case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType
@@ -54,8 +55,8 @@ private[hive] trait HiveInspectors {
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.sql.Date] => DateType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
- case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
- case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
+ case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
@@ -90,7 +91,7 @@ private[hive] trait HiveInspectors {
case hvoi: HiveVarcharObjectInspector =>
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
case hdoi: HiveDecimalObjectInspector =>
- if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+ if (data == null) null else HiveShim.toCatalystDecimal(hdoi, data)
// org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object
// if next timestamp is null, so Timestamp object is cloned
case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone()
@@ -137,8 +138,9 @@ private[hive] trait HiveInspectors {
case l: Short => l: java.lang.Short
case l: Byte => l: java.lang.Byte
case b: BigDecimal => HiveShim.createDecimal(b.underlying())
+ case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying())
case b: Array[Byte] => b
- case d: java.sql.Date => d
+ case d: java.sql.Date => d
case t: java.sql.Timestamp => t
}
case x: StructObjectInspector =>
@@ -200,7 +202,7 @@ private[hive] trait HiveInspectors {
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
- case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
+ case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) =>
ObjectInspectorFactory.getStandardStructObjectInspector(
fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
@@ -229,8 +231,10 @@ private[hive] trait HiveInspectors {
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
case Literal(value: java.sql.Timestamp, TimestampType) =>
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
- case Literal(value: BigDecimal, DecimalType) =>
+ case Literal(value: BigDecimal, DecimalType()) =>
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Decimal, DecimalType()) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal)
case Literal(_, NullType) =>
HiveShim.getPrimitiveNullWritableConstantObjectInspector
case Literal(value: Seq[_], ArrayType(dt, _)) =>
@@ -277,8 +281,8 @@ private[hive] trait HiveInspectors {
case _: JavaFloatObjectInspector => FloatType
case _: WritableBinaryObjectInspector => BinaryType
case _: JavaBinaryObjectInspector => BinaryType
- case _: WritableHiveDecimalObjectInspector => DecimalType
- case _: JavaHiveDecimalObjectInspector => DecimalType
+ case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w)
+ case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j)
case _: WritableDateObjectInspector => DateType
case _: JavaDateObjectInspector => DateType
case _: WritableTimestampObjectInspector => TimestampType
@@ -307,7 +311,7 @@ private[hive] trait HiveInspectors {
case LongType => longTypeInfo
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
- case DecimalType => decimalTypeInfo
+ case d: DecimalType => HiveShim.decimalTypeInfo(d)
case DateType => dateTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 2dd2c88..096b4a0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive
import java.io.IOException
import java.util.{List => JList}
+import scala.util.matching.Regex
import scala.util.parsing.combinator.RegexParsers
import org.apache.hadoop.util.ReflectionUtils
@@ -321,11 +322,18 @@ object HiveMetastoreTypes extends RegexParsers {
"bigint" ^^^ LongType |
"binary" ^^^ BinaryType |
"boolean" ^^^ BooleanType |
- HiveShim.metastoreDecimal ^^^ DecimalType |
+ fixedDecimalType | // Hive 0.13+ decimal with precision/scale
+ "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale
"date" ^^^ DateType |
"timestamp" ^^^ TimestampType |
"varchar\\((\\d+)\\)".r ^^^ StringType
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ {
+ case precision ~ scale =>
+ DecimalType(precision.toInt, scale.toInt)
+ }
+
protected lazy val arrayType: Parser[DataType] =
"array" ~> "<" ~> dataType <~ ">" ^^ {
case tpe => ArrayType(tpe)
@@ -373,7 +381,7 @@ object HiveMetastoreTypes extends RegexParsers {
case BinaryType => "binary"
case BooleanType => "boolean"
case DateType => "date"
- case DecimalType => "decimal"
+ case d: DecimalType => HiveShim.decimalMetastoreString(d)
case TimestampType => "timestamp"
case NullType => "void"
}
@@ -441,7 +449,7 @@ private[hive] case class MetastoreRelation
val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute)
/** Non-partitionKey attributes */
- val attributes = hiveQlTable.getCols.map(_.toAttribute)
+ val attributes = hiveQlTable.getCols.map(_.toAttribute)
val output = attributes ++ partitionKeys
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org