You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/26 15:20:46 UTC

[ignite] branch master updated: IGNITE-10904: [ML] Refactor all examples with regression to use RegressionMetrics

This is an automated email from the ASF dual-hosted git repository.

chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new fb898be  IGNITE-10904: [ML] Refactor all examples with regression to use RegressionMetrics
fb898be is described below

commit fb898be1be7ebada0929838b81eb45f31bf59833
Author: zaleslaw <za...@gmail.com>
AuthorDate: Tue Feb 26 18:18:49 2019 +0300

    IGNITE-10904: [ML] Refactor all examples with regression
    to use RegressionMetrics
    
    This closes #6164
---
 .../examples/ml/knn/KNNRegressionExample.java      | 59 +++++++---------------
 .../linear/LinearRegressionLSQRTrainerExample.java | 35 +++++--------
 ...gressionLSQRTrainerWithMinMaxScalerExample.java | 33 ++++++------
 .../linear/LinearRegressionSGDTrainerExample.java  | 39 +++++++-------
 4 files changed, 65 insertions(+), 101 deletions(-)

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index e1791df..fad238d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -17,21 +17,22 @@
 
 package org.apache.ignite.examples.ml.knn;
 
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
 import org.apache.ignite.ml.math.distances.ManhattanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
+import java.io.FileNotFoundException;
+
 /**
  * Run kNN regression trainer ({@link KNNRegressionTrainer}) over distributed dataset.
  * <p>
@@ -61,51 +62,27 @@ public class KNNRegressionExample {
 
             KNNRegressionTrainer trainer = new KNNRegressionTrainer();
 
+            final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             ).withK(5)
                 .withDistanceMeasure(new ManhattanDistance())
                 .withStrategy(NNStrategy.WEIGHTED);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            int totalAmount = 0;
-            // Calculate mean squared error (MSE)
-            double mse = 0.0;
-            // Calculate mean absolute error (MAE)
-            double mae = 0.0;
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = knnMdl.predict(inputs);
-
-                    mse += Math.pow(prediction - groundTruth, 2.0);
-                    mae += Math.abs(prediction - groundTruth);
-
-                    totalAmount++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                mse /= totalAmount;
-                System.out.println("\n>>> Mean squared error (MSE) " + mse);
-
-                mae /= totalAmount;
-                System.out.println("\n>>> Mean absolute error (MAE) " + mae);
+            double rmse = Evaluator.evaluate(
+                dataCache,
+                knnMdl,
+                featureExtractor,
+                lbExtractor,
+                new RegressionMetrics()
+            );
 
-                System.out.println(">>> kNN regression over cached dataset usage example completed.");
-            }
+            System.out.println("\n>>> Rmse = " + rmse);
         }
     }
 }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 772a35b..6f1fe4c 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -17,21 +17,22 @@
 
 package org.apache.ignite.examples.ml.regression.linear;
 
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.composition.CompositionUtils;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
+import java.io.FileNotFoundException;
+
 /**
  * Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a>
  * ({@link LinearRegressionLSQRTrainer}) over cached dataset.
@@ -79,25 +80,15 @@ public class LinearRegressionLSQRTrainerExample {
                 extractor
             );
 
-            System.out.println(">>> Linear regression model: " + mdl);
-
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-            }
+            double rmse = Evaluator.evaluate(
+                dataCache,
+                mdl,
+                CompositionUtils.asFeatureExtractor(extractor),
+                CompositionUtils.asLabelExtractor(extractor),
+                new RegressionMetrics()
+            );
 
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Rmse = " + rmse);
 
             System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
index c00a3bb..6c7ec85 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
@@ -17,22 +17,22 @@
 
 package org.apache.ignite.examples.ml.regression.linear;
 
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessor;
 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
+import java.io.FileNotFoundException;
+
 /**
  * Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a>
  * ({@link LinearRegressionLSQRTrainer}) over cached dataset that was created using
@@ -75,25 +75,22 @@ public class LinearRegressionLSQRTrainerWithMinMaxScalerExample {
             LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
 
             System.out.println(">>> Perform the training to get the model.");
-            LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v.get(0));
 
-            System.out.println(">>> Linear regression model: " + mdl);
+            final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
+            LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, lbExtractor);
 
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Integer key = observation.getKey();
-                    Vector val = observation.getValue();
-                    double groundTruth = val.get(0);
+            System.out.println(">>> Linear regression model: " + mdl);
 
-                    double prediction = mdl.predict(preprocessor.apply(key, val));
+            double rmse = Evaluator.evaluate(
+                dataCache,
+                mdl,
+                preprocessor,
+                lbExtractor,
+                new RegressionMetrics()
+            );
 
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-            }
+            System.out.println("\n>>> Rmse = " + rmse);
 
             System.out.println(">>> ---------------------------------");
             System.out.println(">>> Linear regression model with MinMaxScaler preprocessor over cache based dataset usage example completed.");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
index cb764c5..cb868b2 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
@@ -17,22 +17,23 @@
 
 package org.apache.ignite.examples.ml.regression.linear;
 
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
 import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
+import java.io.FileNotFoundException;
+
 /**
  * Run linear regression model based on  based on
  * <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent">stochastic gradient descent</a> algorithm
@@ -68,30 +69,28 @@ public class LinearRegressionSGDTrainerExample {
             ), 100000,  10, 100, 123L);
 
             System.out.println(">>> Perform the training to get the model.");
+
+            final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             LinearRegressionModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Linear regression model: " + mdl);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
+            double rmse = Evaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor,
+                new RegressionMetrics()
+            );
 
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-            }
+            System.out.println("\n>>> Rmse = " + rmse);
 
             System.out.println(">>> ---------------------------------");
             System.out.println(">>> Linear regression model over cache based dataset usage example completed.");