You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/23 22:02:43 UTC

spark git commit: [SPARK-22285][SQL] Change implementation of ApproxCountDistinctForIntervals to TypedImperativeAggregate

Repository: spark
Updated Branches:
  refs/heads/master 5a5b6b785 -> f6290aea2


[SPARK-22285][SQL] Change implementation of ApproxCountDistinctForIntervals to TypedImperativeAggregate

## What changes were proposed in this pull request?

The current implementation of `ApproxCountDistinctForIntervals` is `ImperativeAggregate`. The number of `aggBufferAttributes` is the number of total words in the hllppHelper array. Each hllppHelper has 52 words by default relativeSD.

Since this aggregate function is used in equi-height histogram generation, and the number of buckets in histogram is usually hundreds, the number of `aggBufferAttributes` can easily reach tens of thousands or even more.

This leads to a huge method in codegen and causes error:
```
org.codehaus.janino.JaninoRuntimeException: Code of method "apply(Lorg/apache/spark/sql/catalyst/InternalRow;)Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection" grows beyond 64 KB.
```
Besides, huge generated methods also result in performance regression.

In this PR, we change its implementation to `TypedImperativeAggregate`. After the fix, `ApproxCountDistinctForIntervals` can deal with more than thousands endpoints without throwing codegen error, and improve performance from `20 sec` to `2 sec` in a test case of 500 endpoints.

## How was this patch tested?

Test by an added test case and existing tests.

Author: Zhenhua Wang <wa...@huawei.com>

Closes #19506 from wzhfy/change_forIntervals_typedAgg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6290aea
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6290aea
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6290aea

Branch: refs/heads/master
Commit: f6290aea24efeb238db88bdaef4e24d50740ca4c
Parents: 5a5b6b7
Author: Zhenhua Wang <wa...@huawei.com>
Authored: Mon Oct 23 23:02:36 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Oct 23 23:02:36 2017 +0100

----------------------------------------------------------------------
 .../ApproxCountDistinctForIntervals.scala       | 97 +++++++++++---------
 .../ApproxCountDistinctForIntervalsSuite.scala  | 34 +++----
 ...roxCountDistinctForIntervalsQuerySuite.scala | 61 ++++++++++++
 3 files changed, 130 insertions(+), 62 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f6290aea/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
index 096d1b3..d4421ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
@@ -22,9 +22,10 @@ import java.util
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow}
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
 
 /**
  * This function counts the approximate number of distinct values (ndv) in
@@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals(
     relativeSD: Double = 0.05,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends ImperativeAggregate with ExpectsInputTypes {
-
-  def this(child: Expression, endpointsExpression: Expression) = {
-    this(
-      child = child,
-      endpointsExpression = endpointsExpression,
-      relativeSD = 0.05,
-      mutableAggBufferOffset = 0,
-      inputAggBufferOffset = 0)
-  }
+  extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes {
 
   def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = {
     this(
@@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals(
   private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length
 
   /** Allocate enough words to store all registers. */
