You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@datafu.apache.org by ey...@apache.org on 2018/12/02 12:25:51 UTC
datafu git commit: Spark UDAFs and tests
Repository: datafu
Updated Branches:
refs/heads/spark-tmp 986af4540 -> 8c2d55d8e
Spark UDAFs and tests
Signed-off-by: Eyal Allweil <ey...@apache.org>
Project: http://git-wip-us.apache.org/repos/asf/datafu/repo
Commit: http://git-wip-us.apache.org/repos/asf/datafu/commit/8c2d55d8
Tree: http://git-wip-us.apache.org/repos/asf/datafu/tree/8c2d55d8
Diff: http://git-wip-us.apache.org/repos/asf/datafu/diff/8c2d55d8
Branch: refs/heads/spark-tmp
Commit: 8c2d55d8e12e8ff50ef1f184866d3b19382adbac
Parents: 986af45
Author: Ohad Raviv <or...@paypal.com>
Authored: Sun Dec 2 14:22:46 2018 +0200
Committer: Eyal Allweil <ey...@apache.org>
Committed: Sun Dec 2 14:24:41 2018 +0200
----------------------------------------------------------------------
.../main/scala/datafu/spark/SparkUDAFs.scala | 230 +++++++++++++++++++
.../scala/datafu/spark/TestSparkUDAFs.scala | 141 ++++++++++++
2 files changed, 371 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/datafu/blob/8c2d55d8/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
----------------------------------------------------------------------
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala b/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
new file mode 100644
index 0000000..3c4782f
--- /dev/null
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.types.{ArrayType, _}
+
+import scala.collection.{Map, mutable}
+
+// @TODO: add documentation and tests to all the functions. maybe also expose in python.
+
+object SparkUDAFs {
+
+ class MultiSet() extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", StringType)
+
+ def bufferSchema = new StructType().add("mp", MapType(StringType, IntegerType))
+
+ def dataType: DataType = MapType(StringType, IntegerType, false)
+
+ def deterministic = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer) = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row) = {
+ val key = input.getString(0)
+ if (key != null)
+ buffer(0) = buffer.getMap(0) + (key -> (buffer.getMap(0).getOrElse(key, 0) + 1))
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+ val mp = mutable.Map[String, Int]() ++= buffer1.getMap(0)
+ buffer2.getMap(0).keys.foreach((key: String) =>
+ if (key != null)
+ mp.put(key, mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
+ )
+ buffer1(0) = mp
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row) = {
+ buffer(0)
+ }
+
+ }
+
+ class MultiArraySet[T : Ordering](dt : DataType = StringType, maxKeys: Int = -1) extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", ArrayType(dt))
+
+ def bufferSchema = new StructType().add("mp", dataType)
+
+ def dataType: DataType = MapType(dt, IntegerType, false)
+
+ def deterministic = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer) = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row) = {
+ val mp = mutable.Map[T, Int]() ++= buffer.getMap(0)
+ val keyArr: Seq[T] = Option(input.getAs[Seq[T]](0)).getOrElse(Nil)
+ for (key <- keyArr; if key != null)
+ mp.put(key, mp.getOrElse(key, 0) + 1)
+
+ buffer(0) = limitKeys(mp, 3)
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+ val mp = mutable.Map[T, Int]() ++= buffer1.getMap(0)
+ buffer2.getMap(0).keys.foreach((key: T) =>
+ if (key != null)
+ mp.put(key, mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
+ )
+
+ buffer1(0) = limitKeys(mp, 3)
+ }
+
+ private def limitKeys(mp: Map[T, Int], factor: Int = 1): Map[T, Int] = {
+ if (maxKeys > 0 && maxKeys * factor < mp.size) {
+ val k = mp.toList.map(_.swap).sorted.reverse(maxKeys - 1)._1
+ var mp2 = mutable.Map[T, Int]() ++= mp.filter((t: (T, Int)) => t._2 >= k)
+ var toRemove = mp2.size - maxKeys
+ if (toRemove > 0)
+ mp2 = mp2.filter((t: (T, Int)) => {
+ if (t._2 > k)
+ true
+ else {
+ if (toRemove >= 0)
+ toRemove = toRemove - 1
+ toRemove < 0
+ }
+ })
+ mp2
+ } else mp
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row) = {
+ limitKeys(buffer.getMap(0).asInstanceOf[Map[T, Int]])
+ }
+
+ }
+
+ // Merge maps of kind string -> set<string>
+ class MapSetMerge extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", dataType)
+
+ def bufferSchema = inputSchema
+
+ def dataType: DataType = MapType(StringType, ArrayType(StringType))
+
+ def deterministic = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer) = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row) = {
+ val mp0 = input.getMap(0)
+ if (mp0 != null) {
+ val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= input.getMap(0)
+ buffer(0) = merge(mp, buffer.getMap[String, mutable.WrappedArray[String]](0))
+ }
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+ val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= buffer1.getMap(0)
+ buffer1(0) = merge(mp, buffer2.getMap[String, mutable.WrappedArray[String]](0))
+ }
+
+ def merge(mpBuffer: mutable.Map[String, mutable.WrappedArray[String]], mp: Map[String, mutable.WrappedArray[String]]): mutable.Map[String, mutable.WrappedArray[String]] = {
+ if (mp != null)
+ mp.keys.foreach((key: String) => {
+ val blah1: mutable.WrappedArray[String] = mpBuffer.getOrElse(key, mutable.WrappedArray.empty)
+ val blah2: mutable.WrappedArray[String] = mp.getOrElse(key, mutable.WrappedArray.empty)
+ mpBuffer.put(key, mutable.WrappedArray.make((Option(blah1).getOrElse(mutable.WrappedArray.empty) ++ Option(blah2).getOrElse(mutable.WrappedArray.empty)).toSet.toArray) )
+ })
+
+ mpBuffer
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row) = {
+ buffer(0)
+ }
+
+ }
+
+ /**
+ * Counts number of distinct records, but only up to a preset amount - more efficient than an unbounded count
+ */
+ class CountDistinctUpTo(maxItems: Int = -1) extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", StringType)
+
+ def bufferSchema = new StructType().add("mp", MapType(StringType, BooleanType))
+
+ def dataType: DataType = IntegerType
+
+ def deterministic = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer) = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row) = {
+ if(buffer.getMap(0).size < maxItems)
+ {
+ val key = input.getString(0)
+ if (key != null)
+ buffer(0) = buffer.getMap(0) + (key -> true)
+ }
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+ if(buffer1.getMap(0).size < maxItems)
+ {
+ val mp = mutable.Map[String, Boolean]() ++= buffer1.getMap(0)
+ buffer2.getMap(0).keys.foreach((key: String) =>
+ if (key != null)
+ mp.put(key,true)
+ )
+ buffer1(0) = mp
+ }
+
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row) = {
+ math.min(buffer.getMap(0).size,maxItems)
+ }
+
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/datafu/blob/8c2d55d8/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
----------------------------------------------------------------------
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
new file mode 100644
index 0000000..7a1eac9
--- /dev/null
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import com.holdenkarau.spark.testing.DataFrameSuiteBase
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.types.SparkOverwriteUDAFs
+import org.junit.Assert
+import org.junit.runner.RunWith
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
+import org.slf4j.LoggerFactory
+
+@RunWith(classOf[JUnitRunner])
+class UdafTests extends FunSuite with DataFrameSuiteBase {
+
+ import spark.implicits._
+
+ val logger = LoggerFactory.getLogger(this.getClass)
+
+ lazy val df = sc.parallelize(List(("a", 1, "asd1"), ("a", 2, "asd2"), ("a", 3, "asd3"), ("b", 1, "asd4"))).toDF("col_grp", "col_ord", "col_str")
+
+ case class mapExp(map_col: Map[String, Int])
+ case class mapArrExp(map_col: Map[String, Array[String]])
+
+ test("test multiset simple") {
+ val ms = new SparkUDAFs.MultiSet()
+ val expected : DataFrame = sqlContext.createDataFrame(List(mapExp(Map("b" -> 1, "a" -> 3))))
+ assertDataFrameEquals(expected, df.agg(ms($"col_grp").as("map_col")))
+ }
+
+ val mas = new SparkUDAFs.MultiArraySet[String]()
+
+ test("test multiarrayset simple") {
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapExp(Map("tre" -> 1, "asd" -> 2)))),
+ spark.sql("select array('asd','tre','asd') arr").groupBy().agg(mas($"arr").as("map_col")))
+ }
+
+ test("test multiarrayset all nulls") {
+ // end case
+ spark.sql("drop table if exists mas_table")
+ spark.sql("create table mas_table (arr array<string>)")
+ spark.sql("insert overwrite table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+ spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+ spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+ spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+ spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapExp(Map()))),
+ spark.table("mas_table").groupBy().agg(mas($"arr").as("map_col")))
+ }
+
+ test("test multiarrayset max keys") {
+ // max keys case
+ spark.sql("drop table if exists mas_table2")
+ spark.sql("create table mas_table2 (arr array<string>)")
+ spark.sql("insert overwrite table mas_table2 select array('asd','dsa') from (select 1)z")
+ spark.sql("insert into table mas_table2 select array('asd','abc') from (select 1)z")
+ spark.sql("insert into table mas_table2 select array('asd') from (select 1)z")
+ spark.sql("insert into table mas_table2 select array('asd') from (select 1)z")
+ spark.sql("insert into table mas_table2 select array('asd') from (select 1)z")
+ spark.sql("insert into table mas_table2 select array('asd2') from (select 1)z")
+
+ val mas2 = new SparkUDAFs.MultiArraySet[String](maxKeys = 2)
+
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapExp(Map("dsa" -> 1, "asd" -> 5)))),
+ spark.table("mas_table2").groupBy().agg(mas2($"arr").as("map_col")))
+
+ val mas1 = new SparkUDAFs.MultiArraySet[String](maxKeys = 1)
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapExp(Map("asd" -> 5)))),
+ spark.table("mas_table2").groupBy().agg(mas1($"arr").as("map_col")))
+ }
+
+ test("test multiarrayset big input") {
+ val N = 100000
+ val blah = spark.sparkContext.parallelize(1 to N, 20).toDF("num").selectExpr("array('asd',concat('dsa',num)) as arr")
+ val mas = new SparkUDAFs.MultiArraySet[String](maxKeys = 3)
+ val time1 = System.currentTimeMillis()
+ val mp = blah.groupBy().agg(mas($"arr")).collect().map(_.getMap[String,Int](0)).head
+ Assert.assertEquals(3, mp.size)
+ Assert.assertEquals("asd", mp.maxBy(_._2)._1 )
+ Assert.assertEquals(N, mp.maxBy(_._2)._2 )
+ val time2 = System.currentTimeMillis()
+ logger.info("time took: " + (time2-time1)/1000 + " secs")
+ }
+
+ test("test mapmerge") {
+ val mapMerge = new SparkUDAFs.MapSetMerge()
+
+ spark.sql("drop table if exists mapmerge_table")
+ spark.sql("create table mapmerge_table (c map<string, array<string>>)")
+ spark.sql("insert overwrite table mapmerge_table select map('k1', array('v1')) from (select 1) z")
+ spark.sql("insert into table mapmerge_table select map('k1', array('v1')) from (select 1) z")
+ spark.sql("insert into table mapmerge_table select map('k2', array('v3')) from (select 1) z")
+
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapArrExp(Map("k1" -> Array("v1"), "k2" -> Array("v3"))))),
+ spark.table("mapmerge_table").groupBy().agg(mapMerge($"c").as("map_col")))
+ }
+
+ test("minKeyValue") {
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(("b", "asd4"), ("a", "asd1"))),
+ df.groupBy($"col_grp".as("_1")).agg(SparkOverwriteUDAFs.minValueByKey($"col_ord", $"col_str").as("_2")))
+ }
+
+ case class exp4(col_grp: String, col_ord: Int, col_str: String, asd: String)
+
+ test("minKeyValue window") {
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(
+ exp4("b", 1, "asd4", "asd4"),
+ exp4("a", 1, "asd1", "asd1"),
+ exp4("a", 2, "asd2", "asd1"),
+ exp4("a", 3, "asd3", "asd1")
+ )),
+ df.withColumn("asd", SparkOverwriteUDAFs.minValueByKey($"col_ord", $"col_str").over(Window.partitionBy("col_grp"))))
+ }
+
+}