You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/26 15:20:46 UTC
[ignite] branch master updated: IGNITE-10904: [ML] Refactor all
examples with regression to use RegressionMetrics
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new fb898be IGNITE-10904: [ML] Refactor all examples with regression to use RegressionMetrics
fb898be is described below
commit fb898be1be7ebada0929838b81eb45f31bf59833
Author: zaleslaw <za...@gmail.com>
AuthorDate: Tue Feb 26 18:18:49 2019 +0300
IGNITE-10904: [ML] Refactor all examples with regression
to use RegressionMetrics
This closes #6164
---
.../examples/ml/knn/KNNRegressionExample.java | 59 +++++++---------------
.../linear/LinearRegressionLSQRTrainerExample.java | 35 +++++--------
...gressionLSQRTrainerWithMinMaxScalerExample.java | 33 ++++++------
.../linear/LinearRegressionSGDTrainerExample.java | 39 +++++++-------
4 files changed, 65 insertions(+), 101 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index e1791df..fad238d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -17,21 +17,22 @@
package org.apache.ignite.examples.ml.knn;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run kNN regression trainer ({@link KNNRegressionTrainer}) over distributed dataset.
* <p>
@@ -61,51 +62,27 @@ public class KNNRegressionExample {
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+ final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ featureExtractor,
+ lbExtractor
).withK(5)
.withDistanceMeasure(new ManhattanDistance())
.withStrategy(NNStrategy.WEIGHTED);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- int totalAmount = 0;
- // Calculate mean squared error (MSE)
- double mse = 0.0;
- // Calculate mean absolute error (MAE)
- double mae = 0.0;
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = knnMdl.predict(inputs);
-
- mse += Math.pow(prediction - groundTruth, 2.0);
- mae += Math.abs(prediction - groundTruth);
-
- totalAmount++;
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- mse /= totalAmount;
- System.out.println("\n>>> Mean squared error (MSE) " + mse);
-
- mae /= totalAmount;
- System.out.println("\n>>> Mean absolute error (MAE) " + mae);
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ knnMdl,
+ featureExtractor,
+ lbExtractor,
+ new RegressionMetrics()
+ );
- System.out.println(">>> kNN regression over cached dataset usage example completed.");
- }
+ System.out.println("\n>>> Rmse = " + rmse);
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 772a35b..6f1fe4c 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -17,21 +17,22 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a>
* ({@link LinearRegressionLSQRTrainer}) over cached dataset.
@@ -79,25 +80,15 @@ public class LinearRegressionLSQRTrainerExample {
extractor
);
- System.out.println(">>> Linear regression model: " + mdl);
-
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = mdl.predict(inputs);
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
- }
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ CompositionUtils.asFeatureExtractor(extractor),
+ CompositionUtils.asLabelExtractor(extractor),
+ new RegressionMetrics()
+ );
- System.out.println(">>> ---------------------------------");
+ System.out.println("\n>>> Rmse = " + rmse);
System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
index c00a3bb..6c7ec85 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
@@ -17,22 +17,22 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessor;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a>
* ({@link LinearRegressionLSQRTrainer}) over cached dataset that was created using
@@ -75,25 +75,22 @@ public class LinearRegressionLSQRTrainerWithMinMaxScalerExample {
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v.get(0));
- System.out.println(">>> Linear regression model: " + mdl);
+ final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, lbExtractor);
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Integer key = observation.getKey();
- Vector val = observation.getValue();
- double groundTruth = val.get(0);
+ System.out.println(">>> Linear regression model: " + mdl);
- double prediction = mdl.predict(preprocessor.apply(key, val));
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ preprocessor,
+ lbExtractor,
+ new RegressionMetrics()
+ );
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
- }
+ System.out.println("\n>>> Rmse = " + rmse);
System.out.println(">>> ---------------------------------");
System.out.println(">>> Linear regression model with MinMaxScaler preprocessor over cache based dataset usage example completed.");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
index cb764c5..cb868b2 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
@@ -17,22 +17,23 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import java.io.FileNotFoundException;
+
/**
* Run linear regression model based on based on
* <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent">stochastic gradient descent</a> algorithm
@@ -68,30 +69,28 @@ public class LinearRegressionSGDTrainerExample {
), 100000, 10, 100, 123L);
System.out.println(">>> Perform the training to get the model.");
+
+ final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
LinearRegressionModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ featureExtractor,
+ lbExtractor
);
System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
-
- double prediction = mdl.predict(inputs);
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new RegressionMetrics()
+ );
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
- }
+ System.out.println("\n>>> Rmse = " + rmse);
System.out.println(">>> ---------------------------------");
System.out.println(">>> Linear regression model over cache based dataset usage example completed.");