You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/02/01 22:56:20 UTC

spark git commit: [SPARK-13043][SQL] Implement remaining catalyst types in ColumnarBatch.

Repository: spark
Updated Branches:
  refs/heads/master c9b89a0a0 -> 064b029c6


[SPARK-13043][SQL] Implement remaining catalyst types in ColumnarBatch.

This includes: float, boolean, short, decimal and calendar interval.

Decimal is mapped to long or byte array depending on the size and calendar
interval is mapped to a struct of int and long.

The only remaining type is map. The schema mapping is straightforward but
we might want to revisit how we deal with this in the rest of the execution
engine.

Author: Nong Li <no...@databricks.com>

Closes #10961 from nongli/spark-13043.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/064b029c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/064b029c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/064b029c

Branch: refs/heads/master
Commit: 064b029c6a15481fc4dfb147100c19a68cd1cc95
Parents: c9b89a0
Author: Nong Li <no...@databricks.com>
Authored: Mon Feb 1 13:56:14 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Feb 1 13:56:14 2016 -0800

----------------------------------------------------------------------
 .../apache/spark/sql/types/DecimalType.scala    |  22 +++
 .../sql/execution/vectorized/ColumnVector.java  | 180 ++++++++++++++++++-
 .../execution/vectorized/ColumnVectorUtils.java |  34 +++-
 .../sql/execution/vectorized/ColumnarBatch.java |  46 ++---
 .../vectorized/OffHeapColumnVector.java         |  98 +++++++++-
 .../vectorized/OnHeapColumnVector.java          |  94 +++++++++-
 .../vectorized/ColumnarBatchSuite.scala         |  44 ++++-
 .../java/org/apache/spark/unsafe/Platform.java  |   8 +
 8 files changed, 484 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index cf53221..5dd661e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType {
     }
   }
 
+  /**
+   * Returns if dt is a DecimalType that fits inside a long
+   */
+  def is64BitDecimalType(dt: DataType): Boolean = {
+    dt match {
+      case t: DecimalType =>
+        t.precision <= Decimal.MAX_LONG_DIGITS
+      case _ => false
+    }
+  }
+
+  /**
+   * Returns if dt is a DecimalType that doesn't fit inside a long
+   */
+  def isByteArrayDecimalType(dt: DataType): Boolean = {
+    dt match {
+      case t: DecimalType =>
+        t.precision > Decimal.MAX_LONG_DIGITS
+      case _ => false
+    }
+  }
+
   def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
 
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]

http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index a0bf873..a5bc506 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -16,6 +16,9 @@
  */
 package org.apache.spark.sql.execution.vectorized;
 
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
 import org.apache.spark.memory.MemoryMode;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.util.ArrayData;
