You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/10/02 22:47:39 UTC

spark git commit: [SPARK-17509][SQL] When wrapping catalyst datatype to Hive data type avoid…

Repository: spark
Updated Branches:
  refs/heads/master b88cb63da -> f8d7fade4


[SPARK-17509][SQL] When wrapping catalyst datatype to Hive data type avoid\u2026

## What changes were proposed in this pull request?

When wrapping catalyst datatypes to Hive data type, wrap function was doing an expensive pattern matching which was consuming around 11% of cpu time. Avoid the pattern matching by returning the wrapper only once and reuse it.

## How was this patch tested?

Tested by running the job on cluster and saw around 8% cpu improvements.

Author: Sital Kedia <sk...@fb.com>

Closes #15064 from sitalkedia/skedia/hive_wrapper.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f8d7fade
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f8d7fade
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f8d7fade

Branch: refs/heads/master
Commit: f8d7fade4b9a78ae87b6012e3d6f71eef3032b22
Parents: b88cb63
Author: Sital Kedia <sk...@fb.com>
Authored: Sun Oct 2 15:47:36 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sun Oct 2 15:47:36 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/hive/HiveInspectors.scala  | 307 ++++++++-----------
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |  15 +-
 2 files changed, 145 insertions(+), 177 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f8d7fade/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index e4b963e..c3c4351 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -238,102 +238,161 @@ private[hive] trait HiveInspectors {
     case c => throw new AnalysisException(s"Unsupported java type $c")
   }
 
+  private def withNullSafe(f: Any => Any): Any => Any = {
+    input => if (input == null) null else f(input)
+  }
+
   /**
    * Wraps with Hive types based on object inspector.
-   * TODO: Consolidate all hive OI/data interface code.
    */
   protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match {
-    case _: JavaHiveVarcharObjectInspector =>
-      (o: Any) =>
-        if (o != null) {
-          val s = o.asInstanceOf[UTF8String].toString
-          new HiveVarchar(s, s.length)
-        } else {
-          null
-        }
-
-    case _: JavaHiveCharObjectInspector =>
-      (o: Any) =>
-        if (o != null) {
-          val s = o.asInstanceOf[UTF8String].toString
-          new HiveChar(s, s.length)
-        } else {
-          null
-        }
-
-    case _: JavaHiveDecimalObjectInspector =>
-      (o: Any) =>
-        if (o != null) {
-          HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)
-        } else {
-          null
-        }
-
-    case _: JavaDateObjectInspector =>
-      (o: Any) =>
-        if (o != null) {
-          DateTimeUtils.toJavaDate(o.asInstanceOf[Int])
-        } else {
-          null
-        }
-
-    case _: JavaTimestampObjectInspector =>
+    case x: ConstantObjectInspector =>
       (o: Any) =>
-        if (o != null) {
-          DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])
-        } else {
-          null
+        x.getWritableConstantValue
+    case x: PrimitiveObjectInspector => x match {
+      // TODO we don't support the HiveVarcharObjectInspector yet.
+      case _: StringObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getStringWritable(o))
+      case _: StringObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[UTF8String].toString())
+      case _: IntObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getIntWritable(o))
+      case _: IntObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[java.lang.Integer])
+      case _: BooleanObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getBooleanWritable(o))
+      case _: BooleanObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[java.lang.Boolean])
+      case _: FloatObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getFloatWritable(o))
+      case _: FloatObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[java.lang.Float])
+      case _: DoubleObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getDoubleWritable(o))
+      case _: DoubleObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[java.lang.Double])
+      case _: LongObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getLongWritable(o))
+      case _: LongObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[java.lang.Long])
+      case _: ShortObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getShortWritable(o))
+      case _: ShortObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[java.lang.Short])
+      case _: ByteObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getByteWritable(o))
+      case _: ByteObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[java.lang.Byte])
+      case _: JavaHiveVarcharObjectInspector =>
+        withNullSafe { o =>
+            val s = o.asInstanceOf[UTF8String].toString
+            new HiveVarchar(s, s.length)
         }
