You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/23 22:02:43 UTC
spark git commit: [SPARK-22285][SQL] Change implementation of
ApproxCountDistinctForIntervals to TypedImperativeAggregate
Repository: spark
Updated Branches:
refs/heads/master 5a5b6b785 -> f6290aea2
[SPARK-22285][SQL] Change implementation of ApproxCountDistinctForIntervals to TypedImperativeAggregate
## What changes were proposed in this pull request?
The current implementation of `ApproxCountDistinctForIntervals` is `ImperativeAggregate`. The number of `aggBufferAttributes` is the number of total words in the hllppHelper array. Each hllppHelper has 52 words by default relativeSD.
Since this aggregate function is used in equi-height histogram generation, and the number of buckets in histogram is usually hundreds, the number of `aggBufferAttributes` can easily reach tens of thousands or even more.
This leads to a huge method in codegen and causes error:
```
org.codehaus.janino.JaninoRuntimeException: Code of method "apply(Lorg/apache/spark/sql/catalyst/InternalRow;)Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection" grows beyond 64 KB.
```
Besides, huge generated methods also result in performance regression.
In this PR, we change its implementation to `TypedImperativeAggregate`. After the fix, `ApproxCountDistinctForIntervals` can deal with more than thousands endpoints without throwing codegen error, and improve performance from `20 sec` to `2 sec` in a test case of 500 endpoints.
## How was this patch tested?
Test by an added test case and existing tests.
Author: Zhenhua Wang <wa...@huawei.com>
Closes #19506 from wzhfy/change_forIntervals_typedAgg.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6290aea
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6290aea
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6290aea
Branch: refs/heads/master
Commit: f6290aea24efeb238db88bdaef4e24d50740ca4c
Parents: 5a5b6b7
Author: Zhenhua Wang <wa...@huawei.com>
Authored: Mon Oct 23 23:02:36 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Oct 23 23:02:36 2017 +0100
----------------------------------------------------------------------
.../ApproxCountDistinctForIntervals.scala | 97 +++++++++++---------
.../ApproxCountDistinctForIntervalsSuite.scala | 34 +++----
...roxCountDistinctForIntervalsQuerySuite.scala | 61 ++++++++++++
3 files changed, 130 insertions(+), 62 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f6290aea/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
index 096d1b3..d4421ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
@@ -22,9 +22,10 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
/**
* This function counts the approximate number of distinct values (ndv) in
@@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals(
relativeSD: Double = 0.05,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends ImperativeAggregate with ExpectsInputTypes {
-
- def this(child: Expression, endpointsExpression: Expression) = {
- this(
- child = child,
- endpointsExpression = endpointsExpression,
- relativeSD = 0.05,
- mutableAggBufferOffset = 0,
- inputAggBufferOffset = 0)
- }
+ extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes {
def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = {
this(
@@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals(
private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length
/** Allocate enough words to store all registers. */
- override lazy val aggBufferAttributes: Seq[AttributeReference] = {
- Seq.tabulate(totalNumWords) { i =>
- AttributeReference(s"MS[$i]", LongType)()
- }
+ override def createAggregationBuffer(): Array[Long] = {
+ Array.fill(totalNumWords)(0L)
}
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
- // Note: although this simply copies aggBufferAttributes, this common code can not be placed
- // in the superclass because that will lead to initialization ordering issues.
- override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
- aggBufferAttributes.map(_.newInstance())
-
- /** Fill all words with zeros. */
- override def initialize(buffer: InternalRow): Unit = {
- var word = 0
- while (word < totalNumWords) {
- buffer.setLong(mutableAggBufferOffset + word, 0)
- word += 1
- }
- }
-
- override def update(buffer: InternalRow, input: InternalRow): Unit = {
+ override def update(buffer: Array[Long], input: InternalRow): Array[Long] = {
val value = child.eval(input)
// Ignore empty rows
if (value != null) {
@@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals(
// endpoints are sorted into ascending order already
if (endpoints.head > doubleValue || endpoints.last < doubleValue) {
// ignore if the value is out of the whole range
- return
+ return buffer
}
val hllppIndex = findHllppIndex(doubleValue)
- val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp
- hllppArray(hllppIndex).update(buffer, offset, value, child.dataType)
+ val offset = hllppIndex * numWordsPerHllpp
+ hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType)
}
+ buffer
}
// Find which interval (HyperLogLogPlusPlusHelper) should receive the given value.
@@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals(
}
}
- override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
+ override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = {
for (i <- hllppArray.indices) {
hllppArray(i).merge(
- buffer1 = buffer1,
- buffer2 = buffer2,
- offset1 = mutableAggBufferOffset + i * numWordsPerHllpp,
- offset2 = inputAggBufferOffset + i * numWordsPerHllpp)
+ buffer1 = LongArrayInternalRow(buffer1),
+ buffer2 = LongArrayInternalRow(buffer2),
+ offset1 = i * numWordsPerHllpp,
+ offset2 = i * numWordsPerHllpp)
}
+ buffer1
}
- override def eval(buffer: InternalRow): Any = {
+ override def eval(buffer: Array[Long]): Any = {
val ndvArray = hllppResults(buffer)
// If the endpoints contains multiple elements with the same value,
// we set ndv=1 for intervals between these elements.
@@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals(
new GenericArrayData(ndvArray)
}
- def hllppResults(buffer: InternalRow): Array[Long] = {
+ def hllppResults(buffer: Array[Long]): Array[Long] = {
val ndvArray = new Array[Long](hllppArray.length)
for (i <- ndvArray.indices) {
- ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp)
+ ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp)
}
ndvArray
}
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int)
+ : ApproxCountDistinctForIntervals = {
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+ }
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int)
+ : ApproxCountDistinctForIntervals = {
copy(inputAggBufferOffset = newInputAggBufferOffset)
+ }
override def children: Seq[Expression] = Seq(child, endpointsExpression)
@@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals(
override def dataType: DataType = ArrayType(LongType)
override def prettyName: String = "approx_count_distinct_for_intervals"
+
+ override def serialize(obj: Array[Long]): Array[Byte] = {
+ val byteArray = new Array[Byte](obj.length * 8)
+ var i = 0
+ while (i < obj.length) {
+ Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i))
+ i += 1
+ }
+ byteArray
+ }
+
+ override def deserialize(bytes: Array[Byte]): Array[Long] = {
+ assert(bytes.length % 8 == 0)
+ val length = bytes.length / 8
+ val longArray = new Array[Long](length)
+ var i = 0
+ while (i < length) {
+ longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8)
+ i += 1
+ }
+ longArray
+ }
+
+ private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow {
+ override def getLong(offset: Int): Long = array(offset)
+ override def setLong(offset: Int, value: Long): Unit = { array(offset) = value }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6290aea/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
index d6c38c3..73f18d4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
@@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType),
MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType))))
wrongColumnTypes.foreach { dataType =>
- val wrongColumn = new ApproxCountDistinctForIntervals(
+ val wrongColumn = ApproxCountDistinctForIntervals(
AttributeReference("a", dataType)(),
endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_))))
assert(
@@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
})
}
- var wrongEndpoints = new ApproxCountDistinctForIntervals(
+ var wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = Literal(0.5d))
assert(
@@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
case _ => false
})
- wrongEndpoints = new ApproxCountDistinctForIntervals(
+ wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)())))
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("The endpoints provided must be constant literals"))
- wrongEndpoints = new ApproxCountDistinctForIntervals(
+ wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals"))
- wrongEndpoints = new ApproxCountDistinctForIntervals(
+ wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
@@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
private def createEstimator[T](
endpoints: Array[T],
dt: DataType,
- rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = {
+ rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = {
val input = new SpecificInternalRow(Seq(dt))
val aggFunc = ApproxCountDistinctForIntervals(
BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd)
- val buffer = createBuffer(aggFunc)
- (aggFunc, input, buffer)
- }
-
- private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = {
- val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType))
- aggFunc.initialize(buffer)
- buffer
+ (aggFunc, input, aggFunc.createAggregationBuffer())
}
test("merging ApproxCountDistinctForIntervals instances") {
val (aggFunc, input, buffer1a) =
createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType)
- val buffer1b = createBuffer(aggFunc)
- val buffer2 = createBuffer(aggFunc)
+ val buffer1b = aggFunc.createAggregationBuffer()
+ val buffer2 = aggFunc.createAggregationBuffer()
// Add the lower half to `buffer1a`.
var i = 0
@@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
}
// Check if the buffers are equal.
- assert(buffer2 == buffer1a, "Buffers should be equal")
+ assert(buffer2.sameElements(buffer1a), "Buffers should be equal")
}
test("test findHllppIndex(value) for values in the range") {
@@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2)
}
+ test("round trip serialization") {
+ val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType)
+ val longArray = (1L to 100L).toArray
+ val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray))
+ assert(roundtrip.sameElements(longArray))
+ }
+
test("basic operations: update, merge, eval...") {
val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0)
val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33)
http://git-wip-us.apache.org/repos/asf/spark/blob/f6290aea/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
new file mode 100644
index 0000000..c7d86bc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
+ // histogram usually contains hundreds of buckets. So we need to test
+ // ApproxCountDistinctForIntervals with large number of endpoints
+ // (the number of endpoints == the number of buckets + 1).
+ test("test ApproxCountDistinctForIntervals with large number of endpoints") {
+ val table = "approx_count_distinct_for_intervals_tbl"
+ withTable(table) {
+ (1 to 100000).toDF("col").createOrReplaceTempView(table)
+ // percentiles of 0, 0.001, 0.002 ... 0.999, 1
+ val endpoints = (0 to 1000).map(_ * 100000 / 1000)
+
+ // Since approx_count_distinct_for_intervals is not a public function, here we do
+ // the computation by constructing logical plan.
+ val relation = spark.table(table).logicalPlan
+ val attr = relation.output.find(_.name == "col").get
+ val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
+ val aggExpr = aggFunc.toAggregateExpression()
+ val namedExpr = Alias(aggExpr, aggExpr.toString)()
+ val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
+ .executedPlan.executeTake(1).head
+ val ndvArray = ndvsRow.getArray(0).toLongArray()
+ assert(endpoints.length == ndvArray.length + 1)
+
+ // Each bucket has 100 distinct values.
+ val expectedNdv = 100
+ for (i <- ndvArray.indices) {
+ val ndv = ndvArray(i)
+ val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
+ assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org