@@ -102,18 +105,36 @@ public abstract class ColumnVector {
       DataType dt = data.dataType();
       Object[] list = new Object[length];
 
-      if (dt instanceof ByteType) {
+      if (dt instanceof BooleanType) {
+        for (int i = 0; i < length; i++) {
+          if (!data.getIsNull(offset + i)) {
+            list[i] = data.getBoolean(offset + i);
+          }
+        }
+      } else if (dt instanceof ByteType) {
         for (int i = 0; i < length; i++) {
           if (!data.getIsNull(offset + i)) {
             list[i] = data.getByte(offset + i);
           }
         }
+      } else if (dt instanceof ShortType) {
+        for (int i = 0; i < length; i++) {
+          if (!data.getIsNull(offset + i)) {
+            list[i] = data.getShort(offset + i);
+          }
+        }
       } else if (dt instanceof IntegerType) {
         for (int i = 0; i < length; i++) {
           if (!data.getIsNull(offset + i)) {
             list[i] = data.getInt(offset + i);
           }
         }
+      } else if (dt instanceof FloatType) {
+        for (int i = 0; i < length; i++) {
+          if (!data.getIsNull(offset + i)) {
+            list[i] = data.getFloat(offset + i);
+          }
+        }
       } else if (dt instanceof DoubleType) {
         for (int i = 0; i < length; i++) {
           if (!data.getIsNull(offset + i)) {
@@ -126,12 +147,25 @@ public abstract class ColumnVector {
             list[i] = data.getLong(offset + i);
           }
         }
+      } else if (dt instanceof DecimalType) {
+        DecimalType decType = (DecimalType)dt;
+        for (int i = 0; i < length; i++) {
+          if (!data.getIsNull(offset + i)) {
+            list[i] = getDecimal(i, decType.precision(), decType.scale());
+          }
+        }
       } else if (dt instanceof StringType) {
         for (int i = 0; i < length; i++) {
           if (!data.getIsNull(offset + i)) {
             list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i));
           }
         }
+      } else if (dt instanceof CalendarIntervalType) {
+        for (int i = 0; i < length; i++) {
+          if (!data.getIsNull(offset + i)) {
+            list[i] = getInterval(i);
+          }
+        }
       } else {
         throw new NotImplementedException("Type " + dt);
       }
@@ -170,7 +204,14 @@ public abstract class ColumnVector {
 
     @Override
     public Decimal getDecimal(int ordinal, int precision, int scale) {
-      throw new NotImplementedException();
+      if (precision <= Decimal.MAX_LONG_DIGITS()) {
+        return Decimal.apply(getLong(ordinal), precision, scale);
+      } else {
+        byte[] bytes = getBinary(ordinal);
+        BigInteger bigInteger = new BigInteger(bytes);
+        BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+        return Decimal.apply(javaDecimal, precision, scale);
+      }
     }
 
     @Override
@@ -181,17 +222,22 @@ public abstract class ColumnVector {
 
     @Override
     public byte[] getBinary(int ordinal) {
-      throw new NotImplementedException();
+      ColumnVector.Array array = data.getByteArray(offset + ordinal);
+      byte[] bytes = new byte[array.length];
+      System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+      return bytes;
     }
 
     @Override
     public CalendarInterval getInterval(int ordinal) {
-      throw new NotImplementedException();
+      int month = data.getChildColumn(0).getInt(offset + ordinal);
+      long microseconds = data.getChildColumn(1).getLong(offset + ordinal);
+      return new CalendarInterval(month, microseconds);
     }
 
     @Override
     public InternalRow getStruct(int ordinal, int numFields) {
-      throw new NotImplementedException();
+      return data.getStruct(offset + ordinal);
     }
 
     @Override
@@ -282,6 +328,21 @@ public abstract class ColumnVector {
   /**
    * Sets the value at rowId to `value`.
    */
+  public abstract void putBoolean(int rowId, boolean value);
+
+  /**
+   * Sets values from [rowId, rowId + count) to value.
+   */
+  public abstract void putBooleans(int rowId, int count, boolean value);
+
+  /**
+   * Returns the value for rowId.
+   */
+  public abstract boolean getBoolean(int rowId);
+
+  /**
+   * Sets the value at rowId to `value`.
+   */
   public abstract void putByte(int rowId, byte value);
 
   /**
@@ -302,6 +363,26 @@ public abstract class ColumnVector {
   /**
    * Sets the value at rowId to `value`.
    */
+  public abstract void putShort(int rowId, short value);
+
+  /**
+   * Sets values from [rowId, rowId + count) to value.
+   */
+  public abstract void putShorts(int rowId, int count, short value);
+
+  /**
+   * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+   */
+  public abstract void putShorts(int rowId, int count, short[] src, int srcIndex);
+
+  /**
+   * Returns the value for rowId.
+   */
+  public abstract short getShort(int rowId);
+
+  /**
+   * Sets the value at rowId to `value`.
+   */
   public abstract void putInt(int rowId, int value);
 
   /**
@@ -354,6 +435,33 @@ public abstract class ColumnVector {
   /**
    * Sets the value at rowId to `value`.
    */
+  public abstract void putFloat(int rowId, float value);
+
+  /**
+   * Sets values from [rowId, rowId + count) to value.
+   */
+  public abstract void putFloats(int rowId, int count, float value);
+
+  /**
+   * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+   * src should contain `count` doubles written as ieee format.
+   */
+  public abstract void putFloats(int rowId, int count, float[] src, int srcIndex);
+
+  /**
+   * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+   * The data in src must be ieee formatted floats.
+   */
+  public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex);
+
+  /**
+   * Returns the value for rowId.
+   */
+  public abstract float getFloat(int rowId);
+
+  /**
+   * Sets the value at rowId to `value`.
+   */
   public abstract void putDouble(int rowId, double value);
 
   /**
@@ -369,7 +477,7 @@ public abstract class ColumnVector {
 
   /**
    * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
-   * The data in src must be ieee formated doubles.
+   * The data in src must be ieee formatted doubles.
    */
   public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex);
 
@@ -469,6 +577,20 @@ public abstract class ColumnVector {
     return result;
   }
 
+  public final int appendBoolean(boolean v) {
+    reserve(elementsAppended + 1);
+    putBoolean(elementsAppended, v);
+    return elementsAppended++;
+  }
+
+  public final int appendBooleans(int count, boolean v) {
+    reserve(elementsAppended + count);
+    int result = elementsAppended;
+    putBooleans(elementsAppended, count, v);
+    elementsAppended += count;
+    return result;
+  }
+
   public final int appendByte(byte v) {
     reserve(elementsAppended + 1);
     putByte(elementsAppended, v);
@@ -491,6 +613,28 @@ public abstract class ColumnVector {
     return result;
   }
 
+  public final int appendShort(short v) {
+    reserve(elementsAppended + 1);
+    putShort(elementsAppended, v);
+    return elementsAppended++;
+  }
+
+  public final int appendShorts(int count, short v) {
+    reserve(elementsAppended + count);
+    int result = elementsAppended;
+    putShorts(elementsAppended, count, v);
+    elementsAppended += count;
+    return result;
+  }
+
+  public final int appendShorts(int length, short[] src, int offset) {
+    reserve(elementsAppended + length);
+    int result = elementsAppended;
+    putShorts(elementsAppended, length, src, offset);
+    elementsAppended += length;
+    return result;
+  }
+
   public final int appendInt(int v) {
     reserve(elementsAppended + 1);
     putInt(elementsAppended, v);
@@ -535,6 +679,20 @@ public abstract class ColumnVector {
     return result;
   }
 
+  public final int appendFloat(float v) {
+    reserve(elementsAppended + 1);
+    putFloat(elementsAppended, v);
+    return elementsAppended++;
+  }
+
+  public final int appendFloats(int count, float v) {
+    reserve(elementsAppended + count);
+    int result = elementsAppended;
+    putFloats(elementsAppended, count, v);
+    elementsAppended += count;
+    return result;
+  }
+
   public final int appendDouble(double v) {
     reserve(elementsAppended + 1);
     putDouble(elementsAppended, v);
@@ -661,7 +819,8 @@ public abstract class ColumnVector {
     this.capacity = capacity;
     this.type = type;
 
-    if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) {
+    if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
+        || DecimalType.isByteArrayDecimalType(type)) {
       DataType childType;
       int childCapacity = capacity;
       if (type instanceof ArrayType) {
@@ -682,6 +841,13 @@ public abstract class ColumnVector {
       }
       this.resultArray = null;
       this.resultStruct = new ColumnarBatch.Row(this.childColumns);
+    } else if (type instanceof CalendarIntervalType) {
+      // Two columns. Months as int. Microseconds as Long.
+      this.childColumns = new ColumnVector[2];
+      this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode);
+      this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode);
+      this.resultArray = null;
+      this.resultStruct = new ColumnarBatch.Row(this.childColumns);
     } else {
       this.childColumns = null;
       this.resultArray = null;

http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 6c651a7..453bc15 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -16,12 +16,15 @@
  */
 package org.apache.spark.sql.execution.vectorized;
 
+import java.math.BigDecimal;
+import java.math.BigInteger;
 import java.util.Iterator;
 import java.util.List;
 
 import org.apache.spark.memory.MemoryMode;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
 
 import org.apache.commons.lang.NotImplementedException;
 
@@ -59,19 +62,44 @@ public class ColumnVectorUtils {
 
   private static void appendValue(ColumnVector dst, DataType t, Object o) {
     if (o == null) {
-      dst.appendNull();
+      if (t instanceof CalendarIntervalType) {
+        dst.appendStruct(true);
+      } else {
+        dst.appendNull();
+      }
     } else {
-      if (t == DataTypes.ByteType) {
-        dst.appendByte(((Byte)o).byteValue());
+      if (t == DataTypes.BooleanType) {
+        dst.appendBoolean(((Boolean)o).booleanValue());
+      } else if (t == DataTypes.ByteType) {
+        dst.appendByte(((Byte) o).byteValue());
+      } else if (t == DataTypes.ShortType) {
+        dst.appendShort(((Short)o).shortValue());
       } else if (t == DataTypes.IntegerType) {
         dst.appendInt(((Integer)o).intValue());
       } else if (t == DataTypes.LongType) {
         dst.appendLong(((Long)o).longValue());
+      } else if (t == DataTypes.FloatType) {
+        dst.appendFloat(((Float)o).floatValue());
       } else if (t == DataTypes.DoubleType) {
         dst.appendDouble(((Double)o).doubleValue());
       } else if (t == DataTypes.StringType) {
         byte[] b =((String)o).getBytes();
         dst.appendByteArray(b, 0, b.length);
+      } else if (t instanceof DecimalType) {
+        DecimalType dt = (DecimalType)t;
+        Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale());
+        if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
+          dst.appendLong(d.toUnscaledLong());
+        } else {
+          final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
+          byte[] bytes = integer.toByteArray();
+          dst.appendByteArray(bytes, 0, bytes.length);
+        }
+      } else if (t instanceof CalendarIntervalType) {
+        CalendarInterval c = (CalendarInterval)o;
+        dst.appendStruct(false);
+        dst.getChildColumn(0).appendInt(c.months);
+        dst.getChildColumn(1).appendLong(c.microseconds);
       } else {
         throw new NotImplementedException("Type " + t);
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 5a57581..dbad5e0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,6 +16,8 @@
  */
 package org.apache.spark.sql.execution.vectorized;
 
+import java.math.BigDecimal;
+import java.math.BigInteger;
 import java.util.Arrays;
 import java.util.Iterator;
 
@@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.catalyst.util.ArrayData;
 import org.apache.spark.sql.catalyst.util.MapData;
 import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.types.CalendarInterval;
 import org.apache.spark.unsafe.types.UTF8String;
 
@@ -150,44 +153,40 @@ public final class ColumnarBatch {
     }
 
     @Override
-    public final boolean isNullAt(int ordinal) {
-      return columns[ordinal].getIsNull(rowId);
-    }
+    public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); }
 
     @Override
-    public final boolean getBoolean(int ordinal) {
-      throw new NotImplementedException();
-    }
+    public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
 
     @Override
     public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
 
     @Override
-    public final short getShort(int ordinal) {
-      throw new NotImplementedException();
-    }
+    public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
 
     @Override
-    public final int getInt(int ordinal) {
-      return columns[ordinal].getInt(rowId);
-    }
+    public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
 
     @Override
     public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
 
     @Override
-    public final float getFloat(int ordinal) {
-      throw new NotImplementedException();
-    }
+    public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
 
     @Override
-    public final double getDouble(int ordinal) {
-      return columns[ordinal].getDouble(rowId);
-    }
+    public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
 
     @Override
     public final Decimal getDecimal(int ordinal, int precision, int scale) {
-      throw new NotImplementedException();
+      if (precision <= Decimal.MAX_LONG_DIGITS()) {
+        return Decimal.apply(getLong(ordinal), precision, scale);
+      } else {
+        // TODO: best perf?
+        byte[] bytes = getBinary(ordinal);
+        BigInteger bigInteger = new BigInteger(bytes);
+        BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+        return Decimal.apply(javaDecimal, precision, scale);
+      }
     }
 
     @Override
@@ -198,12 +197,17 @@ public final class ColumnarBatch {
 
     @Override
     public final byte[] getBinary(int ordinal) {
-      throw new NotImplementedException();
+      ColumnVector.Array array = columns[ordinal].getByteArray(rowId);
+      byte[] bytes = new byte[array.length];
+      System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+      return bytes;
     }
 
     @Override
     public final CalendarInterval getInterval(int ordinal) {
-      throw new NotImplementedException();
+      final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
+      final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
+      return new CalendarInterval(months, microseconds);
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 335124f..22c5e5f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -19,11 +19,15 @@ package org.apache.spark.sql.execution.vectorized;
 import java.nio.ByteOrder;
 
 import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.types.BooleanType;
 import org.apache.spark.sql.types.ByteType;
 import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
 import org.apache.spark.sql.types.IntegerType;
 import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.ShortType;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.types.UTF8String;
 
@@ -122,6 +126,26 @@ public final class OffHeapColumnVector extends ColumnVector {
   }
 
   //
+  // APIs dealing with Booleans
+  //
+
+  @Override
+  public final void putBoolean(int rowId, boolean value) {
+    Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0));
+  }
+
+  @Override
+  public final void putBooleans(int rowId, int count, boolean value) {
+    byte v = (byte)((value) ? 1 : 0);
+    for (int i = 0; i < count; ++i) {
+      Platform.putByte(null, data + rowId + i, v);
+    }
+  }
+
+  @Override
+  public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; }
+
+  //
   // APIs dealing with Bytes
   //
 
@@ -149,6 +173,34 @@ public final class OffHeapColumnVector extends ColumnVector {
   }
 
   //
+  // APIs dealing with shorts
+  //
+
+  @Override
+  public final void putShort(int rowId, short value) {
+    Platform.putShort(null, data + 2 * rowId, value);
+  }
+
+  @Override
+  public final void putShorts(int rowId, int count, short value) {
+    long offset = data + 2 * rowId;
+    for (int i = 0; i < count; ++i, offset += 4) {
+      Platform.putShort(null, offset, value);
+    }
+  }
+
+  @Override
+  public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+    Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2,
+        null, data + 2 * rowId, count * 2);
+  }
+
+  @Override
+  public final short getShort(int rowId) {
+    return Platform.getShort(null, data + 2 * rowId);
+  }
+
+  //
   // APIs dealing with ints
   //
 
@@ -217,6 +269,41 @@ public final class OffHeapColumnVector extends ColumnVector {
   }
 
   //
+  // APIs dealing with floats
+  //
+
+  @Override
+  public final void putFloat(int rowId, float value) {
+    Platform.putFloat(null, data + rowId * 4, value);
+  }
+
+  @Override
+  public final void putFloats(int rowId, int count, float value) {
+    long offset = data + 4 * rowId;
+    for (int i = 0; i < count; ++i, offset += 4) {
+      Platform.putFloat(null, offset, value);
+    }
+  }
+
+  @Override
+  public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+    Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4,
+        null, data + 4 * rowId, count * 4);
+  }
+
+  @Override
+  public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+    Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+        null, data + rowId * 4, count * 4);
+  }
+
+  @Override
+  public final float getFloat(int rowId) {
+    return Platform.getFloat(null, data + rowId * 4);
+  }
+
+
+  //
   // APIs dealing with doubles
   //
 
