You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "zhengruifeng (via GitHub)" <gi...@apache.org> on 2023/03/13 07:54:39 UTC

[GitHub] [spark] zhengruifeng commented on pull request #40263: [SPARK-42659][ML] Reimplement `FPGrowthModel.transform` with dataframe operations

zhengruifeng commented on PR #40263:
URL: https://github.com/apache/spark/pull/40263#issuecomment-1465662299

   I did a quick test with dataset `T10I4D100K` in http://fimi.uantwerpen.be/data/ 
   
   fit:
   ```
   scala> val df = sc.textFile("/Users/ruifeng.zheng/.dev/data/T10I4D100K.dat").map(_.split(" ")).toDF("items")
   df: org.apache.spark.sql.DataFrame = [items: array<string>]
   
   scala> df.count
   res16: Long = 100000
   
   scala> val model = new FPGrowth().setMinSupport(0.01).setMinConfidence(0.01).fit(df)
   model: org.apache.spark.ml.fpm.FPGrowthModel = FPGrowthModel: uid=fpgrowth_92901252345a, numTrainingRecords=100000
   
   scala> model.freqItemsets.count
   res17: Long = 385                                                               
   
   scala> model.associationRules.count
   res18: Long = 21                                                                
   
   scala> model.save("/tmp/fpm.model")
   ```
   
   
   transformation:
   ```
   import org.apache.spark.ml.fpm._
   val df = sc.textFile("/Users/ruifeng.zheng/.dev/data/T10I4D100K.dat").map(_.split(" ")).toDF("items")
   df.cache()
   df.count()
   
   val model = FPGrowthModel.load("/tmp/fpm.model")
   model.transform(df).explain("extended")
   Seq.range(0, 100).foreach{i => model.transform(df).count()} // warms up
   val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i => model.transform(df).count()}; val end = System.currentTimeMillis; end - start
   val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i => model.transform(df).schema}; val end = System.currentTimeMillis; end - start
   ```
   
   master:
   ```
   master
   ```
   scala> val model = FPGrowthModel.load("/tmp/fpm.model")
   model: org.apache.spark.ml.fpm.FPGrowthModel = FPGrowthModel: uid=fpgrowth_92901252345a, numTrainingRecords=100000
   
   scala> model.transform(df).explain("extended")
   == Parsed Logical Plan ==
   'Project [items#5, UDF('items) AS prediction#70]
   +- Project [value#2 AS items#5]
      +- SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
         +- ExternalRDD [obj#1]
   
   == Analyzed Logical Plan ==
   items: array<string>, prediction: array<string>
   Project [items#5, UDF(items#5) AS prediction#70]
   +- Project [value#2 AS items#5]
      +- SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
         +- ExternalRDD [obj#1]
   
   == Optimized Logical Plan ==
   Project [items#5, UDF(items#5) AS prediction#70]
   +- InMemoryRelation [items#5], StorageLevel(disk, memory, deserialized, 1 replicas)
         +- *(1) Project [value#2 AS items#5]
            +- *(1) SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
               +- Scan[obj#1]
   
   == Physical Plan ==
   *(1) Project [items#5, UDF(items#5) AS prediction#70]
   +- InMemoryTableScan [items#5]
         +- InMemoryRelation [items#5], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- *(1) Project [value#2 AS items#5]
                  +- *(1) SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
                     +- Scan[obj#1]
   
   
   scala> Seq.range(0, 100).foreach{i => model.transform(df).count()} // warms up
   
   scala> val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i => model.transform(df).count()}; val end = System.currentTimeMillis; end - start
   start: Long = 1678692855532
   end: Long = 1678692860098
   res4: Long = 4566
   
   scala> val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i => model.transform(df).schema}; val end = System.currentTimeMillis; end - start
   start: Long = 1678692860277
   end: Long = 1678692862372
   res5: Long = 2095
   ```
   
   this PR:
   ```
   scala> model.transform(df).explain("extended")
   == Parsed Logical Plan ==
   'Project [items#5, CASE WHEN NOT isnull('items) THEN aggregate('prediction, cast(array() as array<string>), lambdafunction(CASE WHEN forall(lambda 'y_1[antecedent], lambdafunction(array_contains('items, lambda 'x_2), lambda 'x_2, false)) THEN array_union(lambda 'x_0, array_except(lambda 'y_1[consequent], 'items)) ELSE lambda 'x_0 END, lambda 'x_0, lambda 'y_1, false), lambdafunction(lambda 'x_3, lambda 'x_3, false)) ELSE cast(array() as array<string>) END AS prediction#72]
   +- Join Cross
      :- Project [value#2 AS items#5]
      :  +- SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
      :     +- ExternalRDD [obj#1]
      +- ResolvedHint (strategy=broadcast)
         +- Aggregate [collect_list(struct(antecedent, antecedent#57, consequent, consequent#58), 0, 0) AS prediction#68]
            +- Filter (NOT isnull(antecedent#57) AND NOT isnull(consequent#58))
               +- LogicalRDD [antecedent#57, consequent#58, confidence#59, lift#60, support#61], false
   
   == Analyzed Logical Plan ==
   items: array<string>, prediction: array<string>
   Project [items#5, CASE WHEN NOT isnull(items#5) THEN aggregate(prediction#68, cast(array() as array<string>), lambdafunction(CASE WHEN forall(lambda y_1#74.antecedent, lambdafunction(array_contains(items#5, lambda x_2#76), lambda x_2#76, false)) THEN array_union(lambda x_0#73, array_except(lambda y_1#74.consequent, items#5)) ELSE lambda x_0#73 END, lambda x_0#73, lambda y_1#74, false), lambdafunction(lambda x_3#75, lambda x_3#75, false)) ELSE cast(array() as array<string>) END AS prediction#72]
   +- Join Cross
      :- Project [value#2 AS items#5]
      :  +- SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
      :     +- ExternalRDD [obj#1]
      +- ResolvedHint (strategy=broadcast)
         +- Aggregate [collect_list(struct(antecedent, antecedent#57, consequent, consequent#58), 0, 0) AS prediction#68]
            +- Filter (NOT isnull(antecedent#57) AND NOT isnull(consequent#58))
               +- LogicalRDD [antecedent#57, consequent#58, confidence#59, lift#60, support#61], false
   
   == Optimized Logical Plan ==
   Project [items#5, CASE WHEN isnotnull(items#5) THEN aggregate(prediction#68, [], lambdafunction(CASE WHEN forall(lambda y_1#74.antecedent, lambdafunction(array_contains(items#5, lambda x_2#76), lambda x_2#76, false)) THEN array_union(lambda x_0#73, array_except(lambda y_1#74.consequent, items#5)) ELSE lambda x_0#73 END, lambda x_0#73, lambda y_1#74, false), lambdafunction(lambda x_3#75, lambda x_3#75, false)) ELSE [] END AS prediction#72]
   +- Join Cross, rightHint=(strategy=broadcast)
      :- InMemoryRelation [items#5], StorageLevel(disk, memory, deserialized, 1 replicas)
      :     +- *(1) Project [value#2 AS items#5]
      :        +- *(1) SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
      :           +- Scan[obj#1]
      +- Aggregate [collect_list(struct(antecedent, antecedent#57, consequent, consequent#58), 0, 0) AS prediction#68]
         +- Project [antecedent#57, consequent#58]
            +- LogicalRDD [antecedent#57, consequent#58, confidence#59, lift#60, support#61], false
   
   == Physical Plan ==
   AdaptiveSparkPlan isFinalPlan=false
   +- Project [items#5, CASE WHEN isnotnull(items#5) THEN aggregate(prediction#68, [], lambdafunction(CASE WHEN forall(lambda y_1#74.antecedent, lambdafunction(array_contains(items#5, lambda x_2#76), lambda x_2#76, false)) THEN array_union(lambda x_0#73, array_except(lambda y_1#74.consequent, items#5)) ELSE lambda x_0#73 END, lambda x_0#73, lambda y_1#74, false), lambdafunction(lambda x_3#75, lambda x_3#75, false)) ELSE [] END AS prediction#72]
      +- BroadcastNestedLoopJoin BuildRight, Cross
         :- InMemoryTableScan [items#5]
         :     +- InMemoryRelation [items#5], StorageLevel(disk, memory, deserialized, 1 replicas)
         :           +- *(1) Project [value#2 AS items#5]
         :              +- *(1) SerializeFromObject [mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), StringType, ObjectType(class java.lang.String)), true, false, true), input[0, [Ljava.lang.String;, true], None) AS value#2]
         :                 +- Scan[obj#1]
         +- BroadcastExchange IdentityBroadcastMode, [plan_id=117]
            +- ObjectHashAggregate(keys=[], functions=[collect_list(struct(antecedent, antecedent#57, consequent, consequent#58), 0, 0)], output=[prediction#68])
               +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=114]
                  +- ObjectHashAggregate(keys=[], functions=[partial_collect_list(struct(antecedent, antecedent#57, consequent, consequent#58), 0, 0)], output=[buf#95])
                     +- Project [antecedent#57, consequent#58]
                        +- Scan ExistingRDD[antecedent#57,consequent#58,confidence#59,lift#60,support#61]
   
   
   scala> Seq.range(0, 100).foreach{i => model.transform(df).count()} // warms up
   
   scala> val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i => model.transform(df).count()}; val end = System.currentTimeMillis; end - start
   start: Long = 1678693708534
   end: Long = 1678693713436
   res6: Long = 4902
   
   scala> val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i => model.transform(df).schema}; val end = System.currentTimeMillis; end - start
   start: Long = 1678693713596
   end: Long = 1678693713807
   res7: Long = 211
   ```
   
   
   the transformation is a bit slower 4566 -> 4902, but when we need to analyze the dataframe it will be much faster 2095 -> 211 since the `collect` execution is delayed.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org