You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/20 11:06:22 UTC
[ignite] branch master updated: IGNITE-10902: [ML] Implement a few
regression metrics in one RegressionMetrics class
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 481e5a5 IGNITE-10902: [ML] Implement a few regression metrics in one RegressionMetrics class
481e5a5 is described below
commit 481e5a5f61d30fd159093a6a74dd9dd718c9bab8
Author: zaleslaw <za...@gmail.com>
AuthorDate: Wed Feb 20 14:06:04 2019 +0300
IGNITE-10902: [ML] Implement a few regression metrics in
one RegressionMetrics class
This closes #5827
---
.../spark/LogRegFromSparkThroughPMMLExample.java | 15 +--
.../modelparser/DecisionTreeFromSparkExample.java | 9 +-
.../spark/modelparser/GBTFromSparkExample.java | 9 +-
.../spark/modelparser/LogRegFromSparkExample.java | 9 +-
.../modelparser/RandomForestFromSparkExample.java | 9 +-
.../spark/modelparser/SVMFromSparkExample.java | 9 +-
.../examples/ml/knn/KNNClassificationExample.java | 7 +-
.../DiscreteNaiveBayesTrainerExample.java | 7 +-
.../GaussianNaiveBayesTrainerExample.java | 7 +-
.../BaggedLogisticRegressionSGDTrainerExample.java | 7 +-
.../LogisticRegressionSGDTrainerExample.java | 7 +-
.../ml/selection/cv/CrossValidationExample.java | 13 +-
.../ml/selection/scoring/EvaluatorExample.java | 11 +-
.../selection/scoring/MultipleMetricsExample.java | 9 +-
.../ml/svm/SVMBinaryClassificationExample.java | 7 +-
.../ml/tutorial/Step_10_Scaling_With_Stacking.java | 11 +-
.../ml/tutorial/Step_1_Read_and_Learn.java | 11 +-
.../examples/ml/tutorial/Step_2_Imputing.java | 11 +-
.../examples/ml/tutorial/Step_3_Categorial.java | 11 +-
.../Step_3_Categorial_with_One_Hot_Encoder.java | 11 +-
.../examples/ml/tutorial/Step_4_Add_age_fare.java | 11 +-
.../examples/ml/tutorial/Step_5_Scaling.java | 11 +-
.../ml/tutorial/Step_5_Scaling_with_Pipeline.java | 11 +-
.../ignite/examples/ml/tutorial/Step_6_KNN.java | 11 +-
.../ml/tutorial/Step_7_Split_train_test.java | 11 +-
.../ignite/examples/ml/tutorial/Step_8_CV.java | 13 +-
.../ml/tutorial/Step_8_CV_with_Param_Grid.java | 13 +-
.../Step_8_CV_with_Param_Grid_and_metrics.java | 19 +--
.../examples/ml/tutorial/Step_9_Go_to_LogReg.java | 13 +-
.../ml/sparkmodelparser/SparkModelParser.java | 100 +++++++++-------
...ClassificationEvaluator.java => Evaluator.java} | 11 +-
.../metric/{Accuracy.java => AbstractMetrics.java} | 51 ++++----
.../metric/{ClassMetric.java => MetricValues.java} | 30 ++---
.../metric/{ => classification}/Accuracy.java | 6 +-
.../BinaryClassificationMetricValues.java | 20 +---
.../BinaryClassificationMetrics.java | 30 ++---
.../metric/{ => classification}/ClassMetric.java | 4 +-
.../metric/{ => classification}/Fmeasure.java | 5 +-
.../metric/{ => classification}/Precision.java | 5 +-
.../metric/{ => classification}/Recall.java | 5 +-
.../package-info.java} | 21 +---
.../UnknownClassLabelException.java | 2 +-
.../package-info.java} | 21 +---
.../metric/regression/RegressionMetricValues.java | 71 +++++++++++
.../metric/regression/RegressionMetrics.java | 66 +++++++++++
.../package-info.java} | 21 +---
.../ignite/ml/selection/SelectionTestSuite.java | 9 +-
.../ml/selection/cv/CrossValidationTest.java | 15 +--
.../BinaryClassificationEvaluatorTest.java | 13 +-
.../selection/scoring/evaluator/EvaluatorTest.java | 21 ++--
.../scoring/evaluator/RegressionEvaluatorTest.java | 131 +++++++++++++++++++++
.../metric/{ => classification}/AccuracyTest.java | 6 +-
.../BinaryClassificationMetricsTest.java | 14 ++-
.../BinaryClassificationMetricsValuesTest.java | 2 +-
.../metric/{ => classification}/FmeasureTest.java | 5 +-
.../metric/{ => classification}/PrecisionTest.java | 5 +-
.../metric/{ => classification}/RecallTest.java | 5 +-
.../metric/regression/RegressionMetricsTest.java | 103 ++++++++++++++++
58 files changed, 729 insertions(+), 372 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java
index acb5cf1..fdcea6a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java
@@ -17,20 +17,14 @@
package org.apache.ignite.examples.ml.inference.spark;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.io.InputStream;
-import javax.xml.bind.JAXBException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
import org.dmg.pmml.PMML;
@@ -39,6 +33,9 @@ import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;
+import javax.xml.bind.JAXBException;
+import java.io.*;
+
/**
* Run logistic regression model loaded from PMML file. The PMML file was generated by Spark MLLib toPMML operator.
* <p>
@@ -63,7 +60,7 @@ public class LogRegFromSparkThroughPMMLExample {
System.out.println(">>> Logistic regression model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
(k, v) -> v.copyOfRange(1, v.size()),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
index 4e8a2d3..3af9916 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,12 +24,14 @@ import org.apache.ignite.examples.ml.tutorial.TitanicUtils;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* Run Decision Tree model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -69,7 +70,7 @@ public class DecisionTreeFromSparkExample {
System.out.println(">>> DT: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java
index b85129d..33e5cca 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,11 +25,13 @@ import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+import java.io.FileNotFoundException;
+
/**
* Run Gradient Boosted trees model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -69,7 +70,7 @@ public class GBTFromSparkExample {
System.out.println(">>> GBT: " + mdl.toString(true));
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java
index 2a416da..c927f44 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,11 +25,13 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+import java.io.FileNotFoundException;
+
/**
* Run logistic regression model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -69,7 +70,7 @@ public class LogRegFromSparkExample {
System.out.println(">>> Logistic regression model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
index 4c040b3..1bfe41f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,11 +25,13 @@ import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+import java.io.FileNotFoundException;
+
/**
* Run Random Forest model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -69,7 +70,7 @@ public class RandomForestFromSparkExample {
System.out.println(">>> Random Forest model: " + mdl.toString(true));
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java
index 5ce177d..888bd54 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,12 +24,14 @@ import org.apache.ignite.examples.ml.tutorial.TitanicUtils;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
+import java.io.FileNotFoundException;
+
/**
* Run SVM model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -69,7 +70,7 @@ public class SVMFromSparkExample {
System.out.println(">>> SVM: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index ec25006..8a2e095 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.knn;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -27,10 +26,12 @@ import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run kNN multi-class classification trainer ({@link KNNClassificationTrainer}) over distributed dataset.
* <p>
@@ -71,7 +72,7 @@ public class KNNClassificationExample {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(NNStrategy.WEIGHTED);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
index fff298b..4114f2d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.naivebayes;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,10 +24,12 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run naive Bayes classification model based on <a href=https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes">
* naive Bayes classifier</a> algorithm ({@link DiscreteNaiveBayesTrainer}) over distributed cache.
@@ -74,7 +75,7 @@ public class DiscreteNaiveBayesTrainerExample {
System.out.println(">>> Discrete Naive Bayes model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
index 4459566..c98ad62 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.naivebayes;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,10 +24,12 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run naive Bayes classification model based on <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier"> naive
* Bayes classifier</a> algorithm ({@link GaussianNaiveBayesTrainer}) over distributed cache.
@@ -73,7 +74,7 @@ public class GaussianNaiveBayesTrainerExample {
System.out.println(">>> Naive Bayes model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
index c9b10b1..8de06a6 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.regression.logistic.bagged;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -31,11 +29,14 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpda
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.trainers.TrainerTransformers;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+
/**
* This example shows how bagging technique may be applied to arbitrary trainer.
* As an example (a bit synthetic) logistic regression is considered.
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 2dab1af..a7c5ba9 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.regression.logistic.binary;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -28,10 +27,12 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpda
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run logistic regression model based on <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent">
* stochastic gradient descent</a> algorithm ({@link LogisticRegressionSGDTrainer}) over distributed cache.
@@ -84,7 +85,7 @@ public class LogisticRegressionSGDTrainerExample {
System.out.println(">>> Logistic regression model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
index 462186c..8f06e0f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.selection.cv;
-import java.util.Arrays;
-import java.util.Random;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,12 +24,15 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.selection.cv.CrossValidation;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetrics;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.util.Arrays;
+import java.util.Random;
+
/**
* Run <a href="https://en.wikipedia.org/wiki/Decision_tree">decision tree</a> classification with
* <a href="https://en.wikipedia.org/wiki/Cross-validation_(statistics)">cross validation</a> ({@link CrossValidation}).
@@ -85,7 +86,7 @@ public class CrossValidationExample {
System.out.println(">>> Accuracy: " + Arrays.toString(accuracyScores));
- BinaryClassificationMetrics metrics = new BinaryClassificationMetrics()
+ BinaryClassificationMetrics metrics = (BinaryClassificationMetrics) new BinaryClassificationMetrics()
.withNegativeClsLb(0.0)
.withPositiveClsLb(1.0)
.withMetric(BinaryClassificationMetricValues::balancedAccuracy);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
index c556e11..b5a6f89 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
@@ -17,19 +17,20 @@
package org.apache.ignite.examples.ml.selection.scoring;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run SVM classification trainer ({@link SVMLinearClassificationTrainer}) over distributed dataset.
* <p>
@@ -68,7 +69,7 @@ public class EvaluatorExample {
lbExtractor
);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
@@ -78,7 +79,7 @@ public class EvaluatorExample {
System.out.println("\n>>> Accuracy " + accuracy);
- double f1Score = BinaryClassificationEvaluator.evaluate(
+ double f1Score = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java
index b8c76e0..934fb32 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java
@@ -17,20 +17,21 @@
package org.apache.ignite.examples.ml.selection.scoring;
-import java.io.FileNotFoundException;
-import java.util.Map;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+import java.util.Map;
+
/**
* Run kNN multi-class classification trainer ({@link KNNClassificationTrainer}) over distributed dataset.
* <p>
@@ -67,7 +68,7 @@ public class MultipleMetricsExample {
lbExtractor
);
- Map<String, Double> scores = BinaryClassificationEvaluator.evaluate(
+ Map<String, Double> scores = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
index 291c7f8..3d9c8ab 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
@@ -17,18 +17,19 @@
package org.apache.ignite.examples.ml.svm;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run SVM binary-class classification model ({@link SVMLinearClassificationModel}) over distributed dataset.
* <p>
@@ -69,7 +70,7 @@ public class SVMBinaryClassificationExample {
System.out.println(">>> SVM model " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
index ec64764..cb88ebf 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -35,10 +34,12 @@ import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import java.io.FileNotFoundException;
+
/**
* {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values
* distribution in columns and rows.
@@ -50,7 +51,7 @@ import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_10_Scaling_With_Stacking {
/** Run example. */
@@ -121,7 +122,7 @@ public class Step_10_Scaling_With_Stacking {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
normalizationPreprocessor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
index adb7e44..34d6fe8 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
@@ -17,18 +17,19 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* Usage of {@link DecisionTreeClassificationTrainer} to predict death in the disaster.
* <p>
@@ -38,7 +39,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* After that it trains the model based on the specified data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_1_Read_and_Learn {
/** Run example. */
@@ -72,7 +73,7 @@ public class Step_1_Read_and_Learn {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
featureExtractor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
index 6fe41ab..72ae0cb 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,11 +24,13 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* Usage of {@link ImputerTrainer} to fill missed data ({@code Double.NaN}) values in the chosen columns.
* <p>
@@ -40,7 +41,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_2_Imputing {
/** Run example. */
@@ -75,7 +76,7 @@ public class Step_2_Imputing {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
imputingPreprocessor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
index f9bd014..337421e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,11 +25,13 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* Let's add two categorial features "sex", "embarked" to predict more precisely than in {@link Step_1_Read_and_Learn}.
* <p>
@@ -43,7 +44,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_3_Categorial {
/** Run example. */
@@ -88,7 +89,7 @@ public class Step_3_Categorial {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
imputingPreprocessor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
index 0b3e235..d390fec 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,11 +25,13 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* Let's add two categorial features "sex", "embarked" to predict more precisely than in {@link Step_1_Read_and_Learn}..
* <p>
@@ -44,7 +45,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_3_Categorial_with_One_Hot_Encoder {
/** Run example. */
@@ -91,7 +92,7 @@ public class Step_3_Categorial_with_One_Hot_Encoder {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
imputingPreprocessor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
index 7576cd6..6b7f6be 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,11 +25,13 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* Add yet two numerical features "age", "fare" to improve our model over {@link Step_3_Categorial}.
* <p>
@@ -41,7 +42,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_4_Add_age_fare {
/** Run example. */
@@ -87,7 +88,7 @@ public class Step_4_Add_age_fare {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
imputingPreprocessor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
index 065eb90..ca595ef 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -28,11 +27,13 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values
* distribution in columns and rows.
@@ -44,7 +45,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_5_Scaling {
/** Run example. */
@@ -105,7 +106,7 @@ public class Step_5_Scaling {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
normalizationPreprocessor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
index 2e97ccb..e65a9a6 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -29,10 +28,12 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import java.io.FileNotFoundException;
+
/**
* {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values
* distribution in columns and rows.
@@ -44,7 +45,7 @@ import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_5_Scaling_with_Pipeline {
/** Run example. */
@@ -79,7 +80,7 @@ public class Step_5_Scaling_with_Pipeline {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
mdl.getFeatureExtractor(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
index d7a0b88..a4ba699 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -31,8 +30,10 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
+
+import java.io.FileNotFoundException;
/**
* Change classification algorithm that was used in {@link Step_5_Scaling} from decision tree to kNN
@@ -45,7 +46,7 @@ import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
* <p>
* Then, it trains the model based on the processed data using kNN classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_6_KNN {
/** Run example. */
@@ -106,7 +107,7 @@ public class Step_6_KNN {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
mdl,
normalizationPreprocessor,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
index a988abe..350145f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -28,13 +27,15 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+
/**
* The highest accuracy in the previous example ({@link Step_6_KNN}) is the result of
* <a href="https://en.wikipedia.org/wiki/Overfitting">overfitting</a>.
@@ -47,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* Then, it trains the model based on the processed data using decision tree classification.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_7_Split_train_test {
/** Run example. */
@@ -112,7 +113,7 @@ public class Step_7_Split_train_test {
System.out.println("\n>>> Trained model: " + mdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
split.getTestFilter(),
mdl,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
index 8a962f3..175133fc 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -30,13 +28,16 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+
/**
* To choose the best hyperparameters the cross-validation will be used in this example.
* <p>
@@ -48,7 +49,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on
* the processed data using decision tree classification and the obtained hyperparams.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
* <p>
* The purpose of cross-validation is model checking, not model building.</p>
* <p>
@@ -175,7 +176,7 @@ public class Step_8_CV {
System.out.println("\n>>> Trained model: " + bestMdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
split.getTestFilter(),
bestMdl,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
index 19951a2..325b656 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,13 +30,16 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.cv.CrossValidationResult;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+
/**
* To choose the best hyperparameters the cross-validation with {@link ParamGrid} will be used in this example.
* <p>
@@ -50,7 +51,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on
* the processed data using decision tree classification and the obtained hyperparams.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
* <p>
* The purpose of cross-validation is model checking, not model building.</p>
* <p>
@@ -163,7 +164,7 @@ public class Step_8_CV_with_Param_Grid {
System.out.println("\n>>> Trained model: " + bestMdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
split.getTestFilter(),
bestMdl,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
index 0ea0ca2..a12dcc2 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,15 +30,18 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.cv.CrossValidationResult;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetrics;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+
/**
* To choose the best hyperparameters the cross-validation with {@link ParamGrid} will be used in this example.
* <p>
@@ -52,7 +53,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on
* the processed data using decision tree classification and the obtained hyperparams.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
* <p>
* The purpose of cross-validation is model checking, not model building.</p>
* <p>
@@ -126,7 +127,7 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
.addHyperParam("maxDeep", new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 10.0})
.addHyperParam("minImpurityDecrease", new Double[] {0.0, 0.25, 0.5});
- BinaryClassificationMetrics metrics = new BinaryClassificationMetrics()
+ BinaryClassificationMetrics metrics = (BinaryClassificationMetrics) new BinaryClassificationMetrics()
.withNegativeClsLb(0.0)
.withPositiveClsLb(1.0)
.withMetric(BinaryClassificationMetricValues::accuracy);
@@ -170,7 +171,7 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
System.out.println("\n>>> Trained model: " + bestMdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
split.getTestFilter(),
bestMdl,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
index eb12e58..5c0ad57 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -35,11 +33,14 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
-import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+
/**
* Change classification algorithm that was used in {@link Step_8_CV_with_Param_Grid} from decision tree to logistic
* regression ({@link LogisticRegressionSGDTrainer}) because sometimes this can give the higher accuracy.
@@ -52,7 +53,7 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit;
* Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on
* the processed data using logistic regression and the obtained hyperparams.</p>
* <p>
- * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
*/
public class Step_9_Go_to_LogReg {
/** Run example. */
@@ -207,7 +208,7 @@ public class Step_9_Go_to_LogReg {
System.out.println("\n>>> Trained model: " + bestMdl);
- double accuracy = BinaryClassificationEvaluator.evaluate(
+ double accuracy = Evaluator.evaluate(
dataCache,
split.getTestFilter(),
bestMdl,
diff --git a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
index b1293de..4ec693e 100644
--- a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
+++ b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
@@ -62,7 +62,7 @@ public class SparkModelParser {
/**
* Load model from parquet (presented as a directory).
*
- * @param pathToMdl Path to directory with saved model.
+ * @param pathToMdl Path to directory with saved model.
* @param parsedSparkMdl Parsed spark model.
*/
public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl) throws IllegalArgumentException {
@@ -109,7 +109,8 @@ public class SparkModelParser {
try {
validateMetadata(pathToMetadata, parsedSparkMdl);
- } catch (FileNotFoundException e) {
+ }
+ catch (FileNotFoundException e) {
throw new IllegalArgumentException("Directory should contain json file with model metadata " +
"with name part-00000 [directory_path=" + pathToMetadata + "]");
}
@@ -133,7 +134,8 @@ public class SparkModelParser {
String pathToTreesMetadataFile = treesMetadataParquetFiles[0].getPath();
return parseDataWithMetadata(pathToMdlFile, pathToTreesMetadataFile, parsedSparkMdl);
- } else
+ }
+ else
return parseData(pathToMdlFile, parsedSparkMdl);
}
@@ -145,14 +147,16 @@ public class SparkModelParser {
* @param pathToMetadata Path to metadata.
* @param parsedSparkMdl Parsed spark model.
*/
- private static void validateMetadata(String pathToMetadata, SupportedSparkModels parsedSparkMdl) throws FileNotFoundException {
+ private static void validateMetadata(String pathToMetadata,
+ SupportedSparkModels parsedSparkMdl) throws FileNotFoundException {
File metadataFile = IgniteUtils.resolveIgnitePath(pathToMetadata + File.separator + "part-00000");
if (metadataFile != null) {
Scanner sc = new Scanner(metadataFile);
boolean isInvalid = true;
while (sc.hasNextLine()) {
final String line = sc.nextLine();
- if (line.contains(parsedSparkMdl.getMdlClsNameInSpark())) isInvalid = false;
+ if (line.contains(parsedSparkMdl.getMdlClsNameInSpark()))
+ isInvalid = false;
}
if (isInvalid)
@@ -169,11 +173,10 @@ public class SparkModelParser {
|| parsedSparkMdl == SupportedSparkModels.GRADIENT_BOOSTED_TREES_REGRESSION;
}
-
/**
* Load model from parquet file.
*
- * @param pathToMdl Hadoop path to model saved from Spark.
+ * @param pathToMdl Hadoop path to model saved from Spark.
* @param parsedSparkMdl One of supported Spark models to parse it.
* @return Instance of parsedSparkMdl model.
*/
@@ -209,13 +212,13 @@ public class SparkModelParser {
/**
* Load model and its metadata from parquet files.
*
- * @param pathToMdl Hadoop path to model saved from Spark.
+ * @param pathToMdl Hadoop path to model saved from Spark.
* @param pathToMetaData Hadoop path to metadata saved from Spark.
* @param parsedSparkMdl One of supported Spark models to parse it.
* @return Instance of parsedSparkMdl model.
*/
private static Model parseDataWithMetadata(String pathToMdl, String pathToMetaData,
- SupportedSparkModels parsedSparkMdl) {
+ SupportedSparkModels parsedSparkMdl) {
File mdlRsrc1 = IgniteUtils.resolveIgnitePath(pathToMdl);
if (mdlRsrc1 == null)
throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMdl + "]");
@@ -273,12 +276,12 @@ public class SparkModelParser {
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pages = r.readNextRowGroup())) {
- final int rows = (int) pages.getRowCount();
+ final int rows = (int)pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
centers = new DenseVector[rows];
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
// final int clusterIdx = g.getInteger(0, 0);
Group clusterCenterCoeff = g.getGroup(1, 0).getGroup(3, 0);
@@ -294,7 +297,8 @@ public class SparkModelParser {
}
}
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
@@ -305,7 +309,7 @@ public class SparkModelParser {
/**
* Load GDB Regression model.
*
- * @param pathToMdl Path to model.
+ * @param pathToMdl Path to model.
* @param pathToMdlMetaData Path to model meta data.
*/
private static Model loadGBTRegressionModel(String pathToMdl, String pathToMdlMetaData) {
@@ -317,7 +321,7 @@ public class SparkModelParser {
/**
* Load GDB Classification model.
*
- * @param pathToMdl Path to model.
+ * @param pathToMdl Path to model.
* @param pathToMdlMetaData Path to model meta data.
*/
private static Model loadGBTClassifierModel(String pathToMdl, String pathToMdlMetaData) {
@@ -329,12 +333,12 @@ public class SparkModelParser {
/**
* Parse and build common GDB model with the custom label mapper.
*
- * @param pathToMdl Path to model.
+ * @param pathToMdl Path to model.
* @param pathToMdlMetaData Path to model meta data.
- * @param lbMapper Label mapper.
+ * @param lbMapper Label mapper.
*/
@Nullable private static Model parseAndBuildGDBModel(String pathToMdl, String pathToMdlMetaData,
- IgniteFunction<Double, Double> lbMapper) {
+ IgniteFunction<Double, Double> lbMapper) {
double[] treeWeights = null;
final Map<Integer, Double> treeWeightsByTreeID = new HashMap<>();
@@ -347,13 +351,14 @@ public class SparkModelParser {
final long rows = pagesMetaData.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pagesMetaData, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
int treeId = g.getInteger(0, 0);
double treeWeight = g.getDouble(2, 0);
treeWeightsByTreeID.put(treeId, treeWeight);
}
}
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file with MetaData by the path: " + pathToMdlMetaData);
e.printStackTrace();
}
@@ -371,15 +376,16 @@ public class SparkModelParser {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
final int treeID = g.getInteger(0, 0);
- final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
+ final SimpleGroup nodeDataGroup = (SimpleGroup)g.getGroup(1, 0);
NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
if (nodesByTreeId.containsKey(treeID)) {
Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
nodesByNodeId.put(nodeData.id, nodeData);
- } else {
+ }
+ else {
TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
nodesByNodeId.put(nodeData.id, nodeData);
nodesByTreeId.put(treeID, nodesByNodeId);
@@ -391,7 +397,8 @@ public class SparkModelParser {
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
return new GDBTrainer.GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
@@ -428,16 +435,17 @@ public class SparkModelParser {
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
final int treeID = g.getInteger(0, 0);
- final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
+ final SimpleGroup nodeDataGroup = (SimpleGroup)g.getGroup(1, 0);
NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
if (nodesByTreeId.containsKey(treeID)) {
Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
nodesByNodeId.put(nodeData.id, nodeData);
- } else {
+ }
+ else {
TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
nodesByNodeId.put(nodeData.id, nodeData);
nodesByTreeId.put(treeID, nodesByNodeId);
@@ -447,7 +455,8 @@ public class SparkModelParser {
List<IgniteModel<Vector, Double>> models = new ArrayList<>();
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
return models;
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
@@ -472,13 +481,14 @@ public class SparkModelParser {
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
NodeData nodeData = extractNodeDataFromParquetRow(g);
nodes.put(nodeData.id, nodeData);
}
}
return buildDecisionTreeModel(nodes);
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
@@ -493,7 +503,7 @@ public class SparkModelParser {
private static DecisionTreeNode buildDecisionTreeModel(Map<Integer, NodeData> nodes) {
DecisionTreeNode mdl = null;
if (!nodes.isEmpty()) {
- NodeData rootNodeData = (NodeData) ((NavigableMap) nodes).firstEntry().getValue();
+ NodeData rootNodeData = (NodeData)((NavigableMap)nodes).firstEntry().getValue();
mdl = buildTree(nodes, rootNodeData);
return mdl;
}
@@ -503,11 +513,11 @@ public class SparkModelParser {
/**
* Build tree or sub-tree based on indices and nodes sorted map as a dictionary.
*
- * @param nodes The sorted map of nodes.
+ * @param nodes The sorted map of nodes.
* @param rootNodeData Root node data.
*/
@NotNull private static DecisionTreeNode buildTree(Map<Integer, NodeData> nodes,
- NodeData rootNodeData) {
+ NodeData rootNodeData) {
return rootNodeData.isLeafNode ? new DecisionTreeLeafNode(rootNodeData.prediction) : new DecisionTreeConditionalNode(rootNodeData.featureIdx,
rootNodeData.threshold,
buildTree(nodes, nodes.get(rootNodeData.rightChildId)),
@@ -532,8 +542,9 @@ public class SparkModelParser {
nodeData.featureIdx = -1;
nodeData.threshold = -1;
nodeData.isLeafNode = true;
- } else {
- final SimpleGroup splitGrp = (SimpleGroup) g.getGroup(7, 0);
+ }
+ else {
+ final SimpleGroup splitGrp = (SimpleGroup)g.getGroup(7, 0);
nodeData.featureIdx = splitGrp.getInteger(0, 0);
nodeData.threshold = splitGrp.getGroup(1, 0).getGroup(0, 0).getDouble(0, 0);
}
@@ -582,12 +593,13 @@ public class SparkModelParser {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
interceptor = readSVMInterceptor(g);
coefficients = readSVMCoefficients(g);
}
}
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
@@ -614,13 +626,14 @@ public class SparkModelParser {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
interceptor = readLinRegInterceptor(g);
coefficients = readLinRegCoefficients(g);
}
}
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
@@ -647,13 +660,14 @@ public class SparkModelParser {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
- final SimpleGroup g = (SimpleGroup) recordReader.read();
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
interceptor = readInterceptor(g);
coefficients = readCoefficients(g);
}
}
- } catch (IOException e) {
+ }
+ catch (IOException e) {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
@@ -728,9 +742,9 @@ public class SparkModelParser {
private static double readInterceptor(SimpleGroup g) {
double interceptor;
- final SimpleGroup interceptVector = (SimpleGroup) g.getGroup(2, 0);
- final SimpleGroup interceptVectorVal = (SimpleGroup) interceptVector.getGroup(3, 0);
- final SimpleGroup interceptVectorValElement = (SimpleGroup) interceptVectorVal.getGroup(0, 0);
+ final SimpleGroup interceptVector = (SimpleGroup)g.getGroup(2, 0);
+ final SimpleGroup interceptVectorVal = (SimpleGroup)interceptVector.getGroup(3, 0);
+ final SimpleGroup interceptVectorValElement = (SimpleGroup)interceptVectorVal.getGroup(0, 0);
interceptor = interceptVectorValElement.getDouble(0, 0);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java
similarity index 97%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java
index 7b9698f..c8f8a27 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java
@@ -17,7 +17,6 @@
package org.apache.ignite.ml.selection.scoring.evaluator;
-import java.util.Map;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.IgniteModel;
@@ -26,14 +25,16 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetrics;
+
+import java.util.Map;
/**
- * Binary classification evaluator that computes metrics from predictions and ground truth values.
+ * Evaluator that computes metrics from predictions and ground truth values.
*/
-public class BinaryClassificationEvaluator {
+public class Evaluator {
/**
* Computes the given metric on the given cache.
*
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/AbstractMetrics.java
similarity index 52%
copy from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/AbstractMetrics.java
index fd0656c..0adc406 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/AbstractMetrics.java
@@ -17,37 +17,38 @@
package org.apache.ignite.ml.selection.scoring.metric;
-import java.util.Iterator;
import org.apache.ignite.ml.selection.scoring.LabelPair;
+import java.util.Iterator;
+import java.util.function.Function;
+
/**
- * Accuracy score calculator.
- *
- * @param <L> Type of a label (truth or prediction).
+ * Abstract metrics calculator.
+ * It could be used in two ways: to caculate all regression metrics or custom regression metric.
*/
-public class Accuracy<L> implements Metric<L> {
- /** {@inheritDoc} */
- @Override public double score(Iterator<LabelPair<L>> iter) {
- long totalCnt = 0;
- long correctCnt = 0;
-
- while (iter.hasNext()) {
- LabelPair<L> e = iter.next();
-
- L prediction = e.getPrediction();
- L truth = e.getTruth();
-
- if (prediction.equals(truth))
- correctCnt++;
-
- totalCnt++;
- }
-
- return 1.0 * correctCnt / totalCnt;
+public abstract class AbstractMetrics<M extends MetricValues> implements Metric<Double> {
+ /** The main metric to get individual score. */
+ protected Function<M, Double> metric;
+
+ /**
+ * Calculates metrics values.
+ *
+ * @param iter Iterator that supplies pairs of truth values and predicated.
+ * @return Scores for all regression metrics.
+ */
+ public abstract M scoreAll(Iterator<LabelPair<Double>> iter);
+
+ /**
+ *
+ */
+ public AbstractMetrics withMetric(Function<M, Double> metric) {
+ if (metric != null)
+ this.metric = metric;
+ return this;
}
/** {@inheritDoc} */
- @Override public String name() {
- return "accuracy";
+ @Override public double score(Iterator<LabelPair<Double>> iter) {
+ return metric.apply(scoreAll(iter));
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/MetricValues.java
similarity index 58%
copy from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/MetricValues.java
index f89e683..7f63e89 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/MetricValues.java
@@ -17,21 +17,23 @@
package org.apache.ignite.ml.selection.scoring.metric;
+import java.lang.reflect.Field;
+import java.util.HashMap;
+import java.util.Map;
+
/**
- * Metric calculator for one class label.
- *
- * @param <L> Type of a label (truth or prediction).
+ * Common interface to present metric values for different ML tasks.
*/
-public abstract class ClassMetric<L> implements Metric<L> {
- /** Class label. */
- protected L clsLb;
-
- /**
- * The class of interest or positive class.
- *
- * @param clsLb The label.
- */
- public ClassMetric(L clsLb) {
- this.clsLb = clsLb;
+public interface MetricValues {
+ /** Returns the pair of metric name and metric value. */
+ public default Map<String, Double> toMap() {
+ Map<String, Double> metricValues = new HashMap<>();
+ for (Field field : getClass().getDeclaredFields())
+ try {
+ metricValues.put(field.getName(), field.getDouble(this));
+ } catch (IllegalAccessException e) {
+ e.printStackTrace();
+ }
+ return metricValues;
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Accuracy.java
similarity index 92%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Accuracy.java
index fd0656c..0ab5cd4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Accuracy.java
@@ -15,10 +15,12 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Iterator;
import org.apache.ignite.ml.selection.scoring.LabelPair;
+import org.apache.ignite.ml.selection.scoring.metric.Metric;
+
+import java.util.Iterator;
/**
* Accuracy score calculator.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricValues.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java
similarity index 87%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricValues.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java
index 04cd981..e04ddc0 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricValues.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java
@@ -15,16 +15,14 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.lang.reflect.Field;
-import java.util.HashMap;
-import java.util.Map;
+import org.apache.ignite.ml.selection.scoring.metric.MetricValues;
/**
* Provides access to binary metric values.
*/
-public class BinaryClassificationMetricValues {
+public class BinaryClassificationMetricValues implements MetricValues {
/** True Positive (TP). */
private double tp;
@@ -170,16 +168,4 @@ public class BinaryClassificationMetricValues {
public double f1Score() {
return f1Score;
}
-
- /** Returns the pair of metric name and metric value. */
- public Map<String, Double> toMap() {
- Map<String, Double> metricValues = new HashMap<>();
- for (Field field : getClass().getDeclaredFields())
- try {
- metricValues.put(field.getName(), field.getDouble(this));
- } catch (IllegalAccessException e) {
- e.printStackTrace();
- }
- return metricValues;
- }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java
similarity index 80%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java
index 35da9fa..bfa1cf3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java
@@ -15,33 +15,35 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Iterator;
-import java.util.function.Function;
import org.apache.ignite.ml.selection.scoring.LabelPair;
+import org.apache.ignite.ml.selection.scoring.metric.AbstractMetrics;
+import org.apache.ignite.ml.selection.scoring.metric.exceptions.UnknownClassLabelException;
+
+import java.util.Iterator;
/**
* Binary classification metrics calculator.
* It could be used in two ways: to caculate all binary classification metrics or specific metric.
*/
-public class BinaryClassificationMetrics implements Metric<Double> {
+public class BinaryClassificationMetrics extends AbstractMetrics<BinaryClassificationMetricValues> {
/** Positive class label. */
private double positiveClsLb = 1.0;
/** Negative class label. Default value is 0.0. */
private double negativeClsLb;
- /** The main metric to get individual score. */
- private Function<BinaryClassificationMetricValues, Double> metric = BinaryClassificationMetricValues::accuracy;
-
+ {
+ metric = BinaryClassificationMetricValues::accuracy;
+ }
/**
* Calculates binary metrics values.
*
* @param iter Iterator that supplies pairs of truth values and predicated.
* @return Scores for all binary metrics.
*/
- public BinaryClassificationMetricValues scoreAll(Iterator<LabelPair<Double>> iter) {
+ @Override public BinaryClassificationMetricValues scoreAll(Iterator<LabelPair<Double>> iter) {
long tp = 0;
long tn = 0;
long fp = 0;
@@ -91,18 +93,6 @@ public class BinaryClassificationMetrics implements Metric<Double> {
return this;
}
- /** */
- public BinaryClassificationMetrics withMetric(Function<BinaryClassificationMetricValues, Double> metric) {
- if (metric != null)
- this.metric = metric;
- return this;
- }
-
- /** {@inheritDoc} */
- @Override public double score(Iterator<LabelPair<Double>> iter) {
- return metric.apply(scoreAll(iter));
- }
-
/** {@inheritDoc} */
@Override public String name() {
return "Binary classification metrics";
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ClassMetric.java
similarity index 90%
copy from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ClassMetric.java
index f89e683..2a0a7eb 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ClassMetric.java
@@ -15,7 +15,9 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
+
+import org.apache.ignite.ml.selection.scoring.metric.Metric;
/**
* Metric calculator for one class label.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Fmeasure.java
similarity index 97%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Fmeasure.java
index fe36f51..10f5c5c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Fmeasure.java
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Iterator;
import org.apache.ignite.ml.selection.scoring.LabelPair;
+import java.util.Iterator;
+
/**
* F-measure calculator.
*
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Precision.java
similarity index 96%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Precision.java
index 482a027..660b288 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Precision.java
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Iterator;
import org.apache.ignite.ml.selection.scoring.LabelPair;
+import java.util.Iterator;
+
/**
* Precision calculator.
*
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Recall.java
similarity index 96%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Recall.java
index d459e94..0b5c8ae 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/Recall.java
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Iterator;
import org.apache.ignite.ml.selection.scoring.LabelPair;
+import java.util.Iterator;
+
/**
* Recall calculator.
*
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/package-info.java
similarity index 65%
copy from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/package-info.java
index f89e683..86f504f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/package-info.java
@@ -15,23 +15,8 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
-
/**
- * Metric calculator for one class label.
- *
- * @param <L> Type of a label (truth or prediction).
+ * <!-- Package description. -->
+ * Root package for classification metrics.
*/
-public abstract class ClassMetric<L> implements Metric<L> {
- /** Class label. */
- protected L clsLb;
-
- /**
- * The class of interest or positive class.
- *
- * @param clsLb The label.
- */
- public ClassMetric(L clsLb) {
- this.clsLb = clsLb;
- }
-}
+package org.apache.ignite.ml.selection.scoring.metric.classification;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/UnknownClassLabelException.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/exceptions/UnknownClassLabelException.java
similarity index 95%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/UnknownClassLabelException.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/exceptions/UnknownClassLabelException.java
index 0531f2e..6c10f32 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/UnknownClassLabelException.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/exceptions/UnknownClassLabelException.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.exceptions;
import org.apache.ignite.IgniteException;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/exceptions/package-info.java
similarity index 65%
copy from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/exceptions/package-info.java
index f89e683..c3e515f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/exceptions/package-info.java
@@ -15,23 +15,8 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
-
/**
- * Metric calculator for one class label.
- *
- * @param <L> Type of a label (truth or prediction).
+ * <!-- Package description. -->
+ * Root package for exceptions.
*/
-public abstract class ClassMetric<L> implements Metric<L> {
- /** Class label. */
- protected L clsLb;
-
- /**
- * The class of interest or positive class.
- *
- * @param clsLb The label.
- */
- public ClassMetric(L clsLb) {
- this.clsLb = clsLb;
- }
-}
+package org.apache.ignite.ml.selection.scoring.metric.exceptions;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetricValues.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetricValues.java
new file mode 100644
index 0000000..d6483ef
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetricValues.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.metric.regression;
+
+import org.apache.ignite.ml.selection.scoring.metric.MetricValues;
+
+/**
+ * Provides access to regression metric values.
+ */
+public class RegressionMetricValues implements MetricValues {
+ /** Mean absolute error. */
+ private double mae;
+
+ /** Mean squared error. */
+ private double mse;
+
+ /** Residual sum of squares. */
+ private double rss;
+
+ /** Root mean squared error. */
+ private double rmse;
+
+ /**
+ * Initalize an instance.
+ *
+ * @param totalAmount Total amount of observations.
+ * @param rss Residual sum of squares.
+ * @param mae Mean absolute error.
+ */
+ public RegressionMetricValues(int totalAmount, double rss, double mae) {
+ this.rss = rss;
+ this.mse = rss / totalAmount;
+ this.rmse = Math.sqrt(this.mse);
+ this.mae = mae;
+ }
+
+ /** Returns mean absolute error. */
+ public double mae() {
+ return mae;
+ }
+
+ /** Returns mean squared error. */
+ public double mse() {
+ return mse;
+ }
+
+ /** Returns residual sum of squares. */
+ public double rss() {
+ return rss;
+ }
+
+ /** Returns root mean squared error. */
+ public double rmse() {
+ return rmse;
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetrics.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetrics.java
new file mode 100644
index 0000000..9cec6e1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetrics.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.metric.regression;
+
+import org.apache.ignite.ml.selection.scoring.LabelPair;
+import org.apache.ignite.ml.selection.scoring.metric.AbstractMetrics;
+
+import java.util.Iterator;
+
+/**
+ * Regression metrics calculator.
+ * It could be used in two ways: to caculate all regression metrics or custom regression metric.
+ */
+public class RegressionMetrics extends AbstractMetrics<RegressionMetricValues> {
+ {
+ metric = RegressionMetricValues::rmse;
+ }
+
+ /**
+ * Calculates regression metrics values.
+ *
+ * @param iter Iterator that supplies pairs of truth values and predicated.
+ * @return Scores for all regression metrics.
+ */
+ @Override public RegressionMetricValues scoreAll(Iterator<LabelPair<Double>> iter) {
+ int totalAmount = 0;
+ double rss = 0.0;
+ double mae = 0.0;
+
+ while (iter.hasNext()) {
+ LabelPair<Double> e = iter.next();
+
+ double prediction = e.getPrediction();
+ double truth = e.getTruth();
+
+ rss += Math.pow(prediction - truth, 2.0);
+ mae += Math.abs(prediction - truth);
+
+ totalAmount++;
+ }
+
+ mae /= totalAmount;
+
+ return new RegressionMetricValues(totalAmount, rss, mae);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String name() {
+ return "Regression metrics";
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/package-info.java
similarity index 65%
rename from modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/package-info.java
index f89e683..0531f92 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/regression/package-info.java
@@ -15,23 +15,8 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
-
/**
- * Metric calculator for one class label.
- *
- * @param <L> Type of a label (truth or prediction).
+ * <!-- Package description. -->
+ * Root package for regression metrics.
*/
-public abstract class ClassMetric<L> implements Metric<L> {
- /** Class label. */
- protected L clsLb;
-
- /**
- * The class of interest or positive class.
- *
- * @param clsLb The label.
- */
- public ClassMetric(L clsLb) {
- this.clsLb = clsLb;
- }
-}
+package org.apache.ignite.ml.selection.scoring.metric.regression;
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
index 135ee64..6c5f101 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
@@ -23,12 +23,8 @@ import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursorTe
import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursorTest;
import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluatorTest;
import org.apache.ignite.ml.selection.scoring.evaluator.EvaluatorTest;
-import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricsTest;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricsValuesTest;
-import org.apache.ignite.ml.selection.scoring.metric.FmeasureTest;
-import org.apache.ignite.ml.selection.scoring.metric.PrecisionTest;
-import org.apache.ignite.ml.selection.scoring.metric.RecallTest;
+import org.apache.ignite.ml.selection.scoring.metric.classification.*;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetricsTest;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitterTest;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapperTest;
import org.junit.runner.RunWith;
@@ -53,6 +49,7 @@ import org.junit.runners.Suite;
BinaryClassificationMetricsTest.class,
BinaryClassificationMetricsValuesTest.class,
BinaryClassificationEvaluatorTest.class,
+ RegressionMetricsTest.class
})
public class SelectionTestSuite {
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
index 26401345..f63b2eb 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
@@ -17,18 +17,19 @@
package org.apache.ignite.ml.selection.cv;
-import java.util.HashMap;
-import java.util.Map;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
-import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetrics;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.junit.Test;
-import static org.junit.Assert.assertTrue;
+import java.util.HashMap;
+import java.util.Map;
+
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
/**
* Tests for {@link CrossValidation}.
@@ -86,7 +87,7 @@ public class CrossValidationTest {
int folds = 4;
- BinaryClassificationMetrics metrics = new BinaryClassificationMetrics()
+ BinaryClassificationMetrics metrics = (BinaryClassificationMetrics) new BinaryClassificationMetrics()
.withMetric(BinaryClassificationMetricValues::accuracy);
verifyScores(folds, scoreCalculator.score(
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java
index c6222c8..67b9cd2 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java
@@ -17,23 +17,24 @@
package org.apache.ignite.ml.selection.scoring.evaluator;
-import java.util.HashMap;
-import java.util.Map;
import org.apache.ignite.ml.common.TrainerTest;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.junit.Test;
+import java.util.HashMap;
+import java.util.Map;
+
import static org.junit.Assert.assertEquals;
/**
- * Tests for {@link BinaryClassificationEvaluator}.
+ * Tests for {@link Evaluator}.
*/
public class BinaryClassificationEvaluatorTest extends TrainerTest {
/**
@@ -58,7 +59,7 @@ public class BinaryClassificationEvaluatorTest extends TrainerTest {
lbExtractor
).withK(3);
- double score = BinaryClassificationEvaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>());
+ double score = Evaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>());
assertEquals(0.9839357429718876, score, 1e-12);
}
@@ -89,7 +90,7 @@ public class BinaryClassificationEvaluatorTest extends TrainerTest {
lbExtractor
).withK(3);
- double score = BinaryClassificationEvaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>());
+ double score = Evaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>());
assertEquals(0.9, score, 1);
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java
index 97b1dcf..7817c16 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java
@@ -17,14 +17,6 @@
package org.apache.ignite.ml.selection.scoring.evaluator;
-import java.text.NumberFormat;
-import java.text.ParseException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Locale;
-import java.util.UUID;
-import java.util.concurrent.atomic.AtomicReference;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
@@ -40,7 +32,7 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.cv.CrossValidationResult;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
-import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
@@ -49,11 +41,16 @@ import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
import org.apache.ignite.thread.IgniteThread;
import org.junit.Test;
+import java.text.NumberFormat;
+import java.text.ParseException;
+import java.util.*;
+import java.util.concurrent.atomic.AtomicReference;
+
import static org.apache.ignite.ml.TestUtils.testEnvBuilder;
import static org.junit.Assert.assertArrayEquals;
/**
- * Tests for {@link BinaryClassificationEvaluator} that require to start the whole Ignite infrastructure. IMPL NOTE based on
+ * Tests for {@link Evaluator} that require to start the whole Ignite infrastructure. IMPL NOTE based on
* Step_8_CV_with_Param_Grid example.
*/
public class EvaluatorTest extends GridCommonAbstractTest {
@@ -259,7 +256,7 @@ public class EvaluatorTest extends GridCommonAbstractTest {
lbExtractor
);
- actualAccuracy.set(BinaryClassificationEvaluator.evaluate(
+ actualAccuracy.set(Evaluator.evaluate(
cache,
split.getTestFilter(),
bestMdl,
@@ -268,7 +265,7 @@ public class EvaluatorTest extends GridCommonAbstractTest {
new Accuracy<>()
));
- actualAccuracy2.set(BinaryClassificationEvaluator.evaluate(
+ actualAccuracy2.set(Evaluator.evaluate(
cache,
bestMdl,
preprocessor,
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/RegressionEvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/RegressionEvaluatorTest.java
new file mode 100644
index 0000000..4b55f5b
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/RegressionEvaluatorTest.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.evaluator;
+
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
+import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
+import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
+import org.apache.ignite.ml.selection.split.TrainTestSplit;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link Evaluator}.
+ */
+public class RegressionEvaluatorTest extends TrainerTest {
+ /**
+ * Test evalutor and trainer.
+ */
+ @Test
+ public void testEvaluatorWithoutFilter() {
+ Map<Integer, Vector> data = new HashMap<>();
+ data.put(0, VectorUtils.of(60323, 83.0, 234289, 2356, 1590, 107608, 1947));
+ data.put(1, VectorUtils.of(61122, 88.5, 259426, 2325, 1456, 108632, 1948));
+ data.put(2, VectorUtils.of(60171, 88.2, 258054, 3682, 1616, 109773, 1949));
+ data.put(3, VectorUtils.of(61187, 89.5, 284599, 3351, 1650, 110929, 1950));
+ data.put(4, VectorUtils.of(63221, 96.2, 328975, 2099, 3099, 112075, 1951));
+ data.put(5, VectorUtils.of(63639, 98.1, 346999, 1932, 3594, 113270, 1952));
+ data.put(6, VectorUtils.of(64989, 99.0, 365385, 1870, 3547, 115094, 1953));
+ data.put(7, VectorUtils.of(63761, 100.0, 363112, 3578, 3350, 116219, 1954));
+ data.put(8, VectorUtils.of(66019, 101.2, 397469, 2904, 3048, 117388, 1955));
+ data.put(9, VectorUtils.of(68169, 108.4, 442769, 2936, 2798, 120445, 1957));
+ data.put(10, VectorUtils.of(66513, 110.8, 444546, 4681, 2637, 121950, 1958));
+ data.put(11, VectorUtils.of(68655, 112.6, 482704, 3813, 2552, 123366, 1959));
+ data.put(12, VectorUtils.of(69564, 114.2, 502601, 3931, 2514, 125368, 1960));
+ data.put(13, VectorUtils.of(69331, 115.7, 518173, 4806, 2572, 127852, 1961));
+ data.put(14, VectorUtils.of(70551, 116.9, 554894, 4007, 2827, 130081, 1962));
+
+ KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+ KNNRegressionModel mdl = (KNNRegressionModel) trainer.fit(
+ new LocalDatasetBuilder<>(data, parts),
+ featureExtractor,
+ lbExtractor
+ ).withK(3)
+ .withDistanceMeasure(new EuclideanDistance());
+
+ double score = Evaluator.evaluate(data, mdl, featureExtractor, lbExtractor,
+ new RegressionMetrics()
+ .withMetric(RegressionMetricValues::rss)
+ );
+
+ assertEquals(1068809.6666666653, score, 1e-4);
+ }
+
+ /**
+ * Test evalutor and trainer with test-train splitting.
+ */
+ @Test
+ public void testEvaluatorWithFilter() {
+ Map<Integer, Vector> data = new HashMap<>();
+ data.put(0, VectorUtils.of(60323, 83.0, 234289, 2356, 1590, 107608, 1947));
+ data.put(1, VectorUtils.of(61122, 88.5, 259426, 2325, 1456, 108632, 1948));
+ data.put(2, VectorUtils.of(60171, 88.2, 258054, 3682, 1616, 109773, 1949));
+ data.put(3, VectorUtils.of(61187, 89.5, 284599, 3351, 1650, 110929, 1950));
+ data.put(4, VectorUtils.of(63221, 96.2, 328975, 2099, 3099, 112075, 1951));
+ data.put(5, VectorUtils.of(63639, 98.1, 346999, 1932, 3594, 113270, 1952));
+ data.put(6, VectorUtils.of(64989, 99.0, 365385, 1870, 3547, 115094, 1953));
+ data.put(7, VectorUtils.of(63761, 100.0, 363112, 3578, 3350, 116219, 1954));
+ data.put(8, VectorUtils.of(66019, 101.2, 397469, 2904, 3048, 117388, 1955));
+ data.put(9, VectorUtils.of(68169, 108.4, 442769, 2936, 2798, 120445, 1957));
+ data.put(10, VectorUtils.of(66513, 110.8, 444546, 4681, 2637, 121950, 1958));
+ data.put(11, VectorUtils.of(68655, 112.6, 482704, 3813, 2552, 123366, 1959));
+ data.put(12, VectorUtils.of(69564, 114.2, 502601, 3931, 2514, 125368, 1960));
+ data.put(13, VectorUtils.of(69331, 115.7, 518173, 4806, 2572, 127852, 1961));
+ data.put(14, VectorUtils.of(70551, 116.9, 554894, 4007, 2827, 130081, 1962));
+
+ KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+ TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>()
+ .split(0.5);
+
+ KNNRegressionModel mdl = (KNNRegressionModel) trainer.fit(
+ data,
+ split.getTestFilter(),
+ parts,
+ featureExtractor,
+ lbExtractor
+ ).withK(3)
+ .withDistanceMeasure(new EuclideanDistance());
+
+ double score = Evaluator.evaluate(data, split.getTrainFilter(), mdl, featureExtractor, lbExtractor,
+ new RegressionMetrics()
+ .withMetric(RegressionMetricValues::rss)
+ );
+
+ assertEquals(4800164.444444457, score, 1e-4);
+ }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/AccuracyTest.java
similarity index 91%
rename from modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java
rename to modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/AccuracyTest.java
index de7c68a..ed0aaa8 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/AccuracyTest.java
@@ -15,13 +15,15 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Arrays;
import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.junit.Test;
+import java.util.Arrays;
+
import static org.junit.Assert.assertEquals;
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsTest.java
similarity index 91%
rename from modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java
rename to modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsTest.java
index a173f5e..69a0039 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsTest.java
@@ -15,17 +15,21 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Arrays;
import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.metric.Metric;
+import org.apache.ignite.ml.selection.scoring.metric.exceptions.UnknownClassLabelException;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
import org.junit.Test;
+import java.util.Arrays;
+
import static org.junit.Assert.assertEquals;
/**
- * Tests for {@link BinaryClassificationMetrics}.
+ * Tests for {@link RegressionMetrics}.
*/
public class BinaryClassificationMetricsTest {
/** */
@@ -131,7 +135,7 @@ public class BinaryClassificationMetricsTest {
}
/** */
- @Test(expected = org.apache.ignite.ml.selection.scoring.metric.UnknownClassLabelException.class)
+ @Test(expected = UnknownClassLabelException.class)
public void testFailWithIncorrectClassLabelsInData() {
Metric scoreCalculator = new BinaryClassificationMetrics();
@@ -144,7 +148,7 @@ public class BinaryClassificationMetricsTest {
}
/** */
- @Test(expected = org.apache.ignite.ml.selection.scoring.metric.UnknownClassLabelException.class)
+ @Test(expected = UnknownClassLabelException.class)
public void testFailWithIncorrectClassLabelsInMetrics() {
Metric scoreCalculator = new BinaryClassificationMetrics()
.withPositiveClsLb(42);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java
similarity index 96%
rename from modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java
rename to modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java
index 75a8183..f513ce3 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
import org.junit.Test;
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/FmeasureTest.java
similarity index 95%
rename from modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java
rename to modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/FmeasureTest.java
index 835d08d..00289cd 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/FmeasureTest.java
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Arrays;
import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.junit.Test;
+import java.util.Arrays;
+
import static org.junit.Assert.assertEquals;
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/PrecisionTest.java
similarity index 95%
rename from modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java
rename to modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/PrecisionTest.java
index d7821d5..78f3605 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/PrecisionTest.java
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Arrays;
import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.junit.Test;
+import java.util.Arrays;
+
import static org.junit.Assert.assertEquals;
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/RecallTest.java
similarity index 95%
rename from modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java
rename to modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/RecallTest.java
index 8c92acd..3948575 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/RecallTest.java
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.selection.scoring.metric;
+package org.apache.ignite.ml.selection.scoring.metric.classification;
-import java.util.Arrays;
import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.junit.Test;
+import java.util.Arrays;
+
import static org.junit.Assert.assertEquals;
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetricsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetricsTest.java
new file mode 100644
index 0000000..2129d21
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/regression/RegressionMetricsTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.metric.regression;
+
+import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.metric.Metric;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link RegressionMetrics}.
+ */
+public class RegressionMetricsTest {
+ /**
+ *
+ */
+ @Test
+ public void testDefaultBehaviour() {
+ Metric scoreCalculator = new RegressionMetrics();
+
+ LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+ Arrays.asList(1.0, 1.0, 1.0, 1.0),
+ Arrays.asList(1.0, 1.0, 0.0, 1.0)
+ );
+
+ double score = scoreCalculator.score(cursor.iterator());
+
+ assertEquals(0.5, score, 1e-12);
+ }
+
+ /**
+ *
+ */
+ @Test
+ public void testDefaultBehaviourForScoreAll() {
+ RegressionMetrics scoreCalculator = new RegressionMetrics();
+
+ LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+ Arrays.asList(1.0, 1.0, 1.0, 1.0),
+ Arrays.asList(1.0, 1.0, 0.0, 1.0)
+ );
+
+ RegressionMetricValues metricValues = scoreCalculator.scoreAll(cursor.iterator());
+
+ assertEquals(1.0, metricValues.rss(), 1e-12);
+ }
+
+ /**
+ *
+ */
+ @Test
+ public void testCustomMetric() {
+ RegressionMetrics scoreCalculator = (RegressionMetrics) new RegressionMetrics()
+ .withMetric(RegressionMetricValues::mae);
+
+ LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+ Arrays.asList(2.0, 2.0, 2.0, 2.0),
+ Arrays.asList(2.0, 2.0, 1.0, 2.0)
+ );
+
+ double score = scoreCalculator.score(cursor.iterator());
+
+ assertEquals(0.25, score, 1e-12);
+ }
+
+ /**
+ *
+ */
+ @Test
+ public void testNullCustomMetric() {
+ RegressionMetrics scoreCalculator = (RegressionMetrics) new RegressionMetrics()
+ .withMetric(null);
+
+ LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+ Arrays.asList(2.0, 2.0, 2.0, 2.0),
+ Arrays.asList(2.0, 2.0, 1.0, 2.0)
+ );
+
+ double score = scoreCalculator.score(cursor.iterator());
+
+ // rmse as default metric
+ assertEquals(0.5, score, 1e-12);
+ }
+}