-  override lazy val aggBufferAttributes: Seq[AttributeReference] = {
-    Seq.tabulate(totalNumWords) { i =>
-      AttributeReference(s"MS[$i]", LongType)()
-    }
+  override def createAggregationBuffer(): Array[Long] = {
+    Array.fill(totalNumWords)(0L)
   }
 
-  override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
-  // Note: although this simply copies aggBufferAttributes, this common code can not be placed
-  // in the superclass because that will lead to initialization ordering issues.
-  override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
-    aggBufferAttributes.map(_.newInstance())
-
-  /** Fill all words with zeros. */
-  override def initialize(buffer: InternalRow): Unit = {
-    var word = 0
-    while (word < totalNumWords) {
-      buffer.setLong(mutableAggBufferOffset + word, 0)
-      word += 1
-    }
-  }
-
-  override def update(buffer: InternalRow, input: InternalRow): Unit = {
+  override def update(buffer: Array[Long], input: InternalRow): Array[Long] = {
     val value = child.eval(input)
     // Ignore empty rows
     if (value != null) {
@@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals(
       // endpoints are sorted into ascending order already
       if (endpoints.head > doubleValue || endpoints.last < doubleValue) {
         // ignore if the value is out of the whole range
-        return
+        return buffer
       }
 
       val hllppIndex = findHllppIndex(doubleValue)
-      val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp
-      hllppArray(hllppIndex).update(buffer, offset, value, child.dataType)
+      val offset = hllppIndex * numWordsPerHllpp
+      hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType)
     }
+    buffer
   }
 
   // Find which interval (HyperLogLogPlusPlusHelper) should receive the given value.
@@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals(
     }
   }
 
-  override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
+  override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = {
     for (i <- hllppArray.indices) {
       hllppArray(i).merge(
-        buffer1 = buffer1,
-        buffer2 = buffer2,
-        offset1 = mutableAggBufferOffset + i * numWordsPerHllpp,
-        offset2 = inputAggBufferOffset + i * numWordsPerHllpp)
+        buffer1 = LongArrayInternalRow(buffer1),
+        buffer2 = LongArrayInternalRow(buffer2),
+        offset1 = i * numWordsPerHllpp,
+        offset2 = i * numWordsPerHllpp)
     }
+    buffer1
   }
 
-  override def eval(buffer: InternalRow): Any = {
+  override def eval(buffer: Array[Long]): Any = {
     val ndvArray = hllppResults(buffer)
     // If the endpoints contains multiple elements with the same value,
     // we set ndv=1 for intervals between these elements.
@@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals(
     new GenericArrayData(ndvArray)
   }
 
-  def hllppResults(buffer: InternalRow): Array[Long] = {
+  def hllppResults(buffer: Array[Long]): Array[Long] = {
     val ndvArray = new Array[Long](hllppArray.length)
     for (i <- ndvArray.indices) {
-      ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp)
+      ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp)
     }
     ndvArray
   }
 
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int)
+    : ApproxCountDistinctForIntervals = {
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+  }
 
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int)
+    : ApproxCountDistinctForIntervals = {
     copy(inputAggBufferOffset = newInputAggBufferOffset)
+  }
 
   override def children: Seq[Expression] = Seq(child, endpointsExpression)
 
@@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals(
   override def dataType: DataType = ArrayType(LongType)
 
   override def prettyName: String = "approx_count_distinct_for_intervals"
+
+  override def serialize(obj: Array[Long]): Array[Byte] = {
+    val byteArray = new Array[Byte](obj.length * 8)
+    var i = 0
+    while (i < obj.length) {
+      Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i))
+      i += 1
+    }
+    byteArray
+  }
+
+  override def deserialize(bytes: Array[Byte]): Array[Long] = {
+    assert(bytes.length % 8 == 0)
+    val length = bytes.length / 8
+    val longArray = new Array[Long](length)
+    var i = 0
+    while (i < length) {
+      longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8)
+      i += 1
+    }
+    longArray
+  }
+
+  private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow {
+    override def getLong(offset: Int): Long = array(offset)
+    override def setLong(offset: Int, value: Long): Unit = { array(offset) = value }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6290aea/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
index d6c38c3..73f18d4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
@@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
     val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType),
       MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType))))
     wrongColumnTypes.foreach { dataType =>
-      val wrongColumn = new ApproxCountDistinctForIntervals(
+      val wrongColumn = ApproxCountDistinctForIntervals(
         AttributeReference("a", dataType)(),
         endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_))))
       assert(
@@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
         })
     }
 
-    var wrongEndpoints = new ApproxCountDistinctForIntervals(
+    var wrongEndpoints = ApproxCountDistinctForIntervals(
       AttributeReference("a", DoubleType)(),
       endpointsExpression = Literal(0.5d))
     assert(
@@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
         case _ => false
       })
 
