You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/25 04:29:09 UTC

spark git commit: [SPARK-9330][SQL] Create specialized getStruct getter in InternalRow.

Repository: spark
Updated Branches:
  refs/heads/master a400ab516 -> f99cb5615


[SPARK-9330][SQL] Create specialized getStruct getter in InternalRow.

Also took the chance to rearrange some of the methods in UnsafeRow to group static/private/public things together.

Author: Reynold Xin <rx...@databricks.com>

Closes #7654 from rxin/getStruct and squashes the following commits:

b491a09 [Reynold Xin] Fixed typo.
48d77e5 [Reynold Xin] [SPARK-9330][SQL] Create specialized getStruct getter in InternalRow.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f99cb561
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f99cb561
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f99cb561

Branch: refs/heads/master
Commit: f99cb5615cbc0b469d52af6bd08f8bf888af58f3
Parents: a400ab5
Author: Reynold Xin <rx...@databricks.com>
Authored: Fri Jul 24 19:29:01 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Jul 24 19:29:01 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/UnsafeRow.java     | 87 +++++++++++++-------
 .../sql/catalyst/CatalystTypeConverters.scala   |  2 +-
 .../apache/spark/sql/catalyst/InternalRow.scala | 22 +++--
 .../catalyst/expressions/BoundAttribute.scala   |  1 +
 .../expressions/codegen/CodeGenerator.scala     |  5 +-
 5 files changed, 77 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index a898660..225f6e6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -51,28 +51,9 @@ import static org.apache.spark.sql.types.DataTypes.*;
  */
 public final class UnsafeRow extends MutableRow {
 
-  private Object baseObject;
-  private long baseOffset;
-
-  public Object getBaseObject() { return baseObject; }
-  public long getBaseOffset() { return baseOffset; }
-  public int getSizeInBytes() { return sizeInBytes; }
-
-  /** The number of fields in this row, used for calculating the bitset width (and in assertions) */
-  private int numFields;
-
-  /** The size of this row's backing data, in bytes) */
-  private int sizeInBytes;
-
-  @Override
-  public int numFields() { return numFields; }
-
-  /** The width of the null tracking bit set, in bytes */
-  private int bitSetWidthInBytes;
-
-  private long getFieldOffset(int ordinal) {
-   return baseOffset + bitSetWidthInBytes + ordinal * 8L;
-  }
+  //////////////////////////////////////////////////////////////////////////////
+  // Static methods
+  //////////////////////////////////////////////////////////////////////////////
 
   public static int calculateBitSetWidthInBytes(int numFields) {
     return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
@@ -103,7 +84,7 @@ public final class UnsafeRow extends MutableRow {
           DoubleType,
           DateType,
           TimestampType
-    })));
+        })));
 
     // We support get() on a superset of the types for which we support set():
     final Set<DataType> _readableFieldTypes = new HashSet<>(
@@ -115,12 +96,48 @@ public final class UnsafeRow extends MutableRow {
     readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
   }
 
+  //////////////////////////////////////////////////////////////////////////////
+  // Private fields and methods
+  //////////////////////////////////////////////////////////////////////////////
+
+  private Object baseObject;
+  private long baseOffset;
+
+  /** The number of fields in this row, used for calculating the bitset width (and in assertions) */
+  private int numFields;
+
+  /** The size of this row's backing data, in bytes) */
+  private int sizeInBytes;
+
+  private void setNotNullAt(int i) {
+    assertIndexIsValid(i);
+    BitSetMethods.unset(baseObject, baseOffset, i);
+  }
+
+  /** The width of the null tracking bit set, in bytes */
+  private int bitSetWidthInBytes;
+
+  private long getFieldOffset(int ordinal) {
+    return baseOffset + bitSetWidthInBytes + ordinal * 8L;
+  }
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Public methods
+  //////////////////////////////////////////////////////////////////////////////
+
   /**
    * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
    * since the value returned by this constructor is equivalent to a null pointer.
    */
   public UnsafeRow() { }
 
+  public Object getBaseObject() { return baseObject; }
+  public long getBaseOffset() { return baseOffset; }
+  public int getSizeInBytes() { return sizeInBytes; }
+
+  @Override
+  public int numFields() { return numFields; }
+
   /**
    * Update this UnsafeRow to point to different backing data.
    *
@@ -130,7 +147,7 @@ public final class UnsafeRow extends MutableRow {
    * @param sizeInBytes the size of this row's backing data, in bytes
    */
   public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) {
-    assert numFields >= 0 : "numFields should >= 0";
+    assert numFields >= 0 : "numFields (" + numFields + ") should >= 0";
     this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
     this.baseObject = baseObject;
     this.baseOffset = baseOffset;
@@ -153,11 +170,6 @@ public final class UnsafeRow extends MutableRow {
     PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0);
   }
 
