You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2017/12/08 08:34:52 UTC
[21/30] ignite git commit: IGNITE-6872: Linear regression should
implement Model API
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
index acc5649..d0d1247 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
@@ -20,12 +20,14 @@ package org.apache.ignite.ml;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
-import java.nio.file.Paths;
+import java.util.function.Function;
import org.apache.ignite.ml.clustering.KMeansLocalClusterer;
import org.apache.ignite.ml.clustering.KMeansModel;
import org.apache.ignite.ml.math.EuclideanDistance;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.junit.After;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModel;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModelFormat;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionTrainer;
import org.junit.Assert;
import org.junit.Test;
@@ -34,39 +36,68 @@ import org.junit.Test;
*/
public class LocalModelsTest {
/** */
- private String mdlFilePath = "model.mlmod";
-
- /**
- *
- */
- @After
- public void cleanUp() throws IOException {
- Files.deleteIfExists(Paths.get(mdlFilePath));
+ @Test
+ public void importExportKMeansModelTest() throws IOException {
+ executeModelTest(mdlFilePath -> {
+ KMeansModel mdl = getClusterModel();
+
+ Exporter<KMeansModelFormat, String> exporter = new FileExporter<>();
+
+ mdl.saveModel(exporter, mdlFilePath);
+
+ KMeansModelFormat load = exporter.load(mdlFilePath);
+
+ Assert.assertNotNull(load);
+
+ KMeansModel importedMdl = new KMeansModel(load.getCenters(), load.getDistance());
+
+ Assert.assertTrue("", mdl.equals(importedMdl));
+
+ return null;
+ });
}
- /**
- *
- */
+ /** */
@Test
- public void importExportKMeansModelTest() {
- Path mdlPath = Paths.get(mdlFilePath);
+ public void importExportOLSMultipleLinearRegressionModelTest() throws IOException {
+ executeModelTest(mdlFilePath -> {
+ OLSMultipleLinearRegressionModel mdl = getAbstractMultipleLinearRegressionModel();
+
+ Exporter<OLSMultipleLinearRegressionModelFormat, String> exporter = new FileExporter<>();
- KMeansModel mdl = getClusterModel();
+ mdl.saveModel(exporter, mdlFilePath);
- Exporter<KMeansModelFormat, String> exporter = new FileExporter<>();
- mdl.saveModel(exporter, mdlFilePath);
+ OLSMultipleLinearRegressionModelFormat load = exporter.load(mdlFilePath);
- Assert.assertTrue(String.format("File %s not found.", mdlPath.toString()), Files.exists(mdlPath));
+ Assert.assertNotNull(load);
- KMeansModelFormat load = exporter.load(mdlFilePath);
- KMeansModel importedMdl = new KMeansModel(load.getCenters(), load.getDistance());
+ OLSMultipleLinearRegressionModel importedMdl = load.getOLSMultipleLinearRegressionModel();
- Assert.assertTrue("", mdl.equals(importedMdl));
+ Assert.assertTrue("", mdl.equals(importedMdl));
+
+ return null;
+ });
}
- /**
- *
- */
+ /** */
+ private void executeModelTest(Function<String, Void> code) throws IOException {
+ Path mdlPath = Files.createTempFile(null, null);
+
+ Assert.assertNotNull(mdlPath);
+
+ try {
+ String mdlFilePath = mdlPath.toAbsolutePath().toString();
+
+ Assert.assertTrue(String.format("File %s not found.", mdlFilePath), Files.exists(mdlPath));
+
+ code.apply(mdlFilePath);
+ }
+ finally {
+ Files.deleteIfExists(mdlPath);
+ }
+ }
+
+ /** */
private KMeansModel getClusterModel() {
KMeansLocalClusterer clusterer = new KMeansLocalClusterer(new EuclideanDistance(), 1, 1L);
@@ -77,4 +108,22 @@ public class LocalModelsTest {
return clusterer.cluster(points, 1);
}
+
+ /** */
+ private OLSMultipleLinearRegressionModel getAbstractMultipleLinearRegressionModel() {
+ double[] data = new double[] {
+ 0, 0, 0, 0, 0, 0, // IMPL NOTE values in this row are later replaced (with 1.0)
+ 0, 2.0, 0, 0, 0, 0,
+ 0, 0, 3.0, 0, 0, 0,
+ 0, 0, 0, 4.0, 0, 0,
+ 0, 0, 0, 0, 5.0, 0,
+ 0, 0, 0, 0, 0, 6.0};
+
+ final int nobs = 6, nvars = 5;
+
+ OLSMultipleLinearRegressionTrainer trainer
+ = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, new DenseLocalOnHeapMatrix(1, 1));
+
+ return trainer.train(data);
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
index 216fd7b..af2154e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.math;
import org.apache.ignite.ml.math.decompositions.CholeskyDecompositionTest;
import org.apache.ignite.ml.math.decompositions.EigenDecompositionTest;
import org.apache.ignite.ml.math.decompositions.LUDecompositionTest;
+import org.apache.ignite.ml.math.decompositions.QRDSolverTest;
import org.apache.ignite.ml.math.decompositions.QRDecompositionTest;
import org.apache.ignite.ml.math.decompositions.SingularValueDecompositionTest;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOffHeapMatrixConstructorTest;
@@ -116,6 +117,7 @@ import org.junit.runners.Suite;
EigenDecompositionTest.class,
CholeskyDecompositionTest.class,
QRDecompositionTest.class,
+ QRDSolverTest.class,
SingularValueDecompositionTest.class
})
public class MathImplLocalTestSuite {
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java
new file mode 100644
index 0000000..d3e8e76
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.decompositions;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/** */
+public class QRDSolverTest {
+ /** */
+ @Test
+ public void basicTest() {
+ Matrix m = new DenseLocalOnHeapMatrix(new double[][] {
+ {2.0d, -1.0d, 0.0d},
+ {-1.0d, 2.0d, -1.0d},
+ {0.0d, -1.0d, 2.0d}
+ });
+
+ QRDecomposition dec = new QRDecomposition(m);
+ assertTrue("Unexpected value for full rank in decomposition " + dec, dec.hasFullRank());
+
+ Matrix q = dec.getQ();
+ Matrix r = dec.getR();
+
+ assertNotNull("Matrix q is expected to be not null.", q);
+ assertNotNull("Matrix r is expected to be not null.", r);
+
+ Matrix qSafeCp = safeCopy(q);
+
+ Matrix expIdentity = qSafeCp.times(qSafeCp.transpose());
+
+ final double delta = 0.0001;
+
+ for (int row = 0; row < expIdentity.rowSize(); row++)
+ for (int col = 0; col < expIdentity.columnSize(); col++)
+ assertEquals("Unexpected identity matrix value at (" + row + "," + col + ").",
+ row == col ? 1d : 0d, expIdentity.get(col, row), delta);
+
+ for (int row = 0; row < r.rowSize(); row++)
+ for (int col = 0; col < row - 1; col++)
+ assertEquals("Unexpected upper triangular matrix value at (" + row + "," + col + ").",
+ 0d, r.get(row, col), delta);
+
+ Matrix recomposed = qSafeCp.times(r);
+
+ for (int row = 0; row < m.rowSize(); row++)
+ for (int col = 0; col < m.columnSize(); col++)
+ assertEquals("Unexpected recomposed matrix value at (" + row + "," + col + ").",
+ m.get(row, col), recomposed.get(row, col), delta);
+
+ Matrix sol = new QRDSolver(q, r).solve(new DenseLocalOnHeapMatrix(3, 10));
+ assertEquals("Unexpected rows in solution matrix.", 3, sol.rowSize());
+ assertEquals("Unexpected cols in solution matrix.", 10, sol.columnSize());
+
+ for (int row = 0; row < sol.rowSize(); row++)
+ for (int col = 0; col < sol.columnSize(); col++)
+ assertEquals("Unexpected solution matrix value at (" + row + "," + col + ").",
+ 0d, sol.get(row, col), delta);
+
+ dec.destroy();
+ }
+
+ /** */
+ private Matrix safeCopy(Matrix orig) {
+ return new DenseLocalOnHeapMatrix(orig.rowSize(), orig.columnSize()).assign(orig);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
index a482737..8c9d429 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
@@ -35,7 +35,6 @@ import org.junit.Assert;
/**
* Tests for {@link OLSMultipleLinearRegression}.
*/
-
@GridCommonTest(group = "Distributed Models")
public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonAbstractTest {
/** */
@@ -95,7 +94,7 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
}
/** */
- protected OLSMultipleLinearRegression createRegression() {
+ private OLSMultipleLinearRegression createRegression() {
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(new SparseBlockDistributedVector(y), new SparseBlockDistributedMatrix(x));
return regression;
@@ -243,7 +242,6 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
// Check R-Square statistics against R
Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12);
Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12);
-
}
/**
@@ -533,12 +531,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
try {
createRegression().newSampleData(null, new SparseBlockDistributedMatrix(new double[][] {{1}}));
- fail("NullArgumentException");
+ fail("Expected NullArgumentException was not caught.");
}
catch (NullArgumentException e) {
return;
}
- fail("NullArgumentException");
+ fail("Expected NullArgumentException was not caught.");
}
/** */
@@ -547,13 +545,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
try {
createRegression().newSampleData(new SparseBlockDistributedVector(new double[] {1}), null);
- fail("NullArgumentException");
+ fail("Expected NullArgumentException was not caught.");
}
catch (NullArgumentException e) {
return;
}
- fail("NullArgumentException");
-
+ fail("Expected NullArgumentException was not caught.");
}
/**
@@ -830,17 +827,16 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
public void testSingularCalculateBeta() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15);
- mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseBlockDistributedMatrix());
try {
+ mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseBlockDistributedMatrix());
mdl.calculateBeta();
- fail("SingularMatrixException");
+ fail("Expected SingularMatrixException was not caught.");
}
catch (SingularMatrixException e) {
return;
}
- fail("SingularMatrixException");
-
+ fail("Expected SingularMatrixException was not caught.");
}
/** */
@@ -850,13 +846,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
try {
mdl.calculateBeta();
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
catch (NullPointerException e) {
return;
}
- fail("java.lang.NullPointerException");
-
+ fail("Expected NullPointerException was not caught.");
}
/** */
@@ -866,12 +861,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
try {
mdl.calculateHat();
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
catch (NullPointerException e) {
return;
}
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
/** */
@@ -881,13 +876,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
try {
mdl.calculateTotalSumOfSquares();
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
catch (NullPointerException e) {
return;
}
- fail("java.lang.NullPointerException");
-
+ fail("Expected NullPointerException was not caught.");
}
/** */
@@ -897,11 +891,11 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
try {
mdl.validateSampleData(new SparseBlockDistributedMatrix(1, 2), new SparseBlockDistributedVector(1));
- fail("MathIllegalArgumentException");
+ fail("Expected MathIllegalArgumentException was not caught.");
}
catch (MathIllegalArgumentException e) {
return;
}
- fail("MathIllegalArgumentException");
+ fail("Expected MathIllegalArgumentException was not caught.");
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
index a2d1e5f..f720406 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
@@ -35,7 +35,6 @@ import org.junit.Assert;
/**
* Tests for {@link OLSMultipleLinearRegression}.
*/
-
@GridCommonTest(group = "Distributed Models")
public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstractTest {
/** */
@@ -58,9 +57,7 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
/** */
public DistributedOLSMultipleLinearRegressionTest() {
-
super(false);
-
}
/** {@inheritDoc} */
@@ -97,7 +94,7 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
}
/** */
- protected OLSMultipleLinearRegression createRegression() {
+ private OLSMultipleLinearRegression createRegression() {
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(new SparseDistributedVector(y), new SparseDistributedMatrix(x));
return regression;
@@ -245,7 +242,6 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
// Check R-Square statistics against R
Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12);
Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12);
-
}
/**
@@ -526,7 +522,6 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
}
for (int i = 0; i < combinedY.size(); i++)
Assert.assertEquals(combinedY.get(i), regression.getY().get(i), PRECISION);
-
}
/** */
@@ -535,12 +530,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
try {
createRegression().newSampleData(null, new SparseDistributedMatrix(new double[][] {{1}}));
- fail("NullArgumentException");
+ fail("Expected NullArgumentException was not caught.");
}
catch (NullArgumentException e) {
return;
}
- fail("NullArgumentException");
+ fail("Expected NullArgumentException was not caught.");
}
/** */
@@ -549,12 +544,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
try {
createRegression().newSampleData(new SparseDistributedVector(new double[] {1}), null);
- fail("NullArgumentException");
+ fail("Expected NullArgumentException was not caught.");
}
catch (NullArgumentException e) {
return;
}
- fail("NullArgumentException");
+ fail("Expected NullArgumentException was not caught.");
}
@@ -832,16 +827,16 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
public void testSingularCalculateBeta() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15);
- mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseDistributedMatrix());
try {
+ mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseDistributedMatrix());
mdl.calculateBeta();
- fail("SingularMatrixException");
+ fail("Expected SingularMatrixException was not caught.");
}
catch (SingularMatrixException e) {
return;
}
- fail("SingularMatrixException");
+ fail("Expected SingularMatrixException was not caught.");
}
@@ -852,12 +847,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
try {
mdl.calculateBeta();
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
catch (NullPointerException e) {
return;
}
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
@@ -868,12 +863,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
try {
mdl.calculateHat();
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
catch (NullPointerException e) {
return;
}
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
/** */
@@ -883,13 +878,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
try {
mdl.calculateTotalSumOfSquares();
- fail("java.lang.NullPointerException");
+ fail("Expected NullPointerException was not caught.");
}
catch (NullPointerException e) {
return;
}
- fail("java.lang.NullPointerException");
-
+ fail("Expected NullPointerException was not caught.");
}
/** */
@@ -899,11 +893,11 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
try {
mdl.validateSampleData(new SparseDistributedMatrix(1, 2), new SparseDistributedVector(1));
- fail("MathIllegalArgumentException");
+ fail("Expected MathIllegalArgumentException was not caught.");
}
catch (MathIllegalArgumentException e) {
return;
}
- fail("MathIllegalArgumentException");
+ fail("Expected MathIllegalArgumentException was not caught.");
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
new file mode 100644
index 0000000..37c972c
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions;
+
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
+/**
+ * Tests for {@link OLSMultipleLinearRegressionModel}.
+ */
+public class OLSMultipleLinearRegressionModelTest {
+ /** */
+ @Test
+ public void testPerfectFit() {
+ Vector val = new DenseLocalOnHeapVector(new double[] {11.0, 12.0, 13.0, 14.0, 15.0, 16.0});
+
+ double[] data = new double[] {
+ 0, 0, 0, 0, 0, 0, // IMPL NOTE values in this row are later replaced (with 1.0)
+ 0, 2.0, 0, 0, 0, 0,
+ 0, 0, 3.0, 0, 0, 0,
+ 0, 0, 0, 4.0, 0, 0,
+ 0, 0, 0, 0, 5.0, 0,
+ 0, 0, 0, 0, 0, 6.0};
+
+ final int nobs = 6, nvars = 5;
+
+ OLSMultipleLinearRegressionTrainer trainer
+ = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, new DenseLocalOnHeapMatrix(1, 1));
+
+ OLSMultipleLinearRegressionModel mdl = trainer.train(data);
+
+ TestUtils.assertEquals(new double[] {0d, 0d, 0d, 0d, 0d, 0d},
+ val.minus(mdl.predict(val)).getStorage().data(), 1e-13);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index 2a0b111..be71934 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -25,7 +25,10 @@ import org.junit.runners.Suite;
*/
@RunWith(Suite.class)
@Suite.SuiteClasses({
- OLSMultipleLinearRegressionTest.class, DistributedOLSMultipleLinearRegressionTest.class, DistributedBlockOLSMultipleLinearRegressionTest.class
+ OLSMultipleLinearRegressionTest.class,
+ DistributedOLSMultipleLinearRegressionTest.class,
+ DistributedBlockOLSMultipleLinearRegressionTest.class,
+ OLSMultipleLinearRegressionModelTest.class
})
public class RegressionsTestSuite {
// No-op.