You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2017/12/08 08:34:53 UTC
[22/30] ignite git commit: IGNITE-6872: Linear regression should
implement Model API
IGNITE-6872: Linear regression should implement Model API
This closes #3168
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/c5c512e4
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/c5c512e4
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/c5c512e4
Branch: refs/heads/ignite-zk
Commit: c5c512e460140c91fb77b527ff909ddbe3d1fd72
Parents: bbeb205
Author: Oleg Ignatenko <oi...@gridgain.com>
Authored: Thu Dec 7 18:14:51 2017 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Thu Dec 7 18:14:51 2017 +0300
----------------------------------------------------------------------
.../decompositions/QRDecompositionExample.java | 82 ++++++
.../DistributedRegressionExample.java | 149 -----------
.../examples/ml/math/trees/MNISTExample.java | 261 -------------------
.../examples/ml/math/trees/package-info.java | 22 --
.../apache/ignite/examples/ml/package-info.java | 22 ++
.../DistributedRegressionExample.java | 149 +++++++++++
.../DistributedRegressionModelExample.java | 134 ++++++++++
.../examples/ml/regression/package-info.java | 22 ++
.../ignite/examples/ml/trees/MNISTExample.java | 261 +++++++++++++++++++
.../ignite/examples/ml/trees/package-info.java | 22 ++
.../ml/math/decompositions/QRDSolver.java | 197 ++++++++++++++
.../ml/math/decompositions/QRDecomposition.java | 54 +---
.../AbstractMultipleLinearRegression.java | 20 ++
.../OLSMultipleLinearRegression.java | 41 +--
.../OLSMultipleLinearRegressionModel.java | 77 ++++++
.../OLSMultipleLinearRegressionModelFormat.java | 46 ++++
.../OLSMultipleLinearRegressionTrainer.java | 62 +++++
.../org/apache/ignite/ml/IgniteMLTestSuite.java | 3 +-
.../org/apache/ignite/ml/LocalModelsTest.java | 99 +++++--
.../ignite/ml/math/MathImplLocalTestSuite.java | 2 +
.../ml/math/decompositions/QRDSolverTest.java | 87 +++++++
...tedBlockOLSMultipleLinearRegressionTest.java | 38 ++-
...tributedOLSMultipleLinearRegressionTest.java | 38 ++-
.../OLSMultipleLinearRegressionModelTest.java | 53 ++++
.../ml/regressions/RegressionsTestSuite.java | 5 +-
25 files changed, 1371 insertions(+), 575 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java
new file mode 100644
index 0000000..bed99d1
--- /dev/null
+++ b/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.math.decompositions;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.decompositions.QRDecomposition;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+
+/**
+ * Example of using {@link QRDecomposition}.
+ */
+public class QRDecompositionExample {
+ /**
+ * Executes example.
+ *
+ * @param args Command line arguments, none required.
+ */
+ public static void main(String[] args) {
+ System.out.println(">>> QR decomposition example started.");
+ Matrix m = new DenseLocalOnHeapMatrix(new double[][] {
+ {2.0d, -1.0d, 0.0d},
+ {-1.0d, 2.0d, -1.0d},
+ {0.0d, -1.0d, 2.0d}
+ });
+
+ System.out.println("\n>>> Input matrix:");
+ Tracer.showAscii(m);
+
+ QRDecomposition dec = new QRDecomposition(m);
+ System.out.println("\n>>> Value for full rank in decomposition: [" + dec.hasFullRank() + "].");
+
+ Matrix q = dec.getQ();
+ Matrix r = dec.getR();
+
+ System.out.println("\n>>> Orthogonal matrix Q:");
+ Tracer.showAscii(q);
+ System.out.println("\n>>> Upper triangular matrix R:");
+ Tracer.showAscii(r);
+
+ Matrix qSafeCp = safeCopy(q);
+
+ Matrix identity = qSafeCp.times(qSafeCp.transpose());
+
+ System.out.println("\n>>> Identity matrix obtained from Q:");
+ Tracer.showAscii(identity);
+
+ Matrix recomposed = qSafeCp.times(r);
+
+ System.out.println("\n>>> Recomposed input matrix:");
+ Tracer.showAscii(recomposed);
+
+ Matrix sol = dec.solve(new DenseLocalOnHeapMatrix(3, 10));
+
+ System.out.println("\n>>> Solved matrix:");
+ Tracer.showAscii(sol);
+
+ dec.destroy();
+
+ System.out.println("\n>>> QR decomposition example completed.");
+ }
+
+ /** */
+ private static Matrix safeCopy(Matrix orig) {
+ return new DenseLocalOnHeapMatrix(orig.rowSize(), orig.columnSize()).assign(orig);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java
deleted file mode 100644
index de2c541..0000000
--- a/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java
+++ /dev/null
@@ -1,149 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.math.regression;
-
-import java.util.Arrays;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
-import org.apache.ignite.ml.math.StorageConstants;
-import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
-import org.apache.ignite.ml.regressions.OLSMultipleLinearRegression;
-import org.apache.ignite.thread.IgniteThread;
-
-/**
- * Run linear regression over distributed matrix.
- *
- * TODO: IGNITE-6222, Currently works only in local mode.
- *
- * @see OLSMultipleLinearRegression
- */
-public class DistributedRegressionExample {
- /** Run example. */
- public static void main(String[] args) throws InterruptedException {
- System.out.println();
- System.out.println(">>> Linear regression over sparse distributed matrix API usage example started.");
- // Start ignite grid.
- try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
- // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
- // because we create ignite cache internally.
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SparseDistributedMatrixExample.class.getSimpleName(), () -> {
-
- double[] data = {
- 8, 78, 284, 9.100000381, 109,
- 9.300000191, 68, 433, 8.699999809, 144,
- 7.5, 70, 739, 7.199999809, 113,
- 8.899999619, 96, 1792, 8.899999619, 97,
- 10.19999981, 74, 477, 8.300000191, 206,
- 8.300000191, 111, 362, 10.89999962, 124,
- 8.800000191, 77, 671, 10, 152,
- 8.800000191, 168, 636, 9.100000381, 162,
- 10.69999981, 82, 329, 8.699999809, 150,
- 11.69999981, 89, 634, 7.599999905, 134,
- 8.5, 149, 631, 10.80000019, 292,
- 8.300000191, 60, 257, 9.5, 108,
- 8.199999809, 96, 284, 8.800000191, 111,
- 7.900000095, 83, 603, 9.5, 182,
- 10.30000019, 130, 686, 8.699999809, 129,
- 7.400000095, 145, 345, 11.19999981, 158,
- 9.600000381, 112, 1357, 9.699999809, 186,
- 9.300000191, 131, 544, 9.600000381, 177,
- 10.60000038, 80, 205, 9.100000381, 127,
- 9.699999809, 130, 1264, 9.199999809, 179,
- 11.60000038, 140, 688, 8.300000191, 80,
- 8.100000381, 154, 354, 8.399999619, 103,
- 9.800000191, 118, 1632, 9.399999619, 101,
- 7.400000095, 94, 348, 9.800000191, 117,
- 9.399999619, 119, 370, 10.39999962, 88,
- 11.19999981, 153, 648, 9.899999619, 78,
- 9.100000381, 116, 366, 9.199999809, 102,
- 10.5, 97, 540, 10.30000019, 95,
- 11.89999962, 176, 680, 8.899999619, 80,
- 8.399999619, 75, 345, 9.600000381, 92,
- 5, 134, 525, 10.30000019, 126,
- 9.800000191, 161, 870, 10.39999962, 108,
- 9.800000191, 111, 669, 9.699999809, 77,
- 10.80000019, 114, 452, 9.600000381, 60,
- 10.10000038, 142, 430, 10.69999981, 71,
- 10.89999962, 238, 822, 10.30000019, 86,
- 9.199999809, 78, 190, 10.69999981, 93,
- 8.300000191, 196, 867, 9.600000381, 106,
- 7.300000191, 125, 969, 10.5, 162,
- 9.399999619, 82, 499, 7.699999809, 95,
- 9.399999619, 125, 925, 10.19999981, 91,
- 9.800000191, 129, 353, 9.899999619, 52,
- 3.599999905, 84, 288, 8.399999619, 110,
- 8.399999619, 183, 718, 10.39999962, 69,
- 10.80000019, 119, 540, 9.199999809, 57,
- 10.10000038, 180, 668, 13, 106,
- 9, 82, 347, 8.800000191, 40,
- 10, 71, 345, 9.199999809, 50,
- 11.30000019, 118, 463, 7.800000191, 35,
- 11.30000019, 121, 728, 8.199999809, 86,
- 12.80000019, 68, 383, 7.400000095, 57,
- 10, 112, 316, 10.39999962, 57,
- 6.699999809, 109, 388, 8.899999619, 94
- };
-
- final int nobs = 53;
- final int nvars = 4;
-
- System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread.");
- // Create SparseDistributedMatrix, new cache will be created automagically.
- SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(0, 0,
- StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
-
- System.out.println(">>> Create new linear regression object");
- OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
- regression.newSampleData(data, nobs, nvars, distributedMatrix);
- System.out.println();
-
- System.out.println(">>> Estimates the regression parameters b:");
- System.out.println(Arrays.toString(regression.estimateRegressionParameters()));
-
- System.out.println(">>> Estimates the residuals, ie u = y - X*b:");
- System.out.println(Arrays.toString(regression.estimateResiduals()));
-
- System.out.println(">>> Standard errors of the regression parameters:");
- System.out.println(Arrays.toString(regression.estimateRegressionParametersStandardErrors()));
-
- System.out.println(">>> Estimates the variance of the regression parameters, ie Var(b):");
- Tracer.showAscii(regression.estimateRegressionParametersVariance());
-
- System.out.println(">>> Estimates the standard error of the regression:");
- System.out.println(regression.estimateRegressionStandardError());
-
- System.out.println(">>> R-Squared statistic:");
- System.out.println(regression.calculateRSquared());
-
- System.out.println(">>> Adjusted R-squared statistic:");
- System.out.println(regression.calculateAdjustedRSquared());
-
- System.out.println(">>> Returns the variance of the regressand, ie Var(y):");
- System.out.println(regression.estimateErrorVariance());
- });
-
- igniteThread.start();
-
- igniteThread.join();
- }
- }
-
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java
deleted file mode 100644
index 6aaadd9..0000000
--- a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java
+++ /dev/null
@@ -1,261 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.math.trees;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.Random;
-import java.util.function.Function;
-import java.util.stream.Stream;
-import org.apache.commons.cli.BasicParser;
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.CommandLineParser;
-import org.apache.commons.cli.Option;
-import org.apache.commons.cli.OptionBuilder;
-import org.apache.commons.cli.Options;
-import org.apache.commons.cli.ParseException;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteDataStreamer;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.examples.ExampleNodeStartup;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.estimators.Estimators;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.trees.models.DecisionTreeModel;
-import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex;
-import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators;
-import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator;
-import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators;
-import org.apache.ignite.ml.util.MnistUtils;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * <p>
- * Example of usage of decision trees algorithm for MNIST dataset
- * (it can be found here: http://yann.lecun.com/exdb/mnist/). </p>
- * <p>
- * Remote nodes should always be started with special configuration file which
- * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
- * <p>
- * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
- * with {@code examples/config/example-ignite.xml} configuration.</p>
- * <p>
- * It is recommended to start at least one node prior to launching this example if you intend
- * to run it with default memory settings.</p>
- * <p>
- * This example should with program arguments, for example
- * -ts_i /path/to/train-images-idx3-ubyte
- * -ts_l /path/to/train-labels-idx1-ubyte
- * -tss_i /path/to/t10k-images-idx3-ubyte
- * -tss_l /path/to/t10k-labels-idx1-ubyte
- * -cfg examples/config/example-ignite.xml.</p>
- * <p>
- * -ts_i specifies path to training set images of MNIST;
- * -ts_l specifies path to training set labels of MNIST;
- * -tss_i specifies path to test set images of MNIST;
- * -tss_l specifies path to test set labels of MNIST;
- * -cfg specifies path to a config path.</p>
- */
-public class MNISTExample {
- /** Name of parameter specifying path to training set images. */
- private static final String MNIST_TRAINING_IMAGES_PATH = "ts_i";
-
- /** Name of parameter specifying path to training set labels. */
- private static final String MNIST_TRAINING_LABELS_PATH = "ts_l";
-
- /** Name of parameter specifying path to test set images. */
- private static final String MNIST_TEST_IMAGES_PATH = "tss_i";
-
- /** Name of parameter specifying path to test set labels. */
- private static final String MNIST_TEST_LABELS_PATH = "tss_l";
-
- /** Name of parameter specifying path of Ignite config. */
- private static final String CONFIG = "cfg";
-
- /** Default config path. */
- private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml";
-
- /**
- * Launches example.
- *
- * @param args Program arguments.
- */
- public static void main(String[] args) {
- String igniteCfgPath;
-
- CommandLineParser parser = new BasicParser();
-
- String trainingImagesPath;
- String trainingLabelsPath;
-
- String testImagesPath;
- String testLabelsPath;
-
- try {
- // Parse the command line arguments.
- CommandLine line = parser.parse(buildOptions(), args);
-
- trainingImagesPath = line.getOptionValue(MNIST_TRAINING_IMAGES_PATH);
- trainingLabelsPath = line.getOptionValue(MNIST_TRAINING_LABELS_PATH);
- testImagesPath = line.getOptionValue(MNIST_TEST_IMAGES_PATH);
- testLabelsPath = line.getOptionValue(MNIST_TEST_LABELS_PATH);
- igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG);
- }
- catch (ParseException e) {
- e.printStackTrace();
- return;
- }
-
- try (Ignite ignite = Ignition.start(igniteCfgPath)) {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- int ptsCnt = 60000;
- int featCnt = 28 * 28;
-
- Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt);
- Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000);
-
- IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
-
- loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite);
-
- ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer =
- new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
-
- System.out.println(">>> Training started");
- long before = System.currentTimeMillis();
- DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
- System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
-
- IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
- Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
- System.out.println(">>> Errs percentage: " + accuracy);
- }
- catch (IOException e) {
- e.printStackTrace();
- }
- }
-
- /**
- * Build cli options.
- */
- @NotNull private static Options buildOptions() {
- Options options = new Options();
-
- Option trsImagesPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_IMAGES_PATH).withLongOpt(MNIST_TRAINING_IMAGES_PATH).hasArg()
- .withDescription("Path to the MNIST training set.")
- .isRequired(true).create();
-
- Option trsLabelsPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_LABELS_PATH).withLongOpt(MNIST_TRAINING_LABELS_PATH).hasArg()
- .withDescription("Path to the MNIST training set.")
- .isRequired(true).create();
-
- Option tssImagesPathOpt = OptionBuilder.withArgName(MNIST_TEST_IMAGES_PATH).withLongOpt(MNIST_TEST_IMAGES_PATH).hasArg()
- .withDescription("Path to the MNIST test set.")
- .isRequired(true).create();
-
- Option tssLabelsPathOpt = OptionBuilder.withArgName(MNIST_TEST_LABELS_PATH).withLongOpt(MNIST_TEST_LABELS_PATH).hasArg()
- .withDescription("Path to the MNIST test set.")
- .isRequired(true).create();
-
- Option configOpt = OptionBuilder.withArgName(CONFIG).withLongOpt(CONFIG).hasArg()
- .withDescription("Path to the config.")
- .isRequired(false).create();
-
- options.addOption(trsImagesPathOpt);
- options.addOption(trsLabelsPathOpt);
- options.addOption(tssImagesPathOpt);
- options.addOption(tssLabelsPathOpt);
- options.addOption(configOpt);
-
- return options;
- }
-
- /**
- * Creates cache where data for training is stored.
- *
- * @param ignite Ignite instance.
- * @return cache where data for training is stored.
- */
- private static IgniteCache<BiIndex, Double> createBiIndexedCache(Ignite ignite) {
- CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>();
-
- // Write to primary.
- cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
-
- // Atomic transactions only.
- cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
- // No eviction.
- cfg.setEvictionPolicy(null);
-
- // No copying of values.
- cfg.setCopyOnRead(false);
-
- // Cache is partitioned.
- cfg.setCacheMode(CacheMode.PARTITIONED);
-
- cfg.setBackups(0);
-
- cfg.setName("TMP_BI_INDEXED_CACHE");
-
- return ignite.getOrCreateCache(cfg);
- }
-
- /**
- * Loads vectors into cache.
- *
- * @param cacheName Name of cache.
- * @param vectorsIterator Iterator over vectors to load.
- * @param vectorSize Size of vector.
- * @param ignite Ignite instance.
- */
- private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIterator,
- int vectorSize, Ignite ignite) {
- try (IgniteDataStreamer<BiIndex, Double> streamer =
- ignite.dataStreamer(cacheName)) {
- int sampleIdx = 0;
-
- streamer.perNodeBufferSize(10000);
-
- while (vectorsIterator.hasNext()) {
- org.apache.ignite.ml.math.Vector next = vectorsIterator.next();
-
- for (int i = 0; i < vectorSize; i++)
- streamer.addData(new BiIndex(sampleIdx, i), next.getX(i));
-
- sampleIdx++;
-
- if (sampleIdx % 1000 == 0)
- System.out.println("Loaded " + sampleIdx + " vectors.");
- }
- }
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java
deleted file mode 100644
index 9b6867b..0000000
--- a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Decision trees examples.
- */
-package org.apache.ignite.examples.ml.math.trees;
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java
new file mode 100644
index 0000000..52778b5
--- /dev/null
+++ b/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Machine learning examples.
+ */
+package org.apache.ignite.examples.ml;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java
new file mode 100644
index 0000000..3e65527
--- /dev/null
+++ b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression;
+
+import java.util.Arrays;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegression;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run linear regression over distributed matrix.
+ *
+ * TODO: IGNITE-6222, Currently works only in local mode.
+ *
+ * @see OLSMultipleLinearRegression
+ */
+public class DistributedRegressionExample {
+ /** Run example. */
+ public static void main(String[] args) throws InterruptedException {
+ System.out.println();
+ System.out.println(">>> Linear regression over sparse distributed matrix API usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+ // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
+ // because we create ignite cache internally.
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SparseDistributedMatrixExample.class.getSimpleName(), () -> {
+
+ double[] data = {
+ 8, 78, 284, 9.100000381, 109,
+ 9.300000191, 68, 433, 8.699999809, 144,
+ 7.5, 70, 739, 7.199999809, 113,
+ 8.899999619, 96, 1792, 8.899999619, 97,
+ 10.19999981, 74, 477, 8.300000191, 206,
+ 8.300000191, 111, 362, 10.89999962, 124,
+ 8.800000191, 77, 671, 10, 152,
+ 8.800000191, 168, 636, 9.100000381, 162,
+ 10.69999981, 82, 329, 8.699999809, 150,
+ 11.69999981, 89, 634, 7.599999905, 134,
+ 8.5, 149, 631, 10.80000019, 292,
+ 8.300000191, 60, 257, 9.5, 108,
+ 8.199999809, 96, 284, 8.800000191, 111,
+ 7.900000095, 83, 603, 9.5, 182,
+ 10.30000019, 130, 686, 8.699999809, 129,
+ 7.400000095, 145, 345, 11.19999981, 158,
+ 9.600000381, 112, 1357, 9.699999809, 186,
+ 9.300000191, 131, 544, 9.600000381, 177,
+ 10.60000038, 80, 205, 9.100000381, 127,
+ 9.699999809, 130, 1264, 9.199999809, 179,
+ 11.60000038, 140, 688, 8.300000191, 80,
+ 8.100000381, 154, 354, 8.399999619, 103,
+ 9.800000191, 118, 1632, 9.399999619, 101,
+ 7.400000095, 94, 348, 9.800000191, 117,
+ 9.399999619, 119, 370, 10.39999962, 88,
+ 11.19999981, 153, 648, 9.899999619, 78,
+ 9.100000381, 116, 366, 9.199999809, 102,
+ 10.5, 97, 540, 10.30000019, 95,
+ 11.89999962, 176, 680, 8.899999619, 80,
+ 8.399999619, 75, 345, 9.600000381, 92,
+ 5, 134, 525, 10.30000019, 126,
+ 9.800000191, 161, 870, 10.39999962, 108,
+ 9.800000191, 111, 669, 9.699999809, 77,
+ 10.80000019, 114, 452, 9.600000381, 60,
+ 10.10000038, 142, 430, 10.69999981, 71,
+ 10.89999962, 238, 822, 10.30000019, 86,
+ 9.199999809, 78, 190, 10.69999981, 93,
+ 8.300000191, 196, 867, 9.600000381, 106,
+ 7.300000191, 125, 969, 10.5, 162,
+ 9.399999619, 82, 499, 7.699999809, 95,
+ 9.399999619, 125, 925, 10.19999981, 91,
+ 9.800000191, 129, 353, 9.899999619, 52,
+ 3.599999905, 84, 288, 8.399999619, 110,
+ 8.399999619, 183, 718, 10.39999962, 69,
+ 10.80000019, 119, 540, 9.199999809, 57,
+ 10.10000038, 180, 668, 13, 106,
+ 9, 82, 347, 8.800000191, 40,
+ 10, 71, 345, 9.199999809, 50,
+ 11.30000019, 118, 463, 7.800000191, 35,
+ 11.30000019, 121, 728, 8.199999809, 86,
+ 12.80000019, 68, 383, 7.400000095, 57,
+ 10, 112, 316, 10.39999962, 57,
+ 6.699999809, 109, 388, 8.899999619, 94
+ };
+
+ final int nobs = 53;
+ final int nvars = 4;
+
+ System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread.");
+ // Create SparseDistributedMatrix, new cache will be created automagically.
+ SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(0, 0,
+ StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
+
+ System.out.println(">>> Create new linear regression object");
+ OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
+ regression.newSampleData(data, nobs, nvars, distributedMatrix);
+ System.out.println();
+
+ System.out.println(">>> Estimates the regression parameters b:");
+ System.out.println(Arrays.toString(regression.estimateRegressionParameters()));
+
+ System.out.println(">>> Estimates the residuals, ie u = y - X*b:");
+ System.out.println(Arrays.toString(regression.estimateResiduals()));
+
+ System.out.println(">>> Standard errors of the regression parameters:");
+ System.out.println(Arrays.toString(regression.estimateRegressionParametersStandardErrors()));
+
+ System.out.println(">>> Estimates the variance of the regression parameters, ie Var(b):");
+ Tracer.showAscii(regression.estimateRegressionParametersVariance());
+
+ System.out.println(">>> Estimates the standard error of the regression:");
+ System.out.println(regression.estimateRegressionStandardError());
+
+ System.out.println(">>> R-Squared statistic:");
+ System.out.println(regression.calculateRSquared());
+
+ System.out.println(">>> Adjusted R-squared statistic:");
+ System.out.println(regression.calculateAdjustedRSquared());
+
+ System.out.println(">>> Returns the variance of the regressand, ie Var(y):");
+ System.out.println(regression.estimateErrorVariance());
+ });
+
+ igniteThread.start();
+
+ igniteThread.join();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java
new file mode 100644
index 0000000..ab1b17d
--- /dev/null
+++ b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModel;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionTrainer;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run linear regression model over distributed matrix.
+ *
+ * @see OLSMultipleLinearRegressionModel
+ */
+public class DistributedRegressionModelExample {
+ /** Run example. */
+ public static void main(String[] args) throws InterruptedException {
+ System.out.println();
+ System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+ // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
+ // because we create ignite cache internally.
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+ SparseDistributedMatrixExample.class.getSimpleName(), () -> {
+ double[] data = {
+ 8, 78, 284, 9.100000381, 109,
+ 9.300000191, 68, 433, 8.699999809, 144,
+ 7.5, 70, 739, 7.199999809, 113,
+ 8.899999619, 96, 1792, 8.899999619, 97,
+ 10.19999981, 74, 477, 8.300000191, 206,
+ 8.300000191, 111, 362, 10.89999962, 124,
+ 8.800000191, 77, 671, 10, 152,
+ 8.800000191, 168, 636, 9.100000381, 162,
+ 10.69999981, 82, 329, 8.699999809, 150,
+ 11.69999981, 89, 634, 7.599999905, 134,
+ 8.5, 149, 631, 10.80000019, 292,
+ 8.300000191, 60, 257, 9.5, 108,
+ 8.199999809, 96, 284, 8.800000191, 111,
+ 7.900000095, 83, 603, 9.5, 182,
+ 10.30000019, 130, 686, 8.699999809, 129,
+ 7.400000095, 145, 345, 11.19999981, 158,
+ 9.600000381, 112, 1357, 9.699999809, 186,
+ 9.300000191, 131, 544, 9.600000381, 177,
+ 10.60000038, 80, 205, 9.100000381, 127,
+ 9.699999809, 130, 1264, 9.199999809, 179,
+ 11.60000038, 140, 688, 8.300000191, 80,
+ 8.100000381, 154, 354, 8.399999619, 103,
+ 9.800000191, 118, 1632, 9.399999619, 101,
+ 7.400000095, 94, 348, 9.800000191, 117,
+ 9.399999619, 119, 370, 10.39999962, 88,
+ 11.19999981, 153, 648, 9.899999619, 78,
+ 9.100000381, 116, 366, 9.199999809, 102,
+ 10.5, 97, 540, 10.30000019, 95,
+ 11.89999962, 176, 680, 8.899999619, 80,
+ 8.399999619, 75, 345, 9.600000381, 92,
+ 5, 134, 525, 10.30000019, 126,
+ 9.800000191, 161, 870, 10.39999962, 108,
+ 9.800000191, 111, 669, 9.699999809, 77,
+ 10.80000019, 114, 452, 9.600000381, 60,
+ 10.10000038, 142, 430, 10.69999981, 71,
+ 10.89999962, 238, 822, 10.30000019, 86,
+ 9.199999809, 78, 190, 10.69999981, 93,
+ 8.300000191, 196, 867, 9.600000381, 106,
+ 7.300000191, 125, 969, 10.5, 162,
+ 9.399999619, 82, 499, 7.699999809, 95,
+ 9.399999619, 125, 925, 10.19999981, 91,
+ 9.800000191, 129, 353, 9.899999619, 52,
+ 3.599999905, 84, 288, 8.399999619, 110,
+ 8.399999619, 183, 718, 10.39999962, 69,
+ 10.80000019, 119, 540, 9.199999809, 57,
+ 10.10000038, 180, 668, 13, 106,
+ 9, 82, 347, 8.800000191, 40,
+ 10, 71, 345, 9.199999809, 50,
+ 11.30000019, 118, 463, 7.800000191, 35,
+ 11.30000019, 121, 728, 8.199999809, 86,
+ 12.80000019, 68, 383, 7.400000095, 57,
+ 10, 112, 316, 10.39999962, 57,
+ 6.699999809, 109, 388, 8.899999619, 94
+ };
+
+ final int nobs = 53;
+ final int nvars = 4;
+
+ System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread.");
+ // Create SparseDistributedMatrix, new cache will be created automagically.
+ SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(0, 0,
+ StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
+
+ System.out.println(">>> Create new linear regression trainer object.");
+ OLSMultipleLinearRegressionTrainer trainer
+ = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, distributedMatrix);
+ System.out.println(">>> Perform the training to get the model.");
+ OLSMultipleLinearRegressionModel mdl = trainer.train(data);
+ System.out.println();
+
+ Vector val = new SparseDistributedVector(nobs).assign((i) -> data[i * (nvars + 1)]);
+
+ System.out.println(">>> The input data:");
+ Tracer.showAscii(val);
+
+ System.out.println(">>> Trained model prediction results:");
+ Tracer.showAscii(mdl.predict(val));
+ });
+
+ igniteThread.start();
+
+ igniteThread.join();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java
new file mode 100644
index 0000000..c89c80c
--- /dev/null
+++ b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML regression examples.
+ */
+package org.apache.ignite.examples.ml.regression;
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java
new file mode 100644
index 0000000..6ff121e
--- /dev/null
+++ b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.trees;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Random;
+import java.util.function.Function;
+import java.util.stream.Stream;
+import org.apache.commons.cli.BasicParser;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.CommandLineParser;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.OptionBuilder;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteDataStreamer;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.CacheAtomicityMode;
+import org.apache.ignite.cache.CacheMode;
+import org.apache.ignite.cache.CacheWriteSynchronizationMode;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.examples.ExampleNodeStartup;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.estimators.Estimators;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.trees.models.DecisionTreeModel;
+import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex;
+import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput;
+import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
+import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators;
+import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator;
+import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators;
+import org.apache.ignite.ml.util.MnistUtils;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * <p>
+ * Example of usage of decision trees algorithm for MNIST dataset
+ * (it can be found here: http://yann.lecun.com/exdb/mnist/). </p>
+ * <p>
+ * Remote nodes should always be started with special configuration file which
+ * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
+ * <p>
+ * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
+ * with {@code examples/config/example-ignite.xml} configuration.</p>
+ * <p>
+ * It is recommended to start at least one node prior to launching this example if you intend
+ * to run it with default memory settings.</p>
+ * <p>
+ * This example should with program arguments, for example
+ * -ts_i /path/to/train-images-idx3-ubyte
+ * -ts_l /path/to/train-labels-idx1-ubyte
+ * -tss_i /path/to/t10k-images-idx3-ubyte
+ * -tss_l /path/to/t10k-labels-idx1-ubyte
+ * -cfg examples/config/example-ignite.xml.</p>
+ * <p>
+ * -ts_i specifies path to training set images of MNIST;
+ * -ts_l specifies path to training set labels of MNIST;
+ * -tss_i specifies path to test set images of MNIST;
+ * -tss_l specifies path to test set labels of MNIST;
+ * -cfg specifies path to a config path.</p>
+ */
+public class MNISTExample {
+ /** Name of parameter specifying path to training set images. */
+ private static final String MNIST_TRAINING_IMAGES_PATH = "ts_i";
+
+ /** Name of parameter specifying path to training set labels. */
+ private static final String MNIST_TRAINING_LABELS_PATH = "ts_l";
+
+ /** Name of parameter specifying path to test set images. */
+ private static final String MNIST_TEST_IMAGES_PATH = "tss_i";
+
+ /** Name of parameter specifying path to test set labels. */
+ private static final String MNIST_TEST_LABELS_PATH = "tss_l";
+
+ /** Name of parameter specifying path of Ignite config. */
+ private static final String CONFIG = "cfg";
+
+ /** Default config path. */
+ private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml";
+
+ /**
+ * Launches example.
+ *
+ * @param args Program arguments.
+ */
+ public static void main(String[] args) {
+ String igniteCfgPath;
+
+ CommandLineParser parser = new BasicParser();
+
+ String trainingImagesPath;
+ String trainingLabelsPath;
+
+ String testImagesPath;
+ String testLabelsPath;
+
+ try {
+ // Parse the command line arguments.
+ CommandLine line = parser.parse(buildOptions(), args);
+
+ trainingImagesPath = line.getOptionValue(MNIST_TRAINING_IMAGES_PATH);
+ trainingLabelsPath = line.getOptionValue(MNIST_TRAINING_LABELS_PATH);
+ testImagesPath = line.getOptionValue(MNIST_TEST_IMAGES_PATH);
+ testLabelsPath = line.getOptionValue(MNIST_TEST_LABELS_PATH);
+ igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG);
+ }
+ catch (ParseException e) {
+ e.printStackTrace();
+ return;
+ }
+
+ try (Ignite ignite = Ignition.start(igniteCfgPath)) {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ int ptsCnt = 60000;
+ int featCnt = 28 * 28;
+
+ Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt);
+ Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000);
+
+ IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
+
+ loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite);
+
+ ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer =
+ new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
+
+ System.out.println(">>> Training started");
+ long before = System.currentTimeMillis();
+ DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
+ System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
+
+ IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
+ Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
+ System.out.println(">>> Errs percentage: " + accuracy);
+ }
+ catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Build cli options.
+ */
+ @NotNull private static Options buildOptions() {
+ Options options = new Options();
+
+ Option trsImagesPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_IMAGES_PATH).withLongOpt(MNIST_TRAINING_IMAGES_PATH).hasArg()
+ .withDescription("Path to the MNIST training set.")
+ .isRequired(true).create();
+
+ Option trsLabelsPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_LABELS_PATH).withLongOpt(MNIST_TRAINING_LABELS_PATH).hasArg()
+ .withDescription("Path to the MNIST training set.")
+ .isRequired(true).create();
+
+ Option tssImagesPathOpt = OptionBuilder.withArgName(MNIST_TEST_IMAGES_PATH).withLongOpt(MNIST_TEST_IMAGES_PATH).hasArg()
+ .withDescription("Path to the MNIST test set.")
+ .isRequired(true).create();
+
+ Option tssLabelsPathOpt = OptionBuilder.withArgName(MNIST_TEST_LABELS_PATH).withLongOpt(MNIST_TEST_LABELS_PATH).hasArg()
+ .withDescription("Path to the MNIST test set.")
+ .isRequired(true).create();
+
+ Option configOpt = OptionBuilder.withArgName(CONFIG).withLongOpt(CONFIG).hasArg()
+ .withDescription("Path to the config.")
+ .isRequired(false).create();
+
+ options.addOption(trsImagesPathOpt);
+ options.addOption(trsLabelsPathOpt);
+ options.addOption(tssImagesPathOpt);
+ options.addOption(tssLabelsPathOpt);
+ options.addOption(configOpt);
+
+ return options;
+ }
+
+ /**
+ * Creates cache where data for training is stored.
+ *
+ * @param ignite Ignite instance.
+ * @return cache where data for training is stored.
+ */
+ private static IgniteCache<BiIndex, Double> createBiIndexedCache(Ignite ignite) {
+ CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>();
+
+ // Write to primary.
+ cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
+
+ // Atomic transactions only.
+ cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
+
+ // No eviction.
+ cfg.setEvictionPolicy(null);
+
+ // No copying of values.
+ cfg.setCopyOnRead(false);
+
+ // Cache is partitioned.
+ cfg.setCacheMode(CacheMode.PARTITIONED);
+
+ cfg.setBackups(0);
+
+ cfg.setName("TMP_BI_INDEXED_CACHE");
+
+ return ignite.getOrCreateCache(cfg);
+ }
+
+ /**
+ * Loads vectors into cache.
+ *
+ * @param cacheName Name of cache.
+ * @param vectorsIterator Iterator over vectors to load.
+ * @param vectorSize Size of vector.
+ * @param ignite Ignite instance.
+ */
+ private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIterator,
+ int vectorSize, Ignite ignite) {
+ try (IgniteDataStreamer<BiIndex, Double> streamer =
+ ignite.dataStreamer(cacheName)) {
+ int sampleIdx = 0;
+
+ streamer.perNodeBufferSize(10000);
+
+ while (vectorsIterator.hasNext()) {
+ org.apache.ignite.ml.math.Vector next = vectorsIterator.next();
+
+ for (int i = 0; i < vectorSize; i++)
+ streamer.addData(new BiIndex(sampleIdx, i), next.getX(i));
+
+ sampleIdx++;
+
+ if (sampleIdx % 1000 == 0)
+ System.out.println("Loaded " + sampleIdx + " vectors.");
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java
new file mode 100644
index 0000000..d944f60
--- /dev/null
+++ b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Decision trees examples.
+ */
+package org.apache.ignite.examples.ml.trees;
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java
new file mode 100644
index 0000000..bb591ee
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.decompositions;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.exceptions.NoDataException;
+import org.apache.ignite.ml.math.exceptions.NullArgumentException;
+import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
+import org.apache.ignite.ml.math.functions.Functions;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+
+import static org.apache.ignite.ml.math.util.MatrixUtil.like;
+
+/**
+ * For an {@code m x n} matrix {@code A} with {@code m >= n}, the QR decomposition
+ * is an {@code m x n} orthogonal matrix {@code Q} and an {@code n x n} upper
+ * triangular matrix {@code R} so that {@code A = Q*R}.
+ */
+public class QRDSolver implements Serializable {
+ /** */
+ private final Matrix q;
+
+ /** */
+ private final Matrix r;
+
+ /**
+ * Constructs a new QR decomposition solver object.
+ *
+ * @param q An orthogonal matrix.
+ * @param r An upper triangular matrix
+ */
+ public QRDSolver(Matrix q, Matrix r) {
+ this.q = q;
+ this.r = r;
+ }
+
+ /**
+ * Least squares solution of {@code A*X = B}; {@code returns X}.
+ *
+ * @param mtx A matrix with as many rows as {@code A} and any number of cols.
+ * @return {@code X<} that minimizes the two norm of {@code Q*R*X - B}.
+ * @throws IllegalArgumentException if {@code B.rows() != A.rows()}.
+ */
+ public Matrix solve(Matrix mtx) {
+ if (mtx.rowSize() != q.rowSize())
+ throw new IllegalArgumentException("Matrix row dimensions must agree.");
+
+ int cols = mtx.columnSize();
+ Matrix x = like(r, r.columnSize(), cols);
+
+ Matrix qt = q.transpose();
+ Matrix y = qt.times(mtx);
+
+ for (int k = Math.min(r.columnSize(), q.rowSize()) - 1; k >= 0; k--) {
+ // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as =
+ x.viewRow(k).map(y.viewRow(k), Functions.plusMult(1 / r.get(k, k)));
+
+ if (k == 0)
+ continue;
+
+ // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,]
+ Vector rCol = r.viewColumn(k).viewPart(0, k);
+
+ for (int c = 0; c < cols; c++)
+ y.viewColumn(c).viewPart(0, k).map(rCol, Functions.plusMult(-x.get(k, c)));
+ }
+
+ return x;
+ }
+
+ /**
+ * Least squares solution of {@code A*X = B}; {@code returns X}.
+ *
+ * @param vec A vector with as many rows as {@code A}.
+ * @return {@code X<} that minimizes the two norm of {@code Q*R*X - B}.
+ * @throws IllegalArgumentException if {@code B.rows() != A.rows()}.
+ */
+ public Vector solve(Vector vec) {
+ if (vec == null)
+ throw new NullArgumentException();
+ if (vec.size() == 0)
+ throw new NoDataException();
+ // TODO: IGNITE-5826, Should we copy here?
+
+ Matrix res = solve(vec.likeMatrix(vec.size(), 1).assignColumn(0, vec));
+
+ return vec.like(res.rowSize()).assign(res.viewColumn(0));
+ }
+
+ /**
+ * <p>Compute the "hat" matrix.
+ * </p>
+ * <p>The hat matrix is defined in terms of the design matrix X
+ * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
+ * </p>
+ * <p>The implementation here uses the QR decomposition to compute the
+ * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
+ * p-dimensional identity matrix augmented by 0's. This computational
+ * formula is from "The Hat Matrix in Regression and ANOVA",
+ * David C. Hoaglin and Roy E. Welsch,
+ * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
+ * </p>
+ * <p>Data for the model must have been successfully loaded using one of
+ * the {@code newSampleData} methods before invoking this method; otherwise
+ * a {@code NullPointerException} will be thrown.</p>
+ *
+ * @return the hat matrix
+ * @throws NullPointerException unless method {@code newSampleData} has been called beforehand.
+ */
+ public Matrix calculateHat() {
+ // Create augmented identity matrix
+ // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
+ Matrix augI = MatrixUtil.like(q, q.columnSize(), q.columnSize());
+
+ int n = augI.columnSize();
+ int p = r.columnSize();
+
+ for (int i = 0; i < n; i++)
+ for (int j = 0; j < n; j++)
+ if (i == j && i < p)
+ augI.setX(i, j, 1d);
+ else
+ augI.setX(i, j, 0d);
+
+ // Compute and return Hat matrix
+ // No DME advertised - args valid if we get here
+ return q.times(augI).times(q.transpose());
+ }
+
+ /**
+ * <p>Calculates the variance-covariance matrix of the regression parameters.
+ * </p>
+ * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
+ * </p>
+ * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
+ * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
+ * R included, where p = the length of the beta vector.</p>
+ *
+ * <p>Data for the model must have been successfully loaded using one of
+ * the {@code newSampleData} methods before invoking this method; otherwise
+ * a {@code NullPointerException} will be thrown.</p>
+ *
+ * @param p Size of the beta variance-covariance matrix
+ * @return The beta variance-covariance matrix
+ * @throws SingularMatrixException if the design matrix is singular
+ * @throws NullPointerException if the data for the model have not been loaded
+ */
+ public Matrix calculateBetaVariance(int p) {
+ Matrix rAug = MatrixUtil.copy(r.viewPart(0, p, 0, p));
+ Matrix rInv = rAug.inverse();
+
+ return rInv.times(rInv.transpose());
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ QRDSolver solver = (QRDSolver)o;
+
+ return q.equals(solver.q) && r.equals(solver.r);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = q.hashCode();
+ res = 31 * res + r.hashCode();
+ return res;
+ }
+
+ /**
+ * Returns a rough string rendition of a QRD solver.
+ */
+ @Override public String toString() {
+ return String.format("QRD Solver(%d x %d)", q.rowSize(), r.columnSize());
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java
index 3d0bb5d..c069683 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java
@@ -46,8 +46,6 @@ public class QRDecomposition implements Destroyable {
private final int rows;
/** */
private final int cols;
- /** */
- private double threshold;
/**
* @param v Value to be checked for being an ordinary double.
@@ -89,7 +87,6 @@ public class QRDecomposition implements Destroyable {
boolean fullRank = true;
r = like(mtx, min, cols);
- this.threshold = threshold;
for (int i = 0; i < min; i++) {
Vector qi = qTmp.viewColumn(i);
@@ -129,6 +126,8 @@ public class QRDecomposition implements Destroyable {
else
q = qTmp;
+ verifyNonSingularR(threshold);
+
this.fullRank = fullRank;
}
@@ -170,32 +169,7 @@ public class QRDecomposition implements Destroyable {
* @throws IllegalArgumentException if {@code B.rows() != A.rows()}.
*/
public Matrix solve(Matrix mtx) {
- if (mtx.rowSize() != rows)
- throw new IllegalArgumentException("Matrix row dimensions must agree.");
-
- int cols = mtx.columnSize();
- Matrix r = getR();
- checkSingular(r, threshold, true);
- Matrix x = like(mType, this.cols, cols);
-
- Matrix qt = getQ().transpose();
- Matrix y = qt.times(mtx);
-
- for (int k = Math.min(this.cols, rows) - 1; k >= 0; k--) {
- // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as =
- x.viewRow(k).map(y.viewRow(k), Functions.plusMult(1 / r.get(k, k)));
-
- if (k == 0)
- continue;
-
- // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,]
- Vector rCol = r.viewColumn(k).viewPart(0, k);
-
- for (int c = 0; c < cols; c++)
- y.viewColumn(c).viewPart(0, k).map(rCol, Functions.plusMult(-x.get(k, c)));
- }
-
- return x;
+ return new QRDSolver(q, r).solve(mtx);
}
/**
@@ -206,8 +180,7 @@ public class QRDecomposition implements Destroyable {
* @throws IllegalArgumentException if {@code B.rows() != A.rows()}.
*/
public Vector solve(Vector vec) {
- Matrix res = solve(vec.likeMatrix(vec.size(), 1).assignColumn(0, vec));
- return vec.like(res.rowSize()).assign(res.viewColumn(0));
+ return new QRDSolver(q, r).solve(vec);
}
/**
@@ -220,27 +193,20 @@ public class QRDecomposition implements Destroyable {
/**
* Check singularity.
*
- * @param r R matrix.
* @param min Singularity threshold.
- * @param raise Whether to raise a {@link SingularMatrixException} if any element of the diagonal fails the check.
- * @return {@code true} if any element of the diagonal is smaller or equal to {@code min}.
* @throws SingularMatrixException if the matrix is singular and {@code raise} is {@code true}.
*/
- private static boolean checkSingular(Matrix r, double min, boolean raise) {
- // TODO: IGNITE-5828, Not a very fast approach for distributed matrices. would be nice if we could independently check
- // parts on different nodes for singularity and do fold with 'or'.
+ private void verifyNonSingularR(double min) {
+ // TODO: IGNITE-5828, Not a very fast approach for distributed matrices. would be nice if we could independently
+ // check parts on different nodes for singularity and do fold with 'or'.
- final int len = r.columnSize();
+ final int len = r.columnSize() > r.rowSize() ? r.rowSize() : r.columnSize();
for (int i = 0; i < len; i++) {
final double d = r.getX(i, i);
if (Math.abs(d) <= min)
- if (raise)
- throw new SingularMatrixException("Number is too small (%f, while " +
- "threshold is %f). Index of diagonal element is (%d, %d)", d, min, i, i);
- else
- return true;
+ throw new SingularMatrixException("Number is too small (%f, while " +
+ "threshold is %f). Index of diagonal element is (%d, %d)", d, min, i, i);
}
- return false;
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java
index a2a8f16..5bc92c9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java
@@ -355,4 +355,24 @@ public abstract class AbstractMultipleLinearRegression implements MultipleLinear
return yVector.minus(xMatrix.times(b));
}
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ AbstractMultipleLinearRegression that = (AbstractMultipleLinearRegression)o;
+
+ return noIntercept == that.noIntercept && xMatrix.equals(that.xMatrix);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = xMatrix.hashCode();
+
+ res = 31 * res + (noIntercept ? 1 : 0);
+
+ return res;
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java
index 36d5f2c..aafeae8 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java
@@ -18,11 +18,11 @@ package org.apache.ignite.ml.regressions;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.decompositions.QRDSolver;
import org.apache.ignite.ml.math.decompositions.QRDecomposition;
import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
import org.apache.ignite.ml.math.functions.Functions;
-import org.apache.ignite.ml.math.util.MatrixUtil;
/**
* This class is based on the corresponding class from Apache Common Math lib.
@@ -51,7 +51,7 @@ import org.apache.ignite.ml.math.util.MatrixUtil;
*/
public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
/** Cached QR decomposition of X matrix */
- private QRDecomposition qr = null;
+ private QRDSolver solver = null;
/** Singularity threshold for QR decomposition */
private final double threshold;
@@ -94,7 +94,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/
@Override public void newSampleData(double[] data, int nobs, int nvars, Matrix like) {
super.newSampleData(data, nobs, nvars, like);
- qr = new QRDecomposition(getX(), threshold);
+ QRDecomposition qr = new QRDecomposition(getX(), threshold);
+ solver = new QRDSolver(qr.getQ(), qr.getR());
}
/**
@@ -118,24 +119,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @throws NullPointerException unless method {@code newSampleData} has been called beforehand.
*/
public Matrix calculateHat() {
- // Create augmented identity matrix
- // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
- Matrix q = qr.getQ();
- Matrix augI = MatrixUtil.like(q, q.columnSize(), q.columnSize());
-
- int n = augI.columnSize();
- int p = qr.getR().columnSize();
-
- for (int i = 0; i < n; i++)
- for (int j = 0; j < n; j++)
- if (i == j && i < p)
- augI.setX(i, j, 1d);
- else
- augI.setX(i, j, 0d);
-
- // Compute and return Hat matrix
- // No DME advertised - args valid if we get here
- return q.times(augI).times(q.transpose());
+ return solver.calculateHat();
}
/**
@@ -226,7 +210,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/
@Override protected void newXSampleData(Matrix x) {
super.newXSampleData(x);
- qr = new QRDecomposition(getX());
+ QRDecomposition qr = new QRDecomposition(getX());
+ solver = new QRDSolver(qr.getQ(), qr.getR());
}
/**
@@ -241,7 +226,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @throws NullPointerException if the data for the model have not been loaded
*/
@Override protected Vector calculateBeta() {
- return qr.solve(getY());
+ return solver.solve(getY());
}
/**
@@ -262,11 +247,11 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @throws NullPointerException if the data for the model have not been loaded
*/
@Override protected Matrix calculateBetaVariance() {
- int p = getX().columnSize();
-
- Matrix rAug = MatrixUtil.copy(qr.getR().viewPart(0, p, 0, p));
- Matrix rInv = rAug.inverse();
+ return solver.calculateBetaVariance(getX().columnSize());
+ }
- return rInv.times(rInv.transpose());
+ /** */
+ QRDSolver solver() {
+ return solver;
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
new file mode 100644
index 0000000..76a90fc
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions;
+
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.decompositions.QRDSolver;
+import org.apache.ignite.ml.math.decompositions.QRDecomposition;
+
+/**
+ * Model for linear regression.
+ */
+public class OLSMultipleLinearRegressionModel implements Model<Vector, Vector>,
+ Exportable<OLSMultipleLinearRegressionModelFormat> {
+ /** */
+ private final Matrix xMatrix;
+ /** */
+ private final QRDSolver solver;
+
+ /**
+ * Construct linear regression model.
+ *
+ * @param xMatrix See {@link QRDecomposition#QRDecomposition(Matrix)}.
+ * @param solver Linear regression solver object.
+ */
+ public OLSMultipleLinearRegressionModel(Matrix xMatrix, QRDSolver solver) {
+ this.xMatrix = xMatrix;
+ this.solver = solver;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Vector predict(Vector val) {
+ return xMatrix.times(solver.solve(val));
+ }
+
+ /** {@inheritDoc} */
+ @Override public <P> void saveModel(Exporter<OLSMultipleLinearRegressionModelFormat, P> exporter, P path) {
+ exporter.save(new OLSMultipleLinearRegressionModelFormat(xMatrix, solver), path);
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ OLSMultipleLinearRegressionModel mdl = (OLSMultipleLinearRegressionModel)o;
+
+ return xMatrix.equals(mdl.xMatrix) && solver.equals(mdl.solver);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = xMatrix.hashCode();
+ res = 31 * res + solver.hashCode();
+ return res;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java
new file mode 100644
index 0000000..fc44968
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.decompositions.QRDSolver;
+
+/**
+ * Linear regression model representation.
+ *
+ * @see OLSMultipleLinearRegressionModel
+ */
+public class OLSMultipleLinearRegressionModelFormat implements Serializable {
+ /** X sample data. */
+ private final Matrix xMatrix;
+
+ /** Whether or not the regression model includes an intercept. True means no intercept. */
+ private final QRDSolver solver;
+
+ /** */
+ public OLSMultipleLinearRegressionModelFormat(Matrix xMatrix, QRDSolver solver) {
+ this.xMatrix = xMatrix;
+ this.solver = solver;
+ }
+
+ /** */
+ public OLSMultipleLinearRegressionModel getOLSMultipleLinearRegressionModel() {
+ return new OLSMultipleLinearRegressionModel(xMatrix, solver);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java
new file mode 100644
index 0000000..dde0aca
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions;
+
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.Matrix;
+
+/**
+ * Trainer for linear regression.
+ */
+public class OLSMultipleLinearRegressionTrainer implements Trainer<OLSMultipleLinearRegressionModel, double[]> {
+ /** */
+ private final double threshold;
+
+ /** */
+ private final int nobs;
+
+ /** */
+ private final int nvars;
+
+ /** */
+ private final Matrix like;
+
+ /**
+ * Construct linear regression trainer.
+ *
+ * @param threshold the singularity threshold for QR decomposition
+ * @param nobs number of observations (rows)
+ * @param nvars number of independent variables (columns, not counting y)
+ * @param like matrix(maybe empty) indicating how data should be stored
+ */
+ public OLSMultipleLinearRegressionTrainer(double threshold, int nobs, int nvars, Matrix like) {
+ this.threshold = threshold;
+ this.nobs = nobs;
+ this.nvars = nvars;
+ this.like = like;
+ }
+
+ /** {@inheritDoc} */
+ @Override public OLSMultipleLinearRegressionModel train(double[] data) {
+ OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(threshold);
+
+ regression.newSampleData(data, nobs, nvars, like);
+
+ return new OLSMultipleLinearRegressionModel(regression.getX(), regression.solver());
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index 47910c8..7a61bad 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -32,7 +32,8 @@ import org.junit.runners.Suite;
MathImplMainTestSuite.class,
RegressionsTestSuite.class,
ClusteringTestSuite.class,
- DecisionTreesTestSuite.class
+ DecisionTreesTestSuite.class,
+ LocalModelsTest.class
})
public class IgniteMLTestSuite {
// No-op.