add RMSE evaluation on runALS.
Project: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/commit/7bdebb5c Tree: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/tree/7bdebb5c Diff: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/diff/7bdebb5c Branch: refs/heads/master Commit: 7bdebb5cc7fac65a451079230caf87f0b0253afa Parents: 851f62a Author: DO YUNG YOON <[email protected]> Authored: Tue May 8 12:13:40 2018 +0900 Committer: DO YUNG YOON <[email protected]> Committed: Tue May 8 12:13:40 2018 +0900 ---------------------------------------------------------------------- .../s2jobs/task/custom/process/ALSModelProcess.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/7bdebb5c/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala ---------------------------------------------------------------------- diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala index 9ffb341..dfbefbf 100644 --- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala +++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala @@ -3,6 +3,7 @@ package org.apache.s2graph.s2jobs.task.custom.process import java.io.File import annoy4s._ +import org.apache.spark.ml.evaluation.RegressionEvaluator //import org.apache.spark.ml.nn.Annoy //import annoy4s.{Angular, Annoy} @@ -17,6 +18,9 @@ object ALSModelProcess { def runALS(ss: SparkSession, conf: TaskConf, dataFrame: DataFrame): DataFrame = { + // split + val Array(training, test) = dataFrame.randomSplit(Array(0.8, 0.2)) + // als model params. val rank = conf.options.getOrElse("rank", "10").toInt val maxIter = conf.options.getOrElse("maxIter", "5").toInt @@ -35,7 +39,16 @@ object ALSModelProcess { .setItemCol(itemCol) .setRatingCol(ratingCol) - val model = als.fit(dataFrame) + val model = als.fit(training) + + val predictions = model.transform(test) + val evaluator = new RegressionEvaluator() + .setMetricName("rmse") + .setLabelCol(ratingCol) + .setPredictionCol("prediction") + + val rmse = evaluator.evaluate(predictions) + println(s"RMSE: ${rmse}") model.itemFactors }
