You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@s2graph.apache.org by st...@apache.org on 2018/05/14 12:29:59 UTC

[15/25] incubator-s2graph git commit: add RMSE evaluation on runALS.

add RMSE evaluation on runALS.


Project: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/commit/7bdebb5c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/tree/7bdebb5c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/diff/7bdebb5c

Branch: refs/heads/master
Commit: 7bdebb5cc7fac65a451079230caf87f0b0253afa
Parents: 851f62a
Author: DO YUNG YOON <st...@apache.org>
Authored: Tue May 8 12:13:40 2018 +0900
Committer: DO YUNG YOON <st...@apache.org>
Committed: Tue May 8 12:13:40 2018 +0900

----------------------------------------------------------------------
 .../s2jobs/task/custom/process/ALSModelProcess.scala | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/7bdebb5c/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
----------------------------------------------------------------------
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
index 9ffb341..dfbefbf 100644
--- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
@@ -3,6 +3,7 @@ package org.apache.s2graph.s2jobs.task.custom.process
 import java.io.File
 
 import annoy4s._
+import org.apache.spark.ml.evaluation.RegressionEvaluator
 //import org.apache.spark.ml.nn.Annoy
 
 //import annoy4s.{Angular, Annoy}
@@ -17,6 +18,9 @@ object ALSModelProcess {
   def runALS(ss: SparkSession,
              conf: TaskConf,
              dataFrame: DataFrame): DataFrame = {
+    // split
+    val Array(training, test) = dataFrame.randomSplit(Array(0.8, 0.2))
+
     // als model params.
     val rank = conf.options.getOrElse("rank", "10").toInt
     val maxIter = conf.options.getOrElse("maxIter", "5").toInt
@@ -35,7 +39,16 @@ object ALSModelProcess {
       .setItemCol(itemCol)
       .setRatingCol(ratingCol)
 
-    val model = als.fit(dataFrame)
+    val model = als.fit(training)
+
+    val predictions = model.transform(test)
+    val evaluator = new RegressionEvaluator()
+      .setMetricName("rmse")
+      .setLabelCol(ratingCol)
+      .setPredictionCol("prediction")
+
+    val rmse = evaluator.evaluate(predictions)
+    println(s"RMSE: ${rmse}")
 
     model.itemFactors
   }