You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2017/12/29 09:28:27 UTC
[10/15] ignite git commit: IGNITE-5217: Gradient descent for OLS lin
reg
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java
deleted file mode 100644
index 2774028..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java
+++ /dev/null
@@ -1,820 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions;
-
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
-import org.apache.ignite.ml.math.exceptions.NullArgumentException;
-import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.math.util.MatrixUtil;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-/**
- * Tests for {@link OLSMultipleLinearRegression}.
- */
-public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegressionTest {
- /** */
- private double[] y;
-
- /** */
- private double[][] x;
-
- /** */
- @Before
- @Override public void setUp() {
- y = new double[] {11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
- x = new double[6][];
- x[0] = new double[] {0, 0, 0, 0, 0};
- x[1] = new double[] {2.0, 0, 0, 0, 0};
- x[2] = new double[] {0, 3.0, 0, 0, 0};
- x[3] = new double[] {0, 0, 4.0, 0, 0};
- x[4] = new double[] {0, 0, 0, 5.0, 0};
- x[5] = new double[] {0, 0, 0, 0, 6.0};
- super.setUp();
- }
-
- /** */
- @Override protected OLSMultipleLinearRegression createRegression() {
- OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
- regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
- return regression;
- }
-
- /** */
- @Override protected int getNumberOfRegressors() {
- return x[0].length + 1;
- }
-
- /** */
- @Override protected int getSampleSize() {
- return y.length;
- }
-
- /** */
- @Test(expected = MathIllegalArgumentException.class)
- public void cannotAddSampleDataWithSizeMismatch() {
- double[] y = new double[] {1.0, 2.0};
- double[][] x = new double[1][];
- x[0] = new double[] {1.0, 0};
- createRegression().newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
- }
-
- /** */
- @Test
- public void testPerfectFit() {
- double[] betaHat = regression.estimateRegressionParameters();
- TestUtils.assertEquals(new double[] {11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0},
- betaHat,
- 1e-13);
- double[] residuals = regression.estimateResiduals();
- TestUtils.assertEquals(new double[] {0d, 0d, 0d, 0d, 0d, 0d}, residuals,
- 1e-13);
- Matrix errors = regression.estimateRegressionParametersVariance();
- final double[] s = {1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0};
- Matrix refVar = new DenseLocalOnHeapMatrix(s.length, s.length);
- for (int i = 0; i < refVar.rowSize(); i++)
- for (int j = 0; j < refVar.columnSize(); j++) {
- if (i == 0) {
- refVar.setX(i, j, s[j]);
- continue;
- }
- double x = s[i] * s[j];
- refVar.setX(i, j, (i == j) ? 2 * x : x);
- }
- Assert.assertEquals(0.0,
- TestUtils.maximumAbsoluteRowSum(errors.minus(refVar)),
- 5.0e-16 * TestUtils.maximumAbsoluteRowSum(refVar));
- Assert.assertEquals(1, ((OLSMultipleLinearRegression)regression).calculateRSquared(), 1E-12);
- }
-
- /**
- * Test Longley dataset against certified values provided by NIST.
- * Data Source: J. Longley (1967) "An Appraisal of Least Squares
- * Programs for the Electronic Computer from the Point of View of the User"
- * Journal of the American Statistical Association, vol. 62. September,
- * pp. 819-841.
- *
- * Certified values (and data) are from NIST:
- * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
- */
- @Test
- public void testLongly() {
- // Y values are first, then independent vars
- // Each row is one observation
- double[] design = new double[] {
- 60323, 83.0, 234289, 2356, 1590, 107608, 1947,
- 61122, 88.5, 259426, 2325, 1456, 108632, 1948,
- 60171, 88.2, 258054, 3682, 1616, 109773, 1949,
- 61187, 89.5, 284599, 3351, 1650, 110929, 1950,
- 63221, 96.2, 328975, 2099, 3099, 112075, 1951,
- 63639, 98.1, 346999, 1932, 3594, 113270, 1952,
- 64989, 99.0, 365385, 1870, 3547, 115094, 1953,
- 63761, 100.0, 363112, 3578, 3350, 116219, 1954,
- 66019, 101.2, 397469, 2904, 3048, 117388, 1955,
- 67857, 104.6, 419180, 2822, 2857, 118734, 1956,
- 68169, 108.4, 442769, 2936, 2798, 120445, 1957,
- 66513, 110.8, 444546, 4681, 2637, 121950, 1958,
- 68655, 112.6, 482704, 3813, 2552, 123366, 1959,
- 69564, 114.2, 502601, 3931, 2514, 125368, 1960,
- 69331, 115.7, 518173, 4806, 2572, 127852, 1961,
- 70551, 116.9, 554894, 4007, 2827, 130081, 1962
- };
-
- final int nobs = 16;
- final int nvars = 6;
-
- // Estimate the model
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix());
-
- // Check expected beta values from NIST
- double[] betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- -3482258.63459582, 15.0618722713733,
- -0.358191792925910E-01, -2.02022980381683,
- -1.03322686717359, -0.511041056535807E-01,
- 1829.15146461355}, 2E-6); //
-
- // Check expected residuals from R
- double[] residuals = mdl.estimateResiduals();
- TestUtils.assertEquals(residuals, new double[] {
- 267.340029759711, -94.0139423988359, 46.28716775752924,
- -410.114621930906, 309.7145907602313, -249.3112153297231,
- -164.0489563956039, -13.18035686637081, 14.30477260005235,
- 455.394094551857, -17.26892711483297, -39.0550425226967,
- -155.5499735953195, -85.6713080421283, 341.9315139607727,
- -206.7578251937366},
- 1E-7);
-
- // Check standard errors from NIST
- double[] errors = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(new double[] {
- 890420.383607373,
- 84.9149257747669,
- 0.334910077722432E-01,
- 0.488399681651699,
- 0.214274163161675,
- 0.226073200069370,
- 455.478499142212}, errors, 1E-6);
-
- // Check regression standard error against R
- Assert.assertEquals(304.8540735619638, mdl.estimateRegressionStandardError(), 1E-8);
-
- // Check R-Square statistics against R
- Assert.assertEquals(0.995479004577296, mdl.calculateRSquared(), 1E-12);
- Assert.assertEquals(0.992465007628826, mdl.calculateAdjustedRSquared(), 1E-12);
-
- // TODO: IGNITE-5826, uncomment.
- // checkVarianceConsistency(model);
-
- // Estimate model without intercept
- mdl.setNoIntercept(true);
- mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix());
-
- // Check expected beta values from R
- betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- -52.99357013868291, 0.07107319907358,
- -0.42346585566399, -0.57256866841929,
- -0.41420358884978, 48.41786562001326}, 1E-8);
-
- // Check standard errors from R
- errors = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(new double[] {
- 129.54486693117232, 0.03016640003786,
- 0.41773654056612, 0.27899087467676, 0.32128496193363,
- 17.68948737819961}, errors, 1E-11);
-
- // Check expected residuals from R
- residuals = mdl.estimateResiduals();
- TestUtils.assertEquals(residuals, new double[] {
- 279.90274927293092, -130.32465380836874, 90.73228661967445, -401.31252201634948,
- -440.46768772620027, -543.54512853774793, 201.32111639536299, 215.90889365977932,
- 73.09368242049943, 913.21694494481869, 424.82484953610174, -8.56475876776709,
- -361.32974610842876, 27.34560497213464, 151.28955976355002, -492.49937355336846},
- 1E-8);
-
- // Check regression standard error against R
- Assert.assertEquals(475.1655079819517, mdl.estimateRegressionStandardError(), 1E-10);
-
- // Check R-Square statistics against R
- Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12);
- Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12);
-
- }
-
- /**
- * Test R Swiss fertility dataset against R.
- * Data Source: R datasets package
- */
- @Test
- public void testSwissFertility() {
- double[] design = new double[] {
- 80.2, 17.0, 15, 12, 9.96,
- 83.1, 45.1, 6, 9, 84.84,
- 92.5, 39.7, 5, 5, 93.40,
- 85.8, 36.5, 12, 7, 33.77,
- 76.9, 43.5, 17, 15, 5.16,
- 76.1, 35.3, 9, 7, 90.57,
- 83.8, 70.2, 16, 7, 92.85,
- 92.4, 67.8, 14, 8, 97.16,
- 82.4, 53.3, 12, 7, 97.67,
- 82.9, 45.2, 16, 13, 91.38,
- 87.1, 64.5, 14, 6, 98.61,
- 64.1, 62.0, 21, 12, 8.52,
- 66.9, 67.5, 14, 7, 2.27,
- 68.9, 60.7, 19, 12, 4.43,
- 61.7, 69.3, 22, 5, 2.82,
- 68.3, 72.6, 18, 2, 24.20,
- 71.7, 34.0, 17, 8, 3.30,
- 55.7, 19.4, 26, 28, 12.11,
- 54.3, 15.2, 31, 20, 2.15,
- 65.1, 73.0, 19, 9, 2.84,
- 65.5, 59.8, 22, 10, 5.23,
- 65.0, 55.1, 14, 3, 4.52,
- 56.6, 50.9, 22, 12, 15.14,
- 57.4, 54.1, 20, 6, 4.20,
- 72.5, 71.2, 12, 1, 2.40,
- 74.2, 58.1, 14, 8, 5.23,
- 72.0, 63.5, 6, 3, 2.56,
- 60.5, 60.8, 16, 10, 7.72,
- 58.3, 26.8, 25, 19, 18.46,
- 65.4, 49.5, 15, 8, 6.10,
- 75.5, 85.9, 3, 2, 99.71,
- 69.3, 84.9, 7, 6, 99.68,
- 77.3, 89.7, 5, 2, 100.00,
- 70.5, 78.2, 12, 6, 98.96,
- 79.4, 64.9, 7, 3, 98.22,
- 65.0, 75.9, 9, 9, 99.06,
- 92.2, 84.6, 3, 3, 99.46,
- 79.3, 63.1, 13, 13, 96.83,
- 70.4, 38.4, 26, 12, 5.62,
- 65.7, 7.7, 29, 11, 13.79,
- 72.7, 16.7, 22, 13, 11.22,
- 64.4, 17.6, 35, 32, 16.92,
- 77.6, 37.6, 15, 7, 4.97,
- 67.6, 18.7, 25, 7, 8.65,
- 35.0, 1.2, 37, 53, 42.34,
- 44.7, 46.6, 16, 29, 50.43,
- 42.8, 27.7, 22, 29, 58.33
- };
-
- final int nobs = 47;
- final int nvars = 4;
-
- // Estimate the model
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix());
-
- // Check expected beta values from R
- double[] betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- 91.05542390271397,
- -0.22064551045715,
- -0.26058239824328,
- -0.96161238456030,
- 0.12441843147162}, 1E-12);
-
- // Check expected residuals from R
- double[] residuals = mdl.estimateResiduals();
- TestUtils.assertEquals(residuals, new double[] {
- 7.1044267859730512, 1.6580347433531366,
- 4.6944952770029644, 8.4548022690166160, 13.6547432343186212,
- -9.3586864458500774, 7.5822446330520386, 15.5568995563859289,
- 0.8113090736598980, 7.1186762732484308, 7.4251378771228724,
- 2.6761316873234109, 0.8351584810309354, 7.1769991119615177,
- -3.8746753206299553, -3.1337779476387251, -0.1412575244091504,
- 1.1186809170469780, -6.3588097346816594, 3.4039270429434074,
- 2.3374058329820175, -7.9272368576900503, -7.8361010968497959,
- -11.2597369269357070, 0.9445333697827101, 6.6544245101380328,
- -0.9146136301118665, -4.3152449403848570, -4.3536932047009183,
- -3.8907885169304661, -6.3027643926302188, -7.8308982189289091,
- -3.1792280015332750, -6.7167298771158226, -4.8469946718041754,
- -10.6335664353633685, 11.1031134362036958, 6.0084032641811733,
- 5.4326230830188482, -7.2375578629692230, 2.1671550814448222,
- 15.0147574652763112, 4.8625103516321015, -7.1597256413907706,
- -0.4515205619767598, -10.2916870903837587, -15.7812984571900063},
- 1E-12);
-
- // Check standard errors from R
- double[] errors = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(new double[] {
- 6.94881329475087,
- 0.07360008972340,
- 0.27410957467466,
- 0.19454551679325,
- 0.03726654773803}, errors, 1E-10);
-
- // Check regression standard error against R
- Assert.assertEquals(7.73642194433223, mdl.estimateRegressionStandardError(), 1E-12);
-
- // Check R-Square statistics against R
- Assert.assertEquals(0.649789742860228, mdl.calculateRSquared(), 1E-12);
- Assert.assertEquals(0.6164363850373927, mdl.calculateAdjustedRSquared(), 1E-12);
-
- // TODO: IGNITE-5826, uncomment.
- // checkVarianceConsistency(model);
-
- // Estimate the model with no intercept
- mdl = new OLSMultipleLinearRegression();
- mdl.setNoIntercept(true);
- mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix());
-
- // Check expected beta values from R
- betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- 0.52191832900513,
- 2.36588087917963,
- -0.94770353802795,
- 0.30851985863609}, 1E-12);
-
- // Check expected residuals from R
- residuals = mdl.estimateResiduals();
- TestUtils.assertEquals(residuals, new double[] {
- 44.138759883538249, 27.720705122356215, 35.873200836126799,
- 34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814,
- 1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165,
- -9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336,
- -17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924,
- -10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944,
- -13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898,
- -1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603,
- -16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384,
- -17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695,
- -0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976,
- 2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615},
- 1E-12);
-
- // Check standard errors from R
- errors = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(new double[] {
- 0.10470063765677, 0.41684100584290,
- 0.43370143099691, 0.07694953606522}, errors, 1E-10);
-
- // Check regression standard error against R
- Assert.assertEquals(17.24710630547, mdl.estimateRegressionStandardError(), 1E-10);
-
- // Check R-Square statistics against R
- Assert.assertEquals(0.946350722085, mdl.calculateRSquared(), 1E-12);
- Assert.assertEquals(0.9413600915813, mdl.calculateAdjustedRSquared(), 1E-12);
- }
-
- /**
- * Test hat matrix computation
- */
- @Test
- public void testHat() {
-
- /*
- * This example is from "The Hat Matrix in Regression and ANOVA",
- * David C. Hoaglin and Roy E. Welsch,
- * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
- *
- */
- double[] design = new double[] {
- 11.14, .499, 11.1,
- 12.74, .558, 8.9,
- 13.13, .604, 8.8,
- 11.51, .441, 8.9,
- 12.38, .550, 8.8,
- 12.60, .528, 9.9,
- 11.13, .418, 10.7,
- 11.7, .480, 10.5,
- 11.02, .406, 10.5,
- 11.41, .467, 10.7
- };
-
- int nobs = 10;
- int nvars = 2;
-
- // Estimate the model
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix());
-
- Matrix hat = mdl.calculateHat();
-
-
- // Reference data is upper half of symmetric hat matrix
- double[] refData = new double[] {
- .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242,
- .242, .292, .136, .243, .128, -.041, .033, -.035, .004,
- .417, -.019, .273, .187, -.126, .044, -.153, .004,
- .604, .197, -.038, .168, -.022, .275, -.028,
- .252, .111, -.030, .019, -.010, -.010,
- .148, .042, .117, .012, .111,
- .262, .145, .277, .174,
- .154, .120, .168,
- .315, .148,
- .187
- };
-
- // Check against reference data and verify symmetry
- int k = 0;
- for (int i = 0; i < 10; i++) {
- for (int j = i; j < 10; j++) {
- Assert.assertEquals(refData[k], hat.getX(i, j), 10e-3);
- Assert.assertEquals(hat.getX(i, j), hat.getX(j, i), 10e-12);
- k++;
- }
- }
-
- /*
- * Verify that residuals computed using the hat matrix are close to
- * what we get from direct computation, i.e. r = (I - H) y
- */
- double[] residuals = mdl.estimateResiduals();
- Matrix id = MatrixUtil.identityLike(hat, 10);
- double[] hatResiduals = id.minus(hat).times(mdl.getY()).getStorage().data();
- TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
- }
-
- /**
- * test calculateYVariance
- */
- @Test
- public void testYVariance() {
- // assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
- TestUtils.assertEquals(mdl.calculateYVariance(), 3.5, 0);
- }
-
- /**
- * Verifies that setting X and Y separately has the same effect as newSample(X,Y).
- */
- @Test
- public void testNewSample2() {
- double[] y = new double[] {1, 2, 3, 4};
- double[][] x = new double[][] {
- {19, 22, 33},
- {20, 30, 40},
- {25, 35, 45},
- {27, 37, 47}
- };
- OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
- regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
- Matrix combinedX = regression.getX().copy();
- Vector combinedY = regression.getY().copy();
- regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
- regression.newYSampleData(new DenseLocalOnHeapVector(y));
- Assert.assertEquals(combinedX, regression.getX());
- Assert.assertEquals(combinedY, regression.getY());
-
- // No intercept
- regression.setNoIntercept(true);
- regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
- combinedX = regression.getX().copy();
- combinedY = regression.getY().copy();
- regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
- regression.newYSampleData(new DenseLocalOnHeapVector(y));
- Assert.assertEquals(combinedX, regression.getX());
- Assert.assertEquals(combinedY, regression.getY());
- }
-
- /** */
- @Test(expected = NullArgumentException.class)
- public void testNewSampleDataYNull() {
- createRegression().newSampleData(null, new DenseLocalOnHeapMatrix(new double[][] {{1}}));
- }
-
- /** */
- @Test(expected = NullArgumentException.class)
- public void testNewSampleDataXNull() {
- createRegression().newSampleData(new DenseLocalOnHeapVector(new double[] {}), null);
- }
-
- /**
- * This is a test based on the Wampler1 data set
- * http://www.itl.nist.gov/div898/strd/lls/data/Wampler1.shtml
- */
- @Test
- public void testWampler1() {
- double[] data = new double[] {
- 1, 0,
- 6, 1,
- 63, 2,
- 364, 3,
- 1365, 4,
- 3906, 5,
- 9331, 6,
- 19608, 7,
- 37449, 8,
- 66430, 9,
- 111111, 10,
- 177156, 11,
- 271453, 12,
- 402234, 13,
- 579195, 14,
- 813616, 15,
- 1118481, 16,
- 1508598, 17,
- 2000719, 18,
- 2613660, 19,
- 3368421, 20};
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
-
- final int nvars = 5;
- final int nobs = 21;
- double[] tmp = new double[(nvars + 1) * nobs];
- int off = 0;
- int off2 = 0;
- for (int i = 0; i < nobs; i++) {
- tmp[off2] = data[off];
- tmp[off2 + 1] = data[off + 1];
- tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
- tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
- tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
- tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
- off2 += (nvars + 1);
- off += 2;
- }
- mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix());
- double[] betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- 1.0,
- 1.0, 1.0,
- 1.0, 1.0,
- 1.0}, 1E-8);
-
- double[] se = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(se,
- new double[] {
- 0.0,
- 0.0, 0.0,
- 0.0, 0.0,
- 0.0}, 1E-8);
-
- TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10);
- TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7);
- TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6);
- }
-
- /**
- * This is a test based on the Wampler2 data set
- * http://www.itl.nist.gov/div898/strd/lls/data/Wampler2.shtml
- */
- @Test
- public void testWampler2() {
- double[] data = new double[] {
- 1.00000, 0,
- 1.11111, 1,
- 1.24992, 2,
- 1.42753, 3,
- 1.65984, 4,
- 1.96875, 5,
- 2.38336, 6,
- 2.94117, 7,
- 3.68928, 8,
- 4.68559, 9,
- 6.00000, 10,
- 7.71561, 11,
- 9.92992, 12,
- 12.75603, 13,
- 16.32384, 14,
- 20.78125, 15,
- 26.29536, 16,
- 33.05367, 17,
- 41.26528, 18,
- 51.16209, 19,
- 63.00000, 20};
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
-
- final int nvars = 5;
- final int nobs = 21;
- double[] tmp = new double[(nvars + 1) * nobs];
- int off = 0;
- int off2 = 0;
- for (int i = 0; i < nobs; i++) {
- tmp[off2] = data[off];
- tmp[off2 + 1] = data[off + 1];
- tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
- tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
- tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
- tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
- off2 += (nvars + 1);
- off += 2;
- }
- mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix());
- double[] betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- 1.0,
- 1.0e-1,
- 1.0e-2,
- 1.0e-3, 1.0e-4,
- 1.0e-5}, 1E-8);
-
- double[] se = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(se,
- new double[] {
- 0.0,
- 0.0, 0.0,
- 0.0, 0.0,
- 0.0}, 1E-8);
- TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10);
- TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7);
- TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6);
- }
-
- /**
- * This is a test based on the Wampler3 data set
- * http://www.itl.nist.gov/div898/strd/lls/data/Wampler3.shtml
- */
- @Test
- public void testWampler3() {
- double[] data = new double[] {
- 760, 0,
- -2042, 1,
- 2111, 2,
- -1684, 3,
- 3888, 4,
- 1858, 5,
- 11379, 6,
- 17560, 7,
- 39287, 8,
- 64382, 9,
- 113159, 10,
- 175108, 11,
- 273291, 12,
- 400186, 13,
- 581243, 14,
- 811568, 15,
- 1121004, 16,
- 1506550, 17,
- 2002767, 18,
- 2611612, 19,
- 3369180, 20};
-
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- final int nvars = 5;
- final int nobs = 21;
- double[] tmp = new double[(nvars + 1) * nobs];
- int off = 0;
- int off2 = 0;
- for (int i = 0; i < nobs; i++) {
- tmp[off2] = data[off];
- tmp[off2 + 1] = data[off + 1];
- tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
- tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
- tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
- tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
- off2 += (nvars + 1);
- off += 2;
- }
- mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix());
- double[] betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0}, 1E-8);
-
- double[] se = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(se,
- new double[] {
- 2152.32624678170,
- 2363.55173469681, 779.343524331583,
- 101.475507550350, 5.64566512170752,
- 0.112324854679312}, 1E-8); //
-
- TestUtils.assertEquals(.999995559025820, mdl.calculateRSquared(), 1.0e-10);
- TestUtils.assertEquals(5570284.53333333, mdl.estimateErrorVariance(), 1.0e-6);
- TestUtils.assertEquals(83554268.0000000, mdl.calculateResidualSumOfSquares(), 1.0e-5);
- }
-
- /**
- * This is a test based on the Wampler4 data set
- * http://www.itl.nist.gov/div898/strd/lls/data/Wampler4.shtml
- */
- @Test
- public void testWampler4() {
- double[] data = new double[] {
- 75901, 0,
- -204794, 1,
- 204863, 2,
- -204436, 3,
- 253665, 4,
- -200894, 5,
- 214131, 6,
- -185192, 7,
- 221249, 8,
- -138370, 9,
- 315911, 10,
- -27644, 11,
- 455253, 12,
- 197434, 13,
- 783995, 14,
- 608816, 15,
- 1370781, 16,
- 1303798, 17,
- 2205519, 18,
- 2408860, 19,
- 3444321, 20};
-
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- final int nvars = 5;
- final int nobs = 21;
- double[] tmp = new double[(nvars + 1) * nobs];
- int off = 0;
- int off2 = 0;
- for (int i = 0; i < nobs; i++) {
- tmp[off2] = data[off];
- tmp[off2 + 1] = data[off + 1];
- tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
- tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
- tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
- tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
- off2 += (nvars + 1);
- off += 2;
- }
- mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix());
- double[] betaHat = mdl.estimateRegressionParameters();
- TestUtils.assertEquals(betaHat,
- new double[] {
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0}, 1E-6);
-
- double[] se = mdl.estimateRegressionParametersStandardErrors();
- TestUtils.assertEquals(se,
- new double[] {
- 215232.624678170,
- 236355.173469681, 77934.3524331583,
- 10147.5507550350, 564.566512170752,
- 11.2324854679312}, 1E-8);
-
- TestUtils.assertEquals(.957478440825662, mdl.calculateRSquared(), 1.0e-10);
- TestUtils.assertEquals(55702845333.3333, mdl.estimateErrorVariance(), 1.0e-4);
- TestUtils.assertEquals(835542680000.000, mdl.calculateResidualSumOfSquares(), 1.0e-3);
- }
-
- /**
- * Anything requiring beta calculation should advertise SME.
- */
- @Test(expected = SingularMatrixException.class)
- public void testSingularCalculateBeta() {
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15);
- mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new DenseLocalOnHeapMatrix());
- mdl.calculateBeta();
- }
-
- /** */
- @Test(expected = NullPointerException.class)
- public void testNoDataNPECalculateBeta() {
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.calculateBeta();
- }
-
- /** */
- @Test(expected = NullPointerException.class)
- public void testNoDataNPECalculateHat() {
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.calculateHat();
- }
-
- /** */
- @Test(expected = NullPointerException.class)
- public void testNoDataNPESSTO() {
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.calculateTotalSumOfSquares();
- }
-
- /** */
- @Test(expected = MathIllegalArgumentException.class)
- public void testMathIllegalArgumentException() {
- OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
- mdl.validateSampleData(new DenseLocalOnHeapMatrix(1, 2), new DenseLocalOnHeapVector(1));
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index be71934..5c79c8f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -17,6 +17,13 @@
package org.apache.ignite.ml.regressions;
+import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionQRTrainerTest;
+import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionSGDTrainerTest;
+import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionQRTrainerTest;
+import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionSGDTrainerTest;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest;
+import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionQRTrainerTest;
+import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionSGDTrainerTest;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -25,11 +32,14 @@ import org.junit.runners.Suite;
*/
@RunWith(Suite.class)
@Suite.SuiteClasses({
- OLSMultipleLinearRegressionTest.class,
- DistributedOLSMultipleLinearRegressionTest.class,
- DistributedBlockOLSMultipleLinearRegressionTest.class,
- OLSMultipleLinearRegressionModelTest.class
+ LinearRegressionModelTest.class,
+ LocalLinearRegressionQRTrainerTest.class,
+ LocalLinearRegressionSGDTrainerTest.class,
+ DistributedLinearRegressionQRTrainerTest.class,
+ DistributedLinearRegressionSGDTrainerTest.class,
+ BlockDistributedLinearRegressionQRTrainerTest.class,
+ BlockDistributedLinearRegressionSGDTrainerTest.class
})
public class RegressionsTestSuite {
// No-op.
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java
new file mode 100644
index 0000000..ed6bf36
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+/**
+ * Artificial regression datasets to be used in regression trainers tests. These datasets were generated by scikit-learn
+ * tools, {@code sklearn.datasets.make_regression} procedure.
+ */
+public class ArtificialRegressionDatasets {
+ /**
+ * Artificial dataset with 10 observations described by 1 feature.
+ */
+ public static final TestDataset regression10x1 = new TestDataset(new double[][] {
+ {1.97657990214, 0.197725444973},
+ {-5.0835948878, -0.279921224228},
+ {-5.09032600779, -0.352291245969},
+ {9.67660993007, 0.755464872441},
+ {4.95927629958, 0.451981771462},
+ {29.2635107429, 2.2277440173},
+ {-18.3122588459, -1.25363275369},
+ {-3.61729307199, -0.273362913982},
+ {-7.19042139249, -0.473846634967},
+ {3.68008403347, 0.353883097536}
+ }, new double[] {13.554054703}, -0.808655936776);
+
+ /**
+ * Artificial dataset with 10 observations described by 5 features.
+ */
+ public static final TestDataset regression10x5 = new TestDataset(new double[][] {
+ {118.635647237, 0.687593385888, -1.18956185502, -0.305420702986, 1.98794097418, -0.776629036361},
+ {-18.2808432286, -0.165921853684, -0.156162539573, 1.56284391134, -0.198876782109, -0.0921618505605},
+ {22.6110523992, 0.0268106268606, 0.702141470035, -0.41503615392, -1.09726502337, 1.30830482813},
+ {209.820435262, 0.379809113402, -0.192097238579, -1.27460497119, 2.48052002019, -0.574430888865},
+ {-253.750024054, -1.48044570917, -0.331747484523, 0.387993627712, 0.372583756237, -2.27404065923},
+ {-24.6467766166, -0.66991474156, 0.269042238935, -0.271412703096, -0.561166818525, 1.37067541854},
+ {-311.903650717, 0.268274438122, -1.10491275353, -1.06738703543, -2.24387799735, -0.207431467989},
+ {74.2055323536, -0.329489531894, -0.493350762533, -0.644851462227, 0.661220945573, 1.65950140864},
+ {57.0312289904, -1.07266578457, 0.80375035572, -0.45207210139, 1.69314420969, -1.10526080856},
+ {12.149399645, 1.46504629281, -1.05843246079, 0.266225365277, -0.0113100353869, -0.983495425471}
+ }, new double[] {99.8393653561, 82.4948224094, 20.2087724072, 97.3306384162, 55.7502297387}, 3.98444039189);
+
+ /**
+ * Artificial dataset with 100 observations described by 5 features.
+ */
+ public static final TestDataset regression100x5 = new TestDataset(new double[][] {
+ {-44.2310642946, -0.0331360137605, -0.5290800706, -0.634340342338, -0.428433927151, 0.830582347183},
+ {76.2539139721, -0.216200869652, 0.513212019048, -0.693404511747, 0.132995973133, 1.28470259833},
+ {293.369799914, 2.90735870802, 0.457740818846, -0.490470696097, -0.442343455187, 0.584038258781},
+ {124.258807314, 1.64158129148, 0.0616936820145, 1.24082841519, -1.20126518593, -0.542298907742},
+ {13.6610807249, -1.10834821778, 0.545508208111, 1.81361288715, -0.786543112444, 0.250772626496},
+ {101.924582305, -0.433526394969, 0.257594734335, 1.22333193911, 0.76626554927, -0.0400734567005},
+ {25.5963186303, -0.202003301507, 0.717101151637, -0.486881225605, 1.15215024807, -0.921615554612},
+ {75.7959681263, -0.604173187402, 0.0364386836472, 1.67544714536, 0.394743148877, 0.0237966550759},
+ {-97.539357166, -0.774517689169, -0.0966902473883, -0.152250704254, -0.325472625458, 0.0720711851256},
+ {0.394748999236, -0.559303402754, -0.0493339259273, -1.10840277768, -0.0800969523557, 1.80939282066},
+ {-62.0138166431, 0.062614716778, -0.844143618016, 0.55269949861, -2.32580899335, 1.58020577369},
+ {584.427692931, 2.13184767906, 1.22222461994, 1.71894070494, 2.69512281718, 0.294123497874},
+ {-59.8323709765, 1.00006112818, -1.54481230765, -0.781282316493, 0.0255925284853, -0.0821173744608},
+ {101.565711925, -0.38699836725, 1.06934591441, -0.260429311097, 1.02628949564, 0.0431473245174},
+ {-141.592607814, 0.993279116267, -0.371768203378, -0.851483217286, -1.96241293548, -0.612279404296},
+ {34.8038723379, -0.0182719243972, 0.306367604506, -0.650526589206, 1.30693112283, -0.587465952557},
+ {-16.9554534069, -0.703006786668, -0.770718401931, 0.748423272307, 0.502544067819, 0.346625621533},
+ {-76.2896177709, -0.16440174812, -1.77431555198, 0.195326723837, 2.01240994405, -1.19559207119},
+ {-3.23827624818, -0.674138419631, -1.62238580284, 2.02235607862, 0.679194838679, 0.150203732584},
+ {-21.962456854, -0.766271014206, 0.958599712131, -0.313045794728, 0.232655576106, -0.360950549871},
+ {349.583669646, 1.75976166947, 1.47271612346, 0.0346005603489, 0.474907228495, 0.61379496381},
+ {-418.397356757, -1.83395936566, -0.911702678716, -0.532478094882, -2.03835348133, -0.423005552518},
+ {55.0298153952, -0.0301384716096, -0.0137929430966, -0.348583692759, 0.986486580719, 0.154436524434},
+ {127.150063206, 1.92682560465, -0.434844790414, 0.1082898967, -0.00723338222402, -0.513199251824},
+ {89.6172507626, 1.02463790902, 0.744369837717, 1.250323683, -1.58252612128, -0.588242778808},
+ {92.5124829355, -0.403298547743, 0.0422774545428, -0.175000467434, 1.61110066857, 0.422330077287},
+ {-303.040366788, 0.611569308879, -1.21926246291, -2.49250330276, -0.789166929605, -1.30166501196},
+ {-17.4020602839, 1.72337202371, -1.83540537288, 0.731588761841, -0.338642535062, -1.11053518125},
+ {114.918701324, 0.437385758628, 0.975885170381, 0.439444038872, 1.51666514156, -1.93095020264},
+ {-8.43548064928, -0.799507968686, -0.00842968328782, -0.154994093964, 1.09169753491, -0.0114818657732},
+ {109.209286025, 2.56472965015, -2.07047248035, -0.46764001177, 0.845267147375, -0.236767841427},
+ {61.5259982971, -0.379391870148, -0.131017762354, -0.220275015864, 1.82097825699, -0.0568354876403},
+ {-71.3872099588, 0.642138455414, -1.00242489879, 0.536780074488, 0.350977275771, -1.8204862883},
+ {-21.2768078629, -0.454268998895, 0.0992324274219, 0.0363496803224, 0.281940751723, -0.198435570828},
+ {-8.07838891387, -0.331642089041, -0.494067341253, 0.386035842816, -0.738221128298, 1.18236299649},
+ {30.4818041751, 0.099206096537, 0.150688905006, 0.332932621949, 0.194845631964, -0.446717875795},
+ {237.209150991, 1.12560447042, 0.448488431264, -0.724623711259, 0.401868257097, 1.67129001163},
+ {185.172816475, 0.36594142556, -0.0796476435741, 0.473836257, 1.30890722633, 0.592415068693},
+ {19.8830237044, 1.52497319332, 0.466906090264, -0.716635613964, -1.19532276745, -0.697663531684},
+ {209.396793626, 0.368478789658, 0.699162303982, 1.96702434462, -0.815379139879, 0.863369634396},
+ {-215.100514168, -1.83902416164, -1.14966820385, -1.01044860587, 1.76881340629, -0.32165916241},
+ {-33.4687353426, -0.0451102002703, 0.642212950033, 0.580822065219, -1.02341504063, -0.781229325942},
+ {150.251474823, 0.220170650298, 0.224858901011, 0.541299425328, 1.15151550963, 0.0329044069571},
+ {92.2160506097, 1.86450932451, -0.991150940533, -1.49137866968, 1.02113774105, 0.0544762857136},
+ {41.2138467595, -0.778892265105, 0.714957464344, 1.79833618993, -0.335322825621, -0.397548301803},
+ {13.151262759, 0.301745607362, 0.129778280739, 0.260094818273, -0.10587841585, -0.599330307629},
+ {-367.864703951, -1.68695981263, -0.611957677512, -0.0362971579679, -1.2169760515, -1.43224375134},
+ {-57.218869838, 0.428806849751, 0.654302177028, -1.31651788496, 0.363857431276, -1.49953703016},
+ {53.0877462955, -0.411907760185, -0.192634094071, -0.275879375023, 0.603562526571, 1.16508196734},
+ {-8.11860742896, 1.00263982158, -0.157031169267, -1.11795623393, 0.35711440521, -0.851124640982},
+ {-49.1878248403, -0.0253797866589, -0.574767070714, 0.200339045636, -0.0107042446803, -0.351288977927},
+ {-73.8835407053, -2.07980276724, 1.12235566491, -0.917150593536, 0.741384768556, 0.56229424235},
+ {143.163604045, 0.33627769945, 1.07948757447, 0.894869929963, 1.18688316974, -1.54722487849},
+ {92.7045830908, 0.944091525689, 0.693296229491, 0.700097596814, -1.23666276942, -0.203890113084},
+ {79.1878852355, -0.221973023853, -0.566066329011, 1.57683748648, 0.52854717911, 0.147924782476},
+ {30.6547392801, -1.03466213359, 0.606784904328, -0.298096511956, 0.83332987683, 0.636339018254},
+ {-329.128386019, -1.41363866598, -1.34966434823, -0.989010564149, 0.46889477248, -1.20493210784},
+ {121.190205512, 0.0393914245697, 1.98392444232, -0.65310705226, -0.385899987099, 0.444982471471},
+ {-97.0333075649, 0.264325871992, -0.43074811924, -1.14737761316, -0.453134140655, -0.038507405311},
+ {158.273624516, 0.302255432981, -0.292046617818, 1.0704087606, 0.815965268115, 0.470631083546},
+ {8.24795061818, -1.15155524496, 1.29538707184, -0.4650881541, 0.805123486308, -0.134706887329},
+ {87.1140049059, -0.103540823781, -0.192259440773, 1.79648860085, -1.07525447993, 1.06985127941},
+ {-25.1300772481, -0.97140742052, 0.033393948794, -0.698311192672, 0.74417168942, 0.752776770225},
+ {-285.477057638, -0.480612406803, -1.46081500036, -1.92518386336, -0.426454066275, -0.0539099489597},
+ {-65.1269988498, -1.22733468764, 0.121538452336, 0.752958777557, -0.40643211762, 0.257674949803},
+ {-17.1813504942, 0.823753836891, 0.445142465255, 0.185644700144, -1.99733367514, -0.247899323048},
+ {-46.7543447303, 0.183482778928, -0.934858705943, -1.21961947396, 0.460921844744, 0.571388077177},
+ {-1.7536190499, -0.107517908181, 0.0334282610968, -0.556676121428, -0.485957577159, 0.943570398164},
+ {-42.8460452689, 0.944999215632, 0.00530052154909, -0.348526283976, -1.724125354, -0.122649339813},
+ {62.6291497267, 0.249619894002, 1.3139125969, -1.5644227783, 0.117605482783, 0.304844650662},
+ {97.4552176343, 1.59332799639, -1.17868305562, 1.02998378902, -0.31959491258, -0.183038322076},
+ {-6.19358885758, 0.437951016253, 0.373339269494, -0.204072768495, 0.477969349931, -1.52176449389},
+ {34.0350630099, 0.839319087287, -0.610157662489, 1.73881448393, -1.89200107709, 0.204946415522},
+ {54.9790822536, -0.191792583114, 0.989791127554, -0.502154080064, 0.469939512389, -0.102304071079},
+ {58.8272402843, 0.0769623906454, 0.501297284297, -0.410054999243, 0.595712387781, -0.0968329050729},
+ {95.3620983209, 0.0661481959314, 0.0935137309086, 1.11823292347, -0.612960777903, 0.767865072757},
+ {62.4278196648, 0.78350610065, -1.09977017652, 0.526824784479, 1.41310104196, -0.887902707319},
+ {57.6298676729, 0.60084172954, -0.785932027202, 0.0271301584637, -0.134109499719, 0.877256170191},
+ {5.14112905382, -0.738359365006, 1.40242539359, -0.852833010305, -0.68365080837, 0.88561193696},
+ {11.6057244034, -0.958911227571, 1.15715937023, 1.20108425431, 0.882980929338, -1.77404120156},
+ {-265.758185272, -1.2092434823, -0.0550151798639, 0.00703735243613, -1.01767244359, -1.40616581707},
+ {180.625928828, -0.139091127126, 0.243250756129, 2.17509702585, -0.541735827898, 1.2109459934},
+ {-183.604103216, -0.324555097769, -1.71317286749, 1.03645005723, 0.497569347608, -1.96688185911},
+ {9.93237328848, 0.825483591345, 0.910287997312, -1.64938108528, 0.98964075968, -1.65748940528},
+ {-88.6846949813, -0.0759295112746, -0.593311990101, -0.578711915019, 0.256298822361, -0.429322890198},
+ {175.367391479, 0.9361754906, -0.0172852897292, 1.04078658833, 0.919566407184, -0.554923019093},
+ {-175.538247146, -1.43498590417, 0.37233438556, -0.897205352198, -0.339309952316, -0.0321624527843},
+ {-126.331680318, 0.160446617623, 0.816642363249, -1.39863371652, 0.199747744327, -2.13493607457},
+ {116.677107593, 1.19300905847, -0.404409346893, 0.646338976096, -0.534204093869, 0.36692724765},
+ {-181.675962893, -1.57613169533, -0.41549571451, -0.956673746013, 0.35723782515, 0.318317395128},
+ {-55.1457877823, 0.63723030991, -0.324480386466, 0.296028333894, -1.68117515658, -0.131945601375},
+ {25.2534791013, 0.594818219911, -0.0247380403547, -0.101492246071, -0.0745619242015, -0.370837128867},
+ {63.6006283756, -1.53493473818, 0.946464097439, 0.637741397831, 0.938866921166, 0.54405291856},
+ {-69.6245547661, 0.328482934094, -0.776881060846, -0.285133098443, -1.06107824512, 0.49952182341},
+ {233.425957233, 3.10582399189, -0.0854710508706, 0.455873479133, -0.0974589364949, -1.18914783551},
+ {-86.5564290626, -0.819839276484, 0.584745927593, -0.544737106102, -1.21927675581, 0.758502626434},
+ {425.357285631, 1.70712253847, 1.19892647853, 1.60619661301, 0.36832665241, 0.880791322709},
+ {111.797225426, 0.558940594145, -0.746492420236, 1.90172101792, 0.853590062366, -0.867970723941},
+ {-253.616801014, -0.426513440051, 0.0388582291888, -1.18576061365, -2.70895868242, 0.26982210287},
+ {-394.801501024, -1.65087241498, 0.735525201393, -2.02413077052, -0.96492749037, -1.89014065613}
+ }, new double[] {93.3843533037, 72.3610889215, 57.5295295915, 63.7287541653, 65.2263084024}, 6.85683020686);
+
+ /**
+ * Artificial dataset with 100 observations described by 10 features.
+ */
+ public static final TestDataset regression100x10 = new TestDataset(new double[][] {
+ {69.5794204114, -0.684238565877, 0.175665643732, 0.882115894035, 0.612844187624,
+ -0.685301720572, -0.8266500007, -0.0383407025118, 1.7105205222, 0.457436379836, -0.291563926494},
+ {80.1390102826, -1.80708821811, 0.811271788195, 0.30248512861, 0.910658009566,
+ -1.61869762501, -0.148325085362, -0.0714164596509, 0.671646742271, 2.15160094956, -0.0495754979721},
+ {-156.975447515, 0.170702943934, -0.973403372054, -0.093974528453, 1.54577255871,
+ -0.0969022857972, -1.10639617368, 1.51752480948, -2.86016865032, 1.24063030602, -0.521785751026},
+ {-158.134931891, 0.0890071395055, -0.0811824442353, -0.737354274843, -1.7575255492,
+ 0.265777246641, 0.0745347238144, -0.457603542683, -1.37034043839, 1.86011799875, 0.651214189491},
+ {-131.465820263, 0.0767565260375, 0.651724194978, 0.142113799753, 0.244367469855,
+ -0.334395162837, -0.069092305876, -0.691806779713, -1.28386786177, -1.43647491141, 0.00721053414234},
+ {-125.468890054, 0.43361925912, -0.800231440065, -0.576001094593, 0.0783664516431,
+ -1.33613252233, -0.968385062126, -1.22077801286, 0.193456109638, -3.09372314386, 0.817979620215},
+ {-44.1113403874, -0.595796803171, 1.29482131972, -0.784513985654, 0.364702038003,
+ -3.2452492093, -0.451605560847, 0.988546607514, 0.492096628873, -0.343018842342, -0.519231306954},
+ {61.2269707872, -0.0289059337716, -1.00409238976, 0.329908621635, 1.41965097539,
+ 0.0395065997587, -0.477939549336, 0.842336765911, -0.808790019648, 1.70241718768, -0.117194118865},
+ {301.434286126, 0.430005308515, 1.01290089725, -0.228221561554, 0.463405921629,
+ -0.602413489517, 1.13832440088, 0.930949226185, -0.196440161506, 1.46304624346, 1.23831509056},
+ {-270.454814681, -1.43805412632, -0.256309572507, -0.358047601174, 0.265151660237,
+ 1.07087986377, -1.93784654681, -0.854440691754, 0.665691996289, -1.87508012738, -0.387092423365},
+ {-97.6198688184, -1.67658167161, -0.170246709551, -2.26863722189, 0.280289356338,
+ -0.690038347855, -1.69282684019, 0.978606053022, 1.28237852256, -1.2941998486, 0.766405365374},
+ {-29.5630902399, -1.75615633921, 0.633927486329, -1.24117311555, -0.15884687004,
+ 0.31296863712, -1.29513272039, 0.344090683606, 1.19598425093, -1.96195019104, 1.81415061059},
+ {-130.896377427, 0.577719366939, -0.087267771748, -0.060088767013, 0.469803880788,
+ -1.03078212088, -1.41547398887, 1.38980586981, -0.37118000595, -1.81689513712, -0.3099432567},
+ {79.6300698059, 1.23408625633, 1.06464588017, 1.23403332691, -1.10993859098,
+ 0.874825200577, 0.589337796957, -1.10266185141, 0.842960469618, -0.89231962021, 0.284074900504},
+ {-154.712112815, -1.64474237898, -0.328581696933, 0.38834343178, 0.02682160335,
+ -0.251167527796, -0.199330632103, -0.0405837345525, -0.908200250794, -1.3283756975, 0.540894408264},
+ {233.447381562, 0.395156450609, 0.156412599781, 0.126453148554, 2.40829068933,
+ 1.01623530754, -0.0856520211145, -0.874970377099, 0.280617145254, -0.307070438514, 0.4599616054},
+ {209.012380432, -0.848646647675, 0.558383548084, -0.259628264419, 1.1624126549,
+ -0.0755949979572, -0.373930759448, 0.985903312667, 0.435839508011, -0.760916312668, 1.89847574116},
+ {-39.8987262091, 0.176656582642, 0.508538223618, 0.995038391204, -2.08809409812,
+ 0.743926580134, 0.246007971514, -0.458288599906, -0.579976479473, 0.0591577146017, 1.64321662761},
+ {222.078510236, -0.24031989218, -0.168104260522, -0.727838425954, 0.557181757624,
+ -0.164906646307, 2.01559331734, 0.897263594222, 0.0921535309562, 0.351910490325, -0.018228500121},
+ {-250.916272061, -2.71504637339, 0.498966191294, -3.16410707344, -0.842488891776,
+ 1.27425275951, 0.0141733666756, 0.695942743199, 0.0917995810179, -0.501447196978, -0.355738068451},
+ {134.07259088, 0.0845637591619, 0.237410106679, -0.291458113729, 1.39418566986,
+ -1.18813057956, -0.683117067763, -0.518910379335, 1.35998426879, -1.28404562245, 0.489131754943},
+ {104.988440209, 0.00770925058526, 0.47113239214, -0.606231247854, 0.310679840217,
+ 0.146297599928, 0.732013998647, -0.284544010865, 0.402622530153, -0.0217367745613, 0.0742970687987},
+ {155.558071031, 1.11171654653, 0.726629222799, -0.195820863177, 0.801333855535,
+ 0.744034755544, 1.11377275513, -0.75673532139, -0.114117607244, -0.158966474923, -0.29701120385},
+ {90.7600194013, -0.104364079622, -0.0165109945217, 0.933002972987, -1.80652594466,
+ -1.34760892883, -0.304511906801, 0.0584734540581, 1.5332169392, 0.478835797824, 1.71534051065},
+ {-313.910553214, 0.149908925551, 0.232806828559, -0.0708920471592, -0.0649553559745,
+ 0.377753357707, -0.957292311668, 0.545360522582, -1.37905464371, -0.940702110994, -1.53620430047},
+ {-80.9380113754, 0.135586606896, 0.95759558815, -1.36879020479, 0.735413996144,
+ 0.637984100201, -1.79563152885, 1.55025691631, 0.634702068786, -0.203690334141, -0.83954824721},
+ {-244.336816695, -0.179127343947, -2.12396005014, -0.431179356484, -0.860562153749,
+ -1.10270688639, -0.986886012982, -0.945091656162, -0.445428453767, 1.32269756209, -0.223712672168},
+ {123.069612745, 0.703857129626, 0.291605144784, 1.40233051946, 0.278603787802,
+ -0.693567967466, -0.15587953395, 2.10213915684, 0.130663329174, -0.393184478882, 0.0874812844555},
+ {-148.274944223, 1.66294967732, 0.0830002694123, 0.32492930502, 1.11864359687,
+ -0.381901627785, -1.06367037132, -0.392583620174, -1.16283326187, 0.104931461025, -1.64719611405},
+ {-82.0018788235, 0.497118817453, 0.731125358012, -0.00976413646786, -0.0178930713492,
+ -0.814978582886, 0.0602834712523, -0.661940479055, -0.957902899386, -1.34489251111, 0.22166518707},
+ {-35.742996986, 0.0661349516701, -0.204314495629, 1.17101314753, -2.53846825562,
+ -0.560282479298, -0.393442894828, 0.988953809491, -0.911281277704, 0.86862242698, 2.59576940486},
+ {-109.588885664, -0.0793151346628, -0.408962434518, -0.598817776528, 0.0277205469561,
+ 0.116291018958, 0.0280416838086, -0.72544170676, -0.669302814774, 0.0751898759816, -0.311002356179},
+ {57.8285173441, 0.53753903532, 0.676340503752, -2.10608342721, 0.477714987751,
+ 0.465695114442, 0.245966562421, -1.05230350808, -0.309794163113, -1.12067331828, 1.07841453304},
+ {204.660622582, -0.717565166685, 0.295179660279, -0.377579912697, 1.88425526905,
+ 0.251875238436, -0.900214103232, -1.02877401105, 0.291693915093, 1.24889067987, 1.78506220081},
+ {350.949109103, 2.82276814452, -0.429358342127, 1.12140362367, 1.18120725208,
+ -1.63913834939, 1.61441562446, -0.364003766916, -0.258752942225, -0.808124680189, 0.556463488303},
+ {170.960252153, 0.147245922081, 0.3257117575, 0.211749283649, -0.0150701808404,
+ -0.888523132148, 0.777862088798, 0.296729270892, -0.332927550718, 0.888968144245, 1.20913118467},
+ {112.192270383, 0.129846138824, -0.934371449036, -0.595825303214, 1.74749214629,
+ -0.0500069421443, -0.161976298602, -2.54100791613, 1.99632530735, -0.0691582773758, -0.863939367415},
+ {-56.7847711121, 0.0950532853751, -0.467349228201, -0.26457152362, -0.422134692317,
+ -0.0734763062127, 0.90128235602, -1.68470856275, -0.0699692697335, -0.463335845504, -0.301754321169},
+ {-37.9223252258, -1.40835827778, 0.566142056244, -3.22393318933, 0.228823495106,
+ -1.8480727782, 0.129468321643, -1.77392686536, 0.0112549619662, 0.146433267822, 1.29379901303},
+ {-59.7303066136, 0.835675535576, -0.552173157548, 1.90730898966, -0.520145317195,
+ 1.55174485912, -1.37531768692, -0.408165743742, 0.0939675842223, 0.318004128812, 0.324378038446},
+ {-0.916090786983, 0.425763794043, -0.295541268984, -0.066619586336, 2.03494974978,
+ -0.197109278058, -0.823307883209, 0.895531446352, -0.276435938737, -1.54580056755, -0.820051830246},
+ {-20.3601082842, 0.56420556369, 0.741234589387, -0.565853617392, -0.311399905686,
+ 2.24066463251, -0.071704904286, -1.22796531596, 0.186020404046, -0.786874824874, 0.23140277151},
+ {-22.9342855182, -0.0682789648279, -1.30680909143, 0.0486490588348, 0.890275695028,
+ -0.257961411112, -0.381531755985, 1.56251482581, -2.11808219232, 0.741828675202, 0.696388901165},
+ {-157.251026807, -2.3120966502, 0.183734662375, 1.02192264962, 0.591272941061,
+ -0.0132855098339, -1.02016546348, 1.19642432892, 0.867653154846, -1.37600041722, -1.08542822792},
+ {-68.6110752055, -1.2429968179, -0.950064269349, -0.332379873336, 0.25793632341,
+ 0.145780713577, -0.512109283074, -0.477887632032, 0.448960776324, -0.190215737958, 0.219578347563},
+ {-56.1204152481, -0.811729480846, -0.647410362207, 0.934547463984, -0.390943346216,
+ -0.409981308474, 0.0923465893049, 1.9281242912, -0.624713581674, -0.0599353282306, -0.0188591746808},
+ {348.530651658, 2.51721790231, 0.7560998114, -2.69620396681, 0.5174276585,
+ 0.403570816695, 0.901648571306, 0.269313230294, 1.07811463589, 0.986649559679, 0.514710327657},
+ {-105.719065924, 0.679016972998, 0.341319363316, -0.515209647377, 0.800000866847,
+ -0.795474442628, -0.866849274801, -1.32927961486, 0.17679343917, -1.93744422464, -0.476447619273},
+ {-197.389429553, -1.98585668879, -0.962610549884, -2.48860863254, -0.545990524642,
+ -0.13005685654, -1.23413782366, 1.17443427507, 1.4785554038, -0.193717671824, -0.466403609229},
+ {-23.9625285402, -0.392164367603, 1.07583388583, -0.412686712477, -0.89339030785,
+ -0.774862334739, -0.186491999529, -0.300162444329, 0.177377235999, 0.134038296039, 0.957945226616},
+ {-91.145725943, -0.154640540119, 0.732911957939, -0.206326119636, -0.569816760116,
+ 0.249393336416, -1.02762332953, 0.25096708081, 0.386927162941, -0.346382299592, 0.243099162109},
+ {-80.7295722208, -1.72670707303, 0.138139045677, 0.0648055728598, 0.186182854422,
+ 1.07226527747, -1.26133459043, 0.213883744163, 1.47115466163, -1.54791582859, 0.170924664865},
+ {-317.060323531, -0.349785690206, -0.740759426066, -0.407970845617, -0.689282767277,
+ -1.25608665316, -0.772546119412, -2.02925712813, 0.132949072522, -0.191465137244, -1.29079690284},
+ {-252.491508279, -1.24643122869, 1.55335609203, 0.356613424877, 0.817434495353,
+ -1.74503747683, -0.818046363088, -1.58284235058, 0.357919389759, -1.18942962791, -1.91728745247},
+ {-66.8121363157, -0.584246455697, -0.104254351782, 1.17911687508, -0.29288167882,
+ 0.891836132692, 0.232853863255, 0.423294355343, -0.669493690103, -1.15783890498, 0.188213983735},
+ {140.681464689, 1.33156046873, -1.8847915949, -0.666528837988, -0.513356191443,
+ 0.281290031669, -1.07815005006, 1.22384196227, 1.39093631269, 0.527644817197, 1.21595221509},
+ {-174.22326767, 0.475428766034, 0.856847216768, -0.734282773151, -0.923514989791,
+ 0.917510828772, 0.674878068543, 0.0644776431114, -0.607796192908, 0.867740011912, -1.97799769281},
+ {74.3899799579, 0.00915743526294, 0.553578683413, 1.66930486354, 0.15562803404,
+ 1.8455840688, -0.371704942927, 1.11228894843, -0.37464389118, -0.48789151589, 0.79553866342},
+ {70.1167175897, 0.154877045187, 1.47803572976, -0.0355743163524, -2.47914644675,
+ 0.672384381837, 1.63160379529, 1.81874583854, 1.22797339421, -0.0131258061634, -0.390265963676},
+ {-11.0364788877, 0.173049156249, -1.78140521797, -1.29982707214, -0.48025663179,
+ -0.469112922302, -1.98718063269, 0.585086542043, 0.264611327837, 1.48855512579, 2.00672263496},
+ {-112.711292736, -1.59239636827, -0.600613018822, -0.0209667499746, -1.81872893331,
+ -0.739893084955, 0.140261888569, -0.498107678308, 2.53664045504, -0.536385019089, -0.608755809378},
+ {-198.064468217, 0.737175509877, -2.01835515547, -2.18045950065, 0.428584922529,
+ -1.01848835019, -0.470645361539, -0.00703630153547, -2.2341302754, 1.51483167022, -0.410184418418},
+ {70.2747963991, 1.49474111532, -0.19517712503, 0.7392852909, -0.326060871666,
+ -0.566710349675, 0.14053094122, -0.562830341306, 0.22931613446, -0.0344439061448, 0.175150510551},
+ {207.909021337, 0.839887009159, 0.268826583246, -0.313047158862, 1.12009996015,
+ 0.214209976971, -0.396147338251, 2.16039704403, 0.699141312749, 0.756192350992, -0.145368196901},
+ {169.428609429, -1.13702350819, 1.23964530597, -0.864443556622, -0.885630795949,
+ -0.523872327352, 0.467159824748, 0.476596383923, 0.4343735578, 1.4075417896, 2.22939328991},
+ {-176.909833405, 0.0875512760866, -0.455542269288, 0.539742307764, -0.762003092788,
+ 0.41829123457, -0.818116139644, -2.01761645956, 0.557395073218, 1.5823271814, -1.0168826293},
+ {-27.734298611, -0.841257541979, 0.348961259301, 1.36935991472, -0.0694528057586,
+ -1.27303784913, 0.152155656569, 1.9279466651, 0.9589415766, -1.76634370106, -1.08831026428},
+ {-55.8416853588, 0.927711536927, 0.157856746063, -0.295628714893, 0.0296602829783,
+ 1.75198587897, -0.38285446366, -0.253287154535, -1.64032395229, -0.842089054965, 1.00493779183},
+ {56.0899797005, 0.326117761734, -1.93514762146, 1.0229172721, 0.125568968732,
+ 2.37760000658, -0.498532972011, -0.733375842271, -0.757445726993, -0.49515057432, 2.01559891524},
+ {-176.220234909, 1.571129843, -0.867707605929, -0.709690799512, -1.51535538937,
+ 1.27424225477, -0.109513704468, -1.46822183, 0.281077088939, -1.97084024232, -0.322309524179},
+ {37.7155152941, 0.363383774219, -0.0240881298641, -1.60692745228, -1.26961656439,
+ -0.41299134216, 1.2890099968, -1.34101694629, -0.455387485256, -0.14055003482, 1.5407059956},
+ {-102.163416997, -2.05927378316, -0.470182865756, -0.875528863204, 0.0361720859253,
+ -1.03713912263, 0.417362606334, 0.707587625276, -0.0591627772581, -2.58905252006, 0.516573345216},
+ {-206.47095321, 0.270030584651, 1.85544202116, -0.144189208964, -0.696400687327,
+ 0.0226388634283, -0.490952489106, -1.69209527849, 0.00973614309272, -0.484105876992, -0.991474668217},
+ {201.50637416, 0.513659215697, -0.335630132208, -0.140006500483, 0.149679720127,
+ -1.89526167503, -0.0614973894156, 0.0813221153552, 0.630952530848, 2.40201011339, 0.997708264073},
+ {-72.0667371571, 0.0841570292899, -0.216125859013, -1.77155215764, 2.15081767322,
+ 0.00953341785443, -1.0826077946, -0.791135571106, -0.989393577892, -0.791485083644, -0.063560999686},
+ {-162.903837815, -0.273764637097, 0.282387854873, -1.39881596931, 0.554941097854,
+ -0.88790718926, -0.693189960902, 0.398762630571, -1.61878562893, -0.345976341096, 0.138298909959},
+ {-34.3291926715, -0.499883755911, -0.847296893019, -0.323673126437, 0.531205373462,
+ -0.0204345595983, 0.284954510306, 0.565031773028, -0.272049818708, -0.130369799738, -0.617572026201},
+ {76.1272883187, -0.908810282403, -1.04139421904, 0.890678872055, 1.32990256154,
+ -0.0150445428835, 0.593918101047, 0.356897732999, 0.824651162423, -1.54544256217, -0.795703905296},
+ {171.833705285, -0.0425219657568, -0.884042952325, 1.91202504537, 0.381908223898,
+ -0.205693527739, 1.53656598237, 0.534880398015, 0.291950716831, -1.1258051056, -0.0612803476297},
+ {-235.445792009, 0.261252102941, -0.170931758001, 1.67878144235, 0.0278283741792,
+ -1.23194408479, -0.190931886594, 1.0000157972, -2.18792142659, -0.230654984288, -1.36626493512},
+ {348.968834231, 1.35713154434, 0.950377770072, 0.0700577471848, 0.96907140156,
+ 2.00890422081, 0.0896405239806, 0.614309607351, 1.07723409067, 2.58506968136, 0.202889806148},
+ {-61.0128039201, 0.465438505031, -1.31448530533, 0.374781933416, -0.0118298606041,
+ -0.477338357738, -0.587656108109, 1.66449545077, 0.435836048385, -0.287027953004, -1.06613472784},
+ {-50.687090469, 0.382331825989, -0.597140322197, 1.1276065465, -1.35593777887,
+ 1.14949964423, -0.858742432885, -0.563211485633, -0.57167161928, 0.0294891749132, 1.9571639493},
+ {-186.653649045, -0.00981380006029, 1.0371088941, -1.25319048981, -0.694043021068,
+ 1.7280802541, -0.191210409232, -0.866039238001, -0.0791927416078, -0.232228656558, -0.93723545053},
+ {34.5395591744, 0.680943971029, -0.075875481801, -0.144408300848, -0.869070791528,
+ 0.496870904214, 1.0940401388, -0.510489750436, -0.47562728601, 0.951406841944, 0.12983846382},
+ {-23.7618645627, 0.527032820313, -0.58295129357, -0.3894567306, -0.0547905472556,
+ -1.86103603537, 0.0506988360667, 1.02778539291, -0.0613720063422, 0.411280841442, -0.665810811374},
+ {116.007776415, 0.441750249008, 0.549342185228, 0.731558201455, -0.903624700864,
+ -2.13208328824, 0.381223328983, 0.283479210749, 1.17705098922, -2.38800904207, 1.32108350152},
+ {-148.479593311, -0.814604260049, -0.821204361946, -1.08768677334, -0.0659445766599,
+ 0.583741297405, 0.669345853296, -0.0935352010726, -0.254906787938, -0.394599725657, -1.26305927257},
+ {244.865845084, 0.776784257443, 0.267205388558, 2.37746488031, -0.379275360853,
+ -0.157454754411, -0.359580726073, 0.886887721861, 1.53707627973, 0.634390546684, 0.984864824122},
+ {-81.9954096721, 0.594841146008, -1.22273253129, 0.532466794358, 1.69864239257,
+ -0.12293671327, -2.06645974171, 0.611808231703, -1.32291985291, 0.722066660478, -0.0021343848511},
+ {-245.715046329, -1.77850303496, -0.176518810079, 1.20463434525, -0.597826204963,
+ -1.45842350123, -0.765730251727, -2.17764204443, 0.12996635702, -0.705509516482, 0.170639846082},
+ {123.011946043, -0.909707162714, 0.92357208515, 0.373251929121, 1.24629576577,
+ 0.0662688299998, -0.372240547929, -0.739353735168, 0.323495756066, 0.954154005738, 0.69606859977},
+ {-70.4564963177, 0.650682297051, 0.378131376232, 1.37860253614, -0.924042783872,
+ 0.802851073842, -0.450299927542, 0.235646185302, -0.148779896161, 1.01308126122, -0.48206889502},
+ {21.5288687935, 0.290876355386, 0.0765702960599, 0.905225489744, 0.252841861521,
+ 1.26729272819, 0.315397441908, -2.00317261368, -0.250990653758, 0.425615332405, 0.0875320802483},
+ {231.370169905, 0.535138021352, -1.07151617232, 0.824383756287, 1.84428896701,
+ -0.890892034494, 0.0480296332924, -0.59251208055, 0.267564961845, -0.230698441998, 0.857077278291},
+ {38.8318274023, 2.63547217711, -0.585553060394, 0.430550920323, -0.532619160993,
+ 1.25335488136, -1.65265278435, 0.0433880112291, -0.166143379872, 0.534066441314, 1.18929937797},
+ {116.362219013, -0.275949982433, 0.468069787645, -0.879814121059, 0.862799331322,
+ 1.18464846725, 0.747084253268, 1.39202500691, -1.23374181275, 0.0949815110503, 0.696546907194},
+ {260.540154731, 1.13798788241, -0.0991903174656, 0.1241636043, -0.201415073037,
+ 1.57683389508, 1.81535629587, 1.07873616646, -0.355800782882, 2.18333193195, 0.0711071144615},
+ {-165.835194521, -2.76613178307, 0.805314338858, 0.81526046683, -0.710489036197,
+ -1.20189542317, -0.692110074722, -0.117239516622, 1.0431459458, -0.111898596299, -0.0775811519297},
+ {-341.189958588, 0.668555635008, -1.0940034941, -0.497881262778, -0.603682823779,
+ -0.396875163796, -0.849144848521, 0.403936807183, -1.82076277475, -0.137500972546, -1.22769896568}
+ }, new double[] {45.8685095528, 11.9400336005, 16.3984976652, 79.9069814034, 5.65486853464,
+ 83.6427296424, 27.4571268153, 73.5881193584, 27.1465364511, 79.4095449062}, -5.14077007134);
+
+ /** */
+ public static class TestDataset {
+
+ /** */
+ private final double[][] data;
+
+ /** */
+ private final double[] expWeights;
+
+ /** */
+ private final double expIntercept;
+
+ /** */
+ TestDataset(double[][] data, double[] expWeights, double expIntercept) {
+ this.data = data;
+ this.expWeights = expWeights;
+ this.expIntercept = expIntercept;
+ }
+
+ /** */
+ public double[][] getData() {
+ return data;
+ }
+
+ /** */
+ public double[] getExpWeights() {
+ return expWeights;
+ }
+
+ /** */
+ public double getExpIntercept() {
+ return expIntercept;
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java
new file mode 100644
index 0000000..0c09d75
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+
+/**
+ * Tests for {@link LinearRegressionQRTrainer} on {@link SparseBlockDistributedMatrix}.
+ */
+public class BlockDistributedLinearRegressionQRTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest {
+ /** */
+ public BlockDistributedLinearRegressionQRTrainerTest() {
+ super(
+ new LinearRegressionQRTrainer(),
+ SparseBlockDistributedMatrix::new,
+ SparseBlockDistributedVector::new,
+ 1e-6
+ );
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java
new file mode 100644
index 0000000..58037e2
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+
+/**
+ * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseBlockDistributedMatrix}.
+ */
+public class BlockDistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest {
+ /** */
+ public BlockDistributedLinearRegressionSGDTrainerTest() {
+ super(
+ new LinearRegressionSGDTrainer(100_000, 1e-12),
+ SparseBlockDistributedMatrix::new,
+ SparseBlockDistributedVector::new,
+ 1e-2);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java
new file mode 100644
index 0000000..2a506d9
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
+
+/**
+ * Tests for {@link LinearRegressionQRTrainer} on {@link SparseDistributedMatrix}.
+ */
+public class DistributedLinearRegressionQRTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest {
+ /** */
+ public DistributedLinearRegressionQRTrainerTest() {
+ super(
+ new LinearRegressionQRTrainer(),
+ SparseDistributedMatrix::new,
+ SparseDistributedVector::new,
+ 1e-6
+ );
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java
new file mode 100644
index 0000000..71d3b3b
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
+
+/**
+ * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseDistributedMatrix}.
+ */
+public class DistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest {
+ /** */
+ public DistributedLinearRegressionSGDTrainerTest() {
+ super(
+ new LinearRegressionSGDTrainer(100_000, 1e-12),
+ SparseDistributedMatrix::new,
+ SparseDistributedVector::new,
+ 1e-2);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java
new file mode 100644
index 0000000..a55623c
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import java.util.Scanner;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.junit.Test;
+
+/**
+ * Base class for all linear regression trainers.
+ */
+public class GenericLinearRegressionTrainerTest {
+ /** */
+ private final Trainer<LinearRegressionModel, Matrix> trainer;
+
+ /** */
+ private final IgniteFunction<double[][], Matrix> matrixCreator;
+
+ /** */
+ private final IgniteFunction<double[], Vector> vectorCreator;
+
+ /** */
+ private final double precision;
+
+ /** */
+ public GenericLinearRegressionTrainerTest(
+ Trainer<LinearRegressionModel, Matrix> trainer,
+ IgniteFunction<double[][], Matrix> matrixCreator,
+ IgniteFunction<double[], Vector> vectorCreator,
+ double precision) {
+ this.trainer = trainer;
+ this.matrixCreator = matrixCreator;
+ this.vectorCreator = vectorCreator;
+ this.precision = precision;
+ }
+
+ /**
+ * Test trainer on regression model y = 2 * x.
+ */
+ @Test
+ public void testTrainWithoutIntercept() {
+ Matrix data = matrixCreator.apply(new double[][] {
+ {2.0, 1.0},
+ {4.0, 2.0}
+ });
+
+ LinearRegressionModel mdl = trainer.train(data);
+
+ TestUtils.assertEquals(4, mdl.apply(vectorCreator.apply(new double[] {2})), precision);
+ TestUtils.assertEquals(6, mdl.apply(vectorCreator.apply(new double[] {3})), precision);
+ TestUtils.assertEquals(8, mdl.apply(vectorCreator.apply(new double[] {4})), precision);
+ }
+
+ /**
+ * Test trainer on regression model y = -1 * x + 1.
+ */
+ @Test
+ public void testTrainWithIntercept() {
+ Matrix data = matrixCreator.apply(new double[][] {
+ {1.0, 0.0},
+ {0.0, 1.0}
+ });
+
+ LinearRegressionModel mdl = trainer.train(data);
+
+ TestUtils.assertEquals(0.5, mdl.apply(vectorCreator.apply(new double[] {0.5})), precision);
+ TestUtils.assertEquals(2, mdl.apply(vectorCreator.apply(new double[] {-1})), precision);
+ TestUtils.assertEquals(-1, mdl.apply(vectorCreator.apply(new double[] {2})), precision);
+ }
+
+ /**
+ * Test trainer on diabetes dataset.
+ */
+ @Test
+ public void testTrainOnDiabetesDataset() {
+ Matrix data = loadDataset("datasets/regression/diabetes.csv", 442, 10);
+
+ LinearRegressionModel mdl = trainer.train(data);
+
+ Vector expWeights = vectorCreator.apply(new double[] {
+ -10.01219782, -239.81908937, 519.83978679, 324.39042769, -792.18416163,
+ 476.74583782, 101.04457032, 177.06417623, 751.27932109, 67.62538639
+ });
+
+ double expIntercept = 152.13348416;
+
+ TestUtils.assertEquals("Wrong weights", expWeights, mdl.getWeights(), precision);
+ TestUtils.assertEquals("Wrong intercept", expIntercept, mdl.getIntercept(), precision);
+ }
+
+ /**
+ * Test trainer on boston dataset.
+ */
+ @Test
+ public void testTrainOnBostonDataset() {
+ Matrix data = loadDataset("datasets/regression/boston.csv", 506, 13);
+
+ LinearRegressionModel mdl = trainer.train(data);
+
+ Vector expWeights = vectorCreator.apply(new double[] {
+ -1.07170557e-01, 4.63952195e-02, 2.08602395e-02, 2.68856140e+00, -1.77957587e+01, 3.80475246e+00,
+ 7.51061703e-04, -1.47575880e+00, 3.05655038e-01, -1.23293463e-02, -9.53463555e-01, 9.39251272e-03,
+ -5.25466633e-01
+ });
+
+ double expIntercept = 36.4911032804;
+
+ TestUtils.assertEquals("Wrong weights", expWeights, mdl.getWeights(), precision);
+ TestUtils.assertEquals("Wrong intercept", expIntercept, mdl.getIntercept(), precision);
+ }
+
+ /**
+ * Tests trainer on artificial dataset with 10 observations described by 1 feature.
+ */
+ @Test
+ public void testTrainOnArtificialDataset10x1() {
+ ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression10x1;
+
+ LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData()));
+
+ TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision);
+ TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision);
+ }
+
+ /**
+ * Tests trainer on artificial dataset with 10 observations described by 5 features.
+ */
+ @Test
+ public void testTrainOnArtificialDataset10x5() {
+ ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression10x5;
+
+ LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData()));
+
+ TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision);
+ TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision);
+ }
+
+ /**
+ * Tests trainer on artificial dataset with 100 observations described by 5 features.
+ */
+ @Test
+ public void testTrainOnArtificialDataset100x5() {
+ ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression100x5;
+
+ LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData()));
+
+ TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision);
+ TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision);
+ }
+
+ /**
+ * Tests trainer on artificial dataset with 100 observations described by 10 features.
+ */
+ @Test
+ public void testTrainOnArtificialDataset100x10() {
+ ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression100x10;
+
+ LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData()));
+
+ TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision);
+ TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision);
+ }
+
+ /**
+ * Loads dataset file and returns corresponding matrix.
+ *
+ * @param fileName Dataset file name
+ * @param nobs Number of observations
+ * @param nvars Number of features
+ * @return Data matrix
+ */
+ private Matrix loadDataset(String fileName, int nobs, int nvars) {
+ double[][] matrix = new double[nobs][nvars + 1];
+ Scanner scanner = new Scanner(this.getClass().getClassLoader().getResourceAsStream(fileName));
+ int i = 0;
+ while (scanner.hasNextLine()) {
+ String row = scanner.nextLine();
+ int j = 0;
+ for (String feature : row.split(",")) {
+ matrix[i][j] = Double.parseDouble(feature);
+ j++;
+ }
+ i++;
+ }
+ return matrixCreator.apply(matrix);
+ }
+}