You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2018/05/21 12:31:32 UTC
ignite git commit: IGNITE-8511: [ML] Add support for Multi-Class
Logistic Regression
Repository: ignite
Updated Branches:
refs/heads/master 436d1230e -> cb8fb7365
IGNITE-8511: [ML] Add support for Multi-Class Logistic Regression
this closes #4008
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/cb8fb736
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/cb8fb736
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/cb8fb736
Branch: refs/heads/master
Commit: cb8fb736597e9b3f25ef6d55a8dc4d8ad0d23b60
Parents: 436d123
Author: zaleslaw <za...@gmail.com>
Authored: Mon May 21 15:31:16 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Mon May 21 15:31:16 2018 +0300
----------------------------------------------------------------------
.../LogisticRegressionSGDTrainerSample.java | 239 ---------------
.../LogisticRegressionSGDTrainerSample.java | 239 +++++++++++++++
.../logistic/binary/package-info.java | 22 ++
...gressionMultiClassClassificationExample.java | 301 +++++++++++++++++++
.../logistic/multiclass/package-info.java | 22 ++
.../LogRegressionMultiClassModel.java | 96 ++++++
.../LogRegressionMultiClassTrainer.java | 222 ++++++++++++++
.../logistic/multiclass/package-info.java | 22 ++
.../ml/regressions/RegressionsTestSuite.java | 6 +-
.../linear/LinearRegressionModelTest.java | 17 ++
.../logistic/LogRegMultiClassTrainerTest.java | 98 ++++++
11 files changed, 1043 insertions(+), 241 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java
deleted file mode 100644
index 0505ddd..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java
+++ /dev/null
@@ -1,239 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.regression.logistic;
-
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.nn.UpdatesStrategy;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
-import org.apache.ignite.thread.IgniteThread;
-
-import javax.cache.Cache;
-import java.util.Arrays;
-import java.util.UUID;
-
-/**
- * Run logistic regression model over distributed cache.
- *
- * @see LogisticRegressionSGDTrainer
- */
-public class LogisticRegressionSGDTrainerSample {
- /** Run example. */
- public static void main(String[] args) throws InterruptedException {
- System.out.println();
- System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
- // Start ignite grid.
- try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- LogisticRegressionSGDTrainerSample.class.getSimpleName(), () -> {
-
- IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
-
- System.out.println(">>> Create new logistic regression trainer object.");
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), 100000, 10, 100, 123L);
-
- System.out.println(">>> Perform the training to get the model.");
- LogisticRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> Arrays.copyOfRange(v, 1, v.length),
- (k, v) -> v[0]
- ).withRawLabels(true);
-
- System.out.println(">>> Logistic regression model: " + mdl);
-
- int amountOfErrors = 0;
- int totalAmount = 0;
-
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0}, {0, 0}};
-
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
-
- double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
-
- totalAmount++;
- if(groundTruth != prediction)
- amountOfErrors++;
-
- int idx1 = (int)prediction;
- int idx2 = (int)groundTruth;
-
- confusionMtx[idx1][idx2]++;
-
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- }
-
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
- System.out.println(">>> ---------------------------------");
- });
-
- igniteThread.start();
-
- igniteThread.join();
- }
- }
- /**
- * Fills cache with data and returns it.
- *
- * @param ignite Ignite instance.
- * @return Filled Ignite Cache.
- */
- private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
- CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
- cacheConfiguration.setName("TEST_" + UUID.randomUUID());
- cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
-
- IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
-
- for (int i = 0; i < data.length; i++)
- cache.put(i, data[i]);
-
- return cache;
- }
-
-
- /** The 1st and 2nd classes from the Iris dataset. */
- private static final double[][] data = {
- {0, 5.1, 3.5, 1.4, 0.2},
- {0, 4.9, 3, 1.4, 0.2},
- {0, 4.7, 3.2, 1.3, 0.2},
- {0, 4.6, 3.1, 1.5, 0.2},
- {0, 5, 3.6, 1.4, 0.2},
- {0, 5.4, 3.9, 1.7, 0.4},
- {0, 4.6, 3.4, 1.4, 0.3},
- {0, 5, 3.4, 1.5, 0.2},
- {0, 4.4, 2.9, 1.4, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 5.4, 3.7, 1.5, 0.2},
- {0, 4.8, 3.4, 1.6, 0.2},
- {0, 4.8, 3, 1.4, 0.1},
- {0, 4.3, 3, 1.1, 0.1},
- {0, 5.8, 4, 1.2, 0.2},
- {0, 5.7, 4.4, 1.5, 0.4},
- {0, 5.4, 3.9, 1.3, 0.4},
- {0, 5.1, 3.5, 1.4, 0.3},
- {0, 5.7, 3.8, 1.7, 0.3},
- {0, 5.1, 3.8, 1.5, 0.3},
- {0, 5.4, 3.4, 1.7, 0.2},
- {0, 5.1, 3.7, 1.5, 0.4},
- {0, 4.6, 3.6, 1, 0.2},
- {0, 5.1, 3.3, 1.7, 0.5},
- {0, 4.8, 3.4, 1.9, 0.2},
- {0, 5, 3, 1.6, 0.2},
- {0, 5, 3.4, 1.6, 0.4},
- {0, 5.2, 3.5, 1.5, 0.2},
- {0, 5.2, 3.4, 1.4, 0.2},
- {0, 4.7, 3.2, 1.6, 0.2},
- {0, 4.8, 3.1, 1.6, 0.2},
- {0, 5.4, 3.4, 1.5, 0.4},
- {0, 5.2, 4.1, 1.5, 0.1},
- {0, 5.5, 4.2, 1.4, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 5, 3.2, 1.2, 0.2},
- {0, 5.5, 3.5, 1.3, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 4.4, 3, 1.3, 0.2},
- {0, 5.1, 3.4, 1.5, 0.2},
- {0, 5, 3.5, 1.3, 0.3},
- {0, 4.5, 2.3, 1.3, 0.3},
- {0, 4.4, 3.2, 1.3, 0.2},
- {0, 5, 3.5, 1.6, 0.6},
- {0, 5.1, 3.8, 1.9, 0.4},
- {0, 4.8, 3, 1.4, 0.3},
- {0, 5.1, 3.8, 1.6, 0.2},
- {0, 4.6, 3.2, 1.4, 0.2},
- {0, 5.3, 3.7, 1.5, 0.2},
- {0, 5, 3.3, 1.4, 0.2},
- {1, 7, 3.2, 4.7, 1.4},
- {1, 6.4, 3.2, 4.5, 1.5},
- {1, 6.9, 3.1, 4.9, 1.5},
- {1, 5.5, 2.3, 4, 1.3},
- {1, 6.5, 2.8, 4.6, 1.5},
- {1, 5.7, 2.8, 4.5, 1.3},
- {1, 6.3, 3.3, 4.7, 1.6},
- {1, 4.9, 2.4, 3.3, 1},
- {1, 6.6, 2.9, 4.6, 1.3},
- {1, 5.2, 2.7, 3.9, 1.4},
- {1, 5, 2, 3.5, 1},
- {1, 5.9, 3, 4.2, 1.5},
- {1, 6, 2.2, 4, 1},
- {1, 6.1, 2.9, 4.7, 1.4},
- {1, 5.6, 2.9, 3.6, 1.3},
- {1, 6.7, 3.1, 4.4, 1.4},
- {1, 5.6, 3, 4.5, 1.5},
- {1, 5.8, 2.7, 4.1, 1},
- {1, 6.2, 2.2, 4.5, 1.5},
- {1, 5.6, 2.5, 3.9, 1.1},
- {1, 5.9, 3.2, 4.8, 1.8},
- {1, 6.1, 2.8, 4, 1.3},
- {1, 6.3, 2.5, 4.9, 1.5},
- {1, 6.1, 2.8, 4.7, 1.2},
- {1, 6.4, 2.9, 4.3, 1.3},
- {1, 6.6, 3, 4.4, 1.4},
- {1, 6.8, 2.8, 4.8, 1.4},
- {1, 6.7, 3, 5, 1.7},
- {1, 6, 2.9, 4.5, 1.5},
- {1, 5.7, 2.6, 3.5, 1},
- {1, 5.5, 2.4, 3.8, 1.1},
- {1, 5.5, 2.4, 3.7, 1},
- {1, 5.8, 2.7, 3.9, 1.2},
- {1, 6, 2.7, 5.1, 1.6},
- {1, 5.4, 3, 4.5, 1.5},
- {1, 6, 3.4, 4.5, 1.6},
- {1, 6.7, 3.1, 4.7, 1.5},
- {1, 6.3, 2.3, 4.4, 1.3},
- {1, 5.6, 3, 4.1, 1.3},
- {1, 5.5, 2.5, 4, 1.3},
- {1, 5.5, 2.6, 4.4, 1.2},
- {1, 6.1, 3, 4.6, 1.4},
- {1, 5.8, 2.6, 4, 1.2},
- {1, 5, 2.3, 3.3, 1},
- {1, 5.6, 2.7, 4.2, 1.3},
- {1, 5.7, 3, 4.2, 1.2},
- {1, 5.7, 2.9, 4.2, 1.3},
- {1, 6.2, 2.9, 4.3, 1.3},
- {1, 5.1, 2.5, 3, 1.1},
- {1, 5.7, 2.8, 4.1, 1.3},
- };
-
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java
new file mode 100644
index 0000000..215d7a4
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression.logistic.binary;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.thread.IgniteThread;
+
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
+/**
+ * Run logistic regression model over distributed cache.
+ *
+ * @see LogisticRegressionSGDTrainer
+ */
+public class LogisticRegressionSGDTrainerSample {
+ /** Run example. */
+ public static void main(String[] args) throws InterruptedException {
+ System.out.println();
+ System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+ LogisticRegressionSGDTrainerSample.class.getSimpleName(), () -> {
+
+ IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+
+ System.out.println(">>> Create new logistic regression trainer object.");
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ), 100000, 10, 100, 123L);
+
+ System.out.println(">>> Perform the training to get the model.");
+ LogisticRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0]
+ ).withRawLabels(true);
+
+ System.out.println(">>> Logistic regression model: " + mdl);
+
+ int amountOfErrors = 0;
+ int totalAmount = 0;
+
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0}, {0, 0}};
+
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
+
+ double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
+
+ totalAmount++;
+ if(groundTruth != prediction)
+ amountOfErrors++;
+
+ int idx1 = (int)prediction;
+ int idx2 = (int)groundTruth;
+
+ confusionMtx[idx1][idx2]++;
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
+
+ System.out.println(">>> ---------------------------------");
+
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ }
+
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+ System.out.println(">>> ---------------------------------");
+ });
+
+ igniteThread.start();
+
+ igniteThread.join();
+ }
+ }
+ /**
+ * Fills cache with data and returns it.
+ *
+ * @param ignite Ignite instance.
+ * @return Filled Ignite Cache.
+ */
+ private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+ CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+ cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+ cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+ for (int i = 0; i < data.length; i++)
+ cache.put(i, data[i]);
+
+ return cache;
+ }
+
+
+ /** The 1st and 2nd classes from the Iris dataset. */
+ private static final double[][] data = {
+ {0, 5.1, 3.5, 1.4, 0.2},
+ {0, 4.9, 3, 1.4, 0.2},
+ {0, 4.7, 3.2, 1.3, 0.2},
+ {0, 4.6, 3.1, 1.5, 0.2},
+ {0, 5, 3.6, 1.4, 0.2},
+ {0, 5.4, 3.9, 1.7, 0.4},
+ {0, 4.6, 3.4, 1.4, 0.3},
+ {0, 5, 3.4, 1.5, 0.2},
+ {0, 4.4, 2.9, 1.4, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 5.4, 3.7, 1.5, 0.2},
+ {0, 4.8, 3.4, 1.6, 0.2},
+ {0, 4.8, 3, 1.4, 0.1},
+ {0, 4.3, 3, 1.1, 0.1},
+ {0, 5.8, 4, 1.2, 0.2},
+ {0, 5.7, 4.4, 1.5, 0.4},
+ {0, 5.4, 3.9, 1.3, 0.4},
+ {0, 5.1, 3.5, 1.4, 0.3},
+ {0, 5.7, 3.8, 1.7, 0.3},
+ {0, 5.1, 3.8, 1.5, 0.3},
+ {0, 5.4, 3.4, 1.7, 0.2},
+ {0, 5.1, 3.7, 1.5, 0.4},
+ {0, 4.6, 3.6, 1, 0.2},
+ {0, 5.1, 3.3, 1.7, 0.5},
+ {0, 4.8, 3.4, 1.9, 0.2},
+ {0, 5, 3, 1.6, 0.2},
+ {0, 5, 3.4, 1.6, 0.4},
+ {0, 5.2, 3.5, 1.5, 0.2},
+ {0, 5.2, 3.4, 1.4, 0.2},
+ {0, 4.7, 3.2, 1.6, 0.2},
+ {0, 4.8, 3.1, 1.6, 0.2},
+ {0, 5.4, 3.4, 1.5, 0.4},
+ {0, 5.2, 4.1, 1.5, 0.1},
+ {0, 5.5, 4.2, 1.4, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 5, 3.2, 1.2, 0.2},
+ {0, 5.5, 3.5, 1.3, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 4.4, 3, 1.3, 0.2},
+ {0, 5.1, 3.4, 1.5, 0.2},
+ {0, 5, 3.5, 1.3, 0.3},
+ {0, 4.5, 2.3, 1.3, 0.3},
+ {0, 4.4, 3.2, 1.3, 0.2},
+ {0, 5, 3.5, 1.6, 0.6},
+ {0, 5.1, 3.8, 1.9, 0.4},
+ {0, 4.8, 3, 1.4, 0.3},
+ {0, 5.1, 3.8, 1.6, 0.2},
+ {0, 4.6, 3.2, 1.4, 0.2},
+ {0, 5.3, 3.7, 1.5, 0.2},
+ {0, 5, 3.3, 1.4, 0.2},
+ {1, 7, 3.2, 4.7, 1.4},
+ {1, 6.4, 3.2, 4.5, 1.5},
+ {1, 6.9, 3.1, 4.9, 1.5},
+ {1, 5.5, 2.3, 4, 1.3},
+ {1, 6.5, 2.8, 4.6, 1.5},
+ {1, 5.7, 2.8, 4.5, 1.3},
+ {1, 6.3, 3.3, 4.7, 1.6},
+ {1, 4.9, 2.4, 3.3, 1},
+ {1, 6.6, 2.9, 4.6, 1.3},
+ {1, 5.2, 2.7, 3.9, 1.4},
+ {1, 5, 2, 3.5, 1},
+ {1, 5.9, 3, 4.2, 1.5},
+ {1, 6, 2.2, 4, 1},
+ {1, 6.1, 2.9, 4.7, 1.4},
+ {1, 5.6, 2.9, 3.6, 1.3},
+ {1, 6.7, 3.1, 4.4, 1.4},
+ {1, 5.6, 3, 4.5, 1.5},
+ {1, 5.8, 2.7, 4.1, 1},
+ {1, 6.2, 2.2, 4.5, 1.5},
+ {1, 5.6, 2.5, 3.9, 1.1},
+ {1, 5.9, 3.2, 4.8, 1.8},
+ {1, 6.1, 2.8, 4, 1.3},
+ {1, 6.3, 2.5, 4.9, 1.5},
+ {1, 6.1, 2.8, 4.7, 1.2},
+ {1, 6.4, 2.9, 4.3, 1.3},
+ {1, 6.6, 3, 4.4, 1.4},
+ {1, 6.8, 2.8, 4.8, 1.4},
+ {1, 6.7, 3, 5, 1.7},
+ {1, 6, 2.9, 4.5, 1.5},
+ {1, 5.7, 2.6, 3.5, 1},
+ {1, 5.5, 2.4, 3.8, 1.1},
+ {1, 5.5, 2.4, 3.7, 1},
+ {1, 5.8, 2.7, 3.9, 1.2},
+ {1, 6, 2.7, 5.1, 1.6},
+ {1, 5.4, 3, 4.5, 1.5},
+ {1, 6, 3.4, 4.5, 1.6},
+ {1, 6.7, 3.1, 4.7, 1.5},
+ {1, 6.3, 2.3, 4.4, 1.3},
+ {1, 5.6, 3, 4.1, 1.3},
+ {1, 5.5, 2.5, 4, 1.3},
+ {1, 5.5, 2.6, 4.4, 1.2},
+ {1, 6.1, 3, 4.6, 1.4},
+ {1, 5.8, 2.6, 4, 1.2},
+ {1, 5, 2.3, 3.3, 1},
+ {1, 5.6, 2.7, 4.2, 1.3},
+ {1, 5.7, 3, 4.2, 1.2},
+ {1, 5.7, 2.9, 4.2, 1.3},
+ {1, 6.2, 2.9, 4.3, 1.3},
+ {1, 5.1, 2.5, 3, 1.1},
+ {1, 5.7, 2.8, 4.1, 1.3},
+ };
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java
new file mode 100644
index 0000000..6ea42e7
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML binary logistic regression examples.
+ */
+package org.apache.ignite.examples.ml.regression.logistic.binary;
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
new file mode 100644
index 0000000..f089923
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression.logistic.multiclass;
+
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
+import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer;
+import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run Logistic Regression multi-class classification trainer over distributed dataset to build two models:
+ * one with normalization and one without normalization.
+ *
+ * @see SVMLinearMultiClassClassificationModel
+ */
+public class LogRegressionMultiClassClassificationExample {
+ /** Run example. */
+ public static void main(String[] args) throws InterruptedException {
+ System.out.println();
+ System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+ LogRegressionMultiClassClassificationExample.class.getSimpleName(), () -> {
+ IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+
+ LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ))
+ .withAmountOfIterations(100000)
+ .withAmountOfLocIterations(10)
+ .withBatchSize(100)
+ .withSeed(123L);
+
+ LogRegressionMultiClassModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0]
+ );
+
+ System.out.println(">>> SVM Multi-class model");
+ System.out.println(mdl.toString());
+
+ NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>();
+
+ IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length)
+ );
+
+ LogRegressionMultiClassModel mdlWithNormalization = trainer.fit(
+ ignite,
+ dataCache,
+ preprocessor,
+ (k, v) -> v[0]
+ );
+
+ System.out.println(">>> Logistic Regression Multi-class model with normalization");
+ System.out.println(mdlWithNormalization.toString());
+
+ System.out.println(">>> ----------------------------------------------------------------");
+ System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|");
+ System.out.println(">>> ----------------------------------------------------------------");
+
+ int amountOfErrors = 0;
+ int amountOfErrorsWithNormalization = 0;
+ int totalAmount = 0;
+
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
+ int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
+
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
+
+ double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
+ double predictionWithNormalization = mdlWithNormalization.apply(new DenseLocalOnHeapVector(inputs));
+
+ totalAmount++;
+
+ // Collect data for model
+ if(groundTruth != prediction)
+ amountOfErrors++;
+
+ int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2);
+ int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
+
+ confusionMtx[idx1][idx2]++;
+
+ // Collect data for model with normalization
+ if(groundTruth != predictionWithNormalization)
+ amountOfErrorsWithNormalization++;
+
+ idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2);
+ idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
+
+ confusionMtxWithNormalization[idx1][idx2]++;
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth);
+ }
+ System.out.println(">>> ----------------------------------------------------------------");
+ System.out.println("\n>>> -----------------Logistic Regression model-------------");
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+
+ System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------");
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount));
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization));
+ }
+ });
+
+ igniteThread.start();
+ igniteThread.join();
+ }
+ }
+
+ /**
+ * Fills cache with data and returns it.
+ *
+ * @param ignite Ignite instance.
+ * @return Filled Ignite Cache.
+ */
+ private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+ CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+ cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+ cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+ for (int i = 0; i < data.length; i++)
+ cache.put(i, data[i]);
+
+ return cache;
+ }
+
+ /** The preprocessed Glass dataset from the Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/Glass+Identification
+ * There are 3 classes with labels: 1 {building_windows_float_processed}, 3 {vehicle_windows_float_processed}, 7 {headlamps}.
+ * Feature names: 'Na-Sodium', 'Mg-Magnesium', 'Al-Aluminum', 'Ba-Barium', 'Fe-Iron'.
+ */
+ private static final double[][] data = {
+ {1, 1.52101, 4.49, 1.10, 0.00, 0.00},
+ {1, 1.51761, 3.60, 1.36, 0.00, 0.00},
+ {1, 1.51618, 3.55, 1.54, 0.00, 0.00},
+ {1, 1.51766, 3.69, 1.29, 0.00, 0.00},
+ {1, 1.51742, 3.62, 1.24, 0.00, 0.00},
+ {1, 1.51596, 3.61, 1.62, 0.00, 0.26},
+ {1, 1.51743, 3.60, 1.14, 0.00, 0.00},
+ {1, 1.51756, 3.61, 1.05, 0.00, 0.00},
+ {1, 1.51918, 3.58, 1.37, 0.00, 0.00},
+ {1, 1.51755, 3.60, 1.36, 0.00, 0.11},
+ {1, 1.51571, 3.46, 1.56, 0.00, 0.24},
+ {1, 1.51763, 3.66, 1.27, 0.00, 0.00},
+ {1, 1.51589, 3.43, 1.40, 0.00, 0.24},
+ {1, 1.51748, 3.56, 1.27, 0.00, 0.17},
+ {1, 1.51763, 3.59, 1.31, 0.00, 0.00},
+ {1, 1.51761, 3.54, 1.23, 0.00, 0.00},
+ {1, 1.51784, 3.67, 1.16, 0.00, 0.00},
+ {1, 1.52196, 3.85, 0.89, 0.00, 0.00},
+ {1, 1.51911, 3.73, 1.18, 0.00, 0.00},
+ {1, 1.51735, 3.54, 1.69, 0.00, 0.07},
+ {1, 1.51750, 3.55, 1.49, 0.00, 0.19},
+ {1, 1.51966, 3.75, 0.29, 0.00, 0.00},
+ {1, 1.51736, 3.62, 1.29, 0.00, 0.00},
+ {1, 1.51751, 3.57, 1.35, 0.00, 0.00},
+ {1, 1.51720, 3.50, 1.15, 0.00, 0.00},
+ {1, 1.51764, 3.54, 1.21, 0.00, 0.00},
+ {1, 1.51793, 3.48, 1.41, 0.00, 0.00},
+ {1, 1.51721, 3.48, 1.33, 0.00, 0.00},
+ {1, 1.51768, 3.52, 1.43, 0.00, 0.00},
+ {1, 1.51784, 3.49, 1.28, 0.00, 0.00},
+ {1, 1.51768, 3.56, 1.30, 0.00, 0.14},
+ {1, 1.51747, 3.50, 1.14, 0.00, 0.00},
+ {1, 1.51775, 3.48, 1.23, 0.09, 0.22},
+ {1, 1.51753, 3.47, 1.38, 0.00, 0.06},
+ {1, 1.51783, 3.54, 1.34, 0.00, 0.00},
+ {1, 1.51567, 3.45, 1.21, 0.00, 0.00},
+ {1, 1.51909, 3.53, 1.32, 0.11, 0.00},
+ {1, 1.51797, 3.48, 1.35, 0.00, 0.00},
+ {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
+ {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
+ {1, 1.51793, 3.50, 1.12, 0.00, 0.00},
+ {1, 1.51755, 3.42, 1.20, 0.00, 0.00},
+ {1, 1.51779, 3.39, 1.33, 0.00, 0.00},
+ {1, 1.52210, 3.84, 0.72, 0.00, 0.00},
+ {1, 1.51786, 3.43, 1.19, 0.00, 0.30},
+ {1, 1.51900, 3.48, 1.35, 0.00, 0.00},
+ {1, 1.51869, 3.37, 1.18, 0.00, 0.16},
+ {1, 1.52667, 3.70, 0.71, 0.00, 0.10},
+ {1, 1.52223, 3.77, 0.79, 0.00, 0.00},
+ {1, 1.51898, 3.35, 1.23, 0.00, 0.00},
+ {1, 1.52320, 3.72, 0.51, 0.00, 0.16},
+ {1, 1.51926, 3.33, 1.28, 0.00, 0.11},
+ {1, 1.51808, 2.87, 1.19, 0.00, 0.00},
+ {1, 1.51837, 2.84, 1.28, 0.00, 0.00},
+ {1, 1.51778, 2.81, 1.29, 0.00, 0.09},
+ {1, 1.51769, 2.71, 1.29, 0.00, 0.24},
+ {1, 1.51215, 3.47, 1.12, 0.00, 0.31},
+ {1, 1.51824, 3.48, 1.29, 0.00, 0.00},
+ {1, 1.51754, 3.74, 1.17, 0.00, 0.00},
+ {1, 1.51754, 3.66, 1.19, 0.00, 0.11},
+ {1, 1.51905, 3.62, 1.11, 0.00, 0.00},
+ {1, 1.51977, 3.58, 1.32, 0.69, 0.00},
+ {1, 1.52172, 3.86, 0.88, 0.00, 0.11},
+ {1, 1.52227, 3.81, 0.78, 0.00, 0.00},
+ {1, 1.52172, 3.74, 0.90, 0.00, 0.07},
+ {1, 1.52099, 3.59, 1.12, 0.00, 0.00},
+ {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
+ {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
+ {1, 1.52152, 3.58, 0.90, 0.00, 0.16},
+ {1, 1.52300, 3.58, 0.82, 0.00, 0.03},
+ {3, 1.51769, 3.66, 1.11, 0.00, 0.00},
+ {3, 1.51610, 3.53, 1.34, 0.00, 0.00},
+ {3, 1.51670, 3.57, 1.38, 0.00, 0.10},
+ {3, 1.51643, 3.52, 1.35, 0.00, 0.00},
+ {3, 1.51665, 3.45, 1.76, 0.00, 0.17},
+ {3, 1.52127, 3.90, 0.83, 0.00, 0.00},
+ {3, 1.51779, 3.65, 0.65, 0.00, 0.00},
+ {3, 1.51610, 3.40, 1.22, 0.00, 0.00},
+ {3, 1.51694, 3.58, 1.31, 0.00, 0.00},
+ {3, 1.51646, 3.40, 1.26, 0.00, 0.00},
+ {3, 1.51655, 3.39, 1.28, 0.00, 0.00},
+ {3, 1.52121, 3.76, 0.58, 0.00, 0.00},
+ {3, 1.51776, 3.41, 1.52, 0.00, 0.00},
+ {3, 1.51796, 3.36, 1.63, 0.00, 0.09},
+ {3, 1.51832, 3.34, 1.54, 0.00, 0.00},
+ {3, 1.51934, 3.54, 0.75, 0.15, 0.24},
+ {3, 1.52211, 3.78, 0.91, 0.00, 0.37},
+ {7, 1.51131, 3.20, 1.81, 1.19, 0.00},
+ {7, 1.51838, 3.26, 2.22, 1.63, 0.00},
+ {7, 1.52315, 3.34, 1.23, 0.00, 0.00},
+ {7, 1.52247, 2.20, 2.06, 0.00, 0.00},
+ {7, 1.52365, 1.83, 1.31, 1.68, 0.00},
+ {7, 1.51613, 1.78, 1.79, 0.76, 0.00},
+ {7, 1.51602, 0.00, 2.38, 0.64, 0.09},
+ {7, 1.51623, 0.00, 2.79, 0.40, 0.09},
+ {7, 1.51719, 0.00, 2.00, 1.59, 0.08},
+ {7, 1.51683, 0.00, 1.98, 1.57, 0.07},
+ {7, 1.51545, 0.00, 2.68, 0.61, 0.05},
+ {7, 1.51556, 0.00, 2.54, 0.81, 0.01},
+ {7, 1.51727, 0.00, 2.34, 0.66, 0.00},
+ {7, 1.51531, 0.00, 2.66, 0.64, 0.00},
+ {7, 1.51609, 0.00, 2.51, 0.53, 0.00},
+ {7, 1.51508, 0.00, 2.25, 0.63, 0.00},
+ {7, 1.51653, 0.00, 1.19, 0.00, 0.00},
+ {7, 1.51514, 0.00, 2.42, 0.56, 0.00},
+ {7, 1.51658, 0.00, 1.99, 1.71, 0.00},
+ {7, 1.51617, 0.00, 2.27, 0.67, 0.00},
+ {7, 1.51732, 0.00, 1.80, 1.55, 0.00},
+ {7, 1.51645, 0.00, 1.87, 1.38, 0.00},
+ {7, 1.51831, 0.00, 1.82, 2.88, 0.00},
+ {7, 1.51640, 0.00, 2.74, 0.54, 0.00},
+ {7, 1.51623, 0.00, 2.88, 1.06, 0.00},
+ {7, 1.51685, 0.00, 1.99, 1.59, 0.00},
+ {7, 1.52065, 0.00, 2.02, 1.64, 0.00},
+ {7, 1.51651, 0.00, 1.94, 1.57, 0.00},
+ {7, 1.51711, 0.00, 2.08, 1.67, 0.00},
+ };
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java
new file mode 100644
index 0000000..c7b7fe8
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML multi-class logistic regression examples.
+ */
+package org.apache.ignite.examples.ml.regression.logistic.multiclass;
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
new file mode 100644
index 0000000..0817432
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.logistic.multiclass;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+
+/** Base class for multi-classification model for set of Logistic Regression classifiers. */
+public class LogRegressionMultiClassModel implements Model<Vector, Double>, Exportable<LogRegressionMultiClassModel>, Serializable {
+ /** */
+ private static final long serialVersionUID = -114986533350117L;
+
+ /** List of models associated with each class. */
+ private Map<Double, LogisticRegressionModel> models;
+
+ /** */
+ public LogRegressionMultiClassModel() {
+ this.models = new HashMap<>();
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector input) {
+ TreeMap<Double, Double> maxMargins = new TreeMap<>();
+
+ models.forEach((k, v) -> maxMargins.put(1.0 / (1.0 + Math.exp(-(input.dot(v.weights()) + v.intercept()))), k));
+
+ return maxMargins.lastEntry().getValue();
+ }
+
+ /** {@inheritDoc} */
+ @Override public <P> void saveModel(Exporter<LogRegressionMultiClassModel, P> exporter, P path) {
+ exporter.save(this, path);
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ LogRegressionMultiClassModel mdl = (LogRegressionMultiClassModel)o;
+
+ return Objects.equals(models, mdl.models);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ return Objects.hash(models);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ StringBuilder wholeStr = new StringBuilder();
+
+ models.forEach((clsLb, mdl) -> {
+ wholeStr.append("The class with label ").append(clsLb).append(" has classifier: ").append(mdl.toString()).append(System.lineSeparator());
+ });
+
+ return wholeStr.toString();
+ }
+
+ /**
+ * Adds a specific Log Regression binary classifier to the bunch of same classifiers.
+ *
+ * @param clsLb The class label for the added model.
+ * @param mdl The model.
+ */
+ public void add(double clsLb, LogisticRegressionModel mdl) {
+ models.put(clsLb, mdl);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
new file mode 100644
index 0000000..e8ed67b
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.logistic.multiclass;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+/**
+ * All common parameters are shared with bunch of binary classification trainers.
+ */
+public class LogRegressionMultiClassTrainer<P extends Serializable>
+ implements SingleLabelDatasetTrainer<LogRegressionMultiClassModel> {
+ /** Update strategy. */
+ private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
+
+ /** Max number of iteration. */
+ private int amountOfIterations;
+
+ /** Batch size. */
+ private int batchSize;
+
+ /** Number of local iterations. */
+ private int amountOfLocIterations;
+
+ /** Seed for random generator. */
+ private long seed;
+
+ /**
+ * Trains model based on the specified data.
+ *
+ * @param datasetBuilder Dataset builder.
+ * @param featureExtractor Feature extractor.
+ * @param lbExtractor Label extractor.
+ * @return Model.
+ */
+ @Override public <K, V> LogRegressionMultiClassModel fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, double[]> featureExtractor,
+ IgniteBiFunction<K, V, Double> lbExtractor) {
+ List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
+
+ LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel();
+
+ classes.forEach(clsLb -> {
+ LogisticRegressionSGDTrainer<?> trainer =
+ new LogisticRegressionSGDTrainer<>(updatesStgy, amountOfIterations, batchSize, amountOfLocIterations, seed);
+
+ IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> {
+ Double lb = lbExtractor.apply(k, v);
+
+ if (lb.equals(clsLb))
+ return 1.0;
+ else
+ return 0.0;
+ };
+ multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
+ });
+
+ return multiClsMdl;
+ }
+
+ /** Iterates among dataset and collects class labels. */
+ private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
+ assert datasetBuilder != null;
+
+ PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
+
+ List<Double> res = new ArrayList<>();
+
+ try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
+ (upstream, upstreamSize) -> new EmptyContext(),
+ partDataBuilder
+ )) {
+ final Set<Double> clsLabels = dataset.compute(data -> {
+ final Set<Double> locClsLabels = new HashSet<>();
+
+ final double[] lbs = data.getY();
+
+ for (double lb : lbs) locClsLabels.add(lb);
+
+ return locClsLabels;
+ }, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()));
+
+ res.addAll(clsLabels);
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return res;
+ }
+
+ /**
+ * Set up the regularization parameter.
+ *
+ * @param batchSize The size of learning batch.
+ * @return Trainer with new batch size parameter value.
+ */
+ public LogRegressionMultiClassTrainer withBatchSize(int batchSize) {
+ this.batchSize = batchSize;
+ return this;
+ }
+
+ /**
+ * Gets the batch size.
+ *
+ * @return The parameter value.
+ */
+ public double batchSize() {
+ return batchSize;
+ }
+
+ /**
+ * Gets the amount of outer iterations of SGD algorithm.
+ *
+ * @return The parameter value.
+ */
+ public int amountOfIterations() {
+ return amountOfIterations;
+ }
+
+ /**
+ * Set up the amount of outer iterations.
+ *
+ * @param amountOfIterations The parameter value.
+ * @return Trainer with new amountOfIterations parameter value.
+ */
+ public LogRegressionMultiClassTrainer withAmountOfIterations(int amountOfIterations) {
+ this.amountOfIterations = amountOfIterations;
+ return this;
+ }
+
+ /**
+ * Gets the amount of local iterations.
+ *
+ * @return The parameter value.
+ */
+ public int amountOfLocIterations() {
+ return amountOfLocIterations;
+ }
+
+ /**
+ * Set up the amount of local iterations of SGD algorithm.
+ *
+ * @param amountOfLocIterations The parameter value.
+ * @return Trainer with new amountOfLocIterations parameter value.
+ */
+ public LogRegressionMultiClassTrainer withAmountOfLocIterations(int amountOfLocIterations) {
+ this.amountOfLocIterations = amountOfLocIterations;
+ return this;
+ }
+
+ /**
+ * Set up the regularization parameter.
+ *
+ * @param seed Seed for random generator.
+ * @return Trainer with new seed parameter value.
+ */
+ public LogRegressionMultiClassTrainer withSeed(long seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ /**
+ * Gets the seed for random generator.
+ *
+ * @return The parameter value.
+ */
+ public long seed() {
+ return seed;
+ }
+
+ /**
+ * Set up the regularization parameter.
+ *
+ * @param updatesStgy Update strategy.
+ * @return Trainer with new update strategy parameter value.
+ */
+ public LogRegressionMultiClassTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) {
+ this.updatesStgy = updatesStgy;
+ return this;
+ }
+
+ /**
+ * Gets the update strategy..
+ *
+ * @return The parameter value.
+ */
+ public UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy() {
+ return updatesStgy;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java
new file mode 100644
index 0000000..2e7b947
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains multi-class logistic regression.
+ */
+package org.apache.ignite.ml.regressions.logistic.multiclass;
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index 2d21d3b..021b567 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.regressions;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainerTest;
+import org.apache.ignite.ml.regressions.logistic.LogRegMultiClassTrainerTest;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModelTest;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainerTest;
import org.junit.runner.RunWith;
@@ -34,8 +35,9 @@ import org.junit.runners.Suite;
LinearRegressionLSQRTrainerTest.class,
LinearRegressionSGDTrainerTest.class,
LogisticRegressionModelTest.class,
- LogisticRegressionSGDTrainerTest.class
+ LogisticRegressionSGDTrainerTest.class,
+ LogRegMultiClassTrainerTest.class
})
public class RegressionsTestSuite {
// No-op.
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java
index aac24f4..7ca9121 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java
@@ -21,6 +21,8 @@ import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
import org.junit.Test;
/**
@@ -53,6 +55,21 @@ public class LinearRegressionModelTest {
}
/** */
+ @Test
+ public void testPredictWithMultiClasses() {
+ Vector weights1 = new DenseLocalOnHeapVector(new double[]{10.0, 0.0});
+ Vector weights2 = new DenseLocalOnHeapVector(new double[]{0.0, 10.0});
+ Vector weights3 = new DenseLocalOnHeapVector(new double[]{-1.0, -1.0});
+ LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel();
+ mdl.add(1, new LogisticRegressionModel(weights1, 0.0).withRawLabels(true));
+ mdl.add(2, new LogisticRegressionModel(weights2, 0.0).withRawLabels(true));
+ mdl.add(2, new LogisticRegressionModel(weights3, 0.0).withRawLabels(true));
+
+ Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
+ TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION);
+ }
+
+ /** */
@Test(expected = CardinalityException.class)
public void testPredictOnAnObservationWithWrongCardinality() {
Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0});
http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
new file mode 100644
index 0000000..d26a4ca
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.logistic;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
+import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer;
+import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link SVMLinearBinaryClassificationTrainer}.
+ */
+public class LogRegMultiClassTrainerTest {
+ /** Fixed size of Dataset. */
+ private static final int AMOUNT_OF_OBSERVATIONS = 1000;
+
+ /** Fixed size of columns in Dataset. */
+ private static final int AMOUNT_OF_FEATURES = 2;
+
+ /** Precision in test checks. */
+ private static final double PRECISION = 1e-2;
+
+ /**
+ * Test trainer on classification model y = x.
+ */
+ @Test
+ public void testTrainWithTheLinearlySeparableCase() {
+ Map<Integer, double[]> data = new HashMap<>();
+
+ ThreadLocalRandom rndX = ThreadLocalRandom.current();
+ ThreadLocalRandom rndY = ThreadLocalRandom.current();
+
+ for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
+ double x = rndX.nextDouble(-1000, 1000);
+ double y = rndY.nextDouble(-1000, 1000);
+ double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+ vec[0] = y - x > 0 ? 1 : -1; // assign label.
+ vec[1] = x;
+ vec[2] = y;
+ data.put(i, vec);
+ }
+
+ final UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> stgy = new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ );
+
+ LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>()
+ .withUpdatesStgy(stgy)
+ .withAmountOfIterations(1000)
+ .withAmountOfLocIterations(10)
+ .withBatchSize(100)
+ .withSeed(123L);
+
+ Assert.assertEquals(trainer.amountOfIterations(), 1000);
+ Assert.assertEquals(trainer.amountOfLocIterations(), 10);
+ Assert.assertEquals(trainer.batchSize(), 100, PRECISION);
+ Assert.assertEquals(trainer.seed(), 123L);
+ Assert.assertEquals(trainer.updatesStgy(), stgy);
+
+ LogRegressionMultiClassModel mdl = trainer.fit(
+ data,
+ 10,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0]
+ );
+
+ TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION);
+ TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION);
+ }
+}