You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by nt...@apache.org on 2017/11/17 13:07:57 UTC
[1/3] ignite git commit: IGNITE-5846 Add support of distributed
matrices for OLS regression. This closes #3030.
Repository: ignite
Updated Branches:
refs/heads/master cbd7e39cf -> b0a860186
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
new file mode 100644
index 0000000..e3c2979
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
@@ -0,0 +1,926 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
+import org.apache.ignite.ml.math.exceptions.NullArgumentException;
+import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
+import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.apache.ignite.testframework.junits.common.GridCommonTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link OLSMultipleLinearRegression}.
+ */
+
+@GridCommonTest(group = "Distributed Models")
+public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonAbstractTest {
+ /** */
+ private double[] y;
+
+ /** */
+ private double[][] x;
+
+ private AbstractMultipleLinearRegression regression;
+
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 3;
+
+ private static final double PRECISION = 1E-12;
+
+ /** Grid instance. */
+ private Ignite ignite;
+
+ public DistributedBlockOLSMultipleLinearRegressionTest() {
+
+ super(false);
+
+
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+
+ y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
+ x = new double[6][];
+ x[0] = new double[]{0, 0, 0, 0, 0};
+ x[1] = new double[]{2.0, 0, 0, 0, 0};
+ x[2] = new double[]{0, 3.0, 0, 0, 0};
+ x[3] = new double[]{0, 0, 4.0, 0, 0};
+ x[4] = new double[]{0, 0, 0, 5.0, 0};
+ x[5] = new double[]{0, 0, 0, 0, 6.0};
+
+ regression = createRegression();
+ }
+
+ /** */
+ protected OLSMultipleLinearRegression createRegression() {
+ OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
+ regression.newSampleData(new SparseBlockDistributedVector(y), new SparseBlockDistributedMatrix(x));
+ return regression;
+ }
+
+ /** */
+ @Test
+ public void testPerfectFit() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ double[] betaHat = regression.estimateRegressionParameters();
+ System.out.println("Beta hat is " + betaHat);
+ TestUtils.assertEquals(new double[]{11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0},
+ betaHat,
+ 1e-13);
+ double[] residuals = regression.estimateResiduals();
+ TestUtils.assertEquals(new double[]{0d, 0d, 0d, 0d, 0d, 0d}, residuals,
+ 1e-13);
+ Matrix errors = regression.estimateRegressionParametersVariance();
+ final double[] s = {1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0};
+ Matrix refVar = new SparseBlockDistributedMatrix(s.length, s.length);
+ for (int i = 0; i < refVar.rowSize(); i++)
+ for (int j = 0; j < refVar.columnSize(); j++) {
+ if (i == 0) {
+ refVar.setX(i, j, s[j]);
+ continue;
+ }
+ double x = s[i] * s[j];
+ refVar.setX(i, j, (i == j) ? 2 * x : x);
+ }
+ Assert.assertEquals(0.0,
+ TestUtils.maximumAbsoluteRowSum(errors.minus(refVar)),
+ 5.0e-16 * TestUtils.maximumAbsoluteRowSum(refVar));
+ Assert.assertEquals(1, ((OLSMultipleLinearRegression) regression).calculateRSquared(), 1E-12);
+ }
+
+ /**
+ * Test Longley dataset against certified values provided by NIST.
+ * Data Source: J. Longley (1967) "An Appraisal of Least Squares
+ * Programs for the Electronic Computer from the Point of View of the User"
+ * Journal of the American Statistical Association, vol. 62. September,
+ * pp. 819-841.
+ *
+ * Certified values (and data) are from NIST:
+ * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
+ */
+ @Test
+ public void testLongly() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ // Y values are first, then independent vars
+ // Each row is one observation
+ double[] design = new double[]{
+ 60323, 83.0, 234289, 2356, 1590, 107608, 1947,
+ 61122, 88.5, 259426, 2325, 1456, 108632, 1948,
+ 60171, 88.2, 258054, 3682, 1616, 109773, 1949,
+ 61187, 89.5, 284599, 3351, 1650, 110929, 1950,
+ 63221, 96.2, 328975, 2099, 3099, 112075, 1951,
+ 63639, 98.1, 346999, 1932, 3594, 113270, 1952,
+ 64989, 99.0, 365385, 1870, 3547, 115094, 1953,
+ 63761, 100.0, 363112, 3578, 3350, 116219, 1954,
+ 66019, 101.2, 397469, 2904, 3048, 117388, 1955,
+ 67857, 104.6, 419180, 2822, 2857, 118734, 1956,
+ 68169, 108.4, 442769, 2936, 2798, 120445, 1957,
+ 66513, 110.8, 444546, 4681, 2637, 121950, 1958,
+ 68655, 112.6, 482704, 3813, 2552, 123366, 1959,
+ 69564, 114.2, 502601, 3931, 2514, 125368, 1960,
+ 69331, 115.7, 518173, 4806, 2572, 127852, 1961,
+ 70551, 116.9, 554894, 4007, 2827, 130081, 1962
+ };
+
+ final int nobs = 16;
+ final int nvars = 6;
+
+ // Estimate the model
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(design, nobs, nvars, new SparseBlockDistributedMatrix());
+
+ // Check expected beta values from NIST
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ -3482258.63459582, 15.0618722713733,
+ -0.358191792925910E-01, -2.02022980381683,
+ -1.03322686717359, -0.511041056535807E-01,
+ 1829.15146461355}, 2E-6); //
+
+ // Check expected residuals from R
+ double[] residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[]{
+ 267.340029759711, -94.0139423988359, 46.28716775752924,
+ -410.114621930906, 309.7145907602313, -249.3112153297231,
+ -164.0489563956039, -13.18035686637081, 14.30477260005235,
+ 455.394094551857, -17.26892711483297, -39.0550425226967,
+ -155.5499735953195, -85.6713080421283, 341.9315139607727,
+ -206.7578251937366},
+ 1E-7);
+
+ // Check standard errors from NIST
+ double[] errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[]{
+ 890420.383607373,
+ 84.9149257747669,
+ 0.334910077722432E-01,
+ 0.488399681651699,
+ 0.214274163161675,
+ 0.226073200069370,
+ 455.478499142212}, errors, 1E-6);
+
+ // Check regression standard error against R
+ Assert.assertEquals(304.8540735619638, mdl.estimateRegressionStandardError(), 1E-8);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.995479004577296, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.992465007628826, mdl.calculateAdjustedRSquared(), 1E-12);
+
+ // TODO: IGNITE-5826, uncomment.
+ // checkVarianceConsistency(model);
+
+ // Estimate model without intercept
+ mdl.setNoIntercept(true);
+ mdl.newSampleData(design, nobs, nvars, new SparseBlockDistributedMatrix());
+
+ // Check expected beta values from R
+ betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ -52.99357013868291, 0.07107319907358,
+ -0.42346585566399, -0.57256866841929,
+ -0.41420358884978, 48.41786562001326}, 1E-8);
+
+ // Check standard errors from R
+ errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[]{
+ 129.54486693117232, 0.03016640003786,
+ 0.41773654056612, 0.27899087467676, 0.32128496193363,
+ 17.68948737819961}, errors, 1E-11);
+
+ // Check expected residuals from R
+ residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[]{
+ 279.90274927293092, -130.32465380836874, 90.73228661967445, -401.31252201634948,
+ -440.46768772620027, -543.54512853774793, 201.32111639536299, 215.90889365977932,
+ 73.09368242049943, 913.21694494481869, 424.82484953610174, -8.56475876776709,
+ -361.32974610842876, 27.34560497213464, 151.28955976355002, -492.49937355336846},
+ 1E-8);
+
+ // Check regression standard error against R
+ Assert.assertEquals(475.1655079819517, mdl.estimateRegressionStandardError(), 1E-10);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12);
+
+ }
+
+ /**
+ * Test R Swiss fertility dataset against R.
+ * Data Source: R datasets package
+ */
+ @Test
+ public void testSwissFertility() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ double[] design = new double[]{
+ 80.2, 17.0, 15, 12, 9.96,
+ 83.1, 45.1, 6, 9, 84.84,
+ 92.5, 39.7, 5, 5, 93.40,
+ 85.8, 36.5, 12, 7, 33.77,
+ 76.9, 43.5, 17, 15, 5.16,
+ 76.1, 35.3, 9, 7, 90.57,
+ 83.8, 70.2, 16, 7, 92.85,
+ 92.4, 67.8, 14, 8, 97.16,
+ 82.4, 53.3, 12, 7, 97.67,
+ 82.9, 45.2, 16, 13, 91.38,
+ 87.1, 64.5, 14, 6, 98.61,
+ 64.1, 62.0, 21, 12, 8.52,
+ 66.9, 67.5, 14, 7, 2.27,
+ 68.9, 60.7, 19, 12, 4.43,
+ 61.7, 69.3, 22, 5, 2.82,
+ 68.3, 72.6, 18, 2, 24.20,
+ 71.7, 34.0, 17, 8, 3.30,
+ 55.7, 19.4, 26, 28, 12.11,
+ 54.3, 15.2, 31, 20, 2.15,
+ 65.1, 73.0, 19, 9, 2.84,
+ 65.5, 59.8, 22, 10, 5.23,
+ 65.0, 55.1, 14, 3, 4.52,
+ 56.6, 50.9, 22, 12, 15.14,
+ 57.4, 54.1, 20, 6, 4.20,
+ 72.5, 71.2, 12, 1, 2.40,
+ 74.2, 58.1, 14, 8, 5.23,
+ 72.0, 63.5, 6, 3, 2.56,
+ 60.5, 60.8, 16, 10, 7.72,
+ 58.3, 26.8, 25, 19, 18.46,
+ 65.4, 49.5, 15, 8, 6.10,
+ 75.5, 85.9, 3, 2, 99.71,
+ 69.3, 84.9, 7, 6, 99.68,
+ 77.3, 89.7, 5, 2, 100.00,
+ 70.5, 78.2, 12, 6, 98.96,
+ 79.4, 64.9, 7, 3, 98.22,
+ 65.0, 75.9, 9, 9, 99.06,
+ 92.2, 84.6, 3, 3, 99.46,
+ 79.3, 63.1, 13, 13, 96.83,
+ 70.4, 38.4, 26, 12, 5.62,
+ 65.7, 7.7, 29, 11, 13.79,
+ 72.7, 16.7, 22, 13, 11.22,
+ 64.4, 17.6, 35, 32, 16.92,
+ 77.6, 37.6, 15, 7, 4.97,
+ 67.6, 18.7, 25, 7, 8.65,
+ 35.0, 1.2, 37, 53, 42.34,
+ 44.7, 46.6, 16, 29, 50.43,
+ 42.8, 27.7, 22, 29, 58.33
+ };
+
+ final int nobs = 47;
+ final int nvars = 4;
+
+ // Estimate the model
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(design, nobs, nvars, new SparseBlockDistributedMatrix());
+
+ // Check expected beta values from R
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ 91.05542390271397,
+ -0.22064551045715,
+ -0.26058239824328,
+ -0.96161238456030,
+ 0.12441843147162}, 1E-12);
+
+ // Check expected residuals from R
+ double[] residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[]{
+ 7.1044267859730512, 1.6580347433531366,
+ 4.6944952770029644, 8.4548022690166160, 13.6547432343186212,
+ -9.3586864458500774, 7.5822446330520386, 15.5568995563859289,
+ 0.8113090736598980, 7.1186762732484308, 7.4251378771228724,
+ 2.6761316873234109, 0.8351584810309354, 7.1769991119615177,
+ -3.8746753206299553, -3.1337779476387251, -0.1412575244091504,
+ 1.1186809170469780, -6.3588097346816594, 3.4039270429434074,
+ 2.3374058329820175, -7.9272368576900503, -7.8361010968497959,
+ -11.2597369269357070, 0.9445333697827101, 6.6544245101380328,
+ -0.9146136301118665, -4.3152449403848570, -4.3536932047009183,
+ -3.8907885169304661, -6.3027643926302188, -7.8308982189289091,
+ -3.1792280015332750, -6.7167298771158226, -4.8469946718041754,
+ -10.6335664353633685, 11.1031134362036958, 6.0084032641811733,
+ 5.4326230830188482, -7.2375578629692230, 2.1671550814448222,
+ 15.0147574652763112, 4.8625103516321015, -7.1597256413907706,
+ -0.4515205619767598, -10.2916870903837587, -15.7812984571900063},
+ 1E-12);
+
+ // Check standard errors from R
+ double[] errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[]{
+ 6.94881329475087,
+ 0.07360008972340,
+ 0.27410957467466,
+ 0.19454551679325,
+ 0.03726654773803}, errors, 1E-10);
+
+ // Check regression standard error against R
+ Assert.assertEquals(7.73642194433223, mdl.estimateRegressionStandardError(), 1E-12);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.649789742860228, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.6164363850373927, mdl.calculateAdjustedRSquared(), 1E-12);
+
+ // TODO: IGNITE-5826, uncomment.
+ // checkVarianceConsistency(model);
+
+ // Estimate the model with no intercept
+ mdl = new OLSMultipleLinearRegression();
+ mdl.setNoIntercept(true);
+ mdl.newSampleData(design, nobs, nvars, new SparseBlockDistributedMatrix());
+
+ // Check expected beta values from R
+ betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ 0.52191832900513,
+ 2.36588087917963,
+ -0.94770353802795,
+ 0.30851985863609}, 1E-12);
+
+ // Check expected residuals from R
+ residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[]{
+ 44.138759883538249, 27.720705122356215, 35.873200836126799,
+ 34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814,
+ 1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165,
+ -9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336,
+ -17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924,
+ -10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944,
+ -13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898,
+ -1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603,
+ -16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384,
+ -17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695,
+ -0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976,
+ 2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615},
+ 1E-12);
+
+ // Check standard errors from R
+ errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[]{
+ 0.10470063765677, 0.41684100584290,
+ 0.43370143099691, 0.07694953606522}, errors, 1E-10);
+
+ // Check regression standard error against R
+ Assert.assertEquals(17.24710630547, mdl.estimateRegressionStandardError(), 1E-10);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.946350722085, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.9413600915813, mdl.calculateAdjustedRSquared(), 1E-12);
+ }
+
+ /**
+ * Test hat matrix computation
+ */
+ @Test
+ public void testHat() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ /*
+ * This example is from "The Hat Matrix in Regression and ANOVA",
+ * David C. Hoaglin and Roy E. Welsch,
+ * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
+ *
+ */
+ double[] design = new double[]{
+ 11.14, .499, 11.1,
+ 12.74, .558, 8.9,
+ 13.13, .604, 8.8,
+ 11.51, .441, 8.9,
+ 12.38, .550, 8.8,
+ 12.60, .528, 9.9,
+ 11.13, .418, 10.7,
+ 11.7, .480, 10.5,
+ 11.02, .406, 10.5,
+ 11.41, .467, 10.7
+ };
+
+ int nobs = 10;
+ int nvars = 2;
+
+ // Estimate the model
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(design, nobs, nvars, new SparseBlockDistributedMatrix());
+
+ Matrix hat = mdl.calculateHat();
+
+ // Reference data is upper half of symmetric hat matrix
+ double[] refData = new double[]{
+ .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242,
+ .242, .292, .136, .243, .128, -.041, .033, -.035, .004,
+ .417, -.019, .273, .187, -.126, .044, -.153, .004,
+ .604, .197, -.038, .168, -.022, .275, -.028,
+ .252, .111, -.030, .019, -.010, -.010,
+ .148, .042, .117, .012, .111,
+ .262, .145, .277, .174,
+ .154, .120, .168,
+ .315, .148,
+ .187
+ };
+
+ // Check against reference data and verify symmetry
+ int k = 0;
+ for (int i = 0; i < 10; i++) {
+ for (int j = i; j < 10; j++) {
+ Assert.assertEquals(refData[k], hat.getX(i, j), 10e-3);
+ Assert.assertEquals(hat.getX(i, j), hat.getX(j, i), 10e-12);
+ k++;
+ }
+ }
+
+ /*
+ * Verify that residuals computed using the hat matrix are close to
+ * what we get from direct computation, i.e. r = (I - H) y
+ */
+ double[] residuals = mdl.estimateResiduals();
+ Matrix id = MatrixUtil.identityLike(hat, 10);
+ double[] hatResiduals = id.minus(hat).times(mdl.getY()).getStorage().data();
+ TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
+ }
+
+ /**
+ * test calculateYVariance
+ */
+ @Test
+ public void testYVariance() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ // assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(new SparseBlockDistributedVector(y), new SparseBlockDistributedMatrix(x));
+ TestUtils.assertEquals(mdl.calculateYVariance(), 3.5, 0);
+ }
+
+ /**
+ * Verifies that setting X and Y separately has the same effect as newSample(X,Y).
+ */
+ @Test
+ public void testNewSample2() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] y = new double[]{1, 2, 3, 4};
+ double[][] x = new double[][]{
+ {19, 22, 33},
+ {20, 30, 40},
+ {25, 35, 45},
+ {27, 37, 47}
+ };
+ OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
+ regression.newSampleData(new SparseBlockDistributedVector(y), new SparseBlockDistributedMatrix(x));
+ Matrix combinedX = regression.getX().copy();
+ Vector combinedY = regression.getY().copy();
+ regression.newXSampleData(new SparseBlockDistributedMatrix(x));
+ regression.newYSampleData(new SparseBlockDistributedVector(y));
+ for (int i = 0; i < combinedX.rowSize(); i++) {
+ for (int j = 0; j < combinedX.columnSize(); j++)
+ Assert.assertEquals(combinedX.get(i, j), regression.getX().get(i, j), PRECISION);
+
+ }
+ for (int i = 0; i < combinedY.size(); i++)
+ Assert.assertEquals(combinedY.get(i), regression.getY().get(i), PRECISION);
+
+
+ // No intercept
+ regression.setNoIntercept(true);
+ regression.newSampleData(new SparseBlockDistributedVector(y), new SparseBlockDistributedMatrix(x));
+ combinedX = regression.getX().copy();
+ combinedY = regression.getY().copy();
+ regression.newXSampleData(new SparseBlockDistributedMatrix(x));
+ regression.newYSampleData(new SparseBlockDistributedVector(y));
+
+ for (int i = 0; i < combinedX.rowSize(); i++) {
+ for (int j = 0; j < combinedX.columnSize(); j++)
+ Assert.assertEquals(combinedX.get(i, j), regression.getX().get(i, j), PRECISION);
+
+ }
+ for (int i = 0; i < combinedY.size(); i++)
+ Assert.assertEquals(combinedY.get(i), regression.getY().get(i), PRECISION);
+
+ }
+
+ /** */
+ @Test(expected = NullArgumentException.class)
+ public void testNewSampleDataYNull() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ try {
+ createRegression().newSampleData(null, new SparseBlockDistributedMatrix(new double[][]{{1}}));
+ fail("NullArgumentException");
+ } catch (NullArgumentException e) {
+ return;
+ }
+ fail("NullArgumentException");
+ }
+
+ /** */
+ public void testNewSampleDataXNull() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ try {
+ createRegression().newSampleData(new SparseBlockDistributedVector(new double[]{1}), null);
+ fail("NullArgumentException");
+ } catch (NullArgumentException e) {
+ return;
+ }
+ fail("NullArgumentException");
+
+
+ }
+
+ /**
+ * This is a test based on the Wampler1 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler1.shtml
+ */
+ @Test
+ public void testWampler1() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[]{
+ 1, 0,
+ 6, 1,
+ 63, 2,
+ 364, 3,
+ 1365, 4,
+ 3906, 5,
+ 9331, 6,
+ 19608, 7,
+ 37449, 8,
+ 66430, 9,
+ 111111, 10,
+ 177156, 11,
+ 271453, 12,
+ 402234, 13,
+ 579195, 14,
+ 813616, 15,
+ 1118481, 16,
+ 1508598, 17,
+ 2000719, 18,
+ 2613660, 19,
+ 3368421, 20};
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseBlockDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ 1.0,
+ 1.0, 1.0,
+ 1.0, 1.0,
+ 1.0}, 1E-8);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[]{
+ 0.0,
+ 0.0, 0.0,
+ 0.0, 0.0,
+ 0.0}, 1E-8);
+
+ TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7);
+ TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6);
+ }
+
+ /**
+ * This is a test based on the Wampler2 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler2.shtml
+ */
+ @Test
+ public void testWampler2() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[]{
+ 1.00000, 0,
+ 1.11111, 1,
+ 1.24992, 2,
+ 1.42753, 3,
+ 1.65984, 4,
+ 1.96875, 5,
+ 2.38336, 6,
+ 2.94117, 7,
+ 3.68928, 8,
+ 4.68559, 9,
+ 6.00000, 10,
+ 7.71561, 11,
+ 9.92992, 12,
+ 12.75603, 13,
+ 16.32384, 14,
+ 20.78125, 15,
+ 26.29536, 16,
+ 33.05367, 17,
+ 41.26528, 18,
+ 51.16209, 19,
+ 63.00000, 20};
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseBlockDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ 1.0,
+ 1.0e-1,
+ 1.0e-2,
+ 1.0e-3, 1.0e-4,
+ 1.0e-5}, 1E-8);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[]{
+ 0.0,
+ 0.0, 0.0,
+ 0.0, 0.0,
+ 0.0}, 1E-8);
+ TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7);
+ TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6);
+ }
+
+ /**
+ * This is a test based on the Wampler3 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler3.shtml
+ */
+ @Test
+ public void testWampler3() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[]{
+ 760, 0,
+ -2042, 1,
+ 2111, 2,
+ -1684, 3,
+ 3888, 4,
+ 1858, 5,
+ 11379, 6,
+ 17560, 7,
+ 39287, 8,
+ 64382, 9,
+ 113159, 10,
+ 175108, 11,
+ 273291, 12,
+ 400186, 13,
+ 581243, 14,
+ 811568, 15,
+ 1121004, 16,
+ 1506550, 17,
+ 2002767, 18,
+ 2611612, 19,
+ 3369180, 20};
+
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseBlockDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0}, 1E-8);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[]{
+ 2152.32624678170,
+ 2363.55173469681, 779.343524331583,
+ 101.475507550350, 5.64566512170752,
+ 0.112324854679312}, 1E-8); //
+
+ TestUtils.assertEquals(.999995559025820, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(5570284.53333333, mdl.estimateErrorVariance(), 1.0e-6);
+ TestUtils.assertEquals(83554268.0000000, mdl.calculateResidualSumOfSquares(), 1.0e-5);
+ }
+
+ /**
+ * This is a test based on the Wampler4 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler4.shtml
+ */
+ @Test
+ public void testWampler4() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[]{
+ 75901, 0,
+ -204794, 1,
+ 204863, 2,
+ -204436, 3,
+ 253665, 4,
+ -200894, 5,
+ 214131, 6,
+ -185192, 7,
+ 221249, 8,
+ -138370, 9,
+ 315911, 10,
+ -27644, 11,
+ 455253, 12,
+ 197434, 13,
+ 783995, 14,
+ 608816, 15,
+ 1370781, 16,
+ 1303798, 17,
+ 2205519, 18,
+ 2408860, 19,
+ 3444321, 20};
+
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseBlockDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[]{
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0}, 1E-6);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[]{
+ 215232.624678170,
+ 236355.173469681, 77934.3524331583,
+ 10147.5507550350, 564.566512170752,
+ 11.2324854679312}, 1E-8);
+
+ TestUtils.assertEquals(.957478440825662, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(55702845333.3333, mdl.estimateErrorVariance(), 1.0e-4);
+ TestUtils.assertEquals(835542680000.000, mdl.calculateResidualSumOfSquares(), 1.0e-3);
+ }
+
+ /**
+ * Anything requiring beta calculation should advertise SME.
+ */
+ public void testSingularCalculateBeta() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15);
+ mdl.newSampleData(new double[]{1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseBlockDistributedMatrix());
+
+ try {
+ mdl.calculateBeta();
+ fail("SingularMatrixException");
+ } catch (SingularMatrixException e) {
+ return;
+ }
+ fail("SingularMatrixException");
+
+ }
+
+ /** */
+ public void testNoDataNPECalculateBeta() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ try {
+ mdl.calculateBeta();
+ fail("java.lang.NullPointerException");
+ } catch (NullPointerException e) {
+ return;
+ }
+ fail("java.lang.NullPointerException");
+
+ }
+
+ /** */
+ public void testNoDataNPECalculateHat() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ try {
+ mdl.calculateHat();
+ fail("java.lang.NullPointerException");
+ } catch (NullPointerException e) {
+ return;
+ }
+ fail("java.lang.NullPointerException");
+ }
+
+ /** */
+ public void testNoDataNPESSTO() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ try {
+ mdl.calculateTotalSumOfSquares();
+ fail("java.lang.NullPointerException");
+ } catch (NullPointerException e) {
+ return;
+ }
+ fail("java.lang.NullPointerException");
+
+
+ }
+
+ /** */
+ public void testMathIllegalArgumentException() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+
+ try {
+ mdl.validateSampleData(new SparseBlockDistributedMatrix(1, 2), new SparseBlockDistributedVector(1));
+ fail("MathIllegalArgumentException");
+ } catch (MathIllegalArgumentException e) {
+ return;
+ }
+ fail("MathIllegalArgumentException");
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
new file mode 100644
index 0000000..764340c
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
@@ -0,0 +1,934 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
+import org.apache.ignite.ml.math.exceptions.NullArgumentException;
+import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
+import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.apache.ignite.testframework.junits.common.GridCommonTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link OLSMultipleLinearRegression}.
+ */
+
+@GridCommonTest(group = "Distributed Models")
+public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstractTest {
+ /** */
+ private double[] y;
+
+ /** */
+ private double[][] x;
+
+ private AbstractMultipleLinearRegression regression;
+
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 3;
+
+ private static final double PRECISION = 1E-12;
+
+ /** Grid instance. */
+ private Ignite ignite;
+
+ public DistributedOLSMultipleLinearRegressionTest(){
+
+ super(false);
+
+
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+
+ y = new double[] {11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
+ x = new double[6][];
+ x[0] = new double[] {0, 0, 0, 0, 0};
+ x[1] = new double[] {2.0, 0, 0, 0, 0};
+ x[2] = new double[] {0, 3.0, 0, 0, 0};
+ x[3] = new double[] {0, 0, 4.0, 0, 0};
+ x[4] = new double[] {0, 0, 0, 5.0, 0};
+ x[5] = new double[] {0, 0, 0, 0, 6.0};
+
+ regression = createRegression();
+ }
+
+ /** */
+ protected OLSMultipleLinearRegression createRegression() {
+ OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
+ regression.newSampleData(new SparseDistributedVector(y), new SparseDistributedMatrix(x));
+ return regression;
+ }
+
+ /** */
+ @Test
+ public void testPerfectFit() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ double[] betaHat = regression.estimateRegressionParameters();
+ System.out.println("Beta hat is " + betaHat);
+ TestUtils.assertEquals(new double[] {11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0},
+ betaHat,
+ 1e-13);
+ double[] residuals = regression.estimateResiduals();
+ TestUtils.assertEquals(new double[] {0d, 0d, 0d, 0d, 0d, 0d}, residuals,
+ 1e-13);
+ Matrix errors = regression.estimateRegressionParametersVariance();
+ final double[] s = {1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0};
+ Matrix refVar = new SparseDistributedMatrix(s.length, s.length);
+ for (int i = 0; i < refVar.rowSize(); i++)
+ for (int j = 0; j < refVar.columnSize(); j++) {
+ if (i == 0) {
+ refVar.setX(i, j, s[j]);
+ continue;
+ }
+ double x = s[i] * s[j];
+ refVar.setX(i, j, (i == j) ? 2 * x : x);
+ }
+ Assert.assertEquals(0.0,
+ TestUtils.maximumAbsoluteRowSum(errors.minus(refVar)),
+ 5.0e-16 * TestUtils.maximumAbsoluteRowSum(refVar));
+ Assert.assertEquals(1, ((OLSMultipleLinearRegression)regression).calculateRSquared(), 1E-12);
+ }
+
+ /**
+ * Test Longley dataset against certified values provided by NIST.
+ * Data Source: J. Longley (1967) "An Appraisal of Least Squares
+ * Programs for the Electronic Computer from the Point of View of the User"
+ * Journal of the American Statistical Association, vol. 62. September,
+ * pp. 819-841.
+ *
+ * Certified values (and data) are from NIST:
+ * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
+ */
+ @Test
+ public void testLongly() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ // Y values are first, then independent vars
+ // Each row is one observation
+ double[] design = new double[] {
+ 60323, 83.0, 234289, 2356, 1590, 107608, 1947,
+ 61122, 88.5, 259426, 2325, 1456, 108632, 1948,
+ 60171, 88.2, 258054, 3682, 1616, 109773, 1949,
+ 61187, 89.5, 284599, 3351, 1650, 110929, 1950,
+ 63221, 96.2, 328975, 2099, 3099, 112075, 1951,
+ 63639, 98.1, 346999, 1932, 3594, 113270, 1952,
+ 64989, 99.0, 365385, 1870, 3547, 115094, 1953,
+ 63761, 100.0, 363112, 3578, 3350, 116219, 1954,
+ 66019, 101.2, 397469, 2904, 3048, 117388, 1955,
+ 67857, 104.6, 419180, 2822, 2857, 118734, 1956,
+ 68169, 108.4, 442769, 2936, 2798, 120445, 1957,
+ 66513, 110.8, 444546, 4681, 2637, 121950, 1958,
+ 68655, 112.6, 482704, 3813, 2552, 123366, 1959,
+ 69564, 114.2, 502601, 3931, 2514, 125368, 1960,
+ 69331, 115.7, 518173, 4806, 2572, 127852, 1961,
+ 70551, 116.9, 554894, 4007, 2827, 130081, 1962
+ };
+
+ final int nobs = 16;
+ final int nvars = 6;
+
+ // Estimate the model
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(design, nobs, nvars, new SparseDistributedMatrix());
+
+ // Check expected beta values from NIST
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ -3482258.63459582, 15.0618722713733,
+ -0.358191792925910E-01, -2.02022980381683,
+ -1.03322686717359, -0.511041056535807E-01,
+ 1829.15146461355}, 2E-6); //
+
+ // Check expected residuals from R
+ double[] residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[] {
+ 267.340029759711, -94.0139423988359, 46.28716775752924,
+ -410.114621930906, 309.7145907602313, -249.3112153297231,
+ -164.0489563956039, -13.18035686637081, 14.30477260005235,
+ 455.394094551857, -17.26892711483297, -39.0550425226967,
+ -155.5499735953195, -85.6713080421283, 341.9315139607727,
+ -206.7578251937366},
+ 1E-7);
+
+ // Check standard errors from NIST
+ double[] errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[] {
+ 890420.383607373,
+ 84.9149257747669,
+ 0.334910077722432E-01,
+ 0.488399681651699,
+ 0.214274163161675,
+ 0.226073200069370,
+ 455.478499142212}, errors, 1E-6);
+
+ // Check regression standard error against R
+ Assert.assertEquals(304.8540735619638, mdl.estimateRegressionStandardError(), 1E-8);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.995479004577296, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.992465007628826, mdl.calculateAdjustedRSquared(), 1E-12);
+
+ // TODO: IGNITE-5826, uncomment.
+ // checkVarianceConsistency(model);
+
+ // Estimate model without intercept
+ mdl.setNoIntercept(true);
+ mdl.newSampleData(design, nobs, nvars, new SparseDistributedMatrix());
+
+ // Check expected beta values from R
+ betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ -52.99357013868291, 0.07107319907358,
+ -0.42346585566399, -0.57256866841929,
+ -0.41420358884978, 48.41786562001326}, 1E-8);
+
+ // Check standard errors from R
+ errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[] {
+ 129.54486693117232, 0.03016640003786,
+ 0.41773654056612, 0.27899087467676, 0.32128496193363,
+ 17.68948737819961}, errors, 1E-11);
+
+ // Check expected residuals from R
+ residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[] {
+ 279.90274927293092, -130.32465380836874, 90.73228661967445, -401.31252201634948,
+ -440.46768772620027, -543.54512853774793, 201.32111639536299, 215.90889365977932,
+ 73.09368242049943, 913.21694494481869, 424.82484953610174, -8.56475876776709,
+ -361.32974610842876, 27.34560497213464, 151.28955976355002, -492.49937355336846},
+ 1E-8);
+
+ // Check regression standard error against R
+ Assert.assertEquals(475.1655079819517, mdl.estimateRegressionStandardError(), 1E-10);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12);
+
+ }
+
+ /**
+ * Test R Swiss fertility dataset against R.
+ * Data Source: R datasets package
+ */
+ @Test
+ public void testSwissFertility() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ double[] design = new double[] {
+ 80.2, 17.0, 15, 12, 9.96,
+ 83.1, 45.1, 6, 9, 84.84,
+ 92.5, 39.7, 5, 5, 93.40,
+ 85.8, 36.5, 12, 7, 33.77,
+ 76.9, 43.5, 17, 15, 5.16,
+ 76.1, 35.3, 9, 7, 90.57,
+ 83.8, 70.2, 16, 7, 92.85,
+ 92.4, 67.8, 14, 8, 97.16,
+ 82.4, 53.3, 12, 7, 97.67,
+ 82.9, 45.2, 16, 13, 91.38,
+ 87.1, 64.5, 14, 6, 98.61,
+ 64.1, 62.0, 21, 12, 8.52,
+ 66.9, 67.5, 14, 7, 2.27,
+ 68.9, 60.7, 19, 12, 4.43,
+ 61.7, 69.3, 22, 5, 2.82,
+ 68.3, 72.6, 18, 2, 24.20,
+ 71.7, 34.0, 17, 8, 3.30,
+ 55.7, 19.4, 26, 28, 12.11,
+ 54.3, 15.2, 31, 20, 2.15,
+ 65.1, 73.0, 19, 9, 2.84,
+ 65.5, 59.8, 22, 10, 5.23,
+ 65.0, 55.1, 14, 3, 4.52,
+ 56.6, 50.9, 22, 12, 15.14,
+ 57.4, 54.1, 20, 6, 4.20,
+ 72.5, 71.2, 12, 1, 2.40,
+ 74.2, 58.1, 14, 8, 5.23,
+ 72.0, 63.5, 6, 3, 2.56,
+ 60.5, 60.8, 16, 10, 7.72,
+ 58.3, 26.8, 25, 19, 18.46,
+ 65.4, 49.5, 15, 8, 6.10,
+ 75.5, 85.9, 3, 2, 99.71,
+ 69.3, 84.9, 7, 6, 99.68,
+ 77.3, 89.7, 5, 2, 100.00,
+ 70.5, 78.2, 12, 6, 98.96,
+ 79.4, 64.9, 7, 3, 98.22,
+ 65.0, 75.9, 9, 9, 99.06,
+ 92.2, 84.6, 3, 3, 99.46,
+ 79.3, 63.1, 13, 13, 96.83,
+ 70.4, 38.4, 26, 12, 5.62,
+ 65.7, 7.7, 29, 11, 13.79,
+ 72.7, 16.7, 22, 13, 11.22,
+ 64.4, 17.6, 35, 32, 16.92,
+ 77.6, 37.6, 15, 7, 4.97,
+ 67.6, 18.7, 25, 7, 8.65,
+ 35.0, 1.2, 37, 53, 42.34,
+ 44.7, 46.6, 16, 29, 50.43,
+ 42.8, 27.7, 22, 29, 58.33
+ };
+
+ final int nobs = 47;
+ final int nvars = 4;
+
+ // Estimate the model
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(design, nobs, nvars, new SparseDistributedMatrix());
+
+ // Check expected beta values from R
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ 91.05542390271397,
+ -0.22064551045715,
+ -0.26058239824328,
+ -0.96161238456030,
+ 0.12441843147162}, 1E-12);
+
+ // Check expected residuals from R
+ double[] residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[] {
+ 7.1044267859730512, 1.6580347433531366,
+ 4.6944952770029644, 8.4548022690166160, 13.6547432343186212,
+ -9.3586864458500774, 7.5822446330520386, 15.5568995563859289,
+ 0.8113090736598980, 7.1186762732484308, 7.4251378771228724,
+ 2.6761316873234109, 0.8351584810309354, 7.1769991119615177,
+ -3.8746753206299553, -3.1337779476387251, -0.1412575244091504,
+ 1.1186809170469780, -6.3588097346816594, 3.4039270429434074,
+ 2.3374058329820175, -7.9272368576900503, -7.8361010968497959,
+ -11.2597369269357070, 0.9445333697827101, 6.6544245101380328,
+ -0.9146136301118665, -4.3152449403848570, -4.3536932047009183,
+ -3.8907885169304661, -6.3027643926302188, -7.8308982189289091,
+ -3.1792280015332750, -6.7167298771158226, -4.8469946718041754,
+ -10.6335664353633685, 11.1031134362036958, 6.0084032641811733,
+ 5.4326230830188482, -7.2375578629692230, 2.1671550814448222,
+ 15.0147574652763112, 4.8625103516321015, -7.1597256413907706,
+ -0.4515205619767598, -10.2916870903837587, -15.7812984571900063},
+ 1E-12);
+
+ // Check standard errors from R
+ double[] errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[] {
+ 6.94881329475087,
+ 0.07360008972340,
+ 0.27410957467466,
+ 0.19454551679325,
+ 0.03726654773803}, errors, 1E-10);
+
+ // Check regression standard error against R
+ Assert.assertEquals(7.73642194433223, mdl.estimateRegressionStandardError(), 1E-12);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.649789742860228, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.6164363850373927, mdl.calculateAdjustedRSquared(), 1E-12);
+
+ // TODO: IGNITE-5826, uncomment.
+ // checkVarianceConsistency(model);
+
+ // Estimate the model with no intercept
+ mdl = new OLSMultipleLinearRegression();
+ mdl.setNoIntercept(true);
+ mdl.newSampleData(design, nobs, nvars, new SparseDistributedMatrix());
+
+ // Check expected beta values from R
+ betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ 0.52191832900513,
+ 2.36588087917963,
+ -0.94770353802795,
+ 0.30851985863609}, 1E-12);
+
+ // Check expected residuals from R
+ residuals = mdl.estimateResiduals();
+ TestUtils.assertEquals(residuals, new double[] {
+ 44.138759883538249, 27.720705122356215, 35.873200836126799,
+ 34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814,
+ 1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165,
+ -9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336,
+ -17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924,
+ -10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944,
+ -13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898,
+ -1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603,
+ -16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384,
+ -17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695,
+ -0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976,
+ 2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615},
+ 1E-12);
+
+ // Check standard errors from R
+ errors = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(new double[] {
+ 0.10470063765677, 0.41684100584290,
+ 0.43370143099691, 0.07694953606522}, errors, 1E-10);
+
+ // Check regression standard error against R
+ Assert.assertEquals(17.24710630547, mdl.estimateRegressionStandardError(), 1E-10);
+
+ // Check R-Square statistics against R
+ Assert.assertEquals(0.946350722085, mdl.calculateRSquared(), 1E-12);
+ Assert.assertEquals(0.9413600915813, mdl.calculateAdjustedRSquared(), 1E-12);
+ }
+
+ /**
+ * Test hat matrix computation
+ */
+ @Test
+ public void testHat() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ /*
+ * This example is from "The Hat Matrix in Regression and ANOVA",
+ * David C. Hoaglin and Roy E. Welsch,
+ * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
+ *
+ */
+ double[] design = new double[] {
+ 11.14, .499, 11.1,
+ 12.74, .558, 8.9,
+ 13.13, .604, 8.8,
+ 11.51, .441, 8.9,
+ 12.38, .550, 8.8,
+ 12.60, .528, 9.9,
+ 11.13, .418, 10.7,
+ 11.7, .480, 10.5,
+ 11.02, .406, 10.5,
+ 11.41, .467, 10.7
+ };
+
+ int nobs = 10;
+ int nvars = 2;
+
+ // Estimate the model
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(design, nobs, nvars, new SparseDistributedMatrix());
+
+ Matrix hat = mdl.calculateHat();
+
+ // Reference data is upper half of symmetric hat matrix
+ double[] refData = new double[] {
+ .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242,
+ .242, .292, .136, .243, .128, -.041, .033, -.035, .004,
+ .417, -.019, .273, .187, -.126, .044, -.153, .004,
+ .604, .197, -.038, .168, -.022, .275, -.028,
+ .252, .111, -.030, .019, -.010, -.010,
+ .148, .042, .117, .012, .111,
+ .262, .145, .277, .174,
+ .154, .120, .168,
+ .315, .148,
+ .187
+ };
+
+ // Check against reference data and verify symmetry
+ int k = 0;
+ for (int i = 0; i < 10; i++) {
+ for (int j = i; j < 10; j++) {
+ Assert.assertEquals(refData[k], hat.getX(i, j), 10e-3);
+ Assert.assertEquals(hat.getX(i, j), hat.getX(j, i), 10e-12);
+ k++;
+ }
+ }
+
+ /*
+ * Verify that residuals computed using the hat matrix are close to
+ * what we get from direct computation, i.e. r = (I - H) y
+ */
+ double[] residuals = mdl.estimateResiduals();
+ Matrix id = MatrixUtil.identityLike(hat, 10);
+ double[] hatResiduals = id.minus(hat).times(mdl.getY()).getStorage().data();
+ TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
+ }
+
+ /**
+ * test calculateYVariance
+ */
+ @Test
+ public void testYVariance() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ // assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ mdl.newSampleData(new SparseDistributedVector(y), new SparseDistributedMatrix(x));
+ TestUtils.assertEquals(mdl.calculateYVariance(), 3.5, 0);
+ }
+
+ /**
+ * Verifies that setting X and Y separately has the same effect as newSample(X,Y).
+ */
+ @Test
+ public void testNewSample2() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] y = new double[] {1, 2, 3, 4};
+ double[][] x = new double[][] {
+ {19, 22, 33},
+ {20, 30, 40},
+ {25, 35, 45},
+ {27, 37, 47}
+ };
+ OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
+ regression.newSampleData(new SparseDistributedVector(y), new SparseDistributedMatrix(x));
+ Matrix combinedX = regression.getX().copy();
+ Vector combinedY = regression.getY().copy();
+ regression.newXSampleData(new SparseDistributedMatrix(x));
+ regression.newYSampleData(new SparseDistributedVector(y));
+ for (int i = 0; i < combinedX.rowSize(); i++) {
+ for (int j = 0; j < combinedX.columnSize(); j++)
+ Assert.assertEquals(combinedX.get(i,j), regression.getX().get(i,j), PRECISION);
+
+ }
+ for (int i = 0; i < combinedY.size(); i++)
+ Assert.assertEquals(combinedY.get(i), regression.getY().get(i), PRECISION);
+
+
+
+ // No intercept
+ regression.setNoIntercept(true);
+ regression.newSampleData(new SparseDistributedVector(y), new SparseDistributedMatrix(x));
+ combinedX = regression.getX().copy();
+ combinedY = regression.getY().copy();
+ regression.newXSampleData(new SparseDistributedMatrix(x));
+ regression.newYSampleData(new SparseDistributedVector(y));
+
+ for (int i = 0; i < combinedX.rowSize(); i++) {
+ for (int j = 0; j < combinedX.columnSize(); j++)
+ Assert.assertEquals(combinedX.get(i,j), regression.getX().get(i,j), PRECISION);
+
+ }
+ for (int i = 0; i < combinedY.size(); i++)
+ Assert.assertEquals(combinedY.get(i), regression.getY().get(i), PRECISION);
+
+ }
+
+ /** */
+ @Test(expected = NullArgumentException.class)
+ public void testNewSampleDataYNull() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ try {
+ createRegression().newSampleData(null, new SparseDistributedMatrix(new double[][] {{1}}));
+ fail("NullArgumentException");
+ }
+ catch (NullArgumentException e) {
+ return;
+ }
+ fail("NullArgumentException");
+ }
+
+ /** */
+ public void testNewSampleDataXNull() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ try {
+ createRegression().newSampleData(new SparseDistributedVector(new double[] {1}), null);
+ fail("NullArgumentException");
+ }
+ catch (NullArgumentException e) {
+ return;
+ }
+ fail("NullArgumentException");
+
+
+ }
+
+ /**
+ * This is a test based on the Wampler1 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler1.shtml
+ */
+ @Test
+ public void testWampler1() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[] {
+ 1, 0,
+ 6, 1,
+ 63, 2,
+ 364, 3,
+ 1365, 4,
+ 3906, 5,
+ 9331, 6,
+ 19608, 7,
+ 37449, 8,
+ 66430, 9,
+ 111111, 10,
+ 177156, 11,
+ 271453, 12,
+ 402234, 13,
+ 579195, 14,
+ 813616, 15,
+ 1118481, 16,
+ 1508598, 17,
+ 2000719, 18,
+ 2613660, 19,
+ 3368421, 20};
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ 1.0,
+ 1.0, 1.0,
+ 1.0, 1.0,
+ 1.0}, 1E-8);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[] {
+ 0.0,
+ 0.0, 0.0,
+ 0.0, 0.0,
+ 0.0}, 1E-8);
+
+ TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7);
+ TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6);
+ }
+
+ /**
+ * This is a test based on the Wampler2 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler2.shtml
+ */
+ @Test
+ public void testWampler2() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[] {
+ 1.00000, 0,
+ 1.11111, 1,
+ 1.24992, 2,
+ 1.42753, 3,
+ 1.65984, 4,
+ 1.96875, 5,
+ 2.38336, 6,
+ 2.94117, 7,
+ 3.68928, 8,
+ 4.68559, 9,
+ 6.00000, 10,
+ 7.71561, 11,
+ 9.92992, 12,
+ 12.75603, 13,
+ 16.32384, 14,
+ 20.78125, 15,
+ 26.29536, 16,
+ 33.05367, 17,
+ 41.26528, 18,
+ 51.16209, 19,
+ 63.00000, 20};
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ 1.0,
+ 1.0e-1,
+ 1.0e-2,
+ 1.0e-3, 1.0e-4,
+ 1.0e-5}, 1E-8);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[] {
+ 0.0,
+ 0.0, 0.0,
+ 0.0, 0.0,
+ 0.0}, 1E-8);
+ TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7);
+ TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6);
+ }
+
+ /**
+ * This is a test based on the Wampler3 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler3.shtml
+ */
+ @Test
+ public void testWampler3() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[] {
+ 760, 0,
+ -2042, 1,
+ 2111, 2,
+ -1684, 3,
+ 3888, 4,
+ 1858, 5,
+ 11379, 6,
+ 17560, 7,
+ 39287, 8,
+ 64382, 9,
+ 113159, 10,
+ 175108, 11,
+ 273291, 12,
+ 400186, 13,
+ 581243, 14,
+ 811568, 15,
+ 1121004, 16,
+ 1506550, 17,
+ 2002767, 18,
+ 2611612, 19,
+ 3369180, 20};
+
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0}, 1E-8);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[] {
+ 2152.32624678170,
+ 2363.55173469681, 779.343524331583,
+ 101.475507550350, 5.64566512170752,
+ 0.112324854679312}, 1E-8); //
+
+ TestUtils.assertEquals(.999995559025820, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(5570284.53333333, mdl.estimateErrorVariance(), 1.0e-6);
+ TestUtils.assertEquals(83554268.0000000, mdl.calculateResidualSumOfSquares(), 1.0e-5);
+ }
+
+ /**
+ * This is a test based on the Wampler4 data set
+ * http://www.itl.nist.gov/div898/strd/lls/data/Wampler4.shtml
+ */
+ @Test
+ public void testWampler4() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ double[] data = new double[] {
+ 75901, 0,
+ -204794, 1,
+ 204863, 2,
+ -204436, 3,
+ 253665, 4,
+ -200894, 5,
+ 214131, 6,
+ -185192, 7,
+ 221249, 8,
+ -138370, 9,
+ 315911, 10,
+ -27644, 11,
+ 455253, 12,
+ 197434, 13,
+ 783995, 14,
+ 608816, 15,
+ 1370781, 16,
+ 1303798, 17,
+ 2205519, 18,
+ 2408860, 19,
+ 3444321, 20};
+
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+ final int nvars = 5;
+ final int nobs = 21;
+ double[] tmp = new double[(nvars + 1) * nobs];
+ int off = 0;
+ int off2 = 0;
+ for (int i = 0; i < nobs; i++) {
+ tmp[off2] = data[off];
+ tmp[off2 + 1] = data[off + 1];
+ tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
+ tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
+ tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
+ tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
+ off2 += (nvars + 1);
+ off += 2;
+ }
+ mdl.newSampleData(tmp, nobs, nvars, new SparseDistributedMatrix());
+ double[] betaHat = mdl.estimateRegressionParameters();
+ TestUtils.assertEquals(betaHat,
+ new double[] {
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0}, 1E-6);
+
+ double[] se = mdl.estimateRegressionParametersStandardErrors();
+ TestUtils.assertEquals(se,
+ new double[] {
+ 215232.624678170,
+ 236355.173469681, 77934.3524331583,
+ 10147.5507550350, 564.566512170752,
+ 11.2324854679312}, 1E-8);
+
+ TestUtils.assertEquals(.957478440825662, mdl.calculateRSquared(), 1.0e-10);
+ TestUtils.assertEquals(55702845333.3333, mdl.estimateErrorVariance(), 1.0e-4);
+ TestUtils.assertEquals(835542680000.000, mdl.calculateResidualSumOfSquares(), 1.0e-3);
+ }
+
+ /**
+ * Anything requiring beta calculation should advertise SME.
+ */
+ public void testSingularCalculateBeta() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15);
+ mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseDistributedMatrix());
+
+ try {
+ mdl.calculateBeta();
+ fail("SingularMatrixException");
+ }
+ catch (SingularMatrixException e) {
+ return;
+ }
+ fail("SingularMatrixException");
+
+ }
+
+ /** */
+ public void testNoDataNPECalculateBeta() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ try {
+ mdl.calculateBeta();
+ fail("java.lang.NullPointerException");
+ }
+ catch (NullPointerException e) {
+ return;
+ }
+ fail("java.lang.NullPointerException");
+
+ }
+
+ /** */
+ public void testNoDataNPECalculateHat() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ try {
+ mdl.calculateHat();
+ fail("java.lang.NullPointerException");
+ }
+ catch (NullPointerException e) {
+ return;
+ }
+ fail("java.lang.NullPointerException");
+ }
+
+
+ public void testNoDataNPESSTO() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+ try {
+ mdl.calculateTotalSumOfSquares();
+ fail("java.lang.NullPointerException");
+ }
+ catch (NullPointerException e) {
+ return;
+ }
+ fail("java.lang.NullPointerException");
+
+
+ }
+
+ /** */
+ public void testMathIllegalArgumentException() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
+
+
+ try {
+ mdl.validateSampleData(new SparseDistributedMatrix(1, 2), new SparseDistributedVector(1));
+ fail("MathIllegalArgumentException");
+ }
+ catch (MathIllegalArgumentException e) {
+ return;
+ }
+ fail("MathIllegalArgumentException");
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java
index 4be7336..2774028 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java
@@ -418,6 +418,7 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
Matrix hat = mdl.calculateHat();
+
// Reference data is upper half of symmetric hat matrix
double[] refData = new double[] {
.418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242,
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index a54a4e3..2a0b111 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -25,7 +25,7 @@ import org.junit.runners.Suite;
*/
@RunWith(Suite.class)
@Suite.SuiteClasses({
- OLSMultipleLinearRegressionTest.class
+ OLSMultipleLinearRegressionTest.class, DistributedOLSMultipleLinearRegressionTest.class, DistributedBlockOLSMultipleLinearRegressionTest.class
})
public class RegressionsTestSuite {
// No-op.
[2/3] ignite git commit: IGNITE-5846 Add support of distributed
matrices for OLS regression. This closes #3030.
Posted by nt...@apache.org.
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java
index cd76e5a..411b038 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java
@@ -17,14 +17,6 @@
package org.apache.ignite.ml.math.impls.storage.matrix;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Set;
-import java.util.UUID;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.CacheAtomicityMode;
@@ -38,16 +30,22 @@ import org.apache.ignite.ml.math.MatrixStorage;
import org.apache.ignite.ml.math.StorageConstants;
import org.apache.ignite.ml.math.distributed.CacheUtils;
import org.apache.ignite.ml.math.distributed.DistributedStorage;
-import org.apache.ignite.ml.math.distributed.keys.impl.BlockMatrixKey;
-import org.apache.ignite.ml.math.impls.matrix.BlockEntry;
+import org.apache.ignite.ml.math.distributed.keys.impl.MatrixBlockKey;
+import org.apache.ignite.ml.math.impls.matrix.MatrixBlockEntry;
import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
-import static org.apache.ignite.ml.math.impls.matrix.BlockEntry.MAX_BLOCK_SIZE;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.*;
+
+import static org.apache.ignite.ml.math.impls.matrix.MatrixBlockEntry.MAX_BLOCK_SIZE;
/**
* Storage for {@link SparseBlockDistributedMatrix}.
*/
-public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, StorageConstants, DistributedStorage<BlockMatrixKey> {
+public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, StorageConstants, DistributedStorage<MatrixBlockKey> {
/** Cache name used for all instances of {@link BlockMatrixStorage}. */
private static final String CACHE_NAME = "ML_BLOCK_SPARSE_MATRICES_CONTAINER";
/** */
@@ -65,8 +63,8 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
/** Actual distributed storage. */
private IgniteCache<
- BlockMatrixKey /* Matrix block number with uuid. */,
- BlockEntry /* Block of matrix, local sparse matrix. */
+ MatrixBlockKey /* Matrix block number with uuid. */,
+ MatrixBlockEntry /* Block of matrix, local sparse matrix. */
> cache = null;
/**
@@ -98,7 +96,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
/**
*
*/
- public IgniteCache<BlockMatrixKey, BlockEntry> cache() {
+ public IgniteCache<MatrixBlockKey, MatrixBlockEntry> cache() {
return cache;
}
@@ -132,20 +130,6 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
return RANDOM_ACCESS_MODE;
}
- /**
- * @return Blocks in column.
- */
- public int blocksInCol() {
- return blocksInCol;
- }
-
- /**
- * @return Blocks in row.
- */
- public int blocksInRow() {
- return blocksInRow;
- }
-
/** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(rows);
@@ -210,8 +194,8 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
*
* NB: NOT cell indices.
*/
- public BlockMatrixKey getCacheKey(long blockIdRow, long blockIdCol) {
- return new BlockMatrixKey(blockIdRow, blockIdCol, uuid, getAffinityKey(blockIdRow, blockIdCol));
+ public MatrixBlockKey getCacheKey(long blockIdRow, long blockIdCol) {
+ return new MatrixBlockKey(blockIdRow, blockIdCol, uuid, getAffinityKey(blockIdRow, blockIdCol));
}
/**
@@ -219,17 +203,17 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
*
* NB: NOT cell indices.
*/
- public BlockMatrixKey getCacheKey(IgnitePair<Long> blockId) {
- return new BlockMatrixKey(blockId.get1(), blockId.get2(), uuid, getAffinityKey(blockId.get1(), blockId.get2()));
+ private MatrixBlockKey getCacheKey(IgnitePair<Long> blockId) {
+ return new MatrixBlockKey(blockId.get1(), blockId.get2(), uuid, getAffinityKey(blockId.get1(), blockId.get2()));
}
/** {@inheritDoc} */
- @Override public Set<BlockMatrixKey> getAllKeys() {
+ @Override public Set<MatrixBlockKey> getAllKeys() {
int maxRowIdx = rows - 1;
int maxColIdx = cols - 1;
IgnitePair<Long> maxBlockId = getBlockId(maxRowIdx, maxColIdx);
- Set<BlockMatrixKey> keyset = new HashSet<>();
+ Set<MatrixBlockKey> keyset = new HashSet<>();
for(int i = 0; i <= maxBlockId.get1(); i++)
for(int j = 0; j <= maxBlockId.get2(); j++)
@@ -249,8 +233,8 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
* @param blockId block id.
* @return The list of block entries.
*/
- public List<BlockEntry> getRowForBlock(IgnitePair<Long> blockId) {
- List<BlockEntry> res = new LinkedList<>();
+ public List<MatrixBlockEntry> getRowForBlock(IgnitePair<Long> blockId) {
+ List<MatrixBlockEntry> res = new LinkedList<>();
for (int i = 0; i < blocksInCol; i++)
res.add(getEntryById(new IgnitePair<>(blockId.get1(), (long) i)));
@@ -265,8 +249,8 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
* @param blockId block id.
* @return The list of block entries.
*/
- public List<BlockEntry> getColForBlock(IgnitePair<Long> blockId) {
- List<BlockEntry> res = new LinkedList<>();
+ public List<MatrixBlockEntry> getColForBlock(IgnitePair<Long> blockId) {
+ List<MatrixBlockEntry> res = new LinkedList<>();
for (int i = 0; i < blocksInRow; i++)
res.add(getEntryById(new IgnitePair<>((long) i, blockId.get2())));
@@ -308,10 +292,10 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
* @param blockId blockId
* @return BlockEntry
*/
- private BlockEntry getEntryById(IgnitePair<Long> blockId) {
- BlockMatrixKey key = getCacheKey(blockId.get1(), blockId.get2());
+ private MatrixBlockEntry getEntryById(IgnitePair<Long> blockId) {
+ MatrixBlockKey key = getCacheKey(blockId.get1(), blockId.get2());
- BlockEntry entry = cache.localPeek(key, CachePeekMode.PRIMARY);
+ MatrixBlockEntry entry = cache.localPeek(key, CachePeekMode.PRIMARY);
entry = entry != null ? entry : cache.get(key);
if (entry == null)
@@ -325,8 +309,8 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
* @param blockId blockId
* @return Empty BlockEntry
*/
- private BlockEntry getEmptyBlockEntry(IgnitePair<Long> blockId) {
- BlockEntry entry;
+ private MatrixBlockEntry getEmptyBlockEntry(IgnitePair<Long> blockId) {
+ MatrixBlockEntry entry;
int rowMod = rows % maxBlockEdge;
int colMod = cols % maxBlockEdge;
@@ -345,7 +329,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
else
colSize = blockId.get2() != (blocksInCol - 1) ? maxBlockEdge : colMod;
- entry = new BlockEntry(rowSize, colSize);
+ entry = new MatrixBlockEntry(rowSize, colSize);
return entry;
}
@@ -354,7 +338,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
*
* Get affinity key for the given id.
*/
- private IgniteUuid getAffinityKey(long blockIdRow, long blockIdCol) {
+ private UUID getAffinityKey(long blockIdRow, long blockIdCol) {
return null;
}
@@ -368,13 +352,13 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
private void matrixSet(int a, int b, double v) {
IgnitePair<Long> blockId = getBlockId(a, b);
// Remote set on the primary node (where given row or column is stored locally).
- ignite().compute(groupForKey(CACHE_NAME, blockId)).run(() -> {
- IgniteCache<BlockMatrixKey, BlockEntry> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
+ ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, blockId)).run(() -> {
+ IgniteCache<MatrixBlockKey, MatrixBlockEntry> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
- BlockMatrixKey key = getCacheKey(blockId.get1(), blockId.get2());
+ MatrixBlockKey key = getCacheKey(blockId.get1(), blockId.get2());
// Local get.
- BlockEntry block = getEntryById(blockId);
+ MatrixBlockEntry block = getEntryById(blockId);
block.set(a % block.rowSize(), b % block.columnSize(), v);
@@ -402,13 +386,13 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
*/
private double matrixGet(int a, int b) {
// Remote get from the primary node (where given row or column is stored locally).
- return ignite().compute(groupForKey(CACHE_NAME, getBlockId(a, b))).call(() -> {
- IgniteCache<BlockMatrixKey, BlockEntry> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
+ return ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, getBlockId(a, b))).call(() -> {
+ IgniteCache<MatrixBlockKey, MatrixBlockEntry> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
- BlockMatrixKey key = getCacheKey(getBlockId(a, b));
+ MatrixBlockKey key = getCacheKey(getBlockId(a, b));
// Local get.
- BlockEntry block = cache.localPeek(key, CachePeekMode.PRIMARY);
+ MatrixBlockEntry block = cache.localPeek(key, CachePeekMode.PRIMARY);
if (block == null)
block = cache.get(key);
@@ -420,8 +404,8 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto
/**
* Create new ML cache if needed.
*/
- private IgniteCache<BlockMatrixKey, BlockEntry> newCache() {
- CacheConfiguration<BlockMatrixKey, BlockEntry> cfg = new CacheConfiguration<>();
+ private IgniteCache<MatrixBlockKey, MatrixBlockEntry> newCache() {
+ CacheConfiguration<MatrixBlockKey, MatrixBlockEntry> cfg = new CacheConfiguration<>();
// Write to primary.
cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockVectorStorage.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockVectorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockVectorStorage.java
new file mode 100644
index 0000000..a44ed8e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockVectorStorage.java
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.storage.matrix;
+
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.CacheAtomicityMode;
+import org.apache.ignite.cache.CacheMode;
+import org.apache.ignite.cache.CachePeekMode;
+import org.apache.ignite.cache.CacheWriteSynchronizationMode;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.lang.IgnitePair;
+import org.apache.ignite.internal.util.typedef.internal.U;
+import org.apache.ignite.lang.IgniteUuid;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.VectorStorage;
+import org.apache.ignite.ml.math.distributed.CacheUtils;
+import org.apache.ignite.ml.math.distributed.DistributedStorage;
+import org.apache.ignite.ml.math.distributed.keys.impl.VectorBlockKey;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+import org.apache.ignite.ml.math.impls.vector.VectorBlockEntry;
+import org.jetbrains.annotations.NotNull;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.*;
+
+import static org.apache.ignite.ml.math.impls.matrix.MatrixBlockEntry.MAX_BLOCK_SIZE;
+
+/**
+ * Storage for {@link SparseBlockDistributedVector}.
+ */
+public class BlockVectorStorage extends CacheUtils implements VectorStorage, StorageConstants, DistributedStorage<VectorBlockKey> {
+ /** Cache name used for all instances of {@link BlockVectorStorage}. */
+ private static final String CACHE_NAME = "ML_BLOCK_SPARSE_MATRICES_CONTAINER";
+ /** */
+ private int blocks;
+ /** Amount of columns in the vector. */
+ private int size;
+ /** Matrix uuid. */
+ private UUID uuid;
+ /** Block size about 8 KB of data. */
+ private int maxBlockEdge = MAX_BLOCK_SIZE;
+
+ /** Actual distributed storage. */
+ private IgniteCache<
+ VectorBlockKey /* Matrix block number with uuid. */,
+ VectorBlockEntry /* Block of matrix, local sparse matrix. */
+ > cache = null;
+
+ /**
+ *
+ */
+ public BlockVectorStorage() {
+ // No-op.
+ }
+
+ /**
+ * @param size Amount of columns in the vector.
+ */
+ public BlockVectorStorage(int size) {
+
+ assert size > 0;
+
+ this.size = size;
+
+ this.blocks = size % maxBlockEdge == 0 ? size / maxBlockEdge : size / maxBlockEdge + 1;
+
+ cache = newCache();
+
+ uuid = UUID.randomUUID();
+ }
+
+ /**
+ *
+ */
+ public IgniteCache<VectorBlockKey, VectorBlockEntry> cache() {
+ return cache;
+ }
+
+ /** {@inheritDoc} */
+ @Override public double get(int x) {
+ return matrixGet(x);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void set(int x, double v) {
+ matrixSet(x, v);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int size() {
+ return size;
+ }
+
+
+ /**
+ * @return Blocks in row.
+ */
+ public int blocksInRow() {
+ return blocks;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(size);
+ out.writeInt(blocks);
+ out.writeObject(uuid);
+ out.writeUTF(cache.getName());
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ size = in.readInt();
+ blocks = in.readInt();
+ uuid = (UUID) in.readObject();
+
+ cache = ignite().getOrCreateCache(in.readUTF());
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isSequentialAccess() {
+ return false;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isDense() {
+ return false;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isRandomAccess() {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isDistributed() {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isArrayBased() {
+ return false;
+ }
+
+ /** Delete all data from cache. */
+ @Override public void destroy() {
+ cache.clearAll(getAllKeys());
+ }
+
+ /**
+ * Get storage UUID.
+ *
+ * @return storage UUID.
+ */
+ public UUID getUUID() {
+ return uuid;
+ }
+
+ /**
+ * Build the cache key for the given blocks id.
+ *
+ * NB: NOT cell indices.
+ */
+ public VectorBlockKey getCacheKey(long blockId) {
+ return new VectorBlockKey(blockId, uuid, getAffinityKey(blockId));
+ }
+
+
+ /** {@inheritDoc} */
+ @Override public Set<VectorBlockKey> getAllKeys() {
+ int maxIndex = size - 1;
+ long maxBlockId = getBlockId(maxIndex);
+
+ Set<VectorBlockKey> keyset = new HashSet<>();
+
+ for (int i = 0; i <= maxBlockId; i++)
+ keyset.add(getCacheKey(i));
+
+ return keyset;
+ }
+
+ /** {@inheritDoc} */
+ @Override public String cacheName() {
+ return CACHE_NAME;
+ }
+
+
+ /**
+ * Get column for current block.
+ *
+ * @param blockId block id.
+ * @return The list of block entries.
+ */
+ public List<VectorBlockEntry> getColForBlock(long blockId) {
+ List<VectorBlockEntry> res = new LinkedList<>();
+
+ for (int i = 0; i < blocks; i++)
+ res.add(getEntryById(i));
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = 1;
+
+ res = res * 37 + size;
+ res = res * 37 + uuid.hashCode();
+ res = res * 37 + cache.hashCode();
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+
+ if (obj == null || getClass() != obj.getClass())
+ return false;
+
+ BlockVectorStorage that = (BlockVectorStorage) obj;
+
+ return size == that.size && uuid.equals(that.uuid)
+ && (cache != null ? cache.equals(that.cache) : that.cache == null);
+ }
+
+ /**
+ *
+ */
+ private VectorBlockEntry getEntryById(long blockId) {
+ VectorBlockKey key = getCacheKey(blockId);
+
+ VectorBlockEntry entry = cache.localPeek(key, CachePeekMode.PRIMARY);
+ entry = entry != null ? entry : cache.get(key);
+
+ if (entry == null)
+ entry = getEmptyBlockEntry(blockId);
+
+ return entry;
+ }
+
+ @NotNull
+ private VectorBlockEntry getEmptyBlockEntry(long blockId) {
+ VectorBlockEntry entry;
+ int colMod = size % maxBlockEdge;
+
+ int colSize;
+
+ if (colMod == 0)
+ colSize = maxBlockEdge;
+ else
+ colSize = blockId != (blocks - 1) ? maxBlockEdge : colMod;
+
+ entry = new VectorBlockEntry(colSize);
+ return entry;
+ }
+
+ /**
+ * TODO: IGNITE-5646, WIP
+ *
+ * Get affinity key for the given id.
+ */
+ private UUID getAffinityKey(long blockId) {
+ return null;
+ }
+
+ /**
+ * Distributed matrix set.
+ *
+ * @param idx Row or column index.
+ * @param v New value to set.
+ */
+ private void matrixSet(int idx, double v) {
+ long blockId = getBlockId(idx);
+ // Remote set on the primary node (where given row or column is stored locally).
+ ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, blockId)).run(() -> {
+ IgniteCache<VectorBlockKey, VectorBlockEntry> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
+
+ VectorBlockKey key = getCacheKey(blockId);
+
+ // Local get.
+ VectorBlockEntry block = getEntryById(blockId);
+
+ block.set(idx % block.size(), v);
+
+ // Local put.
+ cache.put(key, block);
+ });
+ }
+
+ /** */
+ private long getBlockId(int x) {
+ return (long) x / maxBlockEdge;
+ }
+
+ /**
+ * Distributed vector get.
+ *
+ * @param idx index.
+ * @return Vector value at (idx) index.
+ */
+ private double matrixGet(int idx) {
+ // Remote get from the primary node (where given row or column is stored locally).
+ return ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, getBlockId(idx))).call(() -> {
+ IgniteCache<VectorBlockKey, VectorBlockEntry> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
+
+ VectorBlockKey key = getCacheKey(getBlockId(idx));
+
+ // Local get.
+ VectorBlockEntry block = cache.localPeek(key, CachePeekMode.PRIMARY);
+
+ if (block == null)
+ block = cache.get(key);
+
+ return block == null ? 0.0 : block.get(idx % block.size());
+ });
+ }
+
+ /**
+ * Create new ML cache if needed.
+ */
+ private IgniteCache<VectorBlockKey, VectorBlockEntry> newCache() {
+ CacheConfiguration<VectorBlockKey, VectorBlockEntry> cfg = new CacheConfiguration<>();
+
+ // Write to primary.
+ cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
+
+ // Atomic transactions only.
+ cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
+
+ // No eviction.
+ cfg.setEvictionPolicy(null);
+
+ // No copying of values.
+ cfg.setCopyOnRead(false);
+
+ // Cache is partitioned.
+ cfg.setCacheMode(CacheMode.PARTITIONED);
+
+ // Random cache name.
+ cfg.setName(CACHE_NAME);
+
+ return Ignition.localIgnite().getOrCreateCache(cfg);
+ }
+
+ /**
+ * Avoid this method for large vectors
+ *
+ * @return data presented as array
+ */
+ @Override public double[] data() {
+ double[] res = new double[this.size];
+ for (int i = 0; i < this.size; i++) res[i] = this.get(i);
+ return res;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/MapWrapperStorage.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/MapWrapperStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/MapWrapperStorage.java
index 4648421..91db30e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/MapWrapperStorage.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/MapWrapperStorage.java
@@ -17,13 +17,14 @@
package org.apache.ignite.ml.math.impls.storage.matrix;
+import org.apache.ignite.internal.util.GridArgumentCheck;
+import org.apache.ignite.ml.math.VectorStorage;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Map;
import java.util.Set;
-import org.apache.ignite.internal.util.GridArgumentCheck;
-import org.apache.ignite.ml.math.VectorStorage;
/**
* Storage for wrapping given map.
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java
index c40e73d..e976899 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java
@@ -19,14 +19,6 @@ package org.apache.ignite.ml.math.impls.storage.matrix;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleRBTreeMap;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Map;
-import java.util.Set;
-import java.util.UUID;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.CacheAtomicityMode;
@@ -42,6 +34,15 @@ import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey;
import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
/**
* {@link MatrixStorage} implementation for {@link SparseDistributedMatrix}.
*/
@@ -161,7 +162,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix
*/
private double matrixGet(int a, int b) {
// Remote get from the primary node (where given row or column is stored locally).
- return ignite().compute(groupForKey(CACHE_NAME, a)).call(() -> {
+ return ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, a)).call(() -> {
IgniteCache<RowColMatrixKey, Map<Integer, Double>> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
// Local get.
@@ -183,7 +184,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix
*/
private void matrixSet(int a, int b, double v) {
// Remote set on the primary node (where given row or column is stored locally).
- ignite().compute(groupForKey(CACHE_NAME, a)).run(() -> {
+ ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, a)).run(() -> {
IgniteCache<RowColMatrixKey, Map<Integer, Double>> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
// Local get.
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorage.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorage.java
new file mode 100644
index 0000000..8f79413
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorage.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.storage.vector;
+
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.CacheAtomicityMode;
+import org.apache.ignite.cache.CacheMode;
+import org.apache.ignite.cache.CacheWriteSynchronizationMode;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.lang.IgniteUuid;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.VectorStorage;
+import org.apache.ignite.ml.math.distributed.CacheUtils;
+import org.apache.ignite.ml.math.distributed.DistributedStorage;
+import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
+import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey;
+import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Set;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * {@link VectorStorage} implementation for {@link /*SparseDistributedVector}.
+ */
+public class SparseDistributedVectorStorage extends CacheUtils implements VectorStorage, StorageConstants, DistributedStorage<RowColMatrixKey> {
+ /** Cache name used for all instances of {@link SparseDistributedVectorStorage}. */
+ private static final String CACHE_NAME = "ML_SPARSE_VECTORS_CONTAINER";
+ /** Amount of elements in the vector. */
+ private int size;
+ /** Random or sequential access mode. */
+ private int acsMode;
+ /** Matrix uuid. */
+ private UUID uuid;
+
+ /** Actual distributed storage. */
+ private IgniteCache<RowColMatrixKey, Double> cache = null;
+
+ /**
+ *
+ */
+ public SparseDistributedVectorStorage() {
+ // No-op.
+ }
+
+ /**
+ * @param size Amount of elements in the vector.
+ * @param acsMode Random or sequential access mode.
+ */
+ public SparseDistributedVectorStorage(int size, int acsMode) {
+
+ assert size > 0;
+ assertAccessMode(acsMode);
+
+ this.size = size;
+ this.acsMode = acsMode;
+
+ cache = newCache();
+
+ uuid = UUID.randomUUID();
+ }
+
+ /**
+ * Create new ML cache if needed.
+ */
+ private IgniteCache<RowColMatrixKey, Double> newCache() {
+ CacheConfiguration<RowColMatrixKey, Double> cfg = new CacheConfiguration<>();
+
+ // Write to primary.
+ cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
+
+ // Atomic transactions only.
+ cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
+
+ // No eviction.
+ cfg.setEvictionPolicy(null);
+
+ // No copying of values.
+ cfg.setCopyOnRead(false);
+
+ // Cache is partitioned.
+ cfg.setCacheMode(CacheMode.PARTITIONED);
+
+ // Random cache name.
+ cfg.setName(CACHE_NAME);
+
+ return Ignition.localIgnite().getOrCreateCache(cfg);
+ }
+
+ /**
+ * Gets cache
+ *
+ * @return cache
+ */
+ public IgniteCache<RowColMatrixKey, Double> cache() {
+ return cache;
+ }
+
+ /**
+ * Gets access mode
+ *
+ * @return code of access mode
+ */
+ public int accessMode() {
+ return acsMode;
+ }
+
+ /**
+ * Gets vector element by element index
+ *
+ * @param i Vector element index.
+ * @return vector element
+ */
+ @Override public double get(int i) {
+ // Remote get from the primary node (where given row or column is stored locally).
+ return ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, getCacheKey(i))).call(() -> {
+ IgniteCache<RowColMatrixKey, Double> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
+ Double res = cache.get(getCacheKey(i));
+ if (res == null) return 0.0;
+ return res;
+ });
+ }
+
+ /**
+ * Sets vector element by index
+ *
+ * @param i Vector element index.
+ * @param v Value to set at given index.
+ */
+ @Override public void set(int i, double v) {
+ // Remote set on the primary node (where given row or column is stored locally).
+ ignite().compute(getClusterGroupForGivenKey(CACHE_NAME, getCacheKey(i))).run(() -> {
+ IgniteCache<RowColMatrixKey, Double> cache = Ignition.localIgnite().getOrCreateCache(CACHE_NAME);
+
+ RowColMatrixKey cacheKey = getCacheKey(i);
+
+ if (v != 0.0)
+ cache.put(cacheKey, v);
+ else if (cache.containsKey(cacheKey)) // remove zero elements
+ cache.remove(cacheKey);
+
+ });
+ }
+
+
+ /** {@inheritDoc} */
+ @Override public int size() {
+ return size;
+ }
+
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(size);
+ out.writeInt(acsMode);
+ out.writeObject(uuid);
+ out.writeUTF(cache.getName());
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ size = in.readInt();
+ acsMode = in.readInt();
+ uuid = (UUID) in.readObject();
+ cache = ignite().getOrCreateCache(in.readUTF());
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isSequentialAccess() {
+ return acsMode == SEQUENTIAL_ACCESS_MODE;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isDense() {
+ return false;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isRandomAccess() {
+ return acsMode == RANDOM_ACCESS_MODE;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isDistributed() {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isArrayBased() {
+ return false;
+ }
+
+ /** Delete all data from cache. */
+ @Override public void destroy() {
+ Set<RowColMatrixKey> keyset = IntStream.range(0, size).mapToObj(this::getCacheKey).collect(Collectors.toSet());
+ cache.clearAll(keyset);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = 1;
+
+ res = res * 37 + size;
+ res = res * 37 + acsMode;
+ res = res * 37 + uuid.hashCode();
+ res = res * 37 + cache.hashCode();
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+
+ if (obj == null || getClass() != obj.getClass())
+ return false;
+
+ SparseDistributedVectorStorage that = (SparseDistributedVectorStorage) obj;
+
+ return size == that.size && acsMode == that.acsMode
+ && uuid.equals(that.uuid) && (cache != null ? cache.equals(that.cache) : that.cache == null);
+ }
+
+ /**
+ * Builds cache key for vector element
+ *
+ * @param idx Index
+ * @return RowColMatrixKey
+ */
+ public RowColMatrixKey getCacheKey(int idx) {
+ return new SparseMatrixKey(idx, uuid, null);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Set<RowColMatrixKey> getAllKeys() {
+ int range = size;
+
+ return IntStream.range(0, range).mapToObj(i -> new SparseMatrixKey(i, getUUID(), null)).collect(Collectors.toSet());
+ }
+
+ /** {@inheritDoc} */
+ @Override public String cacheName() {
+ return CACHE_NAME;
+ }
+
+ /** */
+ public UUID getUUID() {
+ return uuid;
+ }
+
+ @Override
+ public double[] data() {
+ double[] result = new double[this.size];
+ for (int i = 0; i < this.size; i++) result[i] = this.get(i);
+ return result;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVector.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVector.java
new file mode 100644
index 0000000..e460f9f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVector.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.vector;
+
+import org.apache.ignite.lang.IgniteUuid;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.distributed.CacheUtils;
+import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
+import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
+import org.apache.ignite.ml.math.impls.storage.matrix.BlockMatrixStorage;
+import org.apache.ignite.ml.math.impls.storage.matrix.BlockVectorStorage;
+import org.apache.ignite.ml.math.impls.storage.vector.SparseDistributedVectorStorage;
+
+import java.util.UUID;
+
+/**
+ * Sparse distributed vector implementation based on data grid.
+ * <p>
+ * Unlike {@link CacheVector} that is based on existing cache, this implementation creates distributed
+ * cache internally and doesn't rely on pre-existing cache.</p>
+ * <p>
+ * You also need to call {@link #destroy()} to remove the underlying cache when you no longer need this
+ * vector.</p>
+ * <p>
+ * <b>Currently fold supports only commutative operations.<b/></p>
+ */
+public class SparseBlockDistributedVector extends AbstractVector implements StorageConstants {
+ /**
+ *
+ */
+ public SparseBlockDistributedVector() {
+ // No-op.
+ }
+
+ /**
+ * @param size Vector size
+ */
+ public SparseBlockDistributedVector(int size) {
+
+ assert size > 0;
+ setStorage(new BlockVectorStorage(size));
+ }
+
+
+ /**
+ * @param data Data to fill storage
+ */
+ public SparseBlockDistributedVector(double[] data) {
+ setStorage(new BlockVectorStorage(data.length));
+ for (int i = 0; i < data.length; i++) {
+ double val = data[i];
+ if (val != 0.0) storage().set(i, val);
+ }
+ }
+
+
+ /** */
+ public BlockVectorStorage storage() {
+ return (BlockVectorStorage) getStorage();
+ }
+
+ /**
+ * Return the same matrix with updates values (broken contract).
+ *
+ * @param d Value to divide to.
+ */
+ @Override public Vector divide(double d) {
+ return mapOverValues(v -> v / d);
+ }
+
+ @Override public Vector like(int size) {
+ return new SparseBlockDistributedVector(size);
+ }
+
+ @Override public Matrix likeMatrix(int rows, int cols) {
+ return new SparseBlockDistributedMatrix(rows, cols);
+ }
+
+ /**
+ * Return the same matrix with updates values (broken contract).
+ *
+ * @param x Value to add.
+ */
+ @Override public Vector plus(double x) {
+ return mapOverValues(v -> v + x);
+ }
+
+ /**
+ * Return the same matrix with updates values (broken contract).
+ *
+ * @param x Value to multiply.
+ */
+ @Override public Vector times(double x) {
+ return mapOverValues(v -> v * x);
+ }
+
+
+ /** {@inheritDoc} */
+ @Override public Vector assign(double val) {
+ return mapOverValues(v -> val);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Vector map(IgniteDoubleFunction<Double> fun) {
+ return mapOverValues(fun);
+ }
+
+ /**
+ * @param mapper Mapping function.
+ * @return Vector with mapped values.
+ */
+ private Vector mapOverValues(IgniteDoubleFunction<Double> mapper) {
+ CacheUtils.sparseMapForVector(getUUID(), mapper, storage().cacheName());
+
+ return this;
+ }
+
+ /** */
+ public UUID getUUID() {
+ return ((BlockVectorStorage) getStorage()).getUUID();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVector.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVector.java
new file mode 100644
index 0000000..b773bfa
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVector.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.vector;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.Affinity;
+import org.apache.ignite.cluster.ClusterNode;
+import org.apache.ignite.lang.IgniteUuid;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.distributed.CacheUtils;
+import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
+import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
+import org.apache.ignite.ml.math.impls.matrix.*;
+import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage;
+import org.apache.ignite.ml.math.impls.storage.vector.SparseDistributedVectorStorage;
+
+import java.util.Collection;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * Sparse distributed vector implementation based on data grid.
+ * <p>
+ * Unlike {@link CacheVector} that is based on existing cache, this implementation creates distributed
+ * cache internally and doesn't rely on pre-existing cache.</p>
+ * <p>
+ * You also need to call {@link #destroy()} to remove the underlying cache when you no longer need this
+ * vector.</p>
+ * <p>
+ * <b>Currently fold supports only commutative operations.<b/></p>
+ */
+public class SparseDistributedVector extends AbstractVector implements StorageConstants {
+ /**
+ *
+ */
+ public SparseDistributedVector() {
+ // No-op.
+ }
+
+ /**
+ * @param size Vector size.
+ * @param acsMode Vector elements access mode..
+ */
+ public SparseDistributedVector(int size, int acsMode) {
+
+ assert size > 0;
+ assertAccessMode(acsMode);
+
+
+ setStorage(new SparseDistributedVectorStorage(size, acsMode));
+ }
+
+ public SparseDistributedVector(int size) {
+ this(size, StorageConstants.RANDOM_ACCESS_MODE);
+ }
+
+ /**
+ * @param data
+ */
+ public SparseDistributedVector(double[] data) {
+ setStorage(new SparseDistributedVectorStorage(data.length, StorageConstants.RANDOM_ACCESS_MODE));
+ for (int i = 0; i < data.length; i++) {
+ double value = data[i];
+ if (value != 0.0) storage().set(i, value);
+ }
+ }
+
+
+ /** */
+ public SparseDistributedVectorStorage storage() {
+ return (SparseDistributedVectorStorage) getStorage();
+ }
+
+ /**
+ * Return the same matrix with updates values (broken contract).
+ *
+ * @param d Value to divide to.
+ */
+ @Override public Vector divide(double d) {
+ return mapOverValues(v -> v / d);
+ }
+
+ @Override
+ public Vector like(int size) {
+ return new SparseDistributedVector(size, storage().accessMode());
+ }
+
+ @Override
+ public Matrix likeMatrix(int rows, int cols) {
+ return new SparseDistributedMatrix(rows, cols);
+ }
+
+ /**
+ * Return the same matrix with updates values (broken contract).
+ *
+ * @param x Value to add.
+ */
+ @Override public Vector plus(double x) {
+ return mapOverValues(v -> v + x);
+ }
+
+ /**
+ * Return the same matrix with updates values (broken contract).
+ *
+ * @param x Value to multiply.
+ */
+ @Override public Vector times(double x) {
+ return mapOverValues(v -> v * x);
+ }
+
+
+ /** {@inheritDoc} */
+ @Override public Vector assign(double val) {
+ return mapOverValues(v -> val);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Vector map(IgniteDoubleFunction<Double> fun) {
+ return mapOverValues(fun);
+ }
+
+ /**
+ * @param mapper Mapping function.
+ * @return Vector with mapped values.
+ */
+ private Vector mapOverValues(IgniteDoubleFunction<Double> mapper) {
+ CacheUtils.sparseMapForVector(getUUID(), mapper, storage().cacheName());
+
+ return this;
+ }
+
+ /** */
+ public UUID getUUID() {
+ return ((SparseDistributedVectorStorage) getStorage()).getUUID();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/VectorBlockEntry.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/VectorBlockEntry.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/VectorBlockEntry.java
new file mode 100644
index 0000000..ad795c4
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/vector/VectorBlockEntry.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.vector;
+
+
+import org.apache.ignite.ml.math.Vector;
+
+
+/**
+ * Block for {@link SparseBlockDistributedVector}.
+ */
+public final class VectorBlockEntry extends SparseLocalVector {
+ /** Max block size. */
+ public static final int MAX_BLOCK_SIZE = 32;
+
+ /** */
+ public VectorBlockEntry() {
+ // No-op.
+ }
+
+ /** */
+ public VectorBlockEntry(int size) {
+ super(size, RANDOM_ACCESS_MODE);
+ assert size <= MAX_BLOCK_SIZE;
+ }
+
+ /** */
+ public VectorBlockEntry(Vector v) {
+ assert v.size() <= MAX_BLOCK_SIZE;
+
+ setStorage(v.getStorage());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
index c0a57d7..0ab568c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
@@ -71,7 +71,7 @@ public class MatrixUtil {
* @return Like matrix.
*/
public static Matrix like(Matrix matrix, int rows, int cols) {
- if (isCopyLikeSupport(matrix) || isDistributed(matrix))
+ if (isCopyLikeSupport(matrix))
return new DenseLocalOnHeapMatrix(rows, cols);
else
return matrix.like(rows, cols);
@@ -85,7 +85,7 @@ public class MatrixUtil {
* @return Like vector.
*/
public static Vector likeVector(Matrix matrix, int crd) {
- if (isCopyLikeSupport(matrix) || isDistributed(matrix))
+ if (isCopyLikeSupport(matrix))
return new DenseLocalOnHeapVector(crd);
else
return matrix.likeVector(crd);
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplDistributedTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplDistributedTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplDistributedTestSuite.java
index 5dc860c..784a455 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplDistributedTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplDistributedTestSuite.java
@@ -21,7 +21,10 @@ import org.apache.ignite.ml.math.impls.matrix.CacheMatrixTest;
import org.apache.ignite.ml.math.impls.matrix.SparseDistributedBlockMatrixTest;
import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrixTest;
import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorageTest;
+import org.apache.ignite.ml.math.impls.storage.vector.SparseDistributedVectorStorageTest;
import org.apache.ignite.ml.math.impls.vector.CacheVectorTest;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVectorTest;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVectorTest;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -30,11 +33,14 @@ import org.junit.runners.Suite;
*/
@RunWith(Suite.class)
@Suite.SuiteClasses({
- CacheVectorTest.class,
- CacheMatrixTest.class,
- SparseDistributedMatrixStorageTest.class,
- SparseDistributedMatrixTest.class,
- SparseDistributedBlockMatrixTest.class
+ CacheVectorTest.class,
+ CacheMatrixTest.class,
+ SparseDistributedMatrixStorageTest.class,
+ SparseDistributedMatrixTest.class,
+ SparseDistributedBlockMatrixTest.class,
+ SparseDistributedVectorStorageTest.class,
+ SparseDistributedVectorTest.class,
+ SparseBlockDistributedVectorTest.class
})
public class MathImplDistributedTestSuite {
// No-op.
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java
index fd6ed78..b4f5c2d 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java
@@ -17,26 +17,23 @@
package org.apache.ignite.ml.math.impls.matrix;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.util.Collection;
-import java.util.Set;
-import java.util.UUID;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distributed.DistributedStorage;
-import org.apache.ignite.ml.math.distributed.keys.impl.BlockMatrixKey;
-import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
+import org.apache.ignite.ml.math.distributed.keys.impl.MatrixBlockKey;
import org.apache.ignite.ml.math.impls.MathTestConstants;
import org.apache.ignite.ml.math.impls.storage.matrix.BlockMatrixStorage;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
import org.apache.ignite.testframework.junits.common.GridCommonTest;
+import java.io.*;
+import java.util.Collection;
+import java.util.Set;
+
import static org.apache.ignite.ml.math.impls.MathTestConstants.UNEXPECTED_VAL;
/**
@@ -207,14 +204,14 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
cacheMatrix = new SparseBlockDistributedMatrix(rows, cols);
- try {
- cacheMatrix.copy();
- fail("UnsupportedOperationException expected.");
- }
- catch (UnsupportedOperationException e) {
- return;
- }
- fail("UnsupportedOperationException expected.");
+ cacheMatrix.set(rows-1, cols -1, 1);
+
+
+ Matrix newMatrix = cacheMatrix.copy();
+ assert newMatrix.columnSize() == cols;
+ assert newMatrix.rowSize() == rows;
+ assert newMatrix.get(rows-1,cols-1) == 1;
+
}
/** Test cache behaviour for matrix with different blocks */
@@ -226,7 +223,7 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
/** Test cache behaviour for matrix with homogeneous blocks */
public void testCacheBehaviourWithHomogeneousBlocks(){
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- int size = BlockEntry.MAX_BLOCK_SIZE * 3;
+ int size = MatrixBlockEntry.MAX_BLOCK_SIZE * 3;
cacheBehaviorLogic(size);
}
@@ -242,10 +239,10 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
assert cacheNames.contains(((DistributedStorage)cacheMatrix1.getStorage()).cacheName());
- IgniteCache<BlockMatrixKey, Object> cache = ignite.getOrCreateCache(((DistributedStorage)cacheMatrix1.getStorage()).cacheName());
+ IgniteCache<MatrixBlockKey, Object> cache = ignite.getOrCreateCache(((DistributedStorage)cacheMatrix1.getStorage()).cacheName());
- Set<BlockMatrixKey> keySet1 = buildKeySet(cacheMatrix1);
- Set<BlockMatrixKey> keySet2 = buildKeySet(cacheMatrix2);
+ Set<MatrixBlockKey> keySet1 = buildKeySet(cacheMatrix1);
+ Set<MatrixBlockKey> keySet2 = buildKeySet(cacheMatrix2);
assert cache.containsKeys(keySet1);
assert cache.containsKeys(keySet2);
@@ -275,14 +272,10 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
cacheMatrix = new SparseBlockDistributedMatrix(rows, cols);
- try {
- cacheMatrix.likeVector(1);
- fail("UnsupportedOperationException expected.");
- }
- catch (UnsupportedOperationException e) {
- return;
- }
- fail("UnsupportedOperationException expected.");
+ Vector v = cacheMatrix.likeVector(1);
+ assert v.size() == 1;
+ assert v instanceof SparseBlockDistributedVector;
+
}
/**
@@ -298,7 +291,7 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
* Simple test for two square matrices with size which is proportional to MAX_BLOCK_SIZE constant
*/
public void testSquareMatrixTimesWithHomogeneousBlocks(){
- int size = BlockEntry.MAX_BLOCK_SIZE * 3;
+ int size = MatrixBlockEntry.MAX_BLOCK_SIZE * 3;
squareMatrixTimesLogic(size);
}
@@ -331,8 +324,8 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
public void testNonSquareMatrixTimes(){
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- int size = BlockEntry.MAX_BLOCK_SIZE + 1;
- int size2 = BlockEntry.MAX_BLOCK_SIZE * 2 + 1;
+ int size = MatrixBlockEntry.MAX_BLOCK_SIZE + 1;
+ int size2 = MatrixBlockEntry.MAX_BLOCK_SIZE * 2 + 1;
Matrix cacheMatrix1 = new SparseBlockDistributedMatrix(size2, size);
Matrix cacheMatrix2 = new SparseBlockDistributedMatrix(size, size2);
@@ -358,8 +351,8 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
public void testNonSquareMatrixTimes2(){
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- int size = BlockEntry.MAX_BLOCK_SIZE + 1;
- int size2 = BlockEntry.MAX_BLOCK_SIZE * 2 + 1;
+ int size = MatrixBlockEntry.MAX_BLOCK_SIZE + 1;
+ int size2 = MatrixBlockEntry.MAX_BLOCK_SIZE * 2 + 1;
Matrix cacheMatrix1 = new SparseBlockDistributedMatrix(size, size2);
Matrix cacheMatrix2 = new SparseBlockDistributedMatrix(size2, size);
@@ -379,6 +372,26 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
assertEquals(UNEXPECTED_VAL + " for "+ i +":"+ j, 0, res.get(i, j), PRECISION);
}
+ public void testMatrixVectorTimes(){
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ SparseBlockDistributedMatrix a = new SparseBlockDistributedMatrix(new double[][] {{2.0, 4.0, 0.0}, {-2.0, 1.0, 3.0}, {-1.0, 0.0, 1.0}});
+ SparseBlockDistributedVector b = new SparseBlockDistributedVector(new double[] {1.0, 2.0, -1.0});
+ SparseBlockDistributedVector result = new SparseBlockDistributedVector(new double[] {10, -3.0, -2.0});
+
+
+ Vector calculatedResult = a.times(b);
+
+ for(int i = 0; i < calculatedResult.size(); i++)
+ assertEquals(UNEXPECTED_VAL + " for "+ i, result.get(i), calculatedResult.get(i), PRECISION);
+
+
+ }
+
+
+
+
+
/** */
private void initMtx(Matrix m) {
for (int i = 0; i < m.rowSize(); i++)
@@ -387,7 +400,8 @@ public class SparseDistributedBlockMatrixTest extends GridCommonAbstractTest {
}
/** Build key set for SparseBlockDistributedMatrix. */
- private Set<BlockMatrixKey> buildKeySet(SparseBlockDistributedMatrix m){
+ private Set<MatrixBlockKey> buildKeySet(SparseBlockDistributedMatrix m){
+
BlockMatrixStorage storage = (BlockMatrixStorage)m.getStorage();
return storage.getAllKeys();
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrixTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrixTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrixTest.java
index 1955588..a9343f4 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrixTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrixTest.java
@@ -30,11 +30,14 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distributed.DistributedStorage;
import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.math.impls.MathTestConstants;
import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
import org.apache.ignite.testframework.junits.common.GridCommonTest;
@@ -212,14 +215,13 @@ public class SparseDistributedMatrixTest extends GridCommonAbstractTest {
cacheMatrix = new SparseDistributedMatrix(rows, cols, StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
- try {
- cacheMatrix.copy();
- fail("UnsupportedOperationException expected.");
- }
- catch (UnsupportedOperationException e) {
- return;
+ Matrix copiedMtx = cacheMatrix.copy();
+
+ for (int i = 0; i < cacheMatrix.rowSize(); i++) {
+ for (int j = 0; j < cacheMatrix.columnSize(); j++) {
+ assert copiedMtx.get(i,j) == cacheMatrix.get(i,j);
+ }
}
- fail("UnsupportedOperationException expected.");
}
/** */
@@ -274,14 +276,9 @@ public class SparseDistributedMatrixTest extends GridCommonAbstractTest {
cacheMatrix = new SparseDistributedMatrix(rows, cols, StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
- try {
- cacheMatrix.likeVector(1);
- fail("UnsupportedOperationException expected.");
- }
- catch (UnsupportedOperationException e) {
- return;
- }
- fail("UnsupportedOperationException expected.");
+ Vector v = cacheMatrix.likeVector(1);
+ assert v.size() == 1;
+ assert v instanceof SparseDistributedVector;
}
/** */
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorageTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorageTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorageTest.java
new file mode 100644
index 0000000..9b6aa32
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/SparseDistributedVectorStorageTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.storage.vector;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.impls.MathTestConstants;
+import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.apache.ignite.testframework.junits.common.GridCommonTest;
+
+/**
+ * Tests for {@link SparseDistributedVectorStorage}.
+ */
+@GridCommonTest(group = "Distributed Models")
+public class SparseDistributedVectorStorageTest extends GridCommonAbstractTest {
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 3;
+ /** Cache name. */
+ private static final String CACHE_NAME = "test-cache";
+ /** */
+ private static final String UNEXPECTED_ATTRIBUTE_VALUE = "Unexpected attribute value.";
+ /** Grid instance. */
+ private Ignite ignite;
+
+ /**
+ * Default constructor.
+ */
+ public SparseDistributedVectorStorageTest() {
+ super(false);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTest() throws Exception {
+ ignite.destroyCache(CACHE_NAME);
+ }
+
+ /** */
+ public void testCacheCreation() throws Exception {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ final int size = MathTestConstants.STORAGE_SIZE;
+
+ SparseDistributedVectorStorage storage = new SparseDistributedVectorStorage(size, StorageConstants.RANDOM_ACCESS_MODE);
+
+ assertNotNull("SparseDistributedMatrixStorage cache is null.", storage.cache());
+ }
+
+ /** */
+ public void testSetGet() throws Exception {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ final int size = MathTestConstants.STORAGE_SIZE;
+
+ SparseDistributedVectorStorage storage = new SparseDistributedVectorStorage(size, StorageConstants.RANDOM_ACCESS_MODE);
+
+ for (int i = 0; i < size; i++) {
+ double v = Math.random();
+ storage.set(i, v);
+
+ assert Double.compare(v, storage.get(i)) == 0;
+ assert Double.compare(v, storage.get(i)) == 0;
+ }
+ }
+
+ /** */
+ public void testAttributes() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ final int size = MathTestConstants.STORAGE_SIZE;
+
+ SparseDistributedVectorStorage storage = new SparseDistributedVectorStorage(size, StorageConstants.RANDOM_ACCESS_MODE);
+
+ assertEquals(UNEXPECTED_ATTRIBUTE_VALUE, storage.size(), size);
+
+ assertFalse(UNEXPECTED_ATTRIBUTE_VALUE, storage.isArrayBased());
+ assertFalse(UNEXPECTED_ATTRIBUTE_VALUE, storage.isDense());
+ assertTrue(UNEXPECTED_ATTRIBUTE_VALUE, storage.isDistributed());
+
+ assertEquals(UNEXPECTED_ATTRIBUTE_VALUE, storage.isRandomAccess(), !storage.isSequentialAccess());
+ assertTrue(UNEXPECTED_ATTRIBUTE_VALUE, storage.isRandomAccess());
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVectorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVectorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVectorTest.java
new file mode 100644
index 0000000..4ac2845
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseBlockDistributedVectorTest.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.vector;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.MathTestConstants;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.apache.ignite.testframework.junits.common.GridCommonTest;
+
+import java.io.*;
+
+import static org.apache.ignite.ml.math.impls.MathTestConstants.UNEXPECTED_VAL;
+
+/**
+ * Tests for {@link SparseDistributedVector}.
+ */
+@GridCommonTest(group = "Distributed Models")
+public class SparseBlockDistributedVectorTest extends GridCommonAbstractTest {
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 3;
+ /** Precision. */
+ private static final double PRECISION = 0.0;
+ /** Grid instance. */
+ private Ignite ignite;
+ /** Vector size */
+ private final int size = MathTestConstants.STORAGE_SIZE;
+ /** Vector for tests */
+ private SparseBlockDistributedVector sparseBlockDistributedVector;
+
+ /**
+ * Default constructor.
+ */
+ public SparseBlockDistributedVectorTest() {
+ super(false);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTest() throws Exception {
+ if (sparseBlockDistributedVector != null) {
+ sparseBlockDistributedVector.destroy();
+ sparseBlockDistributedVector = null;
+ }
+ }
+
+ /** */
+ public void testGetSet() throws Exception {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseBlockDistributedVector = new SparseBlockDistributedVector(size);
+
+ for (int i = 0; i < size; i++) {
+ double v = Math.random();
+ sparseBlockDistributedVector.set(i, v);
+ assertEquals("Unexpected value for vector element[" + i + "]", v, sparseBlockDistributedVector.get(i), PRECISION);
+ }
+ }
+
+ /** */
+ public void testExternalize() throws IOException, ClassNotFoundException {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseBlockDistributedVector = new SparseBlockDistributedVector(size);
+
+ sparseBlockDistributedVector.set(1, 1.0);
+
+ ByteArrayOutputStream byteArrOutputStream = new ByteArrayOutputStream();
+ ObjectOutputStream objOutputStream = new ObjectOutputStream(byteArrOutputStream);
+
+ objOutputStream.writeObject(sparseBlockDistributedVector);
+
+ ByteArrayInputStream byteArrInputStream = new ByteArrayInputStream(byteArrOutputStream.toByteArray());
+ ObjectInputStream objInputStream = new ObjectInputStream(byteArrInputStream);
+
+ SparseBlockDistributedVector objRestored = (SparseBlockDistributedVector)objInputStream.readObject();
+
+ assertTrue(MathTestConstants.VAL_NOT_EQUALS, sparseBlockDistributedVector.equals(objRestored));
+ assertEquals(MathTestConstants.VAL_NOT_EQUALS, objRestored.get(1), 1.0, PRECISION);
+ }
+
+ /** Test simple math. */
+ public void testMath() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseBlockDistributedVector = new SparseBlockDistributedVector(size);
+ initVector(sparseBlockDistributedVector);
+
+ sparseBlockDistributedVector.assign(2.0);
+ for (int i = 0; i < sparseBlockDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 2.0, sparseBlockDistributedVector.get(i), PRECISION);
+
+ sparseBlockDistributedVector.plus(3.0);
+ for (int i = 0; i < sparseBlockDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 5.0, sparseBlockDistributedVector.get(i), PRECISION);
+
+ sparseBlockDistributedVector.times(2.0);
+ for (int i = 0; i < sparseBlockDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 10.0, sparseBlockDistributedVector.get(i), PRECISION);
+
+ sparseBlockDistributedVector.divide(10.0);
+ for (int i = 0; i < sparseBlockDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 1.0, sparseBlockDistributedVector.get(i), PRECISION);
+ }
+
+
+ /** */
+ public void testMap() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseBlockDistributedVector = new SparseBlockDistributedVector(size);
+ initVector(sparseBlockDistributedVector);
+
+ sparseBlockDistributedVector.map(i -> 100.0);
+ for (int i = 0; i < sparseBlockDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 100.0, sparseBlockDistributedVector.get(i), PRECISION);
+ }
+
+ /** */
+ public void testCopy() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseBlockDistributedVector = new SparseBlockDistributedVector(size);
+
+ Vector cp = sparseBlockDistributedVector.copy();
+ assertNotNull(cp);
+ for (int i = 0; i < size; i++)
+ assertEquals(UNEXPECTED_VAL, cp.get(i), sparseBlockDistributedVector.get(i), PRECISION);
+ }
+
+ /** */
+ public void testLike() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseBlockDistributedVector = new SparseBlockDistributedVector(size);
+
+ assertNotNull(sparseBlockDistributedVector.like(1));
+ }
+
+
+ /** */
+ private void initVector(Vector v) {
+ for (int i = 0; i < v.size(); i++)
+ v.set(i, 1.0);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVectorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVectorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVectorTest.java
new file mode 100644
index 0000000..416e254
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/SparseDistributedVectorTest.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.vector;
+
+import org.apache.ignite.Ignite;
+
+import org.apache.ignite.internal.util.IgniteUtils;
+
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.Vector;
+
+import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
+import org.apache.ignite.ml.math.impls.MathTestConstants;
+
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.apache.ignite.testframework.junits.common.GridCommonTest;
+import org.junit.Ignore;
+
+import java.io.*;
+
+
+import static org.apache.ignite.ml.math.impls.MathTestConstants.UNEXPECTED_VAL;
+
+/**
+ * Tests for {@link SparseDistributedVector}.
+ */
+@GridCommonTest(group = "Distributed Models")
+public class SparseDistributedVectorTest extends GridCommonAbstractTest {
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 3;
+ /** Precision. */
+ private static final double PRECISION = 0.0;
+ /** Grid instance. */
+ private Ignite ignite;
+ /** Vector size */
+ private final int size = MathTestConstants.STORAGE_SIZE;
+ /** Vector for tests */
+ private SparseDistributedVector sparseDistributedVector;
+
+ /**
+ * Default constructor.
+ */
+ public SparseDistributedVectorTest() {
+ super(false);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTest() throws Exception {
+ if (sparseDistributedVector != null) {
+ sparseDistributedVector.destroy();
+ sparseDistributedVector = null;
+ }
+ }
+
+ /** */
+ public void testGetSet() throws Exception {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseDistributedVector = new SparseDistributedVector(size, StorageConstants.RANDOM_ACCESS_MODE);
+
+ for (int i = 0; i < size; i++) {
+ double v = Math.random();
+ sparseDistributedVector.set(i, v);
+ assertEquals("Unexpected value for vector element[" + i + "]", v, sparseDistributedVector.get(i), PRECISION);
+ }
+ }
+
+ /** */
+ public void testExternalize() throws IOException, ClassNotFoundException {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseDistributedVector = new SparseDistributedVector(size, StorageConstants.RANDOM_ACCESS_MODE);
+
+ sparseDistributedVector.set(1, 1.0);
+
+ ByteArrayOutputStream byteArrOutputStream = new ByteArrayOutputStream();
+ ObjectOutputStream objOutputStream = new ObjectOutputStream(byteArrOutputStream);
+
+ objOutputStream.writeObject(sparseDistributedVector);
+
+ ByteArrayInputStream byteArrInputStream = new ByteArrayInputStream(byteArrOutputStream.toByteArray());
+ ObjectInputStream objInputStream = new ObjectInputStream(byteArrInputStream);
+
+ SparseDistributedVector objRestored = (SparseDistributedVector)objInputStream.readObject();
+
+ assertTrue(MathTestConstants.VAL_NOT_EQUALS, sparseDistributedVector.equals(objRestored));
+ assertEquals(MathTestConstants.VAL_NOT_EQUALS, objRestored.get(1), 1.0, PRECISION);
+ }
+
+ /** Test simple math. */
+ public void testMath() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseDistributedVector = new SparseDistributedVector(size, StorageConstants.RANDOM_ACCESS_MODE);
+ initVector(sparseDistributedVector);
+
+ sparseDistributedVector.assign(2.0);
+ for (int i = 0; i < sparseDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 2.0, sparseDistributedVector.get(i), PRECISION);
+
+ sparseDistributedVector.plus(3.0);
+ for (int i = 0; i < sparseDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 5.0, sparseDistributedVector.get(i), PRECISION);
+
+ sparseDistributedVector.times(2.0);
+ for (int i = 0; i < sparseDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 10.0, sparseDistributedVector.get(i), PRECISION);
+
+ sparseDistributedVector.divide(10.0);
+ for (int i = 0; i < sparseDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 1.0, sparseDistributedVector.get(i), PRECISION);
+
+ // assertEquals(UNEXPECTED_VAL, sparseDistributedVector.rowSize() * sparseDistributedVector.columnSize(), sparseDistributedVector.sum(), PRECISION);
+ }
+
+
+ /** */
+ public void testMap() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseDistributedVector = new SparseDistributedVector(size, StorageConstants.RANDOM_ACCESS_MODE);
+ initVector(sparseDistributedVector);
+
+ sparseDistributedVector.map(i -> 100.0);
+ for (int i = 0; i < sparseDistributedVector.size(); i++)
+ assertEquals(UNEXPECTED_VAL, 100.0, sparseDistributedVector.get(i), PRECISION);
+ }
+
+ /** */
+ public void testCopy() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseDistributedVector = new SparseDistributedVector(size, StorageConstants.RANDOM_ACCESS_MODE);
+
+ Vector copy = sparseDistributedVector.copy();
+ assertNotNull(copy);
+ for (int i = 0; i < size; i++)
+ assertEquals(UNEXPECTED_VAL, copy.get(i), sparseDistributedVector.get(i), PRECISION);
+ }
+
+
+ /** */
+ public void testLike() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+
+ sparseDistributedVector = new SparseDistributedVector(size, StorageConstants.RANDOM_ACCESS_MODE);
+
+ assertNotNull(sparseDistributedVector.like(1));
+ }
+
+
+ /** */
+ private void initVector(Vector v) {
+ for (int i = 0; i < v.size(); i++)
+ v.set(i, 1.0);
+ }
+}
[3/3] ignite git commit: IGNITE-5846 Add support of distributed
matrices for OLS regression. This closes #3030.
Posted by nt...@apache.org.
IGNITE-5846 Add support of distributed matrices for OLS regression. This closes #3030.
Signed-off-by: nikolay_tikhonov <nt...@gridgain.com>
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/b0a86018
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/b0a86018
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/b0a86018
Branch: refs/heads/master
Commit: b0a86018693581065f789635facb88b1e8dac834
Parents: cbd7e39
Author: YuriBabak <y....@gmail.com>
Authored: Fri Nov 17 16:06:34 2017 +0300
Committer: nikolay_tikhonov <nt...@gridgain.com>
Committed: Fri Nov 17 16:07:51 2017 +0300
----------------------------------------------------------------------
.../clustering/KMeansDistributedClusterer.java | 13 +-
.../ignite/ml/math/distributed/CacheUtils.java | 278 +++---
.../math/distributed/keys/BlockMatrixKey.java | 38 -
.../distributed/keys/DataStructureCacheKey.java | 35 +
.../math/distributed/keys/MatrixBlockKey.java | 38 +
.../math/distributed/keys/MatrixCacheKey.java | 35 -
.../math/distributed/keys/RowColMatrixKey.java | 2 +-
.../math/distributed/keys/VectorBlockKey.java | 34 +
.../distributed/keys/impl/BlockMatrixKey.java | 164 ----
.../distributed/keys/impl/MatrixBlockKey.java | 162 ++++
.../distributed/keys/impl/SparseMatrixKey.java | 14 +-
.../distributed/keys/impl/VectorBlockKey.java | 151 +++
.../ignite/ml/math/functions/Functions.java | 3 +-
.../ml/math/impls/matrix/AbstractMatrix.java | 24 +-
.../ignite/ml/math/impls/matrix/BlockEntry.java | 50 -
.../ml/math/impls/matrix/MatrixBlockEntry.java | 50 +
.../matrix/SparseBlockDistributedMatrix.java | 153 ++-
.../impls/matrix/SparseDistributedMatrix.java | 102 +-
.../storage/matrix/BlockMatrixStorage.java | 96 +-
.../storage/matrix/BlockVectorStorage.java | 374 ++++++++
.../impls/storage/matrix/MapWrapperStorage.java | 5 +-
.../matrix/SparseDistributedMatrixStorage.java | 21 +-
.../vector/SparseDistributedVectorStorage.java | 280 ++++++
.../vector/SparseBlockDistributedVector.java | 139 +++
.../impls/vector/SparseDistributedVector.java | 157 ++++
.../ml/math/impls/vector/VectorBlockEntry.java | 49 +
.../apache/ignite/ml/math/util/MatrixUtil.java | 4 +-
.../ml/math/MathImplDistributedTestSuite.java | 16 +-
.../SparseDistributedBlockMatrixTest.java | 86 +-
.../matrix/SparseDistributedMatrixTest.java | 27 +-
.../SparseDistributedVectorStorageTest.java | 121 +++
.../SparseBlockDistributedVectorTest.java | 181 ++++
.../vector/SparseDistributedVectorTest.java | 192 ++++
...tedBlockOLSMultipleLinearRegressionTest.java | 926 ++++++++++++++++++
...tributedOLSMultipleLinearRegressionTest.java | 934 +++++++++++++++++++
.../OLSMultipleLinearRegressionTest.java | 1 +
.../ml/regressions/RegressionsTestSuite.java | 2 +-
37 files changed, 4351 insertions(+), 606 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java
index 6c25edc..4286f42 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java
@@ -196,12 +196,11 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri
return list;
},
- key -> key.matrixId().equals(uid),
+ key -> key.dataStructureId().equals(uid),
(list1, list2) -> {
list1.addAll(list2);
return list1;
- },
- ArrayList::new
+ }, ArrayList::new
);
}
@@ -216,7 +215,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri
return map;
},
- key -> key.matrixId().equals(points.getUUID()),
+ key -> key.dataStructureId().equals(points.getUUID()),
(map1, map2) -> {
map1.putAll(map2);
return map1;
@@ -247,10 +246,10 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri
countMap.compute(resInd, (ind, v) -> v != null ? v + 1 : 1);
return countMap;
},
- key -> key.matrixId().equals(uid),
+ key -> key.dataStructureId().equals(uid),
(map1, map2) -> MapUtil.mergeMaps(map1, map2, (integer, integer2) -> integer2 + integer,
ConcurrentHashMap::new),
- ConcurrentHashMap::new);
+ ConcurrentHashMap::new);
}
/** */
@@ -278,7 +277,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri
return counts;
},
- key -> key.matrixId().equals(uid),
+ key -> key.dataStructureId().equals(uid),
SumsAndCounts::merge, SumsAndCounts::new
);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
index b9eb386..37384b8 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
@@ -17,16 +17,6 @@
package org.apache.ignite.ml.math.distributed;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.UUID;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.BinaryOperator;
-import java.util.stream.Stream;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -40,18 +30,21 @@ import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.lang.IgniteCallable;
import org.apache.ignite.lang.IgnitePredicate;
import org.apache.ignite.lang.IgniteRunnable;
-import org.apache.ignite.lang.IgniteUuid;
import org.apache.ignite.ml.math.KeyMapper;
-import org.apache.ignite.ml.math.distributed.keys.BlockMatrixKey;
-import org.apache.ignite.ml.math.distributed.keys.MatrixCacheKey;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
-import org.apache.ignite.ml.math.functions.IgniteConsumer;
-import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
-import org.apache.ignite.ml.math.impls.matrix.BlockEntry;
+import org.apache.ignite.ml.math.distributed.keys.DataStructureCacheKey;
+import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
+import org.apache.ignite.ml.math.distributed.keys.impl.MatrixBlockKey;
+import org.apache.ignite.ml.math.distributed.keys.impl.VectorBlockKey;
+import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
+import org.apache.ignite.ml.math.functions.*;
+import org.apache.ignite.ml.math.impls.matrix.MatrixBlockEntry;
+import org.apache.ignite.ml.math.impls.vector.VectorBlockEntry;
+
+import javax.cache.Cache;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.BinaryOperator;
+import java.util.stream.Stream;
/**
* Distribution-related misc. support.
@@ -104,11 +97,11 @@ public class CacheUtils {
/**
* @param cacheName Cache name.
- * @param k Key into the cache.
- * @param <K> Key type.
+ * @param k Key into the cache.
+ * @param <K> Key type.
* @return Cluster group for given key.
*/
- public static <K> ClusterGroup groupForKey(String cacheName, K k) {
+ protected static <K> ClusterGroup getClusterGroupForGivenKey(String cacheName, K k) {
return ignite().cluster().forNode(ignite().affinity(cacheName).mapKeyToNode(k));
}
@@ -116,8 +109,8 @@ public class CacheUtils {
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
* @return Sum of the values obtained for valid keys.
*/
public static <K, V> double sum(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper) {
@@ -126,8 +119,7 @@ public class CacheUtils {
double v = valMapper.toDouble(ce.entry().getValue());
return acc == null ? v : acc + v;
- }
- else
+ } else
return acc;
});
@@ -146,19 +138,17 @@ public class CacheUtils {
Collection<Double> subSums = fold(cacheName, (CacheEntry<K, V> ce, Double acc) -> {
V v = ce.entry().getValue();
- double sum = 0.0;
+ double sum;
if (v instanceof Map) {
- Map<Integer, Double> map = (Map<Integer, Double>)v;
+ Map<Integer, Double> map = (Map<Integer, Double>) v;
sum = sum(map.values());
- }
- else if (v instanceof BlockEntry) {
- BlockEntry be = (BlockEntry)v;
+ } else if (v instanceof MatrixBlockEntry) {
+ MatrixBlockEntry be = (MatrixBlockEntry) v;
sum = be.sum();
- }
- else
+ } else
throw new UnsupportedOperationException();
return acc == null ? sum : acc + sum;
@@ -180,8 +170,8 @@ public class CacheUtils {
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
* @return Minimum value for valid keys.
*/
public static <K, V> double min(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper) {
@@ -193,8 +183,7 @@ public class CacheUtils {
return v;
else
return Math.min(acc, v);
- }
- else
+ } else
return acc;
});
@@ -216,16 +205,14 @@ public class CacheUtils {
double min;
if (v instanceof Map) {
- Map<Integer, Double> map = (Map<Integer, Double>)v;
+ Map<Integer, Double> map = (Map<Integer, Double>) v;
min = Collections.min(map.values());
- }
- else if (v instanceof BlockEntry) {
- BlockEntry be = (BlockEntry)v;
+ } else if (v instanceof MatrixBlockEntry) {
+ MatrixBlockEntry be = (MatrixBlockEntry) v;
min = be.minValue();
- }
- else
+ } else
throw new UnsupportedOperationException();
if (acc == null)
@@ -253,16 +240,14 @@ public class CacheUtils {
double max;
if (v instanceof Map) {
- Map<Integer, Double> map = (Map<Integer, Double>)v;
+ Map<Integer, Double> map = (Map<Integer, Double>) v;
max = Collections.max(map.values());
- }
- else if (v instanceof BlockEntry) {
- BlockEntry be = (BlockEntry)v;
+ } else if (v instanceof MatrixBlockEntry) {
+ MatrixBlockEntry be = (MatrixBlockEntry) v;
max = be.maxValue();
- }
- else
+ } else
throw new UnsupportedOperationException();
if (acc == null)
@@ -279,8 +264,8 @@ public class CacheUtils {
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
* @return Maximum value for valid keys.
*/
public static <K, V> double max(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper) {
@@ -292,8 +277,7 @@ public class CacheUtils {
return v;
else
return Math.max(acc, v);
- }
- else
+ } else
return acc;
});
@@ -304,12 +288,12 @@ public class CacheUtils {
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
- * @param mapper Mapping {@link IgniteFunction}.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param mapper Mapping {@link IgniteFunction}.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
*/
public static <K, V> void map(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper,
- IgniteFunction<Double, Double> mapper) {
+ IgniteFunction<Double, Double> mapper) {
foreach(cacheName, (CacheEntry<K, V> ce) -> {
K k = ce.entry().getKey();
@@ -321,7 +305,7 @@ public class CacheUtils {
/**
* @param matrixUuid Matrix UUID.
- * @param mapper Mapping {@link IgniteFunction}.
+ * @param mapper Mapping {@link IgniteFunction}.
*/
@SuppressWarnings("unchecked")
public static <K, V> void sparseMap(UUID matrixUuid, IgniteDoubleFunction<Double> mapper, String cacheName) {
@@ -335,18 +319,16 @@ public class CacheUtils {
V v = ce.entry().getValue();
if (v instanceof Map) {
- Map<Integer, Double> map = (Map<Integer, Double>)v;
+ Map<Integer, Double> map = (Map<Integer, Double>) v;
for (Map.Entry<Integer, Double> e : (map.entrySet()))
e.setValue(mapper.apply(e.getValue()));
- }
- else if (v instanceof BlockEntry) {
- BlockEntry be = (BlockEntry)v;
+ } else if (v instanceof MatrixBlockEntry) {
+ MatrixBlockEntry be = (MatrixBlockEntry) v;
be.map(mapper);
- }
- else
+ } else
throw new UnsupportedOperationException();
ce.cache().put(k, v);
@@ -360,34 +342,40 @@ public class CacheUtils {
*/
private static <K> IgnitePredicate<K> sparseKeyFilter(UUID matrixUuid) {
return key -> {
- if (key instanceof MatrixCacheKey)
- return ((MatrixCacheKey)key).matrixId().equals(matrixUuid);
+ if (key instanceof DataStructureCacheKey)
+ return ((DataStructureCacheKey) key).dataStructureId().equals(matrixUuid);
else if (key instanceof IgniteBiTuple)
- return ((IgniteBiTuple<Integer, UUID>)key).get2().equals(matrixUuid);
+ return ((IgniteBiTuple<Integer, UUID>) key).get2().equals(matrixUuid);
+ else if (key instanceof MatrixBlockKey)
+ return ((MatrixBlockKey) key).dataStructureId().equals(matrixUuid);
+ else if (key instanceof RowColMatrixKey)
+ return ((RowColMatrixKey) key).dataStructureId().equals(matrixUuid);
+ else if (key instanceof VectorBlockKey)
+ return ((VectorBlockKey) key).dataStructureId().equals(matrixUuid);
else
- throw new UnsupportedOperationException();
+ throw new UnsupportedOperationException(); // TODO: handle my poor doubles
};
}
/**
* @param cacheName Cache name.
- * @param fun An operation that accepts a cache entry and processes it.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param fun An operation that accepts a cache entry and processes it.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
*/
- public static <K, V> void foreach(String cacheName, IgniteConsumer<CacheEntry<K, V>> fun) {
+ private static <K, V> void foreach(String cacheName, IgniteConsumer<CacheEntry<K, V>> fun) {
foreach(cacheName, fun, null);
}
/**
* @param cacheName Cache name.
- * @param fun An operation that accepts a cache entry and processes it.
+ * @param fun An operation that accepts a cache entry and processes it.
* @param keyFilter Cache keys filter.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
*/
- public static <K, V> void foreach(String cacheName, IgniteConsumer<CacheEntry<K, V>> fun,
- IgnitePredicate<K> keyFilter) {
+ protected static <K, V> void foreach(String cacheName, IgniteConsumer<CacheEntry<K, V>> fun,
+ IgnitePredicate<K> keyFilter) {
bcast(cacheName, () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<K, V> cache = ignite.getOrCreateCache(cacheName);
@@ -405,7 +393,7 @@ public class CacheUtils {
// Iterate over given partition.
// Query returns an empty cursor if this partition is not stored on this node.
for (Cache.Entry<K, V> entry : cache.query(new ScanQuery<K, V>(part,
- (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
+ (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
fun.accept(new CacheEntry<>(entry, cache));
}
});
@@ -413,14 +401,14 @@ public class CacheUtils {
/**
* @param cacheName Cache name.
- * @param fun An operation that accepts a cache entry and processes it.
- * @param ignite Ignite.
- * @param keysGen Keys generator.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param fun An operation that accepts a cache entry and processes it.
+ * @param ignite Ignite.
+ * @param keysGen Keys generator.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
*/
public static <K, V> void update(String cacheName, Ignite ignite,
- IgniteBiFunction<Ignite, Cache.Entry<K, V>, Stream<Cache.Entry<K, V>>> fun, IgniteSupplier<Set<K>> keysGen) {
+ IgniteBiFunction<Ignite, Cache.Entry<K, V>, Stream<Cache.Entry<K, V>>> fun, IgniteSupplier<Set<K>> keysGen) {
bcast(cacheName, ignite, () -> {
Ignite ig = Ignition.localIgnite();
IgniteCache<K, V> cache = ig.getOrCreateCache(cacheName);
@@ -447,14 +435,14 @@ public class CacheUtils {
/**
* @param cacheName Cache name.
- * @param fun An operation that accepts a cache entry and processes it.
- * @param ignite Ignite.
- * @param keysGen Keys generator.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
+ * @param fun An operation that accepts a cache entry and processes it.
+ * @param ignite Ignite.
+ * @param keysGen Keys generator.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
*/
public static <K, V> void update(String cacheName, Ignite ignite, IgniteConsumer<Cache.Entry<K, V>> fun,
- IgniteSupplier<Set<K>> keysGen) {
+ IgniteSupplier<Set<K>> keysGen) {
bcast(cacheName, ignite, () -> {
Ignite ig = Ignition.localIgnite();
IgniteCache<K, V> cache = ig.getOrCreateCache(cacheName);
@@ -485,10 +473,10 @@ public class CacheUtils {
* <b>Currently fold supports only commutative operations.<b/>
*
* @param cacheName Cache name.
- * @param folder Fold function operating over cache entries.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
- * @param <A> Fold result type.
+ * @param folder Fold function operating over cache entries.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
+ * @param <A> Fold result type.
* @return Fold operation result.
*/
public static <K, V, A> Collection<A> fold(String cacheName, IgniteBiFunction<CacheEntry<K, V>, A, A> folder) {
@@ -499,14 +487,14 @@ public class CacheUtils {
* <b>Currently fold supports only commutative operations.<b/>
*
* @param cacheName Cache name.
- * @param folder Fold function operating over cache entries.
- * @param <K> Cache key object type.
- * @param <V> Cache value object type.
- * @param <A> Fold result type.
+ * @param folder Fold function operating over cache entries.
+ * @param <K> Cache key object type.
+ * @param <V> Cache value object type.
+ * @param <A> Fold result type.
* @return Fold operation result.
*/
public static <K, V, A> Collection<A> fold(String cacheName, IgniteBiFunction<CacheEntry<K, V>, A, A> folder,
- IgnitePredicate<K> keyFilter) {
+ IgnitePredicate<K> keyFilter) {
return bcast(cacheName, () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<K, V> cache = ignite.getOrCreateCache(cacheName);
@@ -526,7 +514,7 @@ public class CacheUtils {
// Iterate over given partition.
// Query returns an empty cursor if this partition is not stored on this node.
for (Cache.Entry<K, V> entry : cache.query(new ScanQuery<K, V>(part,
- (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
+ (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
a = folder.apply(new CacheEntry<>(entry, cache), a);
}
@@ -537,34 +525,34 @@ public class CacheUtils {
/**
* Distributed version of fold operation.
*
- * @param cacheName Cache name.
- * @param folder Folder.
- * @param keyFilter Key filter.
+ * @param cacheName Cache name.
+ * @param folder Folder.
+ * @param keyFilter Key filter.
* @param accumulator Accumulator.
* @param zeroValSupp Zero value supplier.
*/
public static <K, V, A> A distributedFold(String cacheName, IgniteBiFunction<Cache.Entry<K, V>, A, A> folder,
- IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, IgniteSupplier<A> zeroValSupp) {
+ IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, IgniteSupplier<A> zeroValSupp) {
return sparseFold(cacheName, folder, keyFilter, accumulator, zeroValSupp, null, null, 0,
- false);
+ false);
}
/**
* Sparse version of fold. This method also applicable to sparse zeroes.
*
- * @param cacheName Cache name.
- * @param folder Folder.
- * @param keyFilter Key filter.
+ * @param cacheName Cache name.
+ * @param folder Folder.
+ * @param keyFilter Key filter.
* @param accumulator Accumulator.
* @param zeroValSupp Zero value supplier.
- * @param defVal Default value.
- * @param defKey Default key.
- * @param defValCnt Def value count.
+ * @param defVal Default value.
+ * @param defKey Default key.
+ * @param defValCnt Def value count.
* @param isNilpotent Is nilpotent.
*/
private static <K, V, A> A sparseFold(String cacheName, IgniteBiFunction<Cache.Entry<K, V>, A, A> folder,
- IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, IgniteSupplier<A> zeroValSupp, V defVal, K defKey,
- long defValCnt, boolean isNilpotent) {
+ IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, IgniteSupplier<A> zeroValSupp, V defVal, K defKey,
+ long defValCnt, boolean isNilpotent) {
A defRes = zeroValSupp.get();
@@ -591,7 +579,7 @@ public class CacheUtils {
// Iterate over given partition.
// Query returns an empty cursor if this partition is not stored on this node.
for (Cache.Entry<K, V> entry : cache.query(new ScanQuery<K, V>(part,
- (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
+ (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
a = folder.apply(entry, a);
}
@@ -601,10 +589,10 @@ public class CacheUtils {
}
public static <K, V, A, W> A reduce(String cacheName, Ignite ignite,
- IgniteTriFunction<W, Cache.Entry<K, V>, A, A> acc,
- IgniteSupplier<W> supp,
- IgniteSupplier<Iterable<Cache.Entry<K, V>>> entriesGen, IgniteBinaryOperator<A> comb,
- IgniteSupplier<A> zeroValSupp) {
+ IgniteTriFunction<W, Cache.Entry<K, V>, A, A> acc,
+ IgniteSupplier<W> supp,
+ IgniteSupplier<Iterable<Cache.Entry<K, V>>> entriesGen, IgniteBinaryOperator<A> comb,
+ IgniteSupplier<A> zeroValSupp) {
A defRes = zeroValSupp.get();
@@ -624,15 +612,15 @@ public class CacheUtils {
}
public static <K, V, A, W> A reduce(String cacheName, IgniteTriFunction<W, Cache.Entry<K, V>, A, A> acc,
- IgniteSupplier<W> supp,
- IgniteSupplier<Iterable<Cache.Entry<K, V>>> entriesGen, IgniteBinaryOperator<A> comb,
- IgniteSupplier<A> zeroValSupp) {
+ IgniteSupplier<W> supp,
+ IgniteSupplier<Iterable<Cache.Entry<K, V>>> entriesGen, IgniteBinaryOperator<A> comb,
+ IgniteSupplier<A> zeroValSupp) {
return reduce(cacheName, Ignition.localIgnite(), acc, supp, entriesGen, comb, zeroValSupp);
}
/**
* @param cacheName Cache name.
- * @param run {@link Runnable} to broadcast to cache nodes for given cache name.
+ * @param run {@link Runnable} to broadcast to cache nodes for given cache name.
*/
public static void bcast(String cacheName, Ignite ignite, IgniteRunnable run) {
ignite.compute(ignite.cluster().forDataNodes(cacheName)).broadcast(run);
@@ -640,8 +628,9 @@ public class CacheUtils {
/**
* Broadcast runnable to data nodes of given cache.
+ *
* @param cacheName Cache name.
- * @param run Runnable.
+ * @param run Runnable.
*/
public static void bcast(String cacheName, IgniteRunnable run) {
bcast(cacheName, ignite(), run);
@@ -649,8 +638,8 @@ public class CacheUtils {
/**
* @param cacheName Cache name.
- * @param call {@link IgniteCallable} to broadcast to cache nodes for given cache name.
- * @param <A> Type returned by the callable.
+ * @param call {@link IgniteCallable} to broadcast to cache nodes for given cache name.
+ * @param <A> Type returned by the callable.
*/
public static <A> Collection<A> bcast(String cacheName, IgniteCallable<A> call) {
return bcast(cacheName, ignite(), call);
@@ -658,13 +647,42 @@ public class CacheUtils {
/**
* Broadcast callable to data nodes of given cache.
+ *
* @param cacheName Cache name.
- * @param ignite Ignite instance.
- * @param call Callable to broadcast.
- * @param <A> Type of callable result.
+ * @param ignite Ignite instance.
+ * @param call Callable to broadcast.
+ * @param <A> Type of callable result.
* @return Results of callable from each node.
*/
public static <A> Collection<A> bcast(String cacheName, Ignite ignite, IgniteCallable<A> call) {
return ignite.compute(ignite.cluster().forDataNodes(cacheName)).broadcast(call);
}
+
+ /**
+ * @param vectorUuid Matrix UUID.
+ * @param mapper Mapping {@link IgniteFunction}.
+ */
+ @SuppressWarnings("unchecked")
+ public static <K, V> void sparseMapForVector(UUID vectorUuid, IgniteDoubleFunction<V> mapper, String cacheName) {
+ A.notNull(vectorUuid, "vectorUuid");
+ A.notNull(cacheName, "cacheName");
+ A.notNull(mapper, "mapper");
+
+ foreach(cacheName, (CacheEntry<K, V> ce) -> {
+ K k = ce.entry().getKey();
+
+ V v = ce.entry().getValue();
+
+ if (v instanceof VectorBlockEntry) {
+ VectorBlockEntry entry = (VectorBlockEntry) v;
+
+ for (int i = 0; i < entry.size(); i++) entry.set(i, (Double) mapper.apply(entry.get(i)));
+
+ ce.cache().put(k, (V) entry);
+ } else {
+ V mappingRes = mapper.apply((Double) v);
+ ce.cache().put(k, mappingRes);
+ }
+ }, sparseKeyFilter(vectorUuid));
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/BlockMatrixKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/BlockMatrixKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/BlockMatrixKey.java
deleted file mode 100644
index 091b325..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/BlockMatrixKey.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.math.distributed.keys;
-
-import org.apache.ignite.internal.util.lang.IgnitePair;
-import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
-
-/**
- * Cache key for blocks in {@link SparseBlockDistributedMatrix}.
- *
- * TODO: check if using {@link IgnitePair} will be better for block id.
- */
-public interface BlockMatrixKey extends MatrixCacheKey {
- /**
- * @return block row id.
- */
- public long blockRowId();
-
- /**
- * @return block col id.
- */
- public long blockColId();
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/DataStructureCacheKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/DataStructureCacheKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/DataStructureCacheKey.java
new file mode 100644
index 0000000..d99ea48
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/DataStructureCacheKey.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.distributed.keys;
+
+import java.util.UUID;
+
+/**
+ * Base matrix cache key.
+ */
+public interface DataStructureCacheKey {
+ /**
+ * @return matrix id.
+ */
+ public UUID dataStructureId();
+
+ /**
+ * @return affinity key.
+ */
+ public Object affinityKey();
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixBlockKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixBlockKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixBlockKey.java
new file mode 100644
index 0000000..9c76568
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixBlockKey.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.distributed.keys;
+
+import org.apache.ignite.internal.util.lang.IgnitePair;
+import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
+
+/**
+ * Cache key for blocks in {@link SparseBlockDistributedMatrix}.
+ *
+ * TODO: check if using {@link IgnitePair} will be better for block id.
+ */
+public interface MatrixBlockKey extends DataStructureCacheKey {
+ /**
+ * @return block row id.
+ */
+ public long blockRowId();
+
+ /**
+ * @return block col id.
+ */
+ public long blockColId();
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java
deleted file mode 100644
index 0242560..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.math.distributed.keys;
-
-import java.util.UUID;
-
-/**
- * Base matrix cache key.
- */
-public interface MatrixCacheKey {
- /**
- * @return matrix id.
- */
- public UUID matrixId();
-
- /**
- * @return affinity key.
- */
- public Object affinityKey();
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/RowColMatrixKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/RowColMatrixKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/RowColMatrixKey.java
index 168f49f..78af2e8 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/RowColMatrixKey.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/RowColMatrixKey.java
@@ -22,7 +22,7 @@ import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
/**
* Cache key for {@link SparseDistributedMatrix}.
*/
-public interface RowColMatrixKey extends MatrixCacheKey {
+public interface RowColMatrixKey extends DataStructureCacheKey {
/**
* Return index value(blockId, Row/Col index, etc.)
*/
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/VectorBlockKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/VectorBlockKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/VectorBlockKey.java
new file mode 100644
index 0000000..32af965
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/VectorBlockKey.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.distributed.keys;
+
+import org.apache.ignite.internal.util.lang.IgnitePair;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+
+/**
+ * Cache key for blocks in {@link SparseBlockDistributedVector}.
+ *
+ * TODO: check if using {@link IgnitePair} will be better for block id.
+ */
+public interface VectorBlockKey extends DataStructureCacheKey {
+ /**
+ * @return block id.
+ */
+ public long blockId();
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java
deleted file mode 100644
index cc8c488..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.math.distributed.keys.impl;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.UUID;
-import org.apache.ignite.binary.BinaryObjectException;
-import org.apache.ignite.binary.BinaryRawReader;
-import org.apache.ignite.binary.BinaryRawWriter;
-import org.apache.ignite.binary.BinaryReader;
-import org.apache.ignite.binary.BinaryWriter;
-import org.apache.ignite.binary.Binarylizable;
-import org.apache.ignite.internal.binary.BinaryUtils;
-import org.apache.ignite.internal.util.typedef.F;
-import org.apache.ignite.internal.util.typedef.internal.S;
-import org.apache.ignite.internal.util.typedef.internal.U;
-import org.apache.ignite.lang.IgniteUuid;
-import org.apache.ignite.ml.math.impls.matrix.BlockEntry;
-import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
-import org.jetbrains.annotations.Nullable;
-
-/**
- * Key implementation for {@link BlockEntry} using for {@link SparseBlockDistributedMatrix}.
- */
-public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.keys.BlockMatrixKey, Externalizable, Binarylizable {
- /** */
- private static final long serialVersionUID = 0L;
- /** Block row ID */
- private long blockIdRow;
- /** Block col ID */
- private long blockIdCol;
- /** Matrix ID */
- private UUID matrixUuid;
- /** Block affinity key. */
- private IgniteUuid affinityKey;
-
- /**
- * Empty constructor required for {@link Externalizable}.
- */
- public BlockMatrixKey() {
- // No-op.
- }
-
- /**
- * Construct matrix block key.
- *
- * @param matrixUuid Matrix uuid.
- * @param affinityKey Affinity key.
- */
- public BlockMatrixKey(long rowId, long colId, UUID matrixUuid, @Nullable IgniteUuid affinityKey) {
- assert rowId >= 0;
- assert colId >= 0;
- assert matrixUuid != null;
-
- this.blockIdRow = rowId;
- this.blockIdCol = colId;
- this.matrixUuid = matrixUuid;
- this.affinityKey = affinityKey;
- }
-
- /** {@inheritDoc} */
- @Override public long blockRowId() {
- return blockIdRow;
- }
-
- /** {@inheritDoc} */
- @Override public long blockColId() {
- return blockIdCol;
- }
-
- /** {@inheritDoc} */
- @Override public UUID matrixId() {
- return matrixUuid;
- }
-
- /** {@inheritDoc} */
- @Override public IgniteUuid affinityKey() {
- return affinityKey;
- }
-
- /** {@inheritDoc} */
- @Override public void writeExternal(ObjectOutput out) throws IOException {
- out.writeObject(matrixUuid);
- U.writeGridUuid(out, affinityKey);
- out.writeLong(blockIdRow);
- out.writeLong(blockIdCol);
- }
-
- /** {@inheritDoc} */
- @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- matrixUuid = (UUID)in.readObject();
- affinityKey = U.readGridUuid(in);
- blockIdRow = in.readLong();
- blockIdCol = in.readLong();
- }
-
- /** {@inheritDoc} */
- @Override public void writeBinary(BinaryWriter writer) throws BinaryObjectException {
- BinaryRawWriter out = writer.rawWriter();
-
- out.writeUuid(matrixUuid);
- BinaryUtils.writeIgniteUuid(out, affinityKey);
- out.writeLong(blockIdRow);
- out.writeLong(blockIdCol);
- }
-
- /** {@inheritDoc} */
- @Override public void readBinary(BinaryReader reader) throws BinaryObjectException {
- BinaryRawReader in = reader.rawReader();
-
- matrixUuid = in.readUuid();
- affinityKey = BinaryUtils.readIgniteUuid(in);
- blockIdRow = in.readLong();
- blockIdCol = in.readLong();
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- int res = 37;
-
- res += res * 37 + blockIdCol;
- res += res * 37 + blockIdRow;
- res += res * 37 + matrixUuid.hashCode();
-
- return res;
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object obj) {
- if (obj == this)
- return true;
-
- if (obj == null || obj.getClass() != getClass())
- return false;
-
- BlockMatrixKey that = (BlockMatrixKey)obj;
-
- return blockIdRow == that.blockIdRow && blockIdCol == that.blockIdCol && matrixUuid.equals(that.matrixUuid)
- && F.eq(affinityKey, that.affinityKey);
- }
-
- /** {@inheritDoc} */
- @Override public String toString() {
- return S.toString(BlockMatrixKey.class, this);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/MatrixBlockKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/MatrixBlockKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/MatrixBlockKey.java
new file mode 100644
index 0000000..9e8d81e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/MatrixBlockKey.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.distributed.keys.impl;
+
+import org.apache.ignite.binary.*;
+import org.apache.ignite.internal.binary.BinaryUtils;
+import org.apache.ignite.internal.util.typedef.F;
+import org.apache.ignite.internal.util.typedef.internal.S;
+import org.apache.ignite.internal.util.typedef.internal.U;
+import org.apache.ignite.lang.IgniteUuid;
+import org.apache.ignite.ml.math.impls.matrix.MatrixBlockEntry;
+import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
+import org.jetbrains.annotations.Nullable;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.UUID;
+
+/**
+ * Key implementation for {@link MatrixBlockEntry} using for {@link SparseBlockDistributedMatrix}.
+ */
+public class MatrixBlockKey implements org.apache.ignite.ml.math.distributed.keys.MatrixBlockKey, Externalizable, Binarylizable {
+ /** */
+ private static final long serialVersionUID = 0L;
+ /** Block row ID */
+ private long blockIdRow;
+ /** Block col ID */
+ private long blockIdCol;
+ /** Matrix ID */
+ private UUID matrixUuid;
+ /** Block affinity key. */
+ private UUID affinityKey;
+
+ /**
+ * Empty constructor required for {@link Externalizable}.
+ */
+ public MatrixBlockKey() {
+ // No-op.
+ }
+
+ /**
+ * Construct matrix block key.
+ *
+ * @param matrixUuid Matrix uuid.
+ * @param affinityKey Affinity key.
+ */
+ public MatrixBlockKey(long rowId, long colId, UUID matrixUuid, @Nullable UUID affinityKey) {
+ assert rowId >= 0;
+ assert colId >= 0;
+ assert matrixUuid != null;
+
+ this.blockIdRow = rowId;
+ this.blockIdCol = colId;
+ this.matrixUuid = matrixUuid;
+ this.affinityKey = affinityKey;
+ }
+
+ /** {@inheritDoc} */
+ @Override public long blockRowId() {
+ return blockIdRow;
+ }
+
+ /** {@inheritDoc} */
+ @Override public long blockColId() {
+ return blockIdCol;
+ }
+
+ /** {@inheritDoc} */
+ @Override public UUID dataStructureId() {
+ return matrixUuid;
+ }
+
+ /** {@inheritDoc} */
+ @Override public UUID affinityKey() {
+ return affinityKey;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeObject(matrixUuid);
+ out.writeObject(affinityKey);
+ out.writeLong(blockIdRow);
+ out.writeLong(blockIdCol);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ matrixUuid = (UUID)in.readObject();
+ affinityKey = (UUID)in.readObject();
+ blockIdRow = in.readLong();
+ blockIdCol = in.readLong();
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeBinary(BinaryWriter writer) throws BinaryObjectException {
+ BinaryRawWriter out = writer.rawWriter();
+
+ out.writeUuid(matrixUuid);
+ out.writeUuid(affinityKey);
+ out.writeLong(blockIdRow);
+ out.writeLong(blockIdCol);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readBinary(BinaryReader reader) throws BinaryObjectException {
+ BinaryRawReader in = reader.rawReader();
+
+ matrixUuid = in.readUuid();
+ affinityKey = in.readUuid();
+ blockIdRow = in.readLong();
+ blockIdCol = in.readLong();
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = 37;
+
+ res += res * 37 + blockIdCol;
+ res += res * 37 + blockIdRow;
+ res += res * 37 + matrixUuid.hashCode();
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object obj) {
+ if (obj == this)
+ return true;
+
+ if (obj == null || obj.getClass() != getClass())
+ return false;
+
+ MatrixBlockKey that = (MatrixBlockKey)obj;
+
+ return blockIdRow == that.blockIdRow && blockIdCol == that.blockIdCol && matrixUuid.equals(that.matrixUuid)
+ && F.eq(affinityKey, that.affinityKey);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return S.toString(MatrixBlockKey.class, this);
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java
index aa5e0ad..980d433 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java
@@ -17,17 +17,18 @@
package org.apache.ignite.ml.math.distributed.keys.impl;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.UUID;
import org.apache.ignite.cache.affinity.AffinityKeyMapped;
import org.apache.ignite.internal.util.typedef.F;
import org.apache.ignite.internal.util.typedef.internal.S;
import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.UUID;
+
/**
* Key implementation for {@link SparseDistributedMatrix}.
*/
@@ -65,7 +66,7 @@ public class SparseMatrixKey implements RowColMatrixKey, Externalizable {
}
/** {@inheritDoc} */
- @Override public UUID matrixId() {
+ @Override public UUID dataStructureId() {
return matrixId;
}
@@ -76,7 +77,6 @@ public class SparseMatrixKey implements RowColMatrixKey, Externalizable {
/** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
-// U.writeGridUuid(out, matrixId);
out.writeObject(matrixId);
out.writeObject(affinityKey);
out.writeInt(idx);
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/VectorBlockKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/VectorBlockKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/VectorBlockKey.java
new file mode 100644
index 0000000..6052010
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/VectorBlockKey.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.distributed.keys.impl;
+
+import org.apache.ignite.binary.*;
+import org.apache.ignite.internal.binary.BinaryUtils;
+import org.apache.ignite.internal.util.typedef.F;
+import org.apache.ignite.internal.util.typedef.internal.S;
+import org.apache.ignite.internal.util.typedef.internal.U;
+import org.apache.ignite.lang.IgniteUuid;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+import org.apache.ignite.ml.math.impls.vector.VectorBlockEntry;
+import org.jetbrains.annotations.Nullable;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.UUID;
+
+/**
+ * Key implementation for {@link VectorBlockEntry} using for {@link SparseBlockDistributedVector}.
+ */
+public class VectorBlockKey implements org.apache.ignite.ml.math.distributed.keys.VectorBlockKey, Externalizable, Binarylizable {
+ /** */
+ private static final long serialVersionUID = 0L;
+ /** Block row ID */
+ private long blockId;
+ /** Vector ID */
+ private UUID vectorUuid;
+ /** Block affinity key. */
+ private UUID affinityKey;
+
+ /**
+ * Empty constructor required for {@link Externalizable}.
+ */
+ public VectorBlockKey() {
+ // No-op.
+ }
+
+ /**
+ * Construct vector block key.
+ *
+ * @param vectorUuid Vector uuid.
+ * @param affinityKey Affinity key.
+ */
+ public VectorBlockKey(long blockId, UUID vectorUuid, @Nullable UUID affinityKey) {
+ assert blockId >= 0;
+ assert vectorUuid != null;
+
+ this.blockId = blockId;
+ this.vectorUuid = vectorUuid;
+ this.affinityKey = affinityKey;
+ }
+
+ /** {@inheritDoc} */
+ @Override public long blockId() {
+ return blockId;
+ }
+
+
+ /** {@inheritDoc} */
+ @Override public UUID dataStructureId() {
+ return vectorUuid;
+ }
+
+ /** {@inheritDoc} */
+ @Override public UUID affinityKey() {
+ return affinityKey;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeObject(vectorUuid);
+ out.writeObject(affinityKey);
+ out.writeLong(blockId);
+
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ vectorUuid = (UUID)in.readObject();
+ affinityKey = (UUID)in.readObject();
+ blockId = in.readLong();
+
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeBinary(BinaryWriter writer) throws BinaryObjectException {
+ BinaryRawWriter out = writer.rawWriter();
+
+ out.writeUuid(vectorUuid);
+ out.writeUuid(affinityKey);
+ out.writeLong(blockId);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readBinary(BinaryReader reader) throws BinaryObjectException {
+ BinaryRawReader in = reader.rawReader();
+
+ vectorUuid = in.readUuid();
+ affinityKey = in.readUuid();
+ blockId = in.readLong();
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = 37;
+
+ res += res * 37 + blockId;
+ res += res * 37 + vectorUuid.hashCode();
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object obj) {
+ if (obj == this)
+ return true;
+
+ if (obj == null || obj.getClass() != getClass())
+ return false;
+
+ VectorBlockKey that = (VectorBlockKey)obj;
+
+ return blockId == that.blockId && vectorUuid.equals(that.vectorUuid)
+ && F.eq(affinityKey, that.affinityKey);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return S.toString(VectorBlockKey.class, this);
+ }
+
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
index 0b4ad12..ce534bd 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
@@ -17,10 +17,11 @@
package org.apache.ignite.ml.math.functions;
+import org.apache.ignite.lang.IgniteBiTuple;
+
import java.util.Comparator;
import java.util.List;
import java.util.function.BiFunction;
-import org.apache.ignite.lang.IgniteBiTuple;
/**
* Compatibility with Apache Mahout.
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
index 06fb34c..89f567e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
@@ -43,6 +43,7 @@ import org.apache.ignite.ml.math.functions.IgniteTriFunction;
import org.apache.ignite.ml.math.functions.IntIntToDoubleFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.math.impls.vector.MatrixVectorView;
+import org.apache.ignite.ml.math.util.MatrixUtil;
/**
* This class provides a helper implementation of the {@link Matrix}
@@ -282,7 +283,7 @@ public abstract class AbstractMatrix implements Matrix {
*
* @param row Row index.
*/
- private void checkRowIndex(int row) {
+ void checkRowIndex(int row) {
if (row < 0 || row >= rowSize())
throw new RowIndexException(row);
}
@@ -292,7 +293,7 @@ public abstract class AbstractMatrix implements Matrix {
*
* @param col Column index.
*/
- private void checkColumnIndex(int col) {
+ void checkColumnIndex(int col) {
if (col < 0 || col >= columnSize())
throw new ColumnIndexException(col);
}
@@ -303,7 +304,7 @@ public abstract class AbstractMatrix implements Matrix {
* @param row Row index.
* @param col Column index.
*/
- protected void checkIndex(int row, int col) {
+ private void checkIndex(int row, int col) {
checkRowIndex(row);
checkColumnIndex(col);
}
@@ -739,11 +740,12 @@ public abstract class AbstractMatrix implements Matrix {
/** {@inheritDoc} */
@Override public Vector getCol(int col) {
checkColumnIndex(col);
-
- Vector res = new DenseLocalOnHeapVector(rowSize());
+ Vector res;
+ if (isDistributed()) res = MatrixUtil.likeVector(this, rowSize());
+ else res = new DenseLocalOnHeapVector(rowSize());
for (int i = 0; i < rowSize(); i++)
- res.setX(i, getX(i,col));
+ res.setX(i, getX(i, col));
return res;
}
@@ -974,4 +976,14 @@ public abstract class AbstractMatrix implements Matrix {
@Override public void compute(int row, int col, IgniteTriFunction<Integer, Integer, Double, Double> f) {
setX(row, col, f.apply(row, col, getX(row, col)));
}
+
+
+ protected int getMaxAmountOfColumns(double[][] data) {
+ int maxAmountOfColumns = 0;
+
+ for (int i = 0; i < data.length; i++)
+ maxAmountOfColumns = Math.max(maxAmountOfColumns, data[i].length);
+
+ return maxAmountOfColumns;
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/BlockEntry.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/BlockEntry.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/BlockEntry.java
deleted file mode 100644
index 47f07ce..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/BlockEntry.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.math.impls.matrix;
-
-import org.apache.ignite.ml.math.Matrix;
-
-/**
- * Block for {@link SparseBlockDistributedMatrix}.
- */
-public final class BlockEntry extends SparseLocalOnHeapMatrix {
- /** Max block size. */
- public static final int MAX_BLOCK_SIZE = 32;
-
- /** */
- public BlockEntry() {
- // No-op.
- }
-
- /** */
- public BlockEntry(int row, int col) {
- super(row, col);
-
- assert col <= MAX_BLOCK_SIZE;
- assert row <= MAX_BLOCK_SIZE;
- }
-
- /** */
- public BlockEntry(Matrix mtx) {
- assert mtx.columnSize() <= MAX_BLOCK_SIZE;
- assert mtx.rowSize() <= MAX_BLOCK_SIZE;
-
- setStorage(mtx.getStorage());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/MatrixBlockEntry.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/MatrixBlockEntry.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/MatrixBlockEntry.java
new file mode 100644
index 0000000..a2d13a1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/MatrixBlockEntry.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.impls.matrix;
+
+import org.apache.ignite.ml.math.Matrix;
+
+/**
+ * Block for {@link SparseBlockDistributedMatrix}.
+ */
+public final class MatrixBlockEntry extends SparseLocalOnHeapMatrix {
+ /** Max block size. */
+ public static final int MAX_BLOCK_SIZE = 32;
+
+ /** */
+ public MatrixBlockEntry() {
+ // No-op.
+ }
+
+ /** */
+ public MatrixBlockEntry(int row, int col) {
+ super(row, col);
+
+ assert col <= MAX_BLOCK_SIZE;
+ assert row <= MAX_BLOCK_SIZE;
+ }
+
+ /** */
+ public MatrixBlockEntry(Matrix mtx) {
+ assert mtx.columnSize() <= MAX_BLOCK_SIZE;
+ assert mtx.rowSize() <= MAX_BLOCK_SIZE;
+
+ setStorage(mtx.getStorage());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java
index e829168..ea9fb8c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java
@@ -17,10 +17,6 @@
package org.apache.ignite.ml.math.impls.matrix;
-import java.util.Collection;
-import java.util.List;
-import java.util.Map;
-import java.util.UUID;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -31,16 +27,25 @@ import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.StorageConstants;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distributed.CacheUtils;
-import org.apache.ignite.ml.math.distributed.keys.impl.BlockMatrixKey;
+import org.apache.ignite.ml.math.distributed.keys.impl.MatrixBlockKey;
+import org.apache.ignite.ml.math.distributed.keys.impl.VectorBlockKey;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
-import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
import org.apache.ignite.ml.math.impls.storage.matrix.BlockMatrixStorage;
+import org.apache.ignite.ml.math.impls.storage.matrix.BlockVectorStorage;
+import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
+import org.apache.ignite.ml.math.impls.vector.VectorBlockEntry;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
/**
- * Sparse block distributed matrix. This matrix represented by blocks 32x32 {@link BlockEntry}.
+ * Sparse block distributed matrix. This matrix represented by blocks 32x32 {@link MatrixBlockEntry}.
*
- * Using separate cache with keys {@link BlockMatrixKey} and values {@link BlockEntry}.
+ * Using separate cache with keys {@link MatrixBlockKey} and values {@link MatrixBlockEntry}.
*/
public class SparseBlockDistributedMatrix extends AbstractMatrix implements StorageConstants {
/**
@@ -62,6 +67,18 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor
}
/**
+ * @param data Data to fill the matrix
+ */
+ public SparseBlockDistributedMatrix(double[][] data) {
+ assert data.length > 0;
+ setStorage(new BlockMatrixStorage(data.length, getMaxAmountOfColumns(data)));
+
+ for (int i = 0; i < data.length; i++)
+ for (int j = 0; j < data[i].length; j++)
+ storage().set(i, j, data[i][j]);
+ }
+
+ /**
* Return the same matrix with updates values (broken contract).
*
* @param d Value to divide to.
@@ -100,22 +117,22 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor
throw new CardinalityException(columnSize(), mtx.rowSize());
SparseBlockDistributedMatrix matrixA = this;
- SparseBlockDistributedMatrix matrixB = (SparseBlockDistributedMatrix)mtx;
+ SparseBlockDistributedMatrix matrixB = (SparseBlockDistributedMatrix) mtx;
String cacheName = this.storage().cacheName();
SparseBlockDistributedMatrix matrixC = new SparseBlockDistributedMatrix(matrixA.rowSize(), matrixB.columnSize());
CacheUtils.bcast(cacheName, () -> {
Ignite ignite = Ignition.localIgnite();
- Affinity<BlockMatrixKey> affinity = ignite.affinity(cacheName);
+ Affinity<MatrixBlockKey> affinity = ignite.affinity(cacheName);
- IgniteCache<BlockMatrixKey, BlockEntry> cache = ignite.getOrCreateCache(cacheName);
+ IgniteCache<MatrixBlockKey, MatrixBlockEntry> cache = ignite.getOrCreateCache(cacheName);
ClusterNode locNode = ignite.cluster().localNode();
BlockMatrixStorage storageC = matrixC.storage();
- Map<ClusterNode, Collection<BlockMatrixKey>> keysCToNodes = affinity.mapKeysToNodes(storageC.getAllKeys());
- Collection<BlockMatrixKey> locKeys = keysCToNodes.get(locNode);
+ Map<ClusterNode, Collection<MatrixBlockKey>> keysCToNodes = affinity.mapKeysToNodes(storageC.getAllKeys());
+ Collection<MatrixBlockKey> locKeys = keysCToNodes.get(locNode);
if (locKeys == null)
return;
@@ -128,18 +145,18 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor
IgnitePair<Long> newBlockId = new IgnitePair<>(newBlockIdRow, newBlockIdCol);
- BlockEntry blockC = null;
+ MatrixBlockEntry blockC = null;
- List<BlockEntry> aRow = matrixA.storage().getRowForBlock(newBlockId);
- List<BlockEntry> bCol = matrixB.storage().getColForBlock(newBlockId);
+ List<MatrixBlockEntry> aRow = matrixA.storage().getRowForBlock(newBlockId);
+ List<MatrixBlockEntry> bCol = matrixB.storage().getColForBlock(newBlockId);
for (int i = 0; i < aRow.size(); i++) {
- BlockEntry blockA = aRow.get(i);
- BlockEntry blockB = bCol.get(i);
+ MatrixBlockEntry blockA = aRow.get(i);
+ MatrixBlockEntry blockB = bCol.get(i);
- BlockEntry tmpBlock = new BlockEntry(blockA.times(blockB));
+ MatrixBlockEntry tmpBlock = new MatrixBlockEntry(blockA.times(blockB));
- blockC = blockC == null ? tmpBlock : new BlockEntry(blockC.plus(tmpBlock));
+ blockC = blockC == null ? tmpBlock : new MatrixBlockEntry(blockC.plus(tmpBlock));
}
cache.put(storageC.getCacheKey(newBlockIdRow, newBlockIdCol), blockC);
@@ -149,6 +166,90 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor
return matrixC;
}
+
+ /**
+ * {@inheritDoc}
+ */
+ @SuppressWarnings({"unchecked"})
+ @Override public Vector times(final Vector vec) {
+ if (vec == null)
+ throw new IllegalArgumentException("The vector should be not null.");
+
+ if (columnSize() != vec.size())
+ throw new CardinalityException(columnSize(), vec.size());
+
+ SparseBlockDistributedMatrix matrixA = this;
+ SparseBlockDistributedVector vectorB = (SparseBlockDistributedVector) vec;
+
+
+ String cacheName = this.storage().cacheName();
+ SparseBlockDistributedVector vectorC = new SparseBlockDistributedVector(matrixA.rowSize());
+
+ CacheUtils.bcast(cacheName, () -> {
+ Ignite ignite = Ignition.localIgnite();
+ Affinity<VectorBlockKey> affinity = ignite.affinity(cacheName);
+
+ IgniteCache<VectorBlockKey, VectorBlockEntry> cache = ignite.getOrCreateCache(cacheName);
+ ClusterNode locNode = ignite.cluster().localNode();
+
+ BlockVectorStorage storageC = vectorC.storage();
+
+ Map<ClusterNode, Collection<VectorBlockKey>> keysCToNodes = affinity.mapKeysToNodes(storageC.getAllKeys());
+ Collection<VectorBlockKey> locKeys = keysCToNodes.get(locNode);
+
+ if (locKeys == null)
+ return;
+
+ // compute Cij locally on each node
+ // TODO: IGNITE:5114, exec in parallel
+ locKeys.forEach(key -> {
+ long newBlockId = key.blockId();
+
+
+ IgnitePair<Long> newBlockIdForMtx = new IgnitePair<>(newBlockId, 0L);
+
+ VectorBlockEntry blockC = null;
+
+ List<MatrixBlockEntry> aRow = matrixA.storage().getRowForBlock(newBlockIdForMtx);
+ List<VectorBlockEntry> bCol = vectorB.storage().getColForBlock(newBlockId);
+
+ for (int i = 0; i < aRow.size(); i++) {
+ MatrixBlockEntry blockA = aRow.get(i);
+ VectorBlockEntry blockB = bCol.get(i);
+
+ VectorBlockEntry tmpBlock = new VectorBlockEntry(blockA.times(blockB));
+
+ blockC = blockC == null ? tmpBlock : new VectorBlockEntry(blockC.plus(tmpBlock));
+ }
+
+ cache.put(storageC.getCacheKey(newBlockId), blockC);
+ });
+ });
+ return vectorC;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Vector getCol(int col) {
+ checkColumnIndex(col);
+
+ Vector res = new SparseDistributedVector(rowSize());
+
+ for (int i = 0; i < rowSize(); i++)
+ res.setX(i, getX(i, col));
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Vector getRow(int row) {
+ checkRowIndex(row);
+
+ Vector res = new SparseDistributedVector(columnSize());
+
+ for (int i = 0; i < columnSize(); i++)
+ res.setX(i, getX(row, i));
+ return res;
+ }
+
/** {@inheritDoc} */
@Override public Matrix assign(double val) {
return mapOverValues(v -> val);
@@ -176,7 +277,11 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor
/** {@inheritDoc} */
@Override public Matrix copy() {
- throw new UnsupportedOperationException();
+ Matrix cp = like(rowSize(), columnSize());
+
+ cp.assign(this);
+
+ return cp;
}
/** {@inheritDoc} */
@@ -186,12 +291,12 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor
/** {@inheritDoc} */
@Override public Vector likeVector(int crd) {
- throw new UnsupportedOperationException();
+ return new SparseBlockDistributedVector(crd);
}
/** */
private UUID getUUID() {
- return ((BlockMatrixStorage)getStorage()).getUUID();
+ return ((BlockMatrixStorage) getStorage()).getUUID();
}
/**
@@ -208,6 +313,6 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor
*
*/
private BlockMatrixStorage storage() {
- return (BlockMatrixStorage)getStorage();
+ return (BlockMatrixStorage) getStorage();
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0a86018/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java
index 594aebc..497241d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java
@@ -17,24 +17,24 @@
package org.apache.ignite.ml.math.impls.matrix;
-import java.util.Collection;
-import java.util.Map;
-import java.util.UUID;
import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.Affinity;
import org.apache.ignite.cluster.ClusterNode;
-import org.apache.ignite.lang.IgniteUuid;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.StorageConstants;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distributed.CacheUtils;
import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
-import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage;
+import org.apache.ignite.ml.math.impls.storage.vector.SparseDistributedVectorStorage;
+import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector;
+
+import java.util.Collection;
+import java.util.Map;
+import java.util.UUID;
/**
* Sparse distributed matrix implementation based on data grid.
@@ -68,6 +68,28 @@ public class SparseDistributedMatrix extends AbstractMatrix implements StorageCo
assertStorageMode(stoMode);
setStorage(new SparseDistributedMatrixStorage(rows, cols, stoMode, acsMode));
+
+ }
+
+ /**
+ * @param data Data to fill the matrix
+ */
+ public SparseDistributedMatrix(double[][] data) {
+ assert data.length > 0;
+ setStorage(new SparseDistributedMatrixStorage(data.length, getMaxAmountOfColumns(data), StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE));
+
+ for (int i = 0; i < data.length; i++)
+ for (int j = 0; j < data[i].length; j++)
+ storage().set(i,j,data[i][j]);
+ }
+
+
+ /**
+ * @param rows Amount of rows in the matrix.
+ * @param cols Amount of columns in the matrix.
+ */
+ public SparseDistributedMatrix(int rows, int cols) {
+ this(rows, cols, StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
}
/** */
@@ -122,7 +144,6 @@ public class SparseDistributedMatrix extends AbstractMatrix implements StorageCo
Ignite ignite = Ignition.localIgnite();
Affinity<RowColMatrixKey> affinity = ignite.affinity(cacheName);
- IgniteCache<RowColMatrixKey, BlockEntry> cache = ignite.getOrCreateCache(cacheName);
ClusterNode locNode = ignite.cluster().localNode();
SparseDistributedMatrixStorage storageC = matrixC.storage();
@@ -141,17 +162,17 @@ public class SparseDistributedMatrix extends AbstractMatrix implements StorageCo
int idx = key.index();
if (isRowMode){
- Vector Aik = matrixA.getCol(idx);
+ Vector Aik = matrixA.getRow(idx);
- for (int i = 0; i < columnSize(); i++) {
- Vector Bkj = matrixB.getRow(i);
+ for (int i = 0; i < matrixB.columnSize(); i++) {
+ Vector Bkj = matrixB.getCol(i);
matrixC.set(idx, i, Aik.times(Bkj).sum());
}
} else {
- Vector Bkj = matrixB.getRow(idx);
+ Vector Bkj = matrixB.getCol(idx);
- for (int i = 0; i < rowSize(); i++) {
- Vector Aik = matrixA.getCol(i);
+ for (int i = 0; i < matrixA.rowSize(); i++) {
+ Vector Aik = matrixA.getRow(i);
matrixC.set(idx, i, Aik.times(Bkj).sum());
}
}
@@ -161,6 +182,49 @@ public class SparseDistributedMatrix extends AbstractMatrix implements StorageCo
return matrixC;
}
+
+ /** {@inheritDoc} */
+ @Override public Vector times(Vector vec) {
+ if (vec == null)
+ throw new IllegalArgumentException("The vector should be not null.");
+
+ if (columnSize() != vec.size())
+ throw new CardinalityException(columnSize(), vec.size());
+
+ SparseDistributedMatrix matrixA = this;
+ SparseDistributedVector vectorB = (SparseDistributedVector) vec;
+
+ String cacheName = storage().cacheName();
+ int rows = this.rowSize();
+
+ SparseDistributedVector vectorC = (SparseDistributedVector) likeVector(rows);
+
+ CacheUtils.bcast(cacheName, () -> {
+ Ignite ignite = Ignition.localIgnite();
+ Affinity<RowColMatrixKey> affinity = ignite.affinity(cacheName);
+
+ ClusterNode locNode = ignite.cluster().localNode();
+
+ SparseDistributedVectorStorage storageC = vectorC.storage();
+
+ Map<ClusterNode, Collection<RowColMatrixKey>> keysCToNodes = affinity.mapKeysToNodes(storageC.getAllKeys());
+ Collection<RowColMatrixKey> locKeys = keysCToNodes.get(locNode);
+
+ if (locKeys == null)
+ return;
+
+ // compute Cij locally on each node
+ // TODO: IGNITE:5114, exec in parallel
+ locKeys.forEach(key -> {
+ int idx = key.index();
+ Vector Aik = matrixA.getRow(idx);
+ vectorC.set(idx, Aik.times(vectorB).sum());
+ });
+ });
+
+ return vectorC;
+ }
+
/** {@inheritDoc} */
@Override public Matrix assign(double val) {
return mapOverValues(v -> val);
@@ -198,17 +262,23 @@ public class SparseDistributedMatrix extends AbstractMatrix implements StorageCo
/** {@inheritDoc} */
@Override public Matrix copy() {
- throw new UnsupportedOperationException();
+ Matrix cp = like(rowSize(), columnSize());
+
+ cp.assign(this);
+
+ return cp;
}
/** {@inheritDoc} */
@Override public Matrix like(int rows, int cols) {
- return new SparseDistributedMatrix(rows, cols, storage().storageMode(), storage().accessMode());
+ if(storage()==null) return new SparseDistributedMatrix(rows, cols);
+ else return new SparseDistributedMatrix(rows, cols, storage().storageMode(), storage().accessMode());
+
}
/** {@inheritDoc} */
@Override public Vector likeVector(int crd) {
- throw new UnsupportedOperationException();
+ return new SparseDistributedVector(crd, StorageConstants.RANDOM_ACCESS_MODE);
}
/** */