-  private void setNotNullAt(int i) {
-    assertIndexIsValid(i);
-    BitSetMethods.unset(baseObject, baseOffset, i);
-  }
-
   @Override
   public void update(int ordinal, Object value) {
     throw new UnsupportedOperationException();
@@ -316,6 +328,21 @@ public final class UnsafeRow extends MutableRow {
     return getUTF8String(i).toString();
   }
 
+  @Override
+  public UnsafeRow getStruct(int i, int numFields) {
+    if (isNullAt(i)) {
+      return null;
+    } else {
+      assertIndexIsValid(i);
+      final long offsetAndSize = getLong(i);
+      final int offset = (int) (offsetAndSize >> 32);
+      final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+      final UnsafeRow row = new UnsafeRow();
+      row.pointTo(baseObject, baseOffset + offset, numFields, size);
+      return row;
+    }
+  }
+
   /**
    * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
    * byte array rather than referencing data stored in a data page.
@@ -388,7 +415,7 @@ public final class UnsafeRow extends MutableRow {
    */
   public byte[] getBytes() {
     if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
-        && (((byte[]) baseObject).length == sizeInBytes)) {
+      && (((byte[]) baseObject).length == sizeInBytes)) {
       return (byte[]) baseObject;
     } else {
       byte[] bytes = new byte[sizeInBytes];

http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 5c3072a..7416ddb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -271,7 +271,7 @@ object CatalystTypeConverters {
     }
 
     override def toScalaImpl(row: InternalRow, column: Int): Row =
-      toScala(row.get(column).asInstanceOf[InternalRow])
+      toScala(row.getStruct(column, structType.size))
   }
 
   private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {

http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index efc4fae..f248b1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -52,6 +52,21 @@ abstract class InternalRow extends Serializable {
 
   def getDouble(i: Int): Double = getAs[Double](i)
 
+  def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i)
+
+  def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i)
+
+  // This is only use for test
+  def getString(i: Int): String = getAs[UTF8String](i).toString
+
+  /**
+   * Returns a struct from ordinal position.
+   *
+   * @param ordinal position to get the struct from.
+   * @param numFields number of fields the struct type has
+   */
+  def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal)
+
   override def toString: String = s"[${this.mkString(",")}]"
 
   /**
@@ -145,13 +160,6 @@ abstract class InternalRow extends Serializable {
    */
   def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
 
-  def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i)
-
-  def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i)
-
-  // This is only use for test
-  def getString(i: Int): String = getAs[UTF8String](i).toString
-
   // Custom hashCode function that matches the efficient code generated version.
   override def hashCode: Int = {
     var result: Int = 37

http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 6aa4930..1f7adcd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
         case DoubleType => input.getDouble(ordinal)
         case StringType => input.getUTF8String(ordinal)
         case BinaryType => input.getBinary(ordinal)
+        case t: StructType => input.getStruct(ordinal, t.size)
         case _ => input.get(ordinal)
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 48225e1..4a90f1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -109,6 +109,7 @@ class CodeGenContext {
       case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
       case StringType => s"$row.getUTF8String($ordinal)"
       case BinaryType => s"$row.getBinary($ordinal)"
+      case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
       case _ => s"($jt)$row.apply($ordinal)"
     }
   }
@@ -249,13 +250,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
   protected val mutableRowType: String = classOf[MutableRow].getName
   protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
 
-  protected def declareMutableStates(ctx: CodeGenContext) = {
+  protected def declareMutableStates(ctx: CodeGenContext): String = {
     ctx.mutableStates.map { case (javaType, variableName, _) =>
       s"private $javaType $variableName;"
     }.mkString("\n      ")
   }
 
-  protected def initMutableStates(ctx: CodeGenContext) = {
+  protected def initMutableStates(ctx: CodeGenContext): String = {
     ctx.mutableStates.map(_._3).mkString("\n        ")
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org