@@ -241,7 +328,7 @@ public final class OffHeapColumnVector extends ColumnVector {
 
   @Override
   public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
-    Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex,
+    Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
         null, data + rowId * 8, count * 8);
   }
 
@@ -300,11 +387,14 @@ public final class OffHeapColumnVector extends ColumnVector {
           Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4);
       this.offsetData =
           Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4);
-    } else if (type instanceof ByteType) {
+    } else if (type instanceof ByteType || type instanceof BooleanType) {
       this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity);
-    } else if (type instanceof IntegerType) {
+    } else if (type instanceof ShortType) {
+      this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
+    } else if (type instanceof IntegerType || type instanceof FloatType) {
       this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
-    } else if (type instanceof LongType || type instanceof DoubleType) {
+    } else if (type instanceof LongType || type instanceof DoubleType ||
+        DecimalType.is64BitDecimalType(type)) {
       this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8);
     } else if (resultStruct != null) {
       // Nothing to store.

http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 8197fa1..3235633 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector {
 
   // Array for each type. Only 1 is populated for any type.
   private byte[] byteData;
+  private short[] shortData;
   private int[] intData;
   private long[] longData;
+  private float[] floatData;
   private double[] doubleData;
 
   // Only set if type is Array.
@@ -105,6 +107,30 @@ public final class OnHeapColumnVector extends ColumnVector {
   }
 
   //
+  // APIs dealing with Booleans
+  //
+
+  @Override
+  public final void putBoolean(int rowId, boolean value) {
+    byteData[rowId] = (byte)((value) ? 1 : 0);
+  }
+
+  @Override
+  public final void putBooleans(int rowId, int count, boolean value) {
+    byte v = (byte)((value) ? 1 : 0);
+    for (int i = 0; i < count; ++i) {
+      byteData[i + rowId] = v;
+    }
+  }
+
+  @Override
+  public final boolean getBoolean(int rowId) {
+    return byteData[rowId] == 1;
+  }
+
+  //
+
+  //
   // APIs dealing with Bytes
   //
 
@@ -131,6 +157,33 @@ public final class OnHeapColumnVector extends ColumnVector {
   }
 
   //
+  // APIs dealing with Shorts
+  //
+
+  @Override
+  public final void putShort(int rowId, short value) {
+    shortData[rowId] = value;
+  }
+
+  @Override
+  public final void putShorts(int rowId, int count, short value) {
+    for (int i = 0; i < count; ++i) {
+      shortData[i + rowId] = value;
+    }
+  }
+
+  @Override
+  public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+    System.arraycopy(src, srcIndex, shortData, rowId, count);
+  }
+
+  @Override
+  public final short getShort(int rowId) {
+    return shortData[rowId];
+  }
+
+
+  //
   // APIs dealing with Ints
   //
 
@@ -202,6 +255,31 @@ public final class OnHeapColumnVector extends ColumnVector {
     return longData[rowId];
   }
 
+  //
+  // APIs dealing with floats
+  //
+
+  @Override
+  public final void putFloat(int rowId, float value) { floatData[rowId] = value; }
+
+  @Override
+  public final void putFloats(int rowId, int count, float value) {
+    Arrays.fill(floatData, rowId, rowId + count, value);
+  }
+
+  @Override
+  public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+    System.arraycopy(src, srcIndex, floatData, rowId, count);
+  }
+
+  @Override
+  public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+    Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+        floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4);
+  }
+
+  @Override
+  public final float getFloat(int rowId) { return floatData[rowId]; }
 
   //
   // APIs dealing with doubles