+      case _: JavaHiveCharObjectInspector =>
+        withNullSafe { o =>
+            val s = o.asInstanceOf[UTF8String].toString
+            new HiveChar(s, s.length)
+          }
+      case _: JavaHiveDecimalObjectInspector =>
+        withNullSafe(o =>
+          HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal))
+      case _: JavaDateObjectInspector =>
+        withNullSafe(o =>
+            DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))
+      case _: JavaTimestampObjectInspector =>
+        withNullSafe(o =>
+            DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]))
+      case _: HiveDecimalObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getDecimalWritable(o.asInstanceOf[Decimal]))
+      case _: HiveDecimalObjectInspector =>
+        withNullSafe(o =>
+            HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal))
+      case _: BinaryObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getBinaryWritable(o))
+      case _: BinaryObjectInspector =>
+        withNullSafe(o => o.asInstanceOf[Array[Byte]])
+      case _: DateObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getDateWritable(o))
+      case _: DateObjectInspector =>
+        withNullSafe(o => DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))
+      case _: TimestampObjectInspector if x.preferWritable() =>
+        withNullSafe(o => getTimestampWritable(o))
+      case _: TimestampObjectInspector =>
+        withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]))
+    }
 
     case soi: StandardStructObjectInspector =>
       val schema = dataType.asInstanceOf[StructType]
       val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map {
         case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType)
       }
-      (o: Any) => {
-        if (o != null) {
-          val struct = soi.create()
-          val row = o.asInstanceOf[InternalRow]
-          soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
-            case ((field, wrapper), i) =>
-              soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType)))
-          }
-          struct
-        } else {
-          null
+      withNullSafe { o =>
+        val struct = soi.create()
+        val row = o.asInstanceOf[InternalRow]
+        soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
+          case ((field, wrapper), i) =>
+            soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType)))
+        }
+        struct
+      }
+
+    case ssoi: SettableStructObjectInspector =>
+      val structType = dataType.asInstanceOf[StructType]
+      val wrappers = ssoi.getAllStructFieldRefs.asScala.zip(structType).map {
+        case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType)
+      }
+      withNullSafe { o =>
+        val row = o.asInstanceOf[InternalRow]
+        // 1. create the pojo (most likely) object
+        val result = ssoi.create()
+        ssoi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
+          case ((field, wrapper), i) =>
+            val tpe = structType(i).dataType
+            ssoi.setStructFieldData(
+            result,
+            field,
+            wrapper(row.get(i, tpe)).asInstanceOf[AnyRef])
         }
+        result
+      }
+
+    case soi: StructObjectInspector =>
+      val structType = dataType.asInstanceOf[StructType]
+      val wrappers = soi.getAllStructFieldRefs.asScala.zip(structType).map {
+        case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType)
+      }
+      withNullSafe { o =>
+        val row = o.asInstanceOf[InternalRow]
+        val result = new java.util.ArrayList[AnyRef](wrappers.size)
+        soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
+          case ((field, wrapper), i) =>
+          val tpe = structType(i).dataType
+          result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef])
+        }
+        result
       }
 
     case loi: ListObjectInspector =>
       val elementType = dataType.asInstanceOf[ArrayType].elementType
       val wrapper = wrapperFor(loi.getListElementObjectInspector, elementType)
-      (o: Any) => {
-        if (o != null) {
-          val array = o.asInstanceOf[ArrayData]
-          val values = new java.util.ArrayList[Any](array.numElements())
-          array.foreach(elementType, (_, e) => values.add(wrapper(e)))
-          values
-        } else {
-          null
-        }
+      withNullSafe { o =>
+        val array = o.asInstanceOf[ArrayData]
+        val values = new java.util.ArrayList[Any](array.numElements())
+        array.foreach(elementType, (_, e) => values.add(wrapper(e)))
+        values
       }
 
     case moi: MapObjectInspector =>
       val mt = dataType.asInstanceOf[MapType]
       val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector, mt.keyType)
       val valueWrapper = wrapperFor(moi.getMapValueObjectInspector, mt.valueType)
