You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/01/11 11:45:33 UTC
[ignite] branch master updated: IGNITE-10713: [ML] Refactor
examples with accuracy calculation and another metrics usage
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 2026cf7 IGNITE-10713: [ML] Refactor examples with accuracy calculation and another metrics usage
2026cf7 is described below
commit 2026cf75dc58f674e7bca890334eb77041e0afb9
Author: zaleslaw <za...@gmail.com>
AuthorDate: Fri Jan 11 13:56:16 2019 +0300
IGNITE-10713: [ML] Refactor examples with accuracy calculation and
another metrics usage
This closes #5787
---
.../examples/ml/knn/KNNClassificationExample.java | 52 ++++++-------------
.../examples/ml/knn/KNNRegressionExample.java | 4 +-
.../DiscreteNaiveBayesTrainerExample.java | 53 +++++--------------
.../GaussianNaiveBayesTrainerExample.java | 55 ++++++--------------
.../LogisticRegressionSGDTrainerExample.java | 54 +++++---------------
.../ml/selection/scoring/EvaluatorExample.java | 10 +++-
.../ml/svm/SVMBinaryClassificationExample.java | 59 ++++++----------------
7 files changed, 85 insertions(+), 202 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index 4a475a0..ec25006 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -18,18 +18,16 @@
package org.apache.ignite.examples.ml.knn;
import java.io.FileNotFoundException;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
@@ -57,48 +55,30 @@ public class KNNClassificationExample {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.IRIS);
+ .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- NNClassificationModel knnMdl = trainer.fit(
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+ NNClassificationModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ featureExtractor,
+ lbExtractor
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(NNStrategy.WEIGHTED);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- int amountOfErrors = 0;
- int totalAmount = 0;
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = knnMdl.predict(inputs);
-
- totalAmount++;
- if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
- amountOfErrors++;
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
+ double accuracy = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed.");
- }
+ System.out.println("\n>>> Accuracy " + accuracy);
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index 8615b6c..e1791df 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -98,10 +98,10 @@ public class KNNRegressionExample {
System.out.println(">>> ---------------------------------");
- mse = mse / totalAmount;
+ mse /= totalAmount;
System.out.println("\n>>> Mean squared error (MSE) " + mse);
- mae = mae / totalAmount;
+ mae /= totalAmount;
System.out.println("\n>>> Mean absolute error (MAE) " + mae);
System.out.println(">>> kNN regression over cached dataset usage example completed.");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
index 54c9ce0..fff298b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
@@ -18,16 +18,14 @@
package org.apache.ignite.examples.ml.naivebayes;
import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
@@ -64,49 +62,26 @@ public class DiscreteNaiveBayesTrainerExample {
.setBucketThresholds(thresholds);
System.out.println(">>> Perform the training to get the model.");
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
DiscreteNaiveBayesModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ featureExtractor,
+ lbExtractor
);
System.out.println(">>> Discrete Naive Bayes model: " + mdl);
- int amountOfErrors = 0;
- int totalAmount = 0;
-
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0}, {0, 0}};
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = mdl.predict(inputs);
-
- totalAmount++;
- if (groundTruth != prediction)
- amountOfErrors++;
-
- int idx1 = (int)prediction;
- int idx2 = (int)groundTruth;
-
- confusionMtx[idx1][idx2]++;
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- }
+ double accuracy = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
- System.out.println(">>> ---------------------------------");
+ System.out.println("\n>>> Accuracy " + accuracy);
System.out.println(">>> Discrete Naive bayes model over partitioned dataset usage example completed.");
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
index 74e0bfd..4459566 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
@@ -18,17 +18,14 @@
package org.apache.ignite.examples.ml.naivebayes;
import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
@@ -63,49 +60,27 @@ public class GaussianNaiveBayesTrainerExample {
GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
System.out.println(">>> Perform the training to get the model.");
+
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
GaussianNaiveBayesModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ featureExtractor,
+ lbExtractor
);
System.out.println(">>> Naive Bayes model: " + mdl);
- int amountOfErrors = 0;
- int totalAmount = 0;
-
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0}, {0, 0}};
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = mdl.predict(inputs);
-
- totalAmount++;
- if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
- amountOfErrors++;
-
- int idx1 = (int)prediction;
- int idx2 = (int)groundTruth;
-
- confusionMtx[idx1][idx2]++;
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- }
+ double accuracy = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
- System.out.println(">>> ---------------------------------");
+ System.out.println("\n>>> Accuracy " + accuracy);
System.out.println(">>> Naive bayes model over partitioned dataset usage example completed.");
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 059f810..2dab1af 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -18,20 +18,17 @@
package org.apache.ignite.examples.ml.regression.logistic.binary;
import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
@@ -75,49 +72,26 @@ public class LogisticRegressionSGDTrainerExample {
.withSeed(123L);
System.out.println(">>> Perform the training to get the model.");
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
LogisticRegressionModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ featureExtractor,
+ lbExtractor
);
System.out.println(">>> Logistic regression model: " + mdl);
- int amountOfErrors = 0;
- int totalAmount = 0;
-
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0}, {0, 0}};
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = mdl.predict(inputs);
-
- totalAmount++;
- if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
- amountOfErrors++;
-
- int idx1 = (int)prediction;
- int idx2 = (int)groundTruth;
-
- confusionMtx[idx1][idx2]++;
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- }
+ double accuracy = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
- System.out.println(">>> ---------------------------------");
+ System.out.println("\n>>> Accuracy " + accuracy);
System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
index a6a989b..c556e11 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
@@ -77,7 +77,15 @@ public class EvaluatorExample {
);
System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+ double f1Score = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).f1Score();
+
+ System.out.println("\n>>> F1-Score " + f1Score);
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
index f057386..291c7f8 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
@@ -18,15 +18,12 @@
package org.apache.ignite.examples.ml.svm;
import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
import org.apache.ignite.ml.util.MLSandboxDatasets;
@@ -60,54 +57,28 @@ public class SVMBinaryClassificationExample {
SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
SVMLinearClassificationModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ featureExtractor,
+ lbExtractor
);
System.out.println(">>> SVM model " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- int amountOfErrors = 0;
- int totalAmount = 0;
-
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0}, {0, 0}};
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = mdl.predict(inputs);
-
- totalAmount++;
- if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
- amountOfErrors++;
-
- int idx1 = prediction == 0.0 ? 0 : 1;
- int idx2 = groundTruth == 0.0 ? 0 : 1;
-
- confusionMtx[idx1][idx2]++;
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- }
+ double accuracy = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+ System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ System.out.println(">>> SVM Binary classification model over cache based dataset usage example completed.");
}
}
}