You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/26 15:02:28 UTC
[ignite] branch master updated: IGNITE-10903: [ML] Provide an
example with training of regression model and its evaluation
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new c5b5be2 IGNITE-10903: [ML] Provide an example with training of regression model and its evaluation
c5b5be2 is described below
commit c5b5be2ddb6674097730d4ed1c82eadcf49676f4
Author: zaleslaw <za...@gmail.com>
AuthorDate: Tue Feb 26 18:00:46 2019 +0300
IGNITE-10903: [ML] Provide an example with training of regression model
and its evaluation
This closes #6163
---
.../BaggedLogisticRegressionSGDTrainerExample.java | 4 +-
.../selection/scoring/RegressionMetricExample.java | 87 ++++++++++++++++++++++
.../logistic/LogisticRegressionSGDTrainer.java | 6 +-
3 files changed, 95 insertions(+), 2 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
index 8de06a6..1e3914a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
@@ -23,6 +23,7 @@ import org.apache.ignite.Ignition;
import org.apache.ignite.ml.composition.bagging.BaggedModel;
import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
@@ -82,7 +83,8 @@ public class BaggedLogisticRegressionSGDTrainerExample {
0.6,
4,
3,
- new OnMajorityPredictionsAggregator());
+ new OnMajorityPredictionsAggregator())
+ .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1));
System.out.println(">>> Perform evaluation of the model.");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java
new file mode 100644
index 0000000..a978078
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.selection.scoring;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
+import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
+import org.apache.ignite.ml.util.MLSandboxDatasets;
+import org.apache.ignite.ml.util.SandboxMLCache;
+
+import java.io.FileNotFoundException;
+
+/**
+ * Run kNN regression trainer ({@link KNNRegressionTrainer}) over distributed dataset.
+ * <p>
+ * After that it trains the model based on the specified data using
+ * <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">kNN</a> regression algorithm.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster
+ * does this point belong to, and compares prediction to expected outcome (ground truth).</p>
+ * <p>
+ * You can change the test data used in this example or trainer object settings and re-run it to explore
+ * this algorithm further.</p>
+ */
+public class RegressionMetricExample {
+ /** Run example. */
+ public static void main(String[] args) throws FileNotFoundException {
+ System.out.println();
+ System.out.println(">>> kNN regression over cached dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.CLEARED_MACHINES);
+
+ KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+
+ final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+ final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+ KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
+ ignite,
+ dataCache,
+ featureExtractor,
+ lbExtractor
+ ).withK(5)
+ .withDistanceMeasure(new ManhattanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
+
+
+ double mae = Evaluator.evaluate(
+ dataCache,
+ knnMdl,
+ featureExtractor,
+ lbExtractor,
+ new RegressionMetrics().withMetric(RegressionMetricValues::mae)
+ );
+
+ System.out.println("\n>>> Mae " + mae);
+ }
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
index 1070efc..8b2cc3a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
@@ -17,7 +17,6 @@
package org.apache.ignite.ml.regressions.logistic;
-import java.util.Arrays;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -39,6 +38,8 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;
+import java.util.Arrays;
+
/**
* Trainer of the logistic regression model based on stochastic gradient descent algorithm.
*/
@@ -89,6 +90,9 @@ public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer<Logi
return a;
});
+ if (cols == null)
+ throw new IllegalStateException("Cannot train on empty dataset");
+
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);