You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@datafu.apache.org by ey...@apache.org on 2018/12/02 12:25:51 UTC

datafu git commit: Spark UDAFs and tests

Repository: datafu
Updated Branches:
  refs/heads/spark-tmp 986af4540 -> 8c2d55d8e


Spark UDAFs and tests

Signed-off-by: Eyal Allweil <ey...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/datafu/repo
Commit: http://git-wip-us.apache.org/repos/asf/datafu/commit/8c2d55d8
Tree: http://git-wip-us.apache.org/repos/asf/datafu/tree/8c2d55d8
Diff: http://git-wip-us.apache.org/repos/asf/datafu/diff/8c2d55d8

Branch: refs/heads/spark-tmp
Commit: 8c2d55d8e12e8ff50ef1f184866d3b19382adbac
Parents: 986af45
Author: Ohad Raviv <or...@paypal.com>
Authored: Sun Dec 2 14:22:46 2018 +0200
Committer: Eyal Allweil <ey...@apache.org>
Committed: Sun Dec 2 14:24:41 2018 +0200

----------------------------------------------------------------------
 .../main/scala/datafu/spark/SparkUDAFs.scala    | 230 +++++++++++++++++++
 .../scala/datafu/spark/TestSparkUDAFs.scala     | 141 ++++++++++++
 2 files changed, 371 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/datafu/blob/8c2d55d8/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
