You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2018/04/12 08:16:41 UTC
[1/2] ignite git commit: IGNITE-8176: Integrate gradient descent
linear regression with partition based dataset
Repository: ignite
Updated Branches:
refs/heads/master 67023a88b -> df6356d5d
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
new file mode 100644
index 0000000..fa8fac4
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link LinearRegressionSGDTrainer}.
+ */
+@RunWith(Parameterized.class)
+public class LinearRegressionSGDTrainerTest {
+ /** Parameters. */
+ @Parameterized.Parameters(name = "Data divided on {0} partitions")
+ public static Iterable<Integer[]> data() {
+ return Arrays.asList(
+ new Integer[] {1},
+ new Integer[] {2},
+ new Integer[] {3},
+ new Integer[] {5},
+ new Integer[] {7},
+ new Integer[] {100}
+ );
+ }
+
+ /** Number of partitions. */
+ @Parameterized.Parameter
+ public int parts;
+
+ /**
+ * Tests {@code fit()} method on a simple small dataset.
+ */
+ @Test
+ public void testSmallDataFit() {
+ Map<Integer, double[]> data = new HashMap<>();
+ data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107});
+ data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867});
+ data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728});
+ data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991});
+ data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611});
+ data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197});
+ data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012});
+ data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889});
+ data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
+ data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
+
+ LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new RPropUpdateCalculator(),
+ RPropParameterUpdate::sumLocal,
+ RPropParameterUpdate::avg
+ ), 100000, 10, 100, 123L);
+
+ LinearRegressionModel mdl = trainer.fit(
+ data,
+ parts,
+ (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+ (k, v) -> v[4]
+ );
+
+ assertArrayEquals(
+ new double[] {72.26948107, 15.95144674, 24.07403921, 66.73038781},
+ mdl.getWeights().getStorage().data(),
+ 1e-1
+ );
+
+ assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-1);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java
deleted file mode 100644
index bea164d..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions.linear;
-
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-
-/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
- */
-public class LocalLinearRegressionSGDTrainerTest extends GenericLinearRegressionTrainerTest {
- /** */
- public LocalLinearRegressionSGDTrainerTest() {
- super(
- new LinearRegressionSGDTrainer(100_000, 1e-12),
- DenseLocalOnHeapMatrix::new,
- DenseLocalOnHeapVector::new,
- 1e-2);
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
index 26ba2fb..0befd9b 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
@@ -17,14 +17,14 @@
package org.apache.ignite.ml.svm;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.junit.Test;
/**
* Tests for {@link SVMLinearBinaryClassificationTrainer}.
@@ -62,7 +62,8 @@ public class SVMBinaryTrainerTest {
SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer();
SVMLinearBinaryClassificationModel mdl = trainer.fit(
- new LocalDatasetBuilder<>(data, 10),
+ data,
+ 10,
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
index ad95eb4..31ab4d7 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
@@ -17,14 +17,14 @@
package org.apache.ignite.ml.svm;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.junit.Test;
/**
* Tests for {@link SVMLinearBinaryClassificationTrainer}.
@@ -65,7 +65,8 @@ public class SVMMultiClassTrainerTest {
.withAmountOfIterations(20);
SVMLinearMultiClassClassificationModel mdl = trainer.fit(
- new LocalDatasetBuilder<>(data, 10),
+ data,
+ 10,
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
index 94bca3f..d5b0b86 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
@@ -17,16 +17,16 @@
package org.apache.ignite.ml.tree;
-import java.util.Arrays;
-import java.util.Random;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import java.util.Arrays;
+import java.util.Random;
+
/**
* Tests for {@link DecisionTreeClassificationTrainer} that require to start the whole Ignite infrastructure.
*/
@@ -77,7 +77,8 @@ public class DecisionTreeClassificationTrainerIntegrationTest extends GridCommon
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
DecisionTreeNode tree = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, data),
+ ignite,
+ data,
(k, v) -> Arrays.copyOf(v, v.length - 1),
(k, v) -> v[v.length - 1]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
index 2599bfe..12ef698 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
@@ -17,17 +17,12 @@
package org.apache.ignite.ml.tree;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
+import java.util.*;
+
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertTrue;
@@ -68,7 +63,8 @@ public class DecisionTreeClassificationTrainerTest {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
DecisionTreeNode tree = trainer.fit(
- new LocalDatasetBuilder<>(data, parts),
+ data,
+ parts,
(k, v) -> Arrays.copyOf(v, v.length - 1),
(k, v) -> v[v.length - 1]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
index 754ff20..c2a4638 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
@@ -17,16 +17,16 @@
package org.apache.ignite.ml.tree;
-import java.util.Arrays;
-import java.util.Random;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import java.util.Arrays;
+import java.util.Random;
+
/**
* Tests for {@link DecisionTreeRegressionTrainer} that require to start the whole Ignite infrastructure.
*/
@@ -77,7 +77,8 @@ public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbst
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
DecisionTreeNode tree = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, data),
+ ignite,
+ data,
(k, v) -> Arrays.copyOf(v, v.length - 1),
(k, v) -> v[v.length - 1]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
index 3bdbf60..bcfb53f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
@@ -17,17 +17,12 @@
package org.apache.ignite.ml.tree;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
+import java.util.*;
+
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertTrue;
@@ -68,7 +63,8 @@ public class DecisionTreeRegressionTrainerTest {
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
DecisionTreeNode tree = trainer.fit(
- new LocalDatasetBuilder<>(data, parts),
+ data,
+ parts,
(k, v) -> Arrays.copyOf(v, v.length - 1),
(k, v) -> v[v.length - 1]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
index b259ec9..35f805e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
@@ -17,13 +17,11 @@
package org.apache.ignite.ml.tree.performance;
-import java.io.IOException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -31,6 +29,8 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor;
import org.apache.ignite.ml.util.MnistUtils;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import java.io.IOException;
+
/**
* Tests {@link DecisionTreeClassificationTrainer} on the MNIST dataset that require to start the whole Ignite
* infrastructure. For manual run.
@@ -81,7 +81,8 @@ public class DecisionTreeMNISTIntegrationTest extends GridCommonAbstractTest {
new SimpleStepFunctionCompressor<>());
DecisionTreeNode mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+ ignite,
+ trainingSet,
(k, v) -> v.getPixels(),
(k, v) -> (double) v.getLabel()
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
index 6dbd44c..b40c7ac 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
@@ -17,10 +17,6 @@
package org.apache.ignite.ml.tree.performance;
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -28,6 +24,10 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor;
import org.apache.ignite.ml.util.MnistUtils;
import org.junit.Test;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
import static junit.framework.TestCase.assertTrue;
/**
@@ -50,7 +50,8 @@ public class DecisionTreeMNISTTest {
new SimpleStepFunctionCompressor<>());
DecisionTreeNode mdl = trainer.fit(
- new LocalDatasetBuilder<>(trainingSet, 10),
+ trainingSet,
+ 10,
(k, v) -> v.getPixels(),
(k, v) -> (double) v.getLabel()
);
[2/2] ignite git commit: IGNITE-8176: Integrate gradient descent
linear regression with partition based dataset
Posted by ch...@apache.org.
IGNITE-8176: Integrate gradient descent linear regression with partition based dataset
this closes #3787
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/df6356d5
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/df6356d5
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/df6356d5
Branch: refs/heads/master
Commit: df6356d5d1470337a6ea705a332cf07f1dce2222
Parents: 67023a8
Author: dmitrievanthony <dm...@gmail.com>
Authored: Thu Apr 12 11:16:22 2018 +0300
Committer: YuriBabak <y....@gmail.com>
Committed: Thu Apr 12 11:16:22 2018 +0300
----------------------------------------------------------------------
.../ml/knn/KNNClassificationExample.java | 11 +-
.../examples/ml/nn/MLPTrainerExample.java | 4 +-
.../ml/preprocessing/NormalizationExample.java | 17 +--
...nWithLSQRTrainerAndNormalizationExample.java | 23 ++--
...dLinearRegressionWithLSQRTrainerExample.java | 14 +--
...tedLinearRegressionWithQRTrainerExample.java | 9 +-
...edLinearRegressionWithSGDTrainerExample.java | 78 +++++++++---
.../binary/SVMBinaryClassificationExample.java | 11 +-
.../SVMMultiClassClassificationExample.java | 24 ++--
...ecisionTreeClassificationTrainerExample.java | 7 +-
.../DecisionTreeRegressionTrainerExample.java | 4 +-
.../org/apache/ignite/ml/nn/Activators.java | 20 ++++
.../org/apache/ignite/ml/nn/MLPTrainer.java | 46 ++++++--
.../ml/preprocessing/PreprocessingTrainer.java | 41 ++++++-
.../normalization/NormalizationTrainer.java | 35 ++++--
.../linear/FeatureExtractorWrapper.java | 55 +++++++++
.../linear/LinearRegressionLSQRTrainer.java | 38 +-----
.../linear/LinearRegressionSGDTrainer.java | 118 +++++++++++++------
.../ignite/ml/trainers/DatasetTrainer.java | 46 ++++++++
.../ignite/ml/knn/KNNClassificationTest.java | 20 ++--
.../ignite/ml/nn/MLPTrainerIntegrationTest.java | 14 +--
.../org/apache/ignite/ml/nn/MLPTrainerTest.java | 22 ++--
.../MLPTrainerMnistIntegrationTest.java | 7 +-
.../ml/nn/performance/MLPTrainerMnistTest.java | 11 +-
.../normalization/NormalizationTrainerTest.java | 10 +-
.../ml/regressions/RegressionsTestSuite.java | 15 +--
...stributedLinearRegressionSGDTrainerTest.java | 35 ------
...stributedLinearRegressionSGDTrainerTest.java | 35 ------
...wareAbstractLinearRegressionTrainerTest.java | 3 +
.../linear/LinearRegressionLSQRTrainerTest.java | 14 ++-
.../linear/LinearRegressionSGDTrainerTest.java | 94 +++++++++++++++
.../LocalLinearRegressionSGDTrainerTest.java | 35 ------
.../ignite/ml/svm/SVMBinaryTrainerTest.java | 11 +-
.../ignite/ml/svm/SVMMultiClassTrainerTest.java | 11 +-
...reeClassificationTrainerIntegrationTest.java | 9 +-
.../DecisionTreeClassificationTrainerTest.java | 12 +-
...ionTreeRegressionTrainerIntegrationTest.java | 9 +-
.../tree/DecisionTreeRegressionTrainerTest.java | 12 +-
.../DecisionTreeMNISTIntegrationTest.java | 7 +-
.../tree/performance/DecisionTreeMNISTTest.java | 11 +-
40 files changed, 612 insertions(+), 386 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index f3cdbbe..39a8431 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -17,9 +17,6 @@
package org.apache.ignite.examples.ml.knn;
-import java.util.Arrays;
-import java.util.UUID;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -27,7 +24,6 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
@@ -35,6 +31,10 @@ import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.thread.IgniteThread;
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
/**
* Run kNN multi-class classification trainer over distributed dataset.
*
@@ -56,7 +56,8 @@ public class KNNClassificationExample {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
+ ignite,
+ dataCache,
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
).withK(3)
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
index efa1ba7..ce44cc6 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
@@ -23,7 +23,6 @@ import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ExampleNodeStartup;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.nn.Activators;
@@ -99,7 +98,8 @@ public class MLPTrainerExample {
// Train neural network and get multilayer perceptron model.
MultilayerPerceptron mlp = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+ ignite,
+ trainingSet,
(k, v) -> new double[] {v.x, v.y},
(k, v) -> new double[] {v.lb}
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
index e0bcd08..b2c4e12 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
@@ -17,21 +17,19 @@
package org.apache.ignite.examples.ml.preprocessing;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
-import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.DatasetFactory;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import java.util.Arrays;
+
/**
* Example that shows how to use normalization preprocessor to normalize data.
*
@@ -47,8 +45,6 @@ public class NormalizationExample {
IgniteCache<Integer, Person> persons = createCache(ignite);
- DatasetBuilder<Integer, Person> builder = new CacheBasedDatasetBuilder<>(ignite, persons);
-
// Defines first preprocessor that extracts features from an upstream data.
IgniteBiFunction<Integer, Person, double[]> featureExtractor = (k, v) -> new double[] {
v.getAge(),
@@ -56,14 +52,11 @@ public class NormalizationExample {
};
// Defines second preprocessor that normalizes features.
- NormalizationPreprocessor<Integer, Person> preprocessor = new NormalizationTrainer<Integer, Person>()
- .fit(builder, featureExtractor, 2);
+ IgniteBiFunction<Integer, Person, double[]> preprocessor = new NormalizationTrainer<Integer, Person>()
+ .fit(ignite, persons, featureExtractor);
// Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(
- builder,
- preprocessor
- )) {
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
// Calculation of the mean value. This calculation will be performed in map-reduce manner.
double[] mean = dataset.mean();
System.out.println("Mean \n\t" + Arrays.toString(mean));
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
index 567a599..99e6577 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
@@ -17,9 +17,6 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.util.Arrays;
-import java.util.UUID;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -28,7 +25,7 @@ import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
@@ -36,6 +33,10 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.thread.IgniteThread;
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
/**
* Run linear regression model over distributed matrix.
*
@@ -119,21 +120,17 @@ public class DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample {
NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>();
System.out.println(">>> Perform the training to get the normalization preprocessor.");
- NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
- (k, v) -> Arrays.copyOfRange(v, 1, v.length),
- 4
+ IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length)
);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
- preprocessor,
- (k, v) -> v[0]
- );
+ LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v[0]);
System.out.println(">>> Linear regression model: " + mdl);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java
index a853092..25aec0c 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java
@@ -17,9 +17,6 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.util.Arrays;
-import java.util.UUID;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -27,13 +24,15 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.thread.IgniteThread;
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
/**
* Run linear regression model over distributed matrix.
*
@@ -108,7 +107,7 @@ public class DistributedLinearRegressionWithLSQRTrainerExample {
// Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
// because we create ignite cache internally.
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- SparseDistributedMatrixExample.class.getSimpleName(), () -> {
+ DistributedLinearRegressionWithLSQRTrainerExample.class.getSimpleName(), () -> {
IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
System.out.println(">>> Create new linear regression trainer object.");
@@ -116,7 +115,8 @@ public class DistributedLinearRegressionWithLSQRTrainerExample {
System.out.println(">>> Perform the training to get the model.");
LinearRegressionModel mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
+ ignite,
+ dataCache,
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java
index 2b45aa2..98d5e4e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.Ignition;
import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
@@ -30,6 +29,8 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer;
import org.apache.ignite.thread.IgniteThread;
+import java.util.Arrays;
+
/**
* Run linear regression model over distributed matrix.
*
@@ -113,15 +114,15 @@ public class DistributedLinearRegressionWithQRTrainerExample {
Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer();
System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel model = trainer.train(distributedMatrix);
- System.out.println(">>> Linear regression model: " + model);
+ LinearRegressionModel mdl = trainer.train(distributedMatrix);
+ System.out.println(">>> Linear regression model: " + mdl);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
for (double[] observation : data) {
Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length));
- double prediction = model.apply(inputs);
+ double prediction = mdl.apply(inputs);
double groundTruth = observation[0];
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java
index f3b2655..44366e1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java
@@ -17,20 +17,26 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
-import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
import org.apache.ignite.thread.IgniteThread;
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
/**
* Run linear regression model over distributed matrix.
*
@@ -104,28 +110,43 @@ public class DistributedLinearRegressionWithSGDTrainerExample {
// Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
// because we create ignite cache internally.
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- SparseDistributedMatrixExample.class.getSimpleName(), () -> {
+ DistributedLinearRegressionWithSGDTrainerExample.class.getSimpleName(), () -> {
- // Create SparseDistributedMatrix, new cache will be created automagically.
- System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread.");
- SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data);
+ IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
System.out.println(">>> Create new linear regression trainer object.");
- Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionSGDTrainer(100_000, 1e-12);
+ LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new RPropUpdateCalculator(),
+ RPropParameterUpdate::sumLocal,
+ RPropParameterUpdate::avg
+ ), 100000, 10, 100, 123L);
System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel model = trainer.train(distributedMatrix);
- System.out.println(">>> Linear regression model: " + model);
+ LinearRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0]
+ );
+
+ System.out.println(">>> Linear regression model: " + mdl);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
- for (double[] observation : data) {
- Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length));
- double prediction = model.apply(inputs);
- double groundTruth = observation[0];
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
+
+ double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
}
+
System.out.println(">>> ---------------------------------");
});
@@ -134,4 +155,23 @@ public class DistributedLinearRegressionWithSGDTrainerExample {
igniteThread.join();
}
}
+
+ /**
+ * Fills cache with data and returns it.
+ *
+ * @param ignite Ignite instance.
+ * @return Filled Ignite Cache.
+ */
+ private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+ CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+ cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+ cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+ for (int i = 0; i < data.length; i++)
+ cache.put(i, data[i]);
+
+ return cache;
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
index f8bf521..ce37112 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
@@ -17,9 +17,6 @@
package org.apache.ignite.examples.ml.svm.binary;
-import java.util.Arrays;
-import java.util.UUID;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -27,12 +24,15 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
import org.apache.ignite.thread.IgniteThread;
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
/**
* Run SVM binary-class classification model over distributed dataset.
*
@@ -54,7 +54,8 @@ public class SVMBinaryClassificationExample {
SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer();
SVMLinearBinaryClassificationModel mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
+ ignite,
+ dataCache,
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java
index f8281e4..4054201 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java
@@ -17,9 +17,6 @@
package org.apache.ignite.examples.ml.svm.multiclass;
-import java.util.Arrays;
-import java.util.UUID;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -27,14 +24,17 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer;
import org.apache.ignite.thread.IgniteThread;
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
/**
* Run SVM multi-class classification trainer over distributed dataset to build two models:
* one with normalization and one without normalization.
@@ -57,7 +57,8 @@ public class SVMMultiClassClassificationExample {
SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer();
SVMLinearMultiClassClassificationModel mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
+ ignite,
+ dataCache,
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
);
@@ -67,14 +68,15 @@ public class SVMMultiClassClassificationExample {
NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>();
- NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
- (k, v) -> Arrays.copyOfRange(v, 1, v.length),
- 5
+ IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length)
);
SVMLinearMultiClassClassificationModel mdlWithNormalization = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, dataCache),
+ ignite,
+ dataCache,
preprocessor,
(k, v) -> v[0]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
index cef6368..1ecf460 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
@@ -17,17 +17,17 @@
package org.apache.ignite.examples.ml.tree;
-import java.util.Random;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.ignite.thread.IgniteThread;
+import java.util.Random;
+
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer}.
*/
@@ -65,7 +65,8 @@ public class DecisionTreeClassificationTrainerExample {
// Train decision tree model.
DecisionTreeNode mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+ ignite,
+ trainingSet,
(k, v) -> new double[]{v.x, v.y},
(k, v) -> v.lb
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
index 61ba5f9..19b15f3 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
@@ -22,7 +22,6 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
import org.apache.ignite.thread.IgniteThread;
@@ -61,7 +60,8 @@ public class DecisionTreeRegressionTrainerExample {
// Train decision tree model.
DecisionTreeNode mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+ ignite,
+ trainingSet,
(k, v) -> new double[] {v.x},
(k, v) -> v.y
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java
index f05bde8..4c34cd2 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java
@@ -58,4 +58,24 @@ public class Activators {
return Math.max(val, 0);
}
};
+
+ /**
+ * Linear unit activation function.
+ */
+ public static IgniteDifferentiableDoubleToDoubleFunction LINEAR = new IgniteDifferentiableDoubleToDoubleFunction() {
+ /** {@inheritDoc} */
+ @Override public double differential(double pnt) {
+ return 1.0;
+ }
+
+ /**
+ * Differential of linear at pnt.
+ *
+ * @param pnt Point to differentiate at.
+ * @return Differential at pnt.
+ */
+ @Override public Double apply(double pnt) {
+ return pnt;
+ }
+ };
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
index 47d2022..fe955cb 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
@@ -17,11 +17,6 @@
package org.apache.ignite.ml.nn;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
@@ -37,17 +32,23 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.nn.initializers.RandomInitializer;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
+import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
import org.apache.ignite.ml.util.Utils;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
/**
* Multilayer perceptron trainer based on partition based {@link Dataset}.
*
* @param <P> Type of model update used in this trainer.
*/
public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrainer<MultilayerPerceptron> {
- /** Multilayer perceptron architecture that defines layers and activators. */
- private final MLPArchitecture arch;
+ /** Multilayer perceptron architecture supplier that defines layers and activators. */
+ private final IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier;
/** Loss function to be minimized during the training. */
private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
@@ -81,7 +82,25 @@ public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrai
public MLPTrainer(MLPArchitecture arch, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize,
int locIterations, long seed) {
- this.arch = arch;
+ this(dataset -> arch, loss, updatesStgy, maxIterations, batchSize, locIterations, seed);
+ }
+
+ /**
+ * Constructs a new instance of multilayer perceptron trainer.
+ *
+ * @param archSupplier Multilayer perceptron architecture supplier that defines layers and activators.
+ * @param loss Loss function to be minimized during the training.
+ * @param updatesStgy Update strategy that defines how to update model parameters during the training.
+ * @param maxIterations Maximal number of iterations before the training will be stopped.
+ * @param batchSize Batch size (per every partition).
+ * @param locIterations Maximal number of local iterations before synchronization.
+ * @param seed Random initializer seed.
+ */
+ public MLPTrainer(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier,
+ IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
+ UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize,
+ int locIterations, long seed) {
+ this.archSupplier = archSupplier;
this.loss = loss;
this.updatesStgy = updatesStgy;
this.maxIterations = maxIterations;
@@ -94,13 +113,14 @@ public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrai
public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
- MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
- ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator();
-
try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build(
new EmptyContextBuilder<>(),
new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
)) {
+ MLPArchitecture arch = archSupplier.apply(dataset);
+ MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
+ ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator();
+
for (int i = 0; i < maxIterations; i += locIterations) {
MultilayerPerceptron finalMdl = mdl;
@@ -163,12 +183,12 @@ public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrai
P update = updatesStgy.allUpdatesReducer().apply(totUp);
mdl = updater.update(mdl, update);
}
+
+ return mdl;
}
catch (Exception e) {
throw new RuntimeException(e);
}
-
- return mdl;
}
/**
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
index f5a6bb0..1886ee5 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
@@ -17,9 +17,15 @@
package org.apache.ignite.ml.preprocessing;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import java.util.Map;
+
/**
* Trainer for preprocessor.
*
@@ -34,9 +40,40 @@ public interface PreprocessingTrainer<K, V, T, R> {
*
* @param datasetBuilder Dataset builder.
* @param basePreprocessor Base preprocessor.
- * @param cols Number of columns.
* @return Preprocessor.
*/
public IgniteBiFunction<K, V, R> fit(DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, T> basePreprocessor, int cols);
+ IgniteBiFunction<K, V, T> basePreprocessor);
+
+ /**
+ * Fits preprocessor.
+ *
+ * @param ignite Ignite instance.
+ * @param cache Ignite cache.
+ * @param basePreprocessor Base preprocessor.
+ * @return Preprocessor.
+ */
+ public default IgniteBiFunction<K, V, R> fit(Ignite ignite, IgniteCache<K, V> cache,
+ IgniteBiFunction<K, V, T> basePreprocessor) {
+ return fit(
+ new CacheBasedDatasetBuilder<>(ignite, cache),
+ basePreprocessor
+ );
+ }
+
+ /**
+ * Fits preprocessor.
+ *
+ * @param data Data.
+ * @param parts Number of partitions.
+ * @param basePreprocessor Base preprocessor.
+ * @return Preprocessor.
+ */
+ public default IgniteBiFunction<K, V, R> fit(Map<K, V> data, int parts,
+ IgniteBiFunction<K, V, T> basePreprocessor) {
+ return fit(
+ new LocalDatasetBuilder<>(data, parts),
+ basePreprocessor
+ );
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java
index 16623ba..57acbad 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java
@@ -33,33 +33,48 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
public class NormalizationTrainer<K, V> implements PreprocessingTrainer<K, V, double[], double[]> {
/** {@inheritDoc} */
@Override public NormalizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, double[]> basePreprocessor, int cols) {
+ IgniteBiFunction<K, V, double[]> basePreprocessor) {
try (Dataset<EmptyContext, NormalizationPartitionData> dataset = datasetBuilder.build(
(upstream, upstreamSize) -> new EmptyContext(),
(upstream, upstreamSize, ctx) -> {
- double[] min = new double[cols];
- double[] max = new double[cols];
-
- for (int i = 0; i < cols; i++) {
- min[i] = Double.MAX_VALUE;
- max[i] = -Double.MAX_VALUE;
- }
+ double[] min = null;
+ double[] max = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
double[] row = basePreprocessor.apply(entity.getKey(), entity.getValue());
- for (int i = 0; i < cols; i++) {
+
+ if (min == null) {
+ min = new double[row.length];
+ for (int i = 0; i < min.length; i++)
+ min[i] = Double.MAX_VALUE;
+ }
+ else
+ assert min.length == row.length : "Base preprocessor must return exactly " + min.length
+ + " features";
+
+ if (max == null) {
+ max = new double[row.length];
+ for (int i = 0; i < max.length; i++)
+ max[i] = -Double.MAX_VALUE;
+ }
+ else
+ assert max.length == row.length : "Base preprocessor must return exactly " + min.length
+ + " features";
+
+ for (int i = 0; i < row.length; i++) {
if (row[i] < min[i])
min[i] = row[i];
if (row[i] > max[i])
max[i] = row[i];
}
}
+
return new NormalizationPartitionData(min, max);
}
)) {
double[][] minMax = dataset.compute(
- data -> new double[][]{ data.getMin(), data.getMax() },
+ data -> data.getMin() != null ? new double[][]{ data.getMin(), data.getMax() } : null,
(a, b) -> {
if (a == null)
return b;
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java
new file mode 100644
index 0000000..8e8f467
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+
+import java.util.Arrays;
+
+/**
+ * Feature extractor wrapper that adds additional column filled by 1.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class FeatureExtractorWrapper<K, V> implements IgniteBiFunction<K, V, double[]> {
+ /** */
+ private static final long serialVersionUID = -2686524650955735635L;
+
+ /** Underlying feature extractor. */
+ private final IgniteBiFunction<K, V, double[]> featureExtractor;
+
+ /**
+ * Constructs a new instance of feature extractor wrapper.
+ *
+ * @param featureExtractor Underlying feature extractor.
+ */
+ FeatureExtractorWrapper(IgniteBiFunction<K, V, double[]> featureExtractor) {
+ this.featureExtractor = featureExtractor;
+ }
+
+ /** {@inheritDoc} */
+ @Override public double[] apply(K k, V v) {
+ double[] featureRow = featureExtractor.apply(k, v);
+ double[] row = Arrays.copyOf(featureRow, featureRow.length + 1);
+
+ row[featureRow.length] = 1.0;
+
+ return row;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
index ae15f2f..9526db1 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
@@ -17,8 +17,6 @@
package org.apache.ignite.ml.regressions.linear;
-import java.util.Arrays;
-import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -27,6 +25,9 @@ import org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR;
import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap;
import org.apache.ignite.ml.math.isolve.lsqr.LSQRResult;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+import java.util.Arrays;
/**
* Trainer of the linear regression model based on LSQR algorithm.
@@ -55,37 +56,4 @@ public class LinearRegressionLSQRTrainer implements SingleLabelDatasetTrainer<Li
return new LinearRegressionModel(weights, x[x.length - 1]);
}
-
- /**
- * Feature extractor wrapper that adds additional column filled by 1.
- *
- * @param <K> Type of a key in {@code upstream} data.
- * @param <V> Type of a value in {@code upstream} data.
- */
- private static class FeatureExtractorWrapper<K, V> implements IgniteBiFunction<K, V, double[]> {
- /** */
- private static final long serialVersionUID = -2686524650955735635L;
-
- /** Underlying feature extractor. */
- private final IgniteBiFunction<K, V, double[]> featureExtractor;
-
- /**
- * Constructs a new instance of feature extractor wrapper.
- *
- * @param featureExtractor Underlying feature extractor.
- */
- FeatureExtractorWrapper(IgniteBiFunction<K, V, double[]> featureExtractor) {
- this.featureExtractor = featureExtractor;
- }
-
- /** {@inheritDoc} */
- @Override public double[] apply(K k, V v) {
- double[] featureRow = featureExtractor.apply(k, v);
- double[] row = Arrays.copyOf(featureRow, featureRow.length + 1);
-
- row[featureRow.length] = 1.0;
-
- return row;
- }
- }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
index aad4c7a..9be3fdd 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
@@ -17,51 +17,99 @@
package org.apache.ignite.ml.regressions.linear;
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.optimization.BarzilaiBorweinUpdater;
-import org.apache.ignite.ml.optimization.GradientDescent;
-import org.apache.ignite.ml.optimization.LeastSquaresGradientFunction;
-import org.apache.ignite.ml.optimization.SimpleUpdater;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+
+import java.io.Serializable;
+import java.util.Arrays;
/**
- * Linear regression trainer based on least squares loss function and gradient descent optimization algorithm.
+ * Trainer of the linear regression model based on stochastic gradient descent algorithm.
*/
-public class LinearRegressionSGDTrainer implements Trainer<LinearRegressionModel, Matrix> {
- /**
- * Gradient descent optimizer.
- */
- private final GradientDescent gradientDescent;
+public class LinearRegressionSGDTrainer<P extends Serializable> implements SingleLabelDatasetTrainer<LinearRegressionModel> {
+ /** Update strategy. */
+ private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
- /** */
- public LinearRegressionSGDTrainer(GradientDescent gradientDescent) {
- this.gradientDescent = gradientDescent;
- }
+ /** Max number of iteration. */
+ private final int maxIterations;
- /** */
- public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol) {
- this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new BarzilaiBorweinUpdater())
- .withMaxIterations(maxIterations)
- .withConvergenceTol(convergenceTol);
- }
+ /** Batch size. */
+ private final int batchSize;
- /** */
- public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol, double learningRate) {
- this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new SimpleUpdater(learningRate))
- .withMaxIterations(maxIterations)
- .withConvergenceTol(convergenceTol);
- }
+ /** Number of local iterations. */
+ private final int locIterations;
+
+ /** Seed for random generator. */
+ private final long seed;
/**
- * {@inheritDoc}
+ * Constructs a new instance of linear regression SGD trainer.
+ *
+ * @param updatesStgy Update strategy.
+ * @param maxIterations Max number of iteration.
+ * @param batchSize Batch size.
+ * @param locIterations Number of local iterations.
+ * @param seed Seed for random generator.
*/
- @Override public LinearRegressionModel train(Matrix data) {
- Vector variables = gradientDescent.optimize(data, data.likeVector(data.columnSize()));
- Vector weights = variables.viewPart(1, variables.size() - 1);
+ public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations,
+ int batchSize, int locIterations, long seed) {
+ this.updatesStgy = updatesStgy;
+ this.maxIterations = maxIterations;
+ this.batchSize = batchSize;
+ this.locIterations = locIterations;
+ this.seed = seed;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+ IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
+
+ int cols = dataset.compute(data -> {
+ if (data.getFeatures() == null)
+ return null;
+ return data.getFeatures().length / data.getRows();
+ }, (a, b) -> a == null ? b : a);
+
+ MLPArchitecture architecture = new MLPArchitecture(cols);
+ architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
+
+ return architecture;
+ };
+
+ MLPTrainer<?> trainer = new MLPTrainer<>(
+ archSupplier,
+ LossFunctions.MSE,
+ updatesStgy,
+ maxIterations,
+ batchSize,
+ locIterations,
+ seed
+ );
+
+ IgniteBiFunction<K, V, double[]> lbE = new IgniteBiFunction<K, V, double[]>() {
+ @Override public double[] apply(K k, V v) {
+ return new double[]{lbExtractor.apply(k, v)};
+ }
+ };
+
+ MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE);
- double intercept = variables.get(0);
+ double[] p = mlp.parameters().getStorage().data();
- return new LinearRegressionModel(weights, intercept);
+ return new LinearRegressionModel(new DenseLocalOnHeapVector(Arrays.copyOf(p, p.length - 1)), p[p.length - 1]);
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
index 8119a29..fcde3f5 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
@@ -17,10 +17,16 @@
package org.apache.ignite.ml.trainers;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import java.util.Map;
+
/**
* Interface for trainers. Trainer is just a function which produces model from the data.
*
@@ -40,4 +46,44 @@ public interface DatasetTrainer<M extends Model, L> {
*/
public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor,
IgniteBiFunction<K, V, L> lbExtractor);
+
+ /**
+ * Trains model based on the specified data.
+ *
+ * @param ignite Ignite instance.
+ * @param cache Ignite cache.
+ * @param featureExtractor Feature extractor.
+ * @param lbExtractor Label extractor.
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ * @return Model.
+ */
+ public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache, IgniteBiFunction<K, V, double[]> featureExtractor,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+ return fit(
+ new CacheBasedDatasetBuilder<>(ignite, cache),
+ featureExtractor,
+ lbExtractor
+ );
+ }
+
+ /**
+ * Trains model based on the specified data.
+ *
+ * @param data Data.
+ * @param parts Number of partitions.
+ * @param featureExtractor Feature extractor.
+ * @param lbExtractor Label extractor.
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ * @return Model.
+ */
+ public default <K, V> M fit(Map<K, V> data, int parts, IgniteBiFunction<K, V, double[]> featureExtractor,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+ return fit(
+ new LocalDatasetBuilder<>(data, parts),
+ featureExtractor,
+ lbExtractor
+ );
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index b5a4b54..b27fcba 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -17,11 +17,7 @@
package org.apache.ignite.ml.knn;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
@@ -29,6 +25,10 @@ import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
/** Tests behaviour of KNNClassificationTest. */
public class KNNClassificationTest extends BaseKNNTest {
/** */
@@ -46,7 +46,8 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
- new LocalDatasetBuilder<>(data, 2),
+ data,
+ 2,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(3)
@@ -74,7 +75,8 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
- new LocalDatasetBuilder<>(data, 2),
+ data,
+ 2,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(1)
@@ -102,7 +104,8 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
- new LocalDatasetBuilder<>(data, 2),
+ data,
+ 2,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(3)
@@ -128,7 +131,8 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
- new LocalDatasetBuilder<>(data, 2),
+ data,
+ 2,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(3)
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java
index 5ca661f..038b880 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java
@@ -17,7 +17,6 @@
package org.apache.ignite.ml.nn;
-import java.io.Serializable;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
@@ -25,22 +24,18 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.internal.util.typedef.X;
import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.*;
import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import java.io.Serializable;
+
/**
* Tests for {@link MLPTrainer} that require to start the whole Ignite infrastructure.
*/
@@ -137,7 +132,8 @@ public class MLPTrainerIntegrationTest extends GridCommonAbstractTest {
);
MultilayerPerceptron mlp = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, xorCache),
+ ignite,
+ xorCache,
(k, v) -> new double[]{ v.x, v.y },
(k, v) -> new double[]{ v.lb}
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
index 6906424..c53f6f1 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
@@ -17,24 +17,13 @@
package org.apache.ignite.ml.nn;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.*;
import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
import org.junit.Before;
import org.junit.Test;
@@ -42,6 +31,12 @@ import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
/**
* Tests for {@link MLPTrainer} that don't require to start the whole Ignite infrastructure.
*/
@@ -140,7 +135,8 @@ public class MLPTrainerTest {
);
MultilayerPerceptron mlp = trainer.fit(
- new LocalDatasetBuilder<>(xorData, parts),
+ xorData,
+ parts,
(k, v) -> v[0],
(k, v) -> v[1]
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
index c787a47..a64af9b 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
@@ -17,13 +17,11 @@
package org.apache.ignite.ml.nn.performance;
-import java.io.IOException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
@@ -38,6 +36,8 @@ import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
import org.apache.ignite.ml.util.MnistUtils;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import java.io.IOException;
+
/**
* Tests {@link MLPTrainer} on the MNIST dataset that require to start the whole Ignite infrastructure.
*/
@@ -104,7 +104,8 @@ public class MLPTrainerMnistIntegrationTest extends GridCommonAbstractTest {
System.out.println("Start training...");
long start = System.currentTimeMillis();
MultilayerPerceptron mdl = trainer.fit(
- new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+ ignite,
+ trainingSet,
(k, v) -> v.getPixels(),
(k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
index 354af2c..d966484 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
@@ -17,10 +17,6 @@
package org.apache.ignite.ml.nn.performance;
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
@@ -35,6 +31,10 @@ import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
import org.apache.ignite.ml.util.MnistUtils;
import org.junit.Test;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
import static org.junit.Assert.assertTrue;
/**
@@ -74,7 +74,8 @@ public class MLPTrainerMnistTest {
System.out.println("Start training...");
long start = System.currentTimeMillis();
MultilayerPerceptron mdl = trainer.fit(
- new LocalDatasetBuilder<>(trainingSet, 1),
+ trainingSet,
+ 1,
(k, v) -> v.getPixels(),
(k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java
index 1548253..e7a0d47 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java
@@ -17,15 +17,16 @@
package org.apache.ignite.ml.preprocessing.normalization;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
import static org.junit.Assert.assertArrayEquals;
/**
@@ -66,8 +67,7 @@ public class NormalizationTrainerTest {
NormalizationPreprocessor<Integer, double[]> preprocessor = standardizationTrainer.fit(
datasetBuilder,
- (k, v) -> v,
- 3
+ (k, v) -> v
);
assertArrayEquals(new double[] {0, 4, 1}, preprocessor.getMin(), 1e-8);
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index 82b3a1b..b3c9368 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -17,14 +17,7 @@
package org.apache.ignite.ml.regressions;
-import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionQRTrainerTest;
-import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionSGDTrainerTest;
-import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionQRTrainerTest;
-import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionSGDTrainerTest;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest;
-import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionQRTrainerTest;
-import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionSGDTrainerTest;
+import org.apache.ignite.ml.regressions.linear.*;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -35,12 +28,10 @@ import org.junit.runners.Suite;
@Suite.SuiteClasses({
LinearRegressionModelTest.class,
LocalLinearRegressionQRTrainerTest.class,
- LocalLinearRegressionSGDTrainerTest.class,
DistributedLinearRegressionQRTrainerTest.class,
- DistributedLinearRegressionSGDTrainerTest.class,
BlockDistributedLinearRegressionQRTrainerTest.class,
- BlockDistributedLinearRegressionSGDTrainerTest.class,
- LinearRegressionLSQRTrainerTest.class
+ LinearRegressionLSQRTrainerTest.class,
+ LinearRegressionSGDTrainerTest.class
})
public class RegressionsTestSuite {
// No-op.
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java
deleted file mode 100644
index 58037e2..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions.linear;
-
-import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
-import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
-
-/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseBlockDistributedMatrix}.
- */
-public class BlockDistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest {
- /** */
- public BlockDistributedLinearRegressionSGDTrainerTest() {
- super(
- new LinearRegressionSGDTrainer(100_000, 1e-12),
- SparseBlockDistributedMatrix::new,
- SparseBlockDistributedVector::new,
- 1e-2);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java
deleted file mode 100644
index 71d3b3b..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions.linear;
-
-import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
-import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
-
-/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseDistributedMatrix}.
- */
-public class DistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest {
- /** */
- public DistributedLinearRegressionSGDTrainerTest() {
- super(
- new LinearRegressionSGDTrainer(100_000, 1e-12),
- SparseDistributedMatrix::new,
- SparseDistributedVector::new,
- 1e-2);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java
index 1a60b80..9b75bd4 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java
@@ -26,6 +26,9 @@ import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
import org.junit.Test;
+/**
+ * Grid aware abstract linear regression trainer test.
+ */
public abstract class GridAwareAbstractLinearRegressionTrainerTest extends GridCommonAbstractTest {
/** Number of nodes in grid */
private static final int NODE_COUNT = 3;
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
index e3f60ec..2414236 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
@@ -17,14 +17,14 @@
package org.apache.ignite.ml.regressions.linear;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@@ -72,7 +72,8 @@ public class LinearRegressionLSQRTrainerTest {
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
LinearRegressionModel mdl = trainer.fit(
- new LocalDatasetBuilder<>(data, parts),
+ data,
+ parts,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[4]
);
@@ -110,7 +111,8 @@ public class LinearRegressionLSQRTrainerTest {
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
LinearRegressionModel mdl = trainer.fit(
- new LocalDatasetBuilder<>(data, parts),
+ data,
+ parts,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[coef.length]
);