You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/08 05:59:17 UTC
spark git commit: [SPARK-22979][PYTHON][SQL] Avoid per-record type
dispatch in Python data conversion (EvaluatePython.fromJava)
Repository: spark
Updated Branches:
refs/heads/master 3e40eb3f1 -> 8fdeb4b99
[SPARK-22979][PYTHON][SQL] Avoid per-record type dispatch in Python data conversion (EvaluatePython.fromJava)
## What changes were proposed in this pull request?
Seems we can avoid type dispatch for each value when Java objection (from Pyrolite) -> Spark's internal data format because we know the schema ahead.
I manually performed the benchmark as below:
```scala
test("EvaluatePython.fromJava / EvaluatePython.makeFromJava") {
val numRows = 1000 * 1000
val numFields = 30
val random = new Random(System.nanoTime())
val types = Array(
BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType,
DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType)
val schema = RandomDataGenerator.randomSchema(random, numFields, types)
val rows = mutable.ArrayBuffer.empty[Array[Any]]
var i = 0
while (i < numRows) {
val row = RandomDataGenerator.randomRow(random, schema)
rows += row.toSeq.toArray
i += 1
}
val benchmark = new Benchmark("EvaluatePython.fromJava / EvaluatePython.makeFromJava", numRows)
benchmark.addCase("Before - EvaluatePython.fromJava", 3) { _ =>
var i = 0
while (i < numRows) {
EvaluatePython.fromJava(rows(i), schema)
i += 1
}
}
benchmark.addCase("After - EvaluatePython.makeFromJava", 3) { _ =>
val fromJava = EvaluatePython.makeFromJava(schema)
var i = 0
while (i < numRows) {
fromJava(rows(i))
i += 1
}
}
benchmark.run()
}
```
```
EvaluatePython.fromJava / EvaluatePython.makeFromJava: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Before - EvaluatePython.fromJava 1265 / 1346 0.8 1264.8 1.0X
After - EvaluatePython.makeFromJava 571 / 649 1.8 570.8 2.2X
```
If the structure is nested, I think the advantage should be larger than this.
## How was this patch tested?
Existing tests should cover this. Also, I manually checked if the values from before / after are actually same via `assert` when performing the benchmarks.
Author: hyukjinkwon <gu...@gmail.com>
Closes #20172 from HyukjinKwon/type-dispatch-python-eval.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8fdeb4b9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8fdeb4b9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8fdeb4b9
Branch: refs/heads/master
Commit: 8fdeb4b9946bd9be045abb919da2e531708b3bd4
Parents: 3e40eb3
Author: hyukjinkwon <gu...@gmail.com>
Authored: Mon Jan 8 13:59:08 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jan 8 13:59:08 2018 +0800
----------------------------------------------------------------------
.../org/apache/spark/sql/SparkSession.scala | 5 +-
.../execution/python/BatchEvalPythonExec.scala | 7 +-
.../sql/execution/python/EvaluatePython.scala | 166 ++++++++++++-------
3 files changed, 118 insertions(+), 60 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/8fdeb4b9/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 272eb84..734573b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -742,7 +742,10 @@ class SparkSession private(
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
schema: StructType): DataFrame = {
- val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
+ val rowRdd = rdd.mapPartitions { iter =>
+ val fromJava = python.EvaluatePython.makeFromJava(schema)
+ iter.map(r => fromJava(r).asInstanceOf[InternalRow])
+ }
internalCreateDataFrame(rowRdd, schema)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/8fdeb4b9/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 26ee25f..f4d83e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
+
+ val fromJava = EvaluatePython.makeFromJava(resultType)
+
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
if (udfs.length == 1) {
// fast path for single UDF
- mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+ mutableRow(0) = fromJava(result)
mutableRow
} else {
- EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+ fromJava(result).asInstanceOf[InternalRow]
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/8fdeb4b9/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 9bbfa60..520afad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -83,82 +83,134 @@ object EvaluatePython {
}
/**
- * Converts `obj` to the type specified by the data type, or returns null if the type of obj is
- * unexpected. Because Python doesn't enforce the type.
+ * Make a converter that converts `obj` to the type specified by the data type, or returns
+ * null if the type of obj is unexpected. Because Python doesn't enforce the type.
*/
- def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
- case (null, _) => null
-
- case (c: Boolean, BooleanType) => c
+ def makeFromJava(dataType: DataType): Any => Any = dataType match {
+ case BooleanType => (obj: Any) => nullSafeConvert(obj) {
+ case b: Boolean => b
+ }
- case (c: Byte, ByteType) => c
- case (c: Short, ByteType) => c.toByte
- case (c: Int, ByteType) => c.toByte
- case (c: Long, ByteType) => c.toByte
+ case ByteType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c
+ case c: Short => c.toByte
+ case c: Int => c.toByte
+ case c: Long => c.toByte
+ }
- case (c: Byte, ShortType) => c.toShort
- case (c: Short, ShortType) => c
- case (c: Int, ShortType) => c.toShort
- case (c: Long, ShortType) => c.toShort
+ case ShortType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c.toShort
+ case c: Short => c
+ case c: Int => c.toShort
+ case c: Long => c.toShort
+ }
- case (c: Byte, IntegerType) => c.toInt
- case (c: Short, IntegerType) => c.toInt
- case (c: Int, IntegerType) => c
- case (c: Long, IntegerType) => c.toInt
+ case IntegerType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c.toInt
+ case c: Short => c.toInt
+ case c: Int => c
+ case c: Long => c.toInt
+ }
- case (c: Byte, LongType) => c.toLong
- case (c: Short, LongType) => c.toLong
- case (c: Int, LongType) => c.toLong
- case (c: Long, LongType) => c
+ case LongType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c.toLong
+ case c: Short => c.toLong
+ case c: Int => c.toLong
+ case c: Long => c
+ }
- case (c: Float, FloatType) => c
- case (c: Double, FloatType) => c.toFloat
+ case FloatType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Float => c
+ case c: Double => c.toFloat
+ }
- case (c: Float, DoubleType) => c.toDouble
- case (c: Double, DoubleType) => c
+ case DoubleType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Float => c.toDouble
+ case c: Double => c
+ }
- case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)
+ case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) {
+ case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale)
+ }
- case (c: Int, DateType) => c
+ case DateType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Int => c
+ }
- case (c: Long, TimestampType) => c
- // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
- case (c: Int, TimestampType) => c.toLong
+ case TimestampType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Long => c
+ // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
+ case c: Int => c.toLong
+ }
- case (c, StringType) => UTF8String.fromString(c.toString)
+ case StringType => (obj: Any) => nullSafeConvert(obj) {
+ case _ => UTF8String.fromString(obj.toString)
+ }
- case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8)
- case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+ case BinaryType => (obj: Any) => nullSafeConvert(obj) {
+ case c: String => c.getBytes(StandardCharsets.UTF_8)
+ case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+ }
- case (c: java.util.List[_], ArrayType(elementType, _)) =>
- new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray)
+ case ArrayType(elementType, _) =>
+ val elementFromJava = makeFromJava(elementType)
- case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
+ (obj: Any) => nullSafeConvert(obj) {
+ case c: java.util.List[_] =>
+ new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray)
+ case c if c.getClass.isArray =>
+ new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e)))
+ }
- case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
- ArrayBasedMapData(
- javaMap,
- (key: Any) => fromJava(key, keyType),
- (value: Any) => fromJava(value, valueType))
+ case MapType(keyType, valueType, _) =>
+ val keyFromJava = makeFromJava(keyType)
+ val valueFromJava = makeFromJava(valueType)
+
+ (obj: Any) => nullSafeConvert(obj) {
+ case javaMap: java.util.Map[_, _] =>
+ ArrayBasedMapData(
+ javaMap,
+ (key: Any) => keyFromJava(key),
+ (value: Any) => valueFromJava(value))
+ }
- case (c, StructType(fields)) if c.getClass.isArray =>
- val array = c.asInstanceOf[Array[_]]
- if (array.length != fields.length) {
- throw new IllegalStateException(
- s"Input row doesn't have expected number of values required by the schema. " +
- s"${fields.length} fields are required while ${array.length} values are provided."
- )
+ case StructType(fields) =>
+ val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray
+
+ (obj: Any) => nullSafeConvert(obj) {
+ case c if c.getClass.isArray =>
+ val array = c.asInstanceOf[Array[_]]
+ if (array.length != fields.length) {
+ throw new IllegalStateException(
+ s"Input row doesn't have expected number of values required by the schema. " +
+ s"${fields.length} fields are required while ${array.length} values are provided."
+ )
+ }
+
+ val row = new GenericInternalRow(fields.length)
+ var i = 0
+ while (i < fields.length) {
+ row(i) = fieldsFromJava(i)(array(i))
+ i += 1
+ }
+ row
}
- new GenericInternalRow(array.zip(fields).map {
- case (e, f) => fromJava(e, f.dataType)
- })
- case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType)
+ case udt: UserDefinedType[_] => makeFromJava(udt.sqlType)
+
+ case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty)
+ }
- // all other unexpected type should be null, or we will have runtime exception
- // TODO(davies): we could improve this by try to cast the object to expected type
- case (c, _) => null
+ private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = {
+ if (input == null) {
+ null
+ } else {
+ f.applyOrElse(input, {
+ // all other unexpected type should be null, or we will have runtime exception
+ // TODO(davies): we could improve this by try to cast the object to expected type
+ _: Any => null
+ })
+ }
}
private val module = "pyspark.sql.types"
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org