----------------------------------------------------------------------
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala b/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
new file mode 100644
index 0000000..3c4782f
--- /dev/null
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.types.{ArrayType, _}
+
+import scala.collection.{Map, mutable}
+
+// @TODO: add documentation and tests to all the functions. maybe also expose in python.
+
+object SparkUDAFs {
+
+  class MultiSet() extends UserDefinedAggregateFunction {
+
+    def inputSchema: StructType = new StructType().add("key", StringType)
+
+    def bufferSchema = new StructType().add("mp", MapType(StringType, IntegerType))
+
+    def dataType: DataType = MapType(StringType, IntegerType, false)
+
+    def deterministic = true
+
+    // This function is called whenever key changes
+    def initialize(buffer: MutableAggregationBuffer) = {
+      buffer(0) = mutable.Map()
+    }
+
+    // Iterate over each entry of a group
+    def update(buffer: MutableAggregationBuffer, input: Row) = {
+      val key = input.getString(0)
+      if (key != null)
+        buffer(0) = buffer.getMap(0) + (key -> (buffer.getMap(0).getOrElse(key, 0) + 1))
+    }
+
+    // Merge two partial aggregates
+    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+      val mp = mutable.Map[String, Int]() ++= buffer1.getMap(0)
+      buffer2.getMap(0).keys.foreach((key: String) =>
+        if (key != null)
+          mp.put(key, mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
+      )
+      buffer1(0) = mp
+    }
+
+    // Called after all the entries are exhausted.
+    def evaluate(buffer: Row) = {
+      buffer(0)
+    }
+
+  }
+
+  class MultiArraySet[T : Ordering](dt : DataType = StringType, maxKeys: Int = -1) extends UserDefinedAggregateFunction {
+
+    def inputSchema: StructType = new StructType().add("key", ArrayType(dt))
+
+    def bufferSchema = new StructType().add("mp", dataType)
+
+    def dataType: DataType = MapType(dt, IntegerType, false)
+
+    def deterministic = true
+
+    // This function is called whenever key changes
+    def initialize(buffer: MutableAggregationBuffer) = {
+      buffer(0) = mutable.Map()
+    }
+
+    // Iterate over each entry of a group
+    def update(buffer: MutableAggregationBuffer, input: Row) = {
+      val mp = mutable.Map[T, Int]() ++= buffer.getMap(0)
+      val keyArr: Seq[T] = Option(input.getAs[Seq[T]](0)).getOrElse(Nil)
+      for (key <- keyArr; if key != null)
+        mp.put(key, mp.getOrElse(key, 0) + 1)
+
+      buffer(0) = limitKeys(mp, 3)
+    }
+
+    // Merge two partial aggregates
+    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+      val mp = mutable.Map[T, Int]() ++= buffer1.getMap(0)
+      buffer2.getMap(0).keys.foreach((key: T) =>
+        if (key != null)
+          mp.put(key, mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
+      )
+
+      buffer1(0) = limitKeys(mp, 3)
+    }
+
+    private def limitKeys(mp: Map[T, Int], factor: Int = 1): Map[T, Int] = {
+      if (maxKeys > 0 && maxKeys * factor < mp.size) {
+        val k = mp.toList.map(_.swap).sorted.reverse(maxKeys - 1)._1
+        var mp2 = mutable.Map[T, Int]() ++= mp.filter((t: (T, Int)) => t._2 >= k)
+        var toRemove = mp2.size - maxKeys
+        if (toRemove > 0)
+          mp2 = mp2.filter((t: (T, Int)) => {
+            if (t._2 > k)
+              true
+            else {
+              if (toRemove >= 0)
+                toRemove = toRemove - 1
+              toRemove < 0
+            }
+          })
+        mp2
+      } else mp
+    }
+
+    // Called after all the entries are exhausted.
+    def evaluate(buffer: Row) = {
+      limitKeys(buffer.getMap(0).asInstanceOf[Map[T, Int]])
+    }
+
+  }
+
+  // Merge maps of kind string -> set<string>
+  class MapSetMerge extends UserDefinedAggregateFunction {
+
+    def inputSchema: StructType = new StructType().add("key", dataType)
+
+    def bufferSchema = inputSchema
+
+    def dataType: DataType = MapType(StringType, ArrayType(StringType))
+
+    def deterministic = true
+
+    // This function is called whenever key changes
+    def initialize(buffer: MutableAggregationBuffer) = {
+      buffer(0) = mutable.Map()
+    }
+
+    // Iterate over each entry of a group
+    def update(buffer: MutableAggregationBuffer, input: Row) = {
+      val mp0 = input.getMap(0)
+      if (mp0 != null) {
+        val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= input.getMap(0)
+        buffer(0) = merge(mp, buffer.getMap[String, mutable.WrappedArray[String]](0))
+      }
+    }
+
+    // Merge two partial aggregates
+    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+      val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= buffer1.getMap(0)
+      buffer1(0) = merge(mp, buffer2.getMap[String, mutable.WrappedArray[String]](0))
+    }
+
+    def merge(mpBuffer: mutable.Map[String, mutable.WrappedArray[String]], mp: Map[String, mutable.WrappedArray[String]]): mutable.Map[String, mutable.WrappedArray[String]] = {
+      if (mp != null)
+        mp.keys.foreach((key: String) => {
+          val blah1: mutable.WrappedArray[String] = mpBuffer.getOrElse(key, mutable.WrappedArray.empty)
+          val blah2: mutable.WrappedArray[String] = mp.getOrElse(key, mutable.WrappedArray.empty)
+          mpBuffer.put(key, mutable.WrappedArray.make((Option(blah1).getOrElse(mutable.WrappedArray.empty) ++ Option(blah2).getOrElse(mutable.WrappedArray.empty)).toSet.toArray) )
+        })
+
+      mpBuffer
+    }
+
+    // Called after all the entries are exhausted.
+    def evaluate(buffer: Row) = {
+      buffer(0)
+    }
+
+  }
+
+  /**
+   * Counts number of distinct records, but only up to a preset amount - more efficient than an unbounded count
+   */
+  class CountDistinctUpTo(maxItems: Int = -1) extends UserDefinedAggregateFunction {
+
+    def inputSchema: StructType = new StructType().add("key", StringType)
+
+    def bufferSchema = new StructType().add("mp", MapType(StringType, BooleanType))
+
+    def dataType: DataType = IntegerType
+
+    def deterministic = true
+
+    // This function is called whenever key changes
+    def initialize(buffer: MutableAggregationBuffer) = {
+      buffer(0) = mutable.Map()
+    }
+
+    // Iterate over each entry of a group
+    def update(buffer: MutableAggregationBuffer, input: Row) = {
+      if(buffer.getMap(0).size < maxItems)
+      {
+        val key = input.getString(0)
+        if (key != null)
+          buffer(0) = buffer.getMap(0) + (key -> true)
+      }
+    }
+
+    // Merge two partial aggregates
+    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
+      if(buffer1.getMap(0).size < maxItems)
+      {
+        val mp = mutable.Map[String, Boolean]() ++= buffer1.getMap(0)
+        buffer2.getMap(0).keys.foreach((key: String) =>
+          if (key != null)
+            mp.put(key,true)
+        )
+        buffer1(0) = mp
+      }
+
+    }
+
+    // Called after all the entries are exhausted.
+    def evaluate(buffer: Row) = {
+      math.min(buffer.getMap(0).size,maxItems)
+    }
+
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/datafu/blob/8c2d55d8/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
----------------------------------------------------------------------
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
new file mode 100644
index 0000000..7a1eac9
--- /dev/null
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import com.holdenkarau.spark.testing.DataFrameSuiteBase
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.types.SparkOverwriteUDAFs
+import org.junit.Assert
+import org.junit.runner.RunWith
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
+import org.slf4j.LoggerFactory
+
+@RunWith(classOf[JUnitRunner])
+class UdafTests extends FunSuite with DataFrameSuiteBase {
+
+  import spark.implicits._
+  
+  val logger = LoggerFactory.getLogger(this.getClass)
+
+  lazy val df = sc.parallelize(List(("a", 1, "asd1"), ("a", 2, "asd2"), ("a", 3, "asd3"), ("b", 1, "asd4"))).toDF("col_grp", "col_ord", "col_str")
+
+  case class mapExp(map_col: Map[String, Int])
+  case class mapArrExp(map_col: Map[String, Array[String]])
+
+  test("test multiset simple") {
+    val ms = new SparkUDAFs.MultiSet()
+    val expected : DataFrame = sqlContext.createDataFrame(List(mapExp(Map("b" -> 1, "a" -> 3))))
+    assertDataFrameEquals(expected, df.agg(ms($"col_grp").as("map_col")))
+  }
+
+  val mas = new SparkUDAFs.MultiArraySet[String]()
+
+  test("test multiarrayset simple") {
+    assertDataFrameEquals(
+      sqlContext.createDataFrame(List(mapExp(Map("tre" -> 1, "asd" -> 2)))),
+      spark.sql("select array('asd','tre','asd') arr").groupBy().agg(mas($"arr").as("map_col")))
+  }
+
+  test("test multiarrayset all nulls") {
+    // end case
+    spark.sql("drop table if exists mas_table")
+    spark.sql("create table mas_table (arr array<string>)")
+    spark.sql("insert overwrite table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+    spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+    spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+    spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+    spark.sql("insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+
+    assertDataFrameEquals(
+      sqlContext.createDataFrame(List(mapExp(Map()))),
+      spark.table("mas_table").groupBy().agg(mas($"arr").as("map_col")))
+  }
+
+  test("test multiarrayset max keys") {
+    // max keys case
+    spark.sql("drop table if exists mas_table2")
+    spark.sql("create table mas_table2 (arr array<string>)")
+    spark.sql("insert overwrite table mas_table2 select array('asd','dsa') from (select 1)z")
+    spark.sql("insert into table mas_table2 select array('asd','abc') from (select 1)z")
+    spark.sql("insert into table mas_table2 select array('asd') from (select 1)z")
+    spark.sql("insert into table mas_table2 select array('asd') from (select 1)z")
+    spark.sql("insert into table mas_table2 select array('asd') from (select 1)z")
+    spark.sql("insert into table mas_table2 select array('asd2') from (select 1)z")
+
+    val mas2 = new SparkUDAFs.MultiArraySet[String](maxKeys = 2)
+
+    assertDataFrameEquals(
+      sqlContext.createDataFrame(List(mapExp(Map("dsa" -> 1, "asd" -> 5)))),
+      spark.table("mas_table2").groupBy().agg(mas2($"arr").as("map_col")))
+
+    val mas1 = new SparkUDAFs.MultiArraySet[String](maxKeys = 1)
+    assertDataFrameEquals(
+      sqlContext.createDataFrame(List(mapExp(Map("asd" -> 5)))),
+      spark.table("mas_table2").groupBy().agg(mas1($"arr").as("map_col")))
+  }
+
+  test("test multiarrayset big input") {
+    val N = 100000
+    val blah = spark.sparkContext.parallelize(1 to N, 20).toDF("num").selectExpr("array('asd',concat('dsa',num)) as arr")
+    val mas = new SparkUDAFs.MultiArraySet[String](maxKeys = 3)
+    val time1 = System.currentTimeMillis()
+    val mp = blah.groupBy().agg(mas($"arr")).collect().map(_.getMap[String,Int](0)).head
+    Assert.assertEquals(3, mp.size)
+    Assert.assertEquals("asd", mp.maxBy(_._2)._1 )
+    Assert.assertEquals(N, mp.maxBy(_._2)._2 )
+    val time2 = System.currentTimeMillis()
+    logger.info("time took: " + (time2-time1)/1000 + " secs")
+  }
+
+  test("test mapmerge") {
+    val mapMerge = new SparkUDAFs.MapSetMerge()
+
+    spark.sql("drop table if exists mapmerge_table")
+    spark.sql("create table mapmerge_table (c map<string, array<string>>)")
+    spark.sql("insert overwrite table mapmerge_table select map('k1', array('v1')) from (select 1) z")
+    spark.sql("insert into table mapmerge_table select map('k1', array('v1')) from (select 1) z")
+    spark.sql("insert into table mapmerge_table select map('k2', array('v3')) from (select 1) z")
+
+    assertDataFrameEquals(
+      sqlContext.createDataFrame(List(mapArrExp(Map("k1" -> Array("v1"), "k2" -> Array("v3"))))),
+      spark.table("mapmerge_table").groupBy().agg(mapMerge($"c").as("map_col")))
+  }
+
+  test("minKeyValue") {
+    assertDataFrameEquals(
+      sqlContext.createDataFrame(List(("b", "asd4"), ("a", "asd1"))),
+      df.groupBy($"col_grp".as("_1")).agg(SparkOverwriteUDAFs.minValueByKey($"col_ord", $"col_str").as("_2")))
+  }
+
+  case class exp4(col_grp: String, col_ord: Int, col_str: String, asd: String)
+
+  test("minKeyValue window") {
+    assertDataFrameEquals(
+      sqlContext.createDataFrame(List(
+        exp4("b", 1, "asd4", "asd4"),
+        exp4("a", 1, "asd1", "asd1"),
+        exp4("a", 2, "asd2", "asd1"),
+        exp4("a", 3, "asd3", "asd1")
+      )),
+      df.withColumn("asd", SparkOverwriteUDAFs.minValueByKey($"col_ord", $"col_str").over(Window.partitionBy("col_grp"))))
+  }
+
+}