You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/26 15:02:28 UTC

[ignite] branch master updated: IGNITE-10903: [ML] Provide an example with training of regression model and its evaluation

This is an automated email from the ASF dual-hosted git repository.

chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new c5b5be2  IGNITE-10903: [ML] Provide an example with training of regression model and its evaluation
c5b5be2 is described below

commit c5b5be2ddb6674097730d4ed1c82eadcf49676f4
Author: zaleslaw <za...@gmail.com>
AuthorDate: Tue Feb 26 18:00:46 2019 +0300

    IGNITE-10903: [ML] Provide an example with training of regression model
    and its evaluation
    
    This closes #6163
---
 .../BaggedLogisticRegressionSGDTrainerExample.java |  4 +-
 .../selection/scoring/RegressionMetricExample.java | 87 ++++++++++++++++++++++
 .../logistic/LogisticRegressionSGDTrainer.java     |  6 +-
 3 files changed, 95 insertions(+), 2 deletions(-)

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
index 8de06a6..1e3914a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
@@ -23,6 +23,7 @@ import org.apache.ignite.Ignition;
 import org.apache.ignite.ml.composition.bagging.BaggedModel;
 import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
 import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
@@ -82,7 +83,8 @@ public class BaggedLogisticRegressionSGDTrainerExample {
                 0.6,
                 4,
                 3,
-                new OnMajorityPredictionsAggregator());
+                new OnMajorityPredictionsAggregator())
+                .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1));
 
             System.out.println(">>> Perform evaluation of the model.");
 
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java
new file mode 100644
index 0000000..a978078
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.selection.scoring;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
+import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
+import org.apache.ignite.ml.util.MLSandboxDatasets;
+import org.apache.ignite.ml.util.SandboxMLCache;
+
+import java.io.FileNotFoundException;
+
+/**
+ * Run kNN regression trainer ({@link KNNRegressionTrainer}) over distributed dataset.
+ * <p>
+ * After that it trains the model based on the specified data using
+ * <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">kNN</a> regression algorithm.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster
+ * does this point belong to, and compares prediction to expected outcome (ground truth).</p>
+ * <p>
+ * You can change the test data used in this example or trainer object settings and re-run it to explore
+ * this algorithm further.</p>
+ */
+public class RegressionMetricExample {
+    /** Run example. */
+    public static void main(String[] args) throws FileNotFoundException {
+        System.out.println();
+        System.out.println(">>> kNN regression over cached dataset usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+                .fillCacheWith(MLSandboxDatasets.CLEARED_MACHINES);
+
+            KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+
+            final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+            KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
+                ignite,
+                dataCache,
+                featureExtractor,
+                lbExtractor
+            ).withK(5)
+                .withDistanceMeasure(new ManhattanDistance())
+                .withStrategy(NNStrategy.WEIGHTED);
+
+
+            double mae = Evaluator.evaluate(
+                dataCache,
+                knnMdl,
+                featureExtractor,
+                lbExtractor,
+                new RegressionMetrics().withMetric(RegressionMetricValues::mae)
+            );
+
+            System.out.println("\n>>> Mae " + mae);
+        }
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
index 1070efc..8b2cc3a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
@@ -17,7 +17,6 @@
 
 package org.apache.ignite.ml.regressions.logistic;
 
-import java.util.Arrays;
 import org.apache.ignite.ml.composition.CompositionUtils;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -39,6 +38,8 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.jetbrains.annotations.NotNull;
 
+import java.util.Arrays;
+
 /**
  * Trainer of the logistic regression model based on stochastic gradient descent algorithm.
  */
@@ -89,6 +90,9 @@ public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer<Logi
                 return a;
             });
 
+            if (cols == null)
+                throw new IllegalStateException("Cannot train on empty dataset");
+
             MLPArchitecture architecture = new MLPArchitecture(cols);
             architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);