You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@datafu.apache.org by ey...@apache.org on 2022/11/10 08:15:14 UTC

[datafu] branch main updated: DATAFU-170 improve broadcast join

This is an automated email from the ASF dual-hosted git repository.

eyal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafu.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b4d132  DATAFU-170 improve broadcast join
0b4d132 is described below

commit 0b4d1327bcedee9da12c4c74412de2595f69ec27
Author: Shaked <sa...@paypal.com>
AuthorDate: Wed Nov 9 16:02:55 2022 +0200

    DATAFU-170 improve broadcast join
---
 .../src/main/resources/pyspark_utils/df_utils.py   |  5 +-
 .../src/main/scala/datafu/spark/DataFrameOps.scala |  6 ++-
 .../src/main/scala/datafu/spark/SparkDFUtils.scala | 20 ++++----
 .../test/resources/python_tests/df_utils_tests.py  |  2 +-
 .../test/scala/datafu/spark/TestSparkDFUtils.scala | 55 +++++++++-------------
 5 files changed, 41 insertions(+), 47 deletions(-)

diff --git a/datafu-spark/src/main/resources/pyspark_utils/df_utils.py b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
index 31b989d..adf4784 100644
--- a/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
+++ b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
@@ -99,7 +99,7 @@ def join_skewed(df_left, df_right, join_exprs, num_shards = 30, join_type="inner
     return DataFrame(jdf, df_left.sql_ctx)
 
 
-def broadcast_join_skewed(not_skewed_df, skewed_df, join_col, number_of_custs_to_broadcast, filter_cnt):
+def broadcast_join_skewed(not_skewed_df, skewed_df, join_col, number_of_custs_to_broadcast, filter_cnt, join_type):
     """
     Suitable to perform a join in cases when one DF is skewed and the other is not skewed.
     splits both of the DFs to two parts according to the skewed keys.
@@ -110,9 +110,10 @@ def broadcast_join_skewed(not_skewed_df, skewed_df, join_col, number_of_custs_to
     :param join_col: join column
     :param number_of_custs_to_broadcast: number of custs to broadcast
     :param filter_cnt: filter out unskewed rows from the boardcast to ease limit calculation
+    :param join_type: join type
     :return: DataFrame representing the data after the operation
     """
-    jdf = _get_utils(skewed_df).broadcastJoinSkewed(not_skewed_df._jdf, skewed_df._jdf, join_col, number_of_custs_to_broadcast, filter_cnt)
+    jdf = _get_utils(skewed_df).broadcastJoinSkewed(not_skewed_df._jdf, skewed_df._jdf, join_col, number_of_custs_to_broadcast, filter_cnt, join_type)
     return DataFrame(jdf, not_skewed_df.sql_ctx)
 
 
diff --git a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
index 5b4f42d..64c1f8b 100644
--- a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
@@ -88,12 +88,14 @@ object DataFrameOps {
     def broadcastJoinSkewed(skewed: DataFrame,
                             joinCol: String,
                             numberCustsToBroadcast: Int,
-                            filterCnt: Option[Long] = None): DataFrame =
+                            filterCnt: Option[Long] = None,
+                            joinType: String = "inner"): DataFrame =
       SparkDFUtils.broadcastJoinSkewed(df,
                                        skewed,
                                        joinCol,
                                        numberCustsToBroadcast,
-                                       filterCnt)
+                                       filterCnt,
+                                       joinType)
 
     def joinSkewed(notSkewed: DataFrame,
                    joinExprs: Column,
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
index 81bdb45..240bcbc 100644
--- a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
@@ -89,12 +89,14 @@ class SparkDFUtilsBridge {
                           skewed: DataFrame,
                           joinCol: String,
                           numRowsToBroadcast: Int,
-                          filterCnt: Long): DataFrame = {
+                          filterCnt: Long,
+                          joinType: String): DataFrame = {
     SparkDFUtils.broadcastJoinSkewed(notSkewed = notSkewed,
-                                     skewed = skewed,
-                                     joinCol = joinCol,
-                                     numRowsToBroadcast = numRowsToBroadcast,
-                                     filterCnt = Option(filterCnt))
+      skewed = skewed,
+      joinCol = joinCol,
+      numRowsToBroadcast = numRowsToBroadcast,
+      filterCnt = Option(filterCnt),
+      joinType)
   }
 
   def joinWithRange(dfSingle: DataFrame,
@@ -335,7 +337,8 @@ object SparkDFUtils {
                           skewed: DataFrame,
                           joinCol: String,
                           numRowsToBroadcast: Int,
-                          filterCnt: Option[Long] = None): DataFrame = {
+                          filterCnt: Option[Long] = None,
+                          joinType: String = "inner"): DataFrame = {
     val ss = notSkewed.sparkSession
     import ss.implicits._
     val keyCount = skewed
@@ -355,12 +358,11 @@ object SparkDFUtils {
       .join(broadcast(skewedKeys), $"skew_join_key" === col(joinCol), "left")
       .withColumn("is_skewed_record", col("skew_join_key").isNotNull)
       .drop("skew_join_key")
-      .persist(StorageLevel.DISK_ONLY)
 
     // broadcast map-join, sending the notSkewed data
     val bigRecordsJnd =
       broadcast(notSkewedWithSkewIndicator.filter("is_skewed_record"))
-        .join(skewed, joinCol)
+        .join(skewed.join(broadcast(skewedKeys), $"skew_join_key" === col(joinCol)).drop("skew_join_key"), List(joinCol), joinType)
 
     // regular join for the rest
     val skewedWithoutSkewedKeys = skewed
@@ -369,7 +371,7 @@ object SparkDFUtils {
       .drop("skew_join_key")
     val smallRecordsJnd = notSkewedWithSkewIndicator
       .filter("not is_skewed_record")
-      .join(skewedWithoutSkewedKeys, joinCol)
+      .join(skewedWithoutSkewedKeys, List(joinCol), joinType)
 
     smallRecordsJnd
       .union(bigRecordsJnd)
diff --git a/datafu-spark/src/test/resources/python_tests/df_utils_tests.py b/datafu-spark/src/test/resources/python_tests/df_utils_tests.py
index 1e58176..eec47ec 100644
--- a/datafu-spark/src/test/resources/python_tests/df_utils_tests.py
+++ b/datafu-spark/src/test/resources/python_tests/df_utils_tests.py
@@ -73,7 +73,7 @@ func_joinSkewed_res = df_utils.join_skewed(df_left=df_people2.alias("df1"), df_r
 func_joinSkewed_res.registerTempTable("joinSkewed")
 
 func_broadcastJoinSkewed_res = df_utils.broadcast_join_skewed(not_skewed_df=df_people2, skewed_df=simpleDF, join_col="id",
-                                                              number_of_custs_to_broadcast=5, filter_cnt=0)
+                                                              number_of_custs_to_broadcast=5, filter_cnt=0, join_type="inner")
 func_broadcastJoinSkewed_res.registerTempTable("broadcastJoinSkewed")
 
 dfRange = sqlContext.createDataFrame([
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
index 91458ee..fa4060c 100644
--- a/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
@@ -253,39 +253,28 @@ class DataFrameOpsTests extends FunSuite with DataFrameSuiteBase {
     assertDataFrameEquals(expected, actual)
   }
 
-  test("broadcastJoinSkewed") {
-    val skewedList = List(("1", "a"),
-                          ("1", "b"),
-                          ("1", "c"),
-                          ("1", "d"),
-                          ("1", "e"),
-                          ("2", "k"),
-                          ("0", "k"))
-    val skewed =
-      sqlContext.createDataFrame(skewedList).toDF("key", "val_skewed")
+  test("randomJoinSkewedTests") {
+    def makeSkew(i: Int): Int = {
+      if (i < 200) 10 else 50
+    }
+
+    val skewed = sqlContext.createDataFrame((1 to 500).map(i => ((Math.random * makeSkew(i)).toInt, s"str$i")))
+      .toDF("key", "val_skewed")
     val notSkewed = sqlContext
-      .createDataFrame((1 to 10).map(i => (i.toString, s"str$i")))
+      .createDataFrame((1 to 500).map(i => ((Math.random * 50).toInt, s"str$i")))
       .toDF("key", "val")
 
-    val expected = sqlContext
-      .createDataFrame(
-        List(
-          ("1", "str1", "a"),
-          ("1", "str1", "b"),
-          ("1", "str1", "c"),
-          ("1", "str1", "d"),
-          ("1", "str1", "e"),
-          ("2", "str2", "k")
-        ))
-      .toDF("key", "val", "val_skewed")
-
-    val actual1 = notSkewed.broadcastJoinSkewed(skewed, "key", 1)
-
-    assertDataFrameEquals(expected, actual1.sort($"val_skewed"))
+    val expected = notSkewed.join(skewed, Seq("key")).sort($"key", $"val", $"val_skewed")
+    val actual1 = notSkewed.broadcastJoinSkewed(skewed, "key", 1).sort($"key", $"val", $"val_skewed")
+    assertDataFrameEquals(expected, actual1)
 
-    val actual2 = notSkewed.broadcastJoinSkewed(skewed, "key", 2)
+    val leftExpected = notSkewed.join(skewed, Seq("key"), "left").sort($"key", $"val", $"val_skewed")
+    val actual2 = notSkewed.broadcastJoinSkewed(skewed, "key", 1, joinType = "left").sort($"key", $"val", $"val_skewed")
+    assertDataFrameEquals(leftExpected, actual2)
 
-    assertDataFrameEquals(expected, actual2.sort($"val_skewed"))
+    val rightExpected = notSkewed.join(skewed, Seq("key"), "right").sort($"key", $"val", $"val_skewed")
+    val actual3 = notSkewed.broadcastJoinSkewed(skewed, "key", 2, joinType = "right").sort($"key", $"val", $"val_skewed")
+    assertDataFrameEquals(rightExpected, actual3)
   }
 
   // because of nulls in expected data, an actual schema needs to be used
@@ -375,9 +364,9 @@ class DataFrameOpsTests extends FunSuite with DataFrameSuiteBase {
 
     assertDataFrameEquals(expected, actual)
   }
-  
+
   test("test_explode_array") {
- 
+
     val input = spark.createDataFrame(Seq(
       (0.0, Seq("Hi", "I heard", "about", "Spark")),
       (0.0, Seq("I wish", "Java", "could use", "case", "classes")),
@@ -385,9 +374,9 @@ class DataFrameOpsTests extends FunSuite with DataFrameSuiteBase {
       (0.0, Seq()),
       (1.0, null)
     )).toDF("label", "sentence_arr")
-    
+
     val actual = input.explodeArray($"sentence_arr", "token")
-    
+
     val expected = spark.createDataFrame(Seq(
       (0.0, Seq("Hi", "I heard", "about", "Spark"),"Hi", "I heard", "about", "Spark",null),
       (0.0, Seq("I wish", "Java", "could use", "case", "classes"),"I wish", "Java", "could use", "case", "classes"),
@@ -395,7 +384,7 @@ class DataFrameOpsTests extends FunSuite with DataFrameSuiteBase {
       (0.0, Seq(),null,null,null,null,null),
       (1.0, null,null,null,null,null,null)
     )).toDF("label", "sentence_arr","token0","token1","token2","token3","token4")
-    
+
     assertDataFrameEquals(expected, actual)
   }
 }