You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@ignite.apache.org by GitBox <gi...@apache.org> on 2019/01/11 11:48:34 UTC

[GitHub] asfgit closed pull request #5787: IGNITE-10713:[ML] Refactor examples with accuracy calculation and another metrics usage

asfgit closed pull request #5787: IGNITE-10713:[ML] Refactor examples with accuracy calculation and another metrics usage
URL: https://github.com/apache/ignite/pull/5787
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index 4a475a0d49bb..ec250062067a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -18,18 +18,16 @@
 package org.apache.ignite.examples.ml.knn;
 
 import java.io.FileNotFoundException;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -57,48 +55,30 @@ public static void main(String[] args) throws FileNotFoundException {
             System.out.println(">>> Ignite grid started.");
 
             IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
-                .fillCacheWith(MLSandboxDatasets.IRIS);
+                .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
 
             KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-            NNClassificationModel knnMdl = trainer.fit(
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+            NNClassificationModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             ).withK(3)
                 .withDistanceMeasure(new EuclideanDistance())
                 .withStrategy(NNStrategy.WEIGHTED);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = knnMdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-                System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed.");
-            }
+            System.out.println("\n>>> Accuracy " + accuracy);
         }
     }
 }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index 8615b6cd11e0..e1791df46ff5 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -98,10 +98,10 @@ public static void main(String[] args) throws FileNotFoundException {
 
                 System.out.println(">>> ---------------------------------");
 
-                mse = mse / totalAmount;
+                mse /= totalAmount;
                 System.out.println("\n>>> Mean squared error (MSE) " + mse);
 
-                mae = mae / totalAmount;
+                mae /= totalAmount;
                 System.out.println("\n>>> Mean absolute error (MAE) " + mae);
 
                 System.out.println(">>> kNN regression over cached dataset usage example completed.");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
index 54c9ce02fe24..fff298b57cc6 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
@@ -18,16 +18,14 @@
 package org.apache.ignite.examples.ml.naivebayes;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
 import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -64,49 +62,26 @@ public static void main(String[] args) throws FileNotFoundException {
                 .setBucketThresholds(thresholds);
 
             System.out.println(">>> Perform the training to get the model.");
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             DiscreteNaiveBayesModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Discrete Naive Bayes model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (groundTruth != prediction)
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Discrete Naive bayes model over partitioned dataset usage example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
index 74e0bfdc271d..44595669da3b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
@@ -18,17 +18,14 @@
 package org.apache.ignite.examples.ml.naivebayes;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
 import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -63,49 +60,27 @@ public static void main(String[] args) throws FileNotFoundException {
             GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
 
             System.out.println(">>> Perform the training to get the model.");
+
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             GaussianNaiveBayesModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Naive Bayes model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Naive bayes model over partitioned dataset usage example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 059f810a22d6..2dab1afae731 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -18,20 +18,17 @@
 package org.apache.ignite.examples.ml.regression.logistic.binary;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -75,49 +72,26 @@ public static void main(String[] args) throws FileNotFoundException {
                 .withSeed(123L);
 
             System.out.println(">>> Perform the training to get the model.");
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             LogisticRegressionModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Logistic regression model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
index a6a989b4a787..c556e1102112 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
@@ -77,7 +77,15 @@ public static void main(String[] args) throws FileNotFoundException {
             );
 
             System.out.println("\n>>> Accuracy " + accuracy);
-            System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+            double f1Score = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).f1Score();
+
+            System.out.println("\n>>> F1-Score " + f1Score);
         }
     }
 }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
index f0573863b99d..291c7f841c6f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
@@ -18,15 +18,12 @@
 package org.apache.ignite.examples.ml.svm;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
@@ -60,54 +57,28 @@ public static void main(String[] args) throws FileNotFoundException {
 
             SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
 
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             SVMLinearClassificationModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> SVM model " + mdl);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = prediction == 0.0 ? 0 : 1;
-                    int idx2 = groundTruth == 0.0 ? 0 : 1;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+            System.out.println("\n>>> Accuracy " + accuracy);
 
-            System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+            System.out.println(">>> SVM Binary classification model over cache based dataset usage example completed.");
         }
     }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services