@@ -277,7 +355,7 @@ public final class OnHeapColumnVector extends ColumnVector {
 
   // Spilt this function out since it is the slow path.
   private final void reserveInternal(int newCapacity) {
-    if (this.resultArray != null) {
+    if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) {
       int[] newLengths = new int[newCapacity];
       int[] newOffsets = new int[newCapacity];
       if (this.arrayLengths != null) {
@@ -286,18 +364,30 @@ public final class OnHeapColumnVector extends ColumnVector {
       }
       arrayLengths = newLengths;
       arrayOffsets = newOffsets;
+    } else if (type instanceof BooleanType) {
+      byte[] newData = new byte[newCapacity];
+      if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+      byteData = newData;
     } else if (type instanceof ByteType) {
       byte[] newData = new byte[newCapacity];
       if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
       byteData = newData;
+    } else if (type instanceof ShortType) {
+      short[] newData = new short[newCapacity];
+      if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
+      shortData = newData;
     } else if (type instanceof IntegerType) {
       int[] newData = new int[newCapacity];
       if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
       intData = newData;
-    } else if (type instanceof LongType) {
+    } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) {
       long[] newData = new long[newCapacity];
       if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended);
       longData = newData;
+    } else if (type instanceof FloatType) {
+      float[] newData = new float[newCapacity];
+      if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended);
+      floatData = newData;
     } else if (type instanceof DoubleType) {
       double[] newData = new double[newCapacity];
       if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended);