-
-      (o: Any) => {
-        if (o != null) {
+      withNullSafe { o =>
           val map = o.asInstanceOf[MapData]
           val jmap = new java.util.HashMap[Any, Any](map.numElements())
           map.foreach(mt.keyType, mt.valueType, (k, v) =>
             jmap.put(keyWrapper(k), valueWrapper(v)))
           jmap
-        } else {
-          null
         }
-      }
 
     case _ =>
       identity[Any]
@@ -648,119 +707,19 @@ private[hive] trait HiveInspectors {
         (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value)
     }
 
-  /**
-   * Converts native catalyst types to the types expected by Hive
-   * @param a the value to be wrapped
-   * @param oi This ObjectInspector associated with the value returned by this function, and
-   *           the ObjectInspector should also be consistent with those returned from
-   *           toInspector: DataType => ObjectInspector and
-   *           toInspector: Expression => ObjectInspector
-   *
-   * Strictly follows the following order in wrapping (constant OI has the higher priority):
-   *   Constant object inspector => return the bundled value of Constant object inspector
-   *   Check whether the `a` is null => return null if true
-   *   If object inspector prefers writable object => return a Writable for the given data `a`
-   *   Map the catalyst data to the boxed java primitive
-   *
-   *  NOTICE: the complex data type requires recursive wrapping.
-   */
-  def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match {
-    case x: ConstantObjectInspector => x.getWritableConstantValue
-    case _ if a == null => null
-    case x: PrimitiveObjectInspector => x match {
-      // TODO we don't support the HiveVarcharObjectInspector yet.
-      case _: StringObjectInspector if x.preferWritable() => getStringWritable(a)
-      case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString()
-      case _: IntObjectInspector if x.preferWritable() => getIntWritable(a)
-      case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer]
-      case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a)
-      case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean]
-      case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a)
-      case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float]
-      case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a)
-      case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double]
-      case _: LongObjectInspector if x.preferWritable() => getLongWritable(a)
-      case _: LongObjectInspector => a.asInstanceOf[java.lang.Long]
-      case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a)
-      case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short]
-      case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a)
-      case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte]
-      case _: HiveDecimalObjectInspector if x.preferWritable() =>
-        getDecimalWritable(a.asInstanceOf[Decimal])
-      case _: HiveDecimalObjectInspector =>
-        HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal)
-      case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a)
-      case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]]
-      case _: DateObjectInspector if x.preferWritable() => getDateWritable(a)
-      case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int])
-      case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a)
-      case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long])
-    }
-    case x: SettableStructObjectInspector =>
-      val fieldRefs = x.getAllStructFieldRefs
-      val structType = dataType.asInstanceOf[StructType]
-      val row = a.asInstanceOf[InternalRow]
-      // 1. create the pojo (most likely) object
-      val result = x.create()
-      var i = 0
-      val size = fieldRefs.size
-      while (i < size) {
-        // 2. set the property for the pojo
-        val tpe = structType(i).dataType
-        x.setStructFieldData(
-          result,
-          fieldRefs.get(i),
-          wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe))
-        i += 1
-      }
-
-      result
-    case x: StructObjectInspector =>
-      val fieldRefs = x.getAllStructFieldRefs
-      val structType = dataType.asInstanceOf[StructType]
-      val row = a.asInstanceOf[InternalRow]
-      val result = new java.util.ArrayList[AnyRef](fieldRefs.size)
-      var i = 0
-      val size = fieldRefs.size
-      while (i < size) {
-        val tpe = structType(i).dataType
-        result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe))
-        i += 1
-      }
-
-      result
-    case x: ListObjectInspector =>
-      val list = new java.util.ArrayList[Object]
-      val tpe = dataType.asInstanceOf[ArrayType].elementType
-      a.asInstanceOf[ArrayData].foreach(tpe, (_, e) =>
-        list.add(wrap(e, x.getListElementObjectInspector, tpe))
-      )
-      list
-    case x: MapObjectInspector =>
-      val keyType = dataType.asInstanceOf[MapType].keyType
-      val valueType = dataType.asInstanceOf[MapType].valueType
-      val map = a.asInstanceOf[MapData]
-
-      // Some UDFs seem to assume we pass in a HashMap.
-      val hashMap = new java.util.HashMap[Any, Any](map.numElements())
-
-      map.foreach(keyType, valueType, (k, v) =>
-        hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType),
-          wrap(v, x.getMapValueObjectInspector, valueType))
-      )
-
-      hashMap
+  def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = {
+    wrapperFor(oi, dataType)(a).asInstanceOf[AnyRef]
   }
 
   def wrap(
       row: InternalRow,
-      inspectors: Seq[ObjectInspector],
+      wrappers: Array[(Any) => Any],
       cache: Array[AnyRef],
       dataTypes: Array[DataType]): Array[AnyRef] = {
     var i = 0
-    val length = inspectors.length
+    val length = wrappers.length
     while (i < length) {
-      cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i))
+      cache(i) = wrappers(i)(row.get(i, dataTypes(i))).asInstanceOf[AnyRef]
       i += 1
     }
     cache
