You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/01/11 11:45:33 UTC

[ignite] branch master updated: IGNITE-10713: [ML] Refactor examples with accuracy calculation and another metrics usage

This is an automated email from the ASF dual-hosted git repository.

chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new 2026cf7  IGNITE-10713: [ML] Refactor examples with accuracy calculation and another metrics usage
2026cf7 is described below

commit 2026cf75dc58f674e7bca890334eb77041e0afb9
Author: zaleslaw <za...@gmail.com>
AuthorDate: Fri Jan 11 13:56:16 2019 +0300

    IGNITE-10713: [ML] Refactor examples with accuracy calculation and
    another metrics usage
    
    This closes #5787
---
 .../examples/ml/knn/KNNClassificationExample.java  | 52 ++++++-------------
 .../examples/ml/knn/KNNRegressionExample.java      |  4 +-
 .../DiscreteNaiveBayesTrainerExample.java          | 53 +++++--------------
 .../GaussianNaiveBayesTrainerExample.java          | 55 ++++++--------------
 .../LogisticRegressionSGDTrainerExample.java       | 54 +++++---------------
 .../ml/selection/scoring/EvaluatorExample.java     | 10 +++-
 .../ml/svm/SVMBinaryClassificationExample.java     | 59 ++++++----------------
 7 files changed, 85 insertions(+), 202 deletions(-)

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index 4a475a0..ec25006 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -18,18 +18,16 @@
 package org.apache.ignite.examples.ml.knn;
 
 import java.io.FileNotFoundException;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -57,48 +55,30 @@ public class KNNClassificationExample {
             System.out.println(">>> Ignite grid started.");
 
             IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
-                .fillCacheWith(MLSandboxDatasets.IRIS);
+                .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
 
             KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-            NNClassificationModel knnMdl = trainer.fit(
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+            NNClassificationModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             ).withK(3)
                 .withDistanceMeasure(new EuclideanDistance())
                 .withStrategy(NNStrategy.WEIGHTED);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = knnMdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-                System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed.");
-            }
+            System.out.println("\n>>> Accuracy " + accuracy);
         }
     }
 }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index 8615b6c..e1791df 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -98,10 +98,10 @@ public class KNNRegressionExample {
 
                 System.out.println(">>> ---------------------------------");
 
-                mse = mse / totalAmount;
+                mse /= totalAmount;
                 System.out.println("\n>>> Mean squared error (MSE) " + mse);
 
-                mae = mae / totalAmount;
+                mae /= totalAmount;
                 System.out.println("\n>>> Mean absolute error (MAE) " + mae);
 
                 System.out.println(">>> kNN regression over cached dataset usage example completed.");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
index 54c9ce0..fff298b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
@@ -18,16 +18,14 @@
 package org.apache.ignite.examples.ml.naivebayes;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
 import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -64,49 +62,26 @@ public class DiscreteNaiveBayesTrainerExample {
                 .setBucketThresholds(thresholds);
 
             System.out.println(">>> Perform the training to get the model.");
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             DiscreteNaiveBayesModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Discrete Naive Bayes model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (groundTruth != prediction)
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Discrete Naive bayes model over partitioned dataset usage example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
index 74e0bfd..4459566 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
@@ -18,17 +18,14 @@
 package org.apache.ignite.examples.ml.naivebayes;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
 import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -63,49 +60,27 @@ public class GaussianNaiveBayesTrainerExample {
             GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
 
             System.out.println(">>> Perform the training to get the model.");
+
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             GaussianNaiveBayesModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Naive Bayes model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Naive bayes model over partitioned dataset usage example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 059f810..2dab1af 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -18,20 +18,17 @@
 package org.apache.ignite.examples.ml.regression.logistic.binary;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -75,49 +72,26 @@ public class LogisticRegressionSGDTrainerExample {
                 .withSeed(123L);
 
             System.out.println(">>> Perform the training to get the model.");
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             LogisticRegressionModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Logistic regression model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
index a6a989b..c556e11 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
@@ -77,7 +77,15 @@ public class EvaluatorExample {
             );
 
             System.out.println("\n>>> Accuracy " + accuracy);
-            System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+            double f1Score = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).f1Score();
+
+            System.out.println("\n>>> F1-Score " + f1Score);
         }
     }
 }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
index f057386..291c7f8 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
@@ -18,15 +18,12 @@
 package org.apache.ignite.examples.ml.svm;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
@@ -60,54 +57,28 @@ public class SVMBinaryClassificationExample {
 
             SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
 
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             SVMLinearClassificationModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> SVM model " + mdl);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = prediction == 0.0 ? 0 : 1;
-                    int idx2 = groundTruth == 0.0 ? 0 : 1;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+            System.out.println("\n>>> Accuracy " + accuracy);
 
-            System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+            System.out.println(">>> SVM Binary classification model over cache based dataset usage example completed.");
         }
     }
 }