http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 67cc08b..445f311 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.types.CalendarInterval
 
 class ColumnarBatchSuite extends SparkFunSuite {
   test("Null Apis") {
@@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
     }}
   }
 
-
   private def doubleEquals(d1: Double, d2: Double): Boolean = {
     if (d1.isNaN && d2.isNaN) {
       true
@@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite {
       assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed)
       if (!r1.isNullAt(v._2)) {
         v._1.dataType match {
+          case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed)
           case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed)
+          case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed)
           case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed)
           case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed)
+          case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)),
+            "Seed = " + seed)
           case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)),
             "Seed = " + seed)
+          case t: DecimalType =>
+            val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal
+            val d2 = r2.getDecimal(v._2)
+            assert(d1.compare(d2) == 0, "Seed = " + seed)
           case StringType =>
             assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed)
+          case CalendarIntervalType =>
+            assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval])
           case ArrayType(childType, n) =>
             val a1 = r1.getArray(v._2).array
             val a2 = r2.getList(v._2).toArray
@@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite {
                   i += 1
                 }
               }
+              case FloatType => {
+                var i = 0
+                while (i < a1.length) {
+                  assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]),
+                    "Seed = " + seed)
+                  i += 1
+                }
+              }
+
+              case t: DecimalType =>
+                var i = 0
+                while (i < a1.length) {
+                  assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+                  if (a1(i) != null) {
+                    val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal
+                    val d2 = a2(i).asInstanceOf[java.math.BigDecimal]
+                    assert(d1.compare(d2) == 0, "Seed = " + seed)
+                  }
+                  i += 1
+                }
+
               case _ => assert(a1 === a2, "Seed = " + seed)
             }
           case StructType(childFields) =>