@@ -768,13 +727,13 @@ private[hive] trait HiveInspectors {
 
   def wrap(
       row: Seq[Any],
-      inspectors: Seq[ObjectInspector],
+      wrappers: Array[(Any) => Any],
       cache: Array[AnyRef],
       dataTypes: Array[DataType]): Array[AnyRef] = {
     var i = 0
-    val length = inspectors.length
+    val length = wrappers.length
     while (i < length) {
-      cache(i) = wrap(row(i), inspectors(i), dataTypes(i))
+      cache(i) = wrappers(i)(row(i)).asInstanceOf[AnyRef]
       i += 1
     }
     cache

http://git-wip-us.apache.org/repos/asf/spark/blob/f8d7fade/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 962dd5a..d549135 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -71,6 +71,9 @@ private[hive] case class HiveSimpleUDF(
   override lazy val dataType = javaClassToDataType(method.getReturnType)
 
   @transient
+  private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+  @transient
   lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
     method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
 
@@ -82,7 +85,7 @@ private[hive] case class HiveSimpleUDF(
 
   // TODO: Finish input output types.
   override def eval(input: InternalRow): Any = {
-    val inputs = wrap(children.map(_.eval(input)), arguments, cached, inputDataTypes)
+    val inputs = wrap(children.map(_.eval(input)), wrappers, cached, inputDataTypes)
     val ret = FunctionRegistry.invoke(
       method,
       function,
@@ -215,6 +218,9 @@ private[hive] case class HiveGenericUDTF(
   private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
 
   @transient
+  private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+  @transient
   private lazy val unwrapper = unwrapperFor(outputInspector)
 
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -222,7 +228,7 @@ private[hive] case class HiveGenericUDTF(
 
     val inputProjection = new InterpretedProjection(children)
 
-    function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes))
+    function.process(wrap(inputProjection(input), wrappers, udtInput, inputDataTypes))
     collector.collectRows()
   }
 
@@ -297,6 +303,9 @@ private[hive] case class HiveUDAFFunction(
   private lazy val function = functionAndInspector._1
 
   @transient
+  private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+  @transient
   private lazy val returnInspector = functionAndInspector._2
 
   @transient
@@ -322,7 +331,7 @@ private[hive] case class HiveUDAFFunction(
 
   override def update(_buffer: MutableRow, input: InternalRow): Unit = {
     val inputs = inputProjection(input)
-    function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
+    function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes))
   }
 
   override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org