-    wrongEndpoints = new ApproxCountDistinctForIntervals(
+    wrongEndpoints = ApproxCountDistinctForIntervals(
       AttributeReference("a", DoubleType)(),
       endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)())))
     assert(wrongEndpoints.checkInputDataTypes() ==
       TypeCheckFailure("The endpoints provided must be constant literals"))
 
-    wrongEndpoints = new ApproxCountDistinctForIntervals(
+    wrongEndpoints = ApproxCountDistinctForIntervals(
       AttributeReference("a", DoubleType)(),
       endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
     assert(wrongEndpoints.checkInputDataTypes() ==
       TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals"))
 
-    wrongEndpoints = new ApproxCountDistinctForIntervals(
+    wrongEndpoints = ApproxCountDistinctForIntervals(
       AttributeReference("a", DoubleType)(),
       endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
     assert(wrongEndpoints.checkInputDataTypes() ==
@@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
   private def createEstimator[T](
       endpoints: Array[T],
       dt: DataType,
-      rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = {
+      rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = {
     val input = new SpecificInternalRow(Seq(dt))
     val aggFunc = ApproxCountDistinctForIntervals(
       BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd)
-    val buffer = createBuffer(aggFunc)
-    (aggFunc, input, buffer)
-  }
-
-  private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = {
-    val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType))
-    aggFunc.initialize(buffer)
-    buffer
+    (aggFunc, input, aggFunc.createAggregationBuffer())
   }
 
   test("merging ApproxCountDistinctForIntervals instances") {
     val (aggFunc, input, buffer1a) =
       createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType)
-    val buffer1b = createBuffer(aggFunc)
-    val buffer2 = createBuffer(aggFunc)
+    val buffer1b = aggFunc.createAggregationBuffer()
+    val buffer2 = aggFunc.createAggregationBuffer()
 
     // Add the lower half to `buffer1a`.
     var i = 0
@@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
     }
 
     // Check if the buffers are equal.
-    assert(buffer2 == buffer1a, "Buffers should be equal")
+    assert(buffer2.sameElements(buffer1a), "Buffers should be equal")
   }
 
   test("test findHllppIndex(value) for values in the range") {
@@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
     checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2)
   }
 
+  test("round trip serialization") {
+    val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType)
+    val longArray = (1L to 100L).toArray
+    val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray))
+    assert(roundtrip.sameElements(longArray))
+  }
+
   test("basic operations: update, merge, eval...") {
     val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0)
     val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33)

http://git-wip-us.apache.org/repos/asf/spark/blob/f6290aea/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
new file mode 100644
index 0000000..c7d86bc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
+
+  // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
+  // histogram usually contains hundreds of buckets. So we need to test
+  // ApproxCountDistinctForIntervals with large number of endpoints
+  // (the number of endpoints == the number of buckets + 1).
+  test("test ApproxCountDistinctForIntervals with large number of endpoints") {
+    val table = "approx_count_distinct_for_intervals_tbl"
+    withTable(table) {
+      (1 to 100000).toDF("col").createOrReplaceTempView(table)
+      // percentiles of 0, 0.001, 0.002 ... 0.999, 1
+      val endpoints = (0 to 1000).map(_ * 100000 / 1000)
+
+      // Since approx_count_distinct_for_intervals is not a public function, here we do
+      // the computation by constructing logical plan.
+      val relation = spark.table(table).logicalPlan
+      val attr = relation.output.find(_.name == "col").get
+      val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
+      val aggExpr = aggFunc.toAggregateExpression()
+      val namedExpr = Alias(aggExpr, aggExpr.toString)()
+      val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
+        .executedPlan.executeTake(1).head
+      val ndvArray = ndvsRow.getArray(0).toLongArray()
+      assert(endpoints.length == ndvArray.length + 1)
+
+      // Each bucket has 100 distinct values.
+      val expectedNdv = 100
+      for (i <- ndvArray.indices) {
+        val ndv = ndvArray(i)
+        val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
+        assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org