@@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
    * results.
    */
   def testRandomRows(flatSchema: Boolean, numFields: Int) {
-    // TODO: add remaining types. Figure out why StringType doesn't work on jenkins.
-    val types = Array(ByteType, IntegerType, LongType, DoubleType)
+    // TODO: Figure out why StringType doesn't work on jenkins.
+    val types = Array(
+      BooleanType, ByteType, FloatType, DoubleType,
+      IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10),
+      CalendarIntervalType)
     val seed = System.nanoTime()
-    val NUM_ROWS = 500
+    val NUM_ROWS = 200
     val NUM_ITERS = 1000
     val random = new Random(seed)
     var i = 0
@@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
   }
 
   test("Random flat schema") {
-    testRandomRows(true, 10)
+    testRandomRows(true, 15)
   }
 
   test("Random nested schema") {

http://git-wip-us.apache.org/repos/asf/spark/blob/064b029c/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index b29bf6a..18761bf 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -27,10 +27,14 @@ public final class Platform {
 
   public static final int BYTE_ARRAY_OFFSET;
 
+  public static final int SHORT_ARRAY_OFFSET;
+
   public static final int INT_ARRAY_OFFSET;
 
   public static final int LONG_ARRAY_OFFSET;
 
+  public static final int FLOAT_ARRAY_OFFSET;
+
   public static final int DOUBLE_ARRAY_OFFSET;
 
   public static int getInt(Object object, long offset) {
@@ -168,13 +172,17 @@ public final class Platform {
 
     if (_UNSAFE != null) {
       BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
+      SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class);
       INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
       LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
+      FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class);
       DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
     } else {
       BYTE_ARRAY_OFFSET = 0;
+      SHORT_ARRAY_OFFSET = 0;
       INT_ARRAY_OFFSET = 0;
       LONG_ARRAY_OFFSET = 0;
+      FLOAT_ARRAY_OFFSET = 0;
       DOUBLE_ARRAY_OFFSET = 0;
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org