You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@predictionio.apache.org by sh...@apache.org on 2019/06/10 01:47:56 UTC
[predictionio-template-text-classifier] branch develop updated:
Mark override methods
This is an automated email from the ASF dual-hosted git repository.
shimamoto pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/predictionio-template-text-classifier.git
The following commit(s) were added to refs/heads/develop by this push:
new a389657 Mark override methods
new 4e29572 Merge pull request #19 from takezoe/mark-override
a389657 is described below
commit a389657203dd5e3dbfe6c7b538fb296039af9804
Author: Naoki Takezoe <ta...@gmail.com>
AuthorDate: Thu May 30 20:31:06 2019 +0900
Mark override methods
---
src/main/scala/DataSource.scala | 1 +
src/main/scala/Evaluation.scala | 1 +
src/main/scala/LRAlgorithm.scala | 2 ++
src/main/scala/NBAlgorithm.scala | 2 ++
src/main/scala/Preparator.scala | 5 ++---
5 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/src/main/scala/DataSource.scala b/src/main/scala/DataSource.scala
index f4dd11c..ceb46f2 100644
--- a/src/main/scala/DataSource.scala
+++ b/src/main/scala/DataSource.scala
@@ -118,6 +118,7 @@ class TrainingData(
) extends Serializable with SanityCheck {
/** Sanity check to make sure your data is being fed in correctly. */
+ override
def sanityCheck(): Unit = {
try {
val obs : Array[Double] = data.takeSample(false, 5).map(_.label)
diff --git a/src/main/scala/Evaluation.scala b/src/main/scala/Evaluation.scala
index 60c0d49..8b3e673 100644
--- a/src/main/scala/Evaluation.scala
+++ b/src/main/scala/Evaluation.scala
@@ -11,6 +11,7 @@ case class Accuracy()
extends AverageMetric[EmptyEvaluationInfo, Query, PredictedResult, ActualResult] {
/** Method for calculating prediction accuracy. */
+ override
def calculate(
query: Query,
predicted: PredictedResult,
diff --git a/src/main/scala/LRAlgorithm.scala b/src/main/scala/LRAlgorithm.scala
index e296851..7f16cb9 100644
--- a/src/main/scala/LRAlgorithm.scala
+++ b/src/main/scala/LRAlgorithm.scala
@@ -18,6 +18,7 @@ class LRAlgorithm(val ap: LRAlgorithmParams)
@transient lazy val logger = Logger[this.type]
+ override
def train(sc: SparkContext, pd: PreparedData): LRModel = {
// Import SQLContext for creating DataFrame.
@@ -66,6 +67,7 @@ class LRAlgorithm(val ap: LRAlgorithmParams)
)
}
+ override
def predict(model: LRModel, query: Query): PredictedResult = {
model.predict(query.text)
}
diff --git a/src/main/scala/NBAlgorithm.scala b/src/main/scala/NBAlgorithm.scala
index 6d5c164..4915f92 100644
--- a/src/main/scala/NBAlgorithm.scala
+++ b/src/main/scala/NBAlgorithm.scala
@@ -21,6 +21,7 @@ class NBAlgorithm(
) extends P2LAlgorithm[PreparedData, NBModel, Query, PredictedResult] {
/** Train your model. */
+ override
def train(sc: SparkContext, pd: PreparedData): NBModel = {
// Fit a Naive Bayes model using the prepared data.
val nb: NaiveBayesModel = NaiveBayes.train(pd.transformedData, ap.lambda)
@@ -32,6 +33,7 @@ class NBAlgorithm(
}
/** Prediction method for trained model. */
+ override
def predict(model: NBModel, query: Query): PredictedResult = {
model.predict(query.text)
}
diff --git a/src/main/scala/Preparator.scala b/src/main/scala/Preparator.scala
index 1f4d51d..98d1129 100644
--- a/src/main/scala/Preparator.scala
+++ b/src/main/scala/Preparator.scala
@@ -4,10 +4,8 @@ import org.apache.predictionio.controller.PPreparator
import org.apache.predictionio.controller.Params
import org.apache.spark.SparkContext
-import org.apache.spark.SparkContext._
import org.apache.spark.mllib.feature.{IDF, IDFModel, HashingTF}
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -31,6 +29,7 @@ case class PreparatorParams(
class Preparator(pp: PreparatorParams)
extends PPreparator[TrainingData, PreparedData] {
+ override
def prepare(sc: SparkContext, td: TrainingData): PreparedData = {
val tfHasher = new TFHasher(pp.numFeatures, pp.nGram, td.stopWords)
@@ -106,7 +105,7 @@ class TFIDFModel(
val idf: IDFModel
) extends Serializable {
- /** trasform text to tf-idf vector. */
+ /** transform text to tf-idf vector. */
def transform(text: String): Vector = {
// Map(n-gram -> document tf)
idf.transform(hasher.hashTF(text))