You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/13 01:53:49 UTC

spark git commit: [SPARK-7528] [MLLIB] make RankingMetrics Java-friendly

Repository: spark
Updated Branches:
  refs/heads/master 00e7b09a0 -> 2713bc65a


[SPARK-7528] [MLLIB] make RankingMetrics Java-friendly

`RankingMetrics` contains a ClassTag, which is hard to create in Java. This PR adds a factory method `of` for Java users. coderxiang

Author: Xiangrui Meng <me...@databricks.com>

Closes #6098 from mengxr/SPARK-7528 and squashes the following commits:

e5d57ae [Xiangrui Meng] make RankingMetrics Java-friendly


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2713bc65
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2713bc65
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2713bc65

Branch: refs/heads/master
Commit: 2713bc65af1e0e81edd5fad0338e34fd127391f9
Parents: 00e7b09
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue May 12 16:53:47 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 12 16:53:47 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/evaluation/RankingMetrics.scala | 27 +++++++--
 .../evaluation/JavaRankingMetricsSuite.java     | 64 ++++++++++++++++++++
 2 files changed, 87 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2713bc65/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 93a7353..b9b54b9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.mllib.evaluation
 
+import java.{lang => jl}
+
+import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
 
 import org.apache.spark.Logging
-import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
 import org.apache.spark.rdd.RDD
 
 /**
@@ -71,7 +74,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
         logWarning("Empty ground truth set, check input data")
         0.0
       }
-    }.mean
+    }.mean()
   }
 
   /**
@@ -100,7 +103,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
         logWarning("Empty ground truth set, check input data")
         0.0
       }
-    }.mean
+    }.mean()
   }
 
   /**
@@ -146,7 +149,23 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
         logWarning("Empty ground truth set, check input data")
         0.0
       }
-    }.mean
+    }.mean()
   }
 
 }
+
+@Experimental
+object RankingMetrics {
+
+  /**
+   * Creates a [[RankingMetrics]] instance (for Java users).
+   * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs
+   */
+  def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = {
+    implicit val tag = JavaSparkContext.fakeClassTag[E]
+    val rdd = predictionAndLabels.rdd.map { case (predictions, labels) =>
+      (predictions.asScala.toArray, labels.asScala.toArray)
+    }
+    new RankingMetrics(rdd)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2713bc65/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
new file mode 100644
index 0000000..effc8a1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+
+import scala.Tuple2;
+import scala.Tuple2$;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaRankingMetricsSuite implements Serializable {
+  private transient JavaSparkContext sc;
+  private transient JavaRDD<Tuple2<ArrayList<Integer>, ArrayList<Integer>>> predictionAndLabels;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
+    predictionAndLabels = sc.parallelize(Lists.newArrayList(
+      Tuple2$.MODULE$.apply(
+        Lists.newArrayList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Lists.newArrayList(1, 2, 3, 4, 5)),
+      Tuple2$.MODULE$.apply(
+        Lists.newArrayList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Lists.newArrayList(1, 2, 3)),
+      Tuple2$.MODULE$.apply(
+        Lists.newArrayList(1, 2, 3, 4, 5), Lists.<Integer>newArrayList())), 2);
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void rankingMetrics() {
+    @SuppressWarnings("unchecked")
+    RankingMetrics<?> metrics = RankingMetrics.of(predictionAndLabels);
+    Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
+    Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org