You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/08 05:59:17 UTC

spark git commit: [SPARK-22979][PYTHON][SQL] Avoid per-record type dispatch in Python data conversion (EvaluatePython.fromJava)

Repository: spark
Updated Branches:
  refs/heads/master 3e40eb3f1 -> 8fdeb4b99


[SPARK-22979][PYTHON][SQL] Avoid per-record type dispatch in Python data conversion (EvaluatePython.fromJava)

## What changes were proposed in this pull request?

Seems we can avoid type dispatch for each value when Java objection (from Pyrolite) -> Spark's internal data format because we know the schema ahead.

I manually performed the benchmark as below:

```scala
  test("EvaluatePython.fromJava / EvaluatePython.makeFromJava") {
    val numRows = 1000 * 1000
    val numFields = 30

    val random = new Random(System.nanoTime())
    val types = Array(
      BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType,
      DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
      DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
      new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType)
    val schema = RandomDataGenerator.randomSchema(random, numFields, types)
    val rows = mutable.ArrayBuffer.empty[Array[Any]]
    var i = 0
    while (i < numRows) {
      val row = RandomDataGenerator.randomRow(random, schema)
      rows += row.toSeq.toArray
      i += 1
    }

    val benchmark = new Benchmark("EvaluatePython.fromJava / EvaluatePython.makeFromJava", numRows)
    benchmark.addCase("Before - EvaluatePython.fromJava", 3) { _ =>
      var i = 0
      while (i < numRows) {
        EvaluatePython.fromJava(rows(i), schema)
        i += 1
      }
    }

    benchmark.addCase("After - EvaluatePython.makeFromJava", 3) { _ =>
      val fromJava = EvaluatePython.makeFromJava(schema)
      var i = 0
      while (i < numRows) {
        fromJava(rows(i))
        i += 1
      }
    }

    benchmark.run()
  }
```

```
EvaluatePython.fromJava / EvaluatePython.makeFromJava: Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Before - EvaluatePython.fromJava              1265 / 1346          0.8        1264.8       1.0X
After - EvaluatePython.makeFromJava            571 /  649          1.8         570.8       2.2X
```

If the structure is nested, I think the advantage should be larger than this.

## How was this patch tested?

Existing tests should cover this. Also, I manually checked if the values from before / after are actually same via `assert` when performing the benchmarks.

Author: hyukjinkwon <gu...@gmail.com>

Closes #20172 from HyukjinKwon/type-dispatch-python-eval.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8fdeb4b9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8fdeb4b9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8fdeb4b9

Branch: refs/heads/master
Commit: 8fdeb4b9946bd9be045abb919da2e531708b3bd4
Parents: 3e40eb3
Author: hyukjinkwon <gu...@gmail.com>
Authored: Mon Jan 8 13:59:08 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jan 8 13:59:08 2018 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/SparkSession.scala     |   5 +-
 .../execution/python/BatchEvalPythonExec.scala  |   7 +-
 .../sql/execution/python/EvaluatePython.scala   | 166 ++++++++++++-------
 3 files changed, 118 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8fdeb4b9/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 272eb84..734573b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -742,7 +742,10 @@ class SparkSession private(
   private[sql] def applySchemaToPythonRDD(
       rdd: RDD[Array[Any]],
       schema: StructType): DataFrame = {
-    val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
+    val rowRdd = rdd.mapPartitions { iter =>
+      val fromJava = python.EvaluatePython.makeFromJava(schema)
+      iter.map(r => fromJava(r).asInstanceOf[InternalRow])
+    }
     internalCreateDataFrame(rowRdd, schema)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8fdeb4b9/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 26ee25f..f4d83e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
     } else {
       StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
     }
+
+    val fromJava = EvaluatePython.makeFromJava(resultType)
+
     outputIterator.flatMap { pickedResult =>
       val unpickledBatch = unpickle.loads(pickedResult)
       unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
     }.map { result =>
       if (udfs.length == 1) {
         // fast path for single UDF
-        mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+        mutableRow(0) = fromJava(result)
         mutableRow
       } else {
-        EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+        fromJava(result).asInstanceOf[InternalRow]
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8fdeb4b9/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 9bbfa60..520afad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -83,82 +83,134 @@ object EvaluatePython {
   }
 
   /**
-   * Converts `obj` to the type specified by the data type, or returns null if the type of obj is
-   * unexpected. Because Python doesn't enforce the type.
+   * Make a converter that converts `obj` to the type specified by the data type, or returns
+   * null if the type of obj is unexpected. Because Python doesn't enforce the type.
    */
-  def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
-    case (null, _) => null
-
-    case (c: Boolean, BooleanType) => c
+  def makeFromJava(dataType: DataType): Any => Any = dataType match {
+    case BooleanType => (obj: Any) => nullSafeConvert(obj) {
+      case b: Boolean => b
+    }
 
-    case (c: Byte, ByteType) => c
-    case (c: Short, ByteType) => c.toByte
-    case (c: Int, ByteType) => c.toByte
-    case (c: Long, ByteType) => c.toByte
+    case ByteType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Byte => c
+      case c: Short => c.toByte
+      case c: Int => c.toByte
+      case c: Long => c.toByte
+    }
 
-    case (c: Byte, ShortType) => c.toShort
-    case (c: Short, ShortType) => c
-    case (c: Int, ShortType) => c.toShort
-    case (c: Long, ShortType) => c.toShort
+    case ShortType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Byte => c.toShort
+      case c: Short => c
+      case c: Int => c.toShort
+      case c: Long => c.toShort
+    }
 
-    case (c: Byte, IntegerType) => c.toInt
-    case (c: Short, IntegerType) => c.toInt
-    case (c: Int, IntegerType) => c
-    case (c: Long, IntegerType) => c.toInt
+    case IntegerType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Byte => c.toInt
+      case c: Short => c.toInt
+      case c: Int => c
+      case c: Long => c.toInt
+    }
 
-    case (c: Byte, LongType) => c.toLong
-    case (c: Short, LongType) => c.toLong
-    case (c: Int, LongType) => c.toLong
-    case (c: Long, LongType) => c
+    case LongType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Byte => c.toLong
+      case c: Short => c.toLong
+      case c: Int => c.toLong
+      case c: Long => c
+    }
 
