You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by WeichenXu123 <gi...@git.apache.org> on 2018/04/09 09:19:56 UTC
[GitHub] spark pull request #20235: [Spark-22887][ML][TESTS][WIP] ML test for Structu...
Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20235#discussion_r180027926
--- Diff: mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala ---
@@ -34,86 +35,122 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("FPGrowth fit and transform with different data types") {
- Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt =>
- val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
- val model = new FPGrowth().setMinSupport(0.5).fit(data)
- val generatedRules = model.setMinConfidence(0.5).associationRules
- val expectedRules = spark.createDataFrame(Seq(
- (Array("2"), Array("1"), 1.0),
- (Array("1"), Array("2"), 0.75)
- )).toDF("antecedent", "consequent", "confidence")
- .withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
- .withColumn("consequent", col("consequent").cast(ArrayType(dt)))
- assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
- generatedRules.sort("antecedent").rdd.collect()))
-
- val transformed = model.transform(data)
- val expectedTransformed = spark.createDataFrame(Seq(
- (0, Array("1", "2"), Array.emptyIntArray),
- (0, Array("1", "2"), Array.emptyIntArray),
- (0, Array("1", "2"), Array.emptyIntArray),
- (0, Array("1", "3"), Array(2))
- )).toDF("id", "items", "prediction")
- .withColumn("items", col("items").cast(ArrayType(dt)))
- .withColumn("prediction", col("prediction").cast(ArrayType(dt)))
- assert(expectedTransformed.collect().toSet.equals(
- transformed.collect().toSet))
+ class DataTypeWithEncoder[A](val a: DataType)
+ (implicit val encoder: Encoder[(Int, Array[A], Array[A])])
+
+ Array(
+ new DataTypeWithEncoder[Int](IntegerType),
+ new DataTypeWithEncoder[String](StringType),
+ new DataTypeWithEncoder[Short](ShortType),
+ new DataTypeWithEncoder[Long](LongType)
+ // , new DataTypeWithEncoder[Byte](ByteType)
+ // TODO: using ByteType produces error, as Array[Byte] is handled as Binary
+ // cannot resolve 'CAST(`items` AS BINARY)' due to data type mismatch:
+ // cannot cast array<tinyint> to binary;
+ ).foreach { dt => {
+ val data = dataset.withColumn("items", col("items").cast(ArrayType(dt.a)))
+ val model = new FPGrowth().setMinSupport(0.5).fit(data)
+ val generatedRules = model.setMinConfidence(0.5).associationRules
+ val expectedRules = Seq(
+ (Array("2"), Array("1"), 1.0),
+ (Array("1"), Array("2"), 0.75)
+ ).toDF("antecedent", "consequent", "confidence")
+ .withColumn("antecedent", col("antecedent").cast(ArrayType(dt.a)))
+ .withColumn("consequent", col("consequent").cast(ArrayType(dt.a)))
+ assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
+ generatedRules.sort("antecedent").rdd.collect()))
+
+ val expectedTransformed = Seq(
+ (0, Array("1", "2"), Array.emptyIntArray),
--- End diff --
I think the "id" column should be of values "0, 1, 2, 3".
Here id column is useless, we can remove it.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org