-    case (c: Float, FloatType) => c
-    case (c: Double, FloatType) => c.toFloat
+    case FloatType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Float => c
+      case c: Double => c.toFloat
+    }
 
-    case (c: Float, DoubleType) => c.toDouble
-    case (c: Double, DoubleType) => c
+    case DoubleType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Float => c.toDouble
+      case c: Double => c
+    }
 
-    case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)
+    case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) {
+      case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale)
+    }
 
-    case (c: Int, DateType) => c
+    case DateType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Int => c
+    }
 
-    case (c: Long, TimestampType) => c
-    // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
-    case (c: Int, TimestampType) => c.toLong
+    case TimestampType => (obj: Any) => nullSafeConvert(obj) {
+      case c: Long => c
+      // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
+      case c: Int => c.toLong
+    }
 
-    case (c, StringType) => UTF8String.fromString(c.toString)
+    case StringType => (obj: Any) => nullSafeConvert(obj) {
+      case _ => UTF8String.fromString(obj.toString)
+    }
 
-    case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8)
-    case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+    case BinaryType => (obj: Any) => nullSafeConvert(obj) {
+      case c: String => c.getBytes(StandardCharsets.UTF_8)
+      case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+    }
 
-    case (c: java.util.List[_], ArrayType(elementType, _)) =>
-      new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray)
+    case ArrayType(elementType, _) =>
+      val elementFromJava = makeFromJava(elementType)
 
-    case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
-      new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
+      (obj: Any) => nullSafeConvert(obj) {
+        case c: java.util.List[_] =>
+          new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray)
+        case c if c.getClass.isArray =>
+          new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e)))
+      }
 
-    case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
-      ArrayBasedMapData(
-        javaMap,
-        (key: Any) => fromJava(key, keyType),
-        (value: Any) => fromJava(value, valueType))
+    case MapType(keyType, valueType, _) =>
+      val keyFromJava = makeFromJava(keyType)
+      val valueFromJava = makeFromJava(valueType)
+
+      (obj: Any) => nullSafeConvert(obj) {
+        case javaMap: java.util.Map[_, _] =>
+          ArrayBasedMapData(
+            javaMap,
+            (key: Any) => keyFromJava(key),
+            (value: Any) => valueFromJava(value))
+      }
 
-    case (c, StructType(fields)) if c.getClass.isArray =>
-      val array = c.asInstanceOf[Array[_]]
-      if (array.length != fields.length) {
-        throw new IllegalStateException(
-          s"Input row doesn't have expected number of values required by the schema. " +
-            s"${fields.length} fields are required while ${array.length} values are provided."
-        )
+    case StructType(fields) =>
+      val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray
+
+      (obj: Any) => nullSafeConvert(obj) {
+        case c if c.getClass.isArray =>
+          val array = c.asInstanceOf[Array[_]]
+          if (array.length != fields.length) {
+            throw new IllegalStateException(
+              s"Input row doesn't have expected number of values required by the schema. " +
+                s"${fields.length} fields are required while ${array.length} values are provided."
+            )
+          }
+
+          val row = new GenericInternalRow(fields.length)
+          var i = 0
+          while (i < fields.length) {
+            row(i) = fieldsFromJava(i)(array(i))
+            i += 1
+          }
+          row
       }
-      new GenericInternalRow(array.zip(fields).map {
-        case (e, f) => fromJava(e, f.dataType)
-      })
 
-    case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType)
+    case udt: UserDefinedType[_] => makeFromJava(udt.sqlType)
+
+    case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty)
+  }
 
-    // all other unexpected type should be null, or we will have runtime exception
-    // TODO(davies): we could improve this by try to cast the object to expected type
-    case (c, _) => null
+  private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = {
+    if (input == null) {
+      null
+    } else {
+      f.applyOrElse(input, {
+        // all other unexpected type should be null, or we will have runtime exception
+        // TODO(davies): we could improve this by try to cast the object to expected type
+        _: Any => null
+      })
+    }
   }
 
   private val module = "pyspark.sql.types"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org