You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2017/12/28 16:13:19 UTC
[5/6] ignite git commit: IGNITE-5217: Gradient descent for OLS lin reg
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java
deleted file mode 100644
index aafeae8..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java
+++ /dev/null
@@ -1,257 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.ignite.ml.regressions;
-
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.decompositions.QRDSolver;
-import org.apache.ignite.ml.math.decompositions.QRDecomposition;
-import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
-import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
-import org.apache.ignite.ml.math.functions.Functions;
-
-/**
- * This class is based on the corresponding class from Apache Common Math lib.
- * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
- * multiple linear regression model.</p>
- *
- * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
- * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
- *
- * <p>To solve the normal equations, this implementation uses QR decomposition
- * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
- * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
- * has rows corresponding to sample observations and columns corresponding to independent
- * variables. When the model is estimated using an intercept term (i.e. when
- * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
- * matrix includes an initial column identically equal to 1. We solve the normal equations
- * as follows:
- * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
- * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
- * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
- * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
- * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
- * R b = Q<sup>T</sup> y </code></pre></p>
- *
- * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
- */
-public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
- /** Cached QR decomposition of X matrix */
- private QRDSolver solver = null;
-
- /** Singularity threshold for QR decomposition */
- private final double threshold;
-
- /**
- * Create an empty OLSMultipleLinearRegression instance.
- */
- public OLSMultipleLinearRegression() {
- this(0d);
- }
-
- /**
- * Create an empty OLSMultipleLinearRegression instance, using the given
- * singularity threshold for the QR decomposition.
- *
- * @param threshold the singularity threshold
- */
- public OLSMultipleLinearRegression(final double threshold) {
- this.threshold = threshold;
- }
-
- /**
- * Loads model x and y sample data, overriding any previous sample.
- *
- * Computes and caches QR decomposition of the X matrix.
- *
- * @param y the {@code n}-sized vector representing the y sample
- * @param x the {@code n x k} matrix representing the x sample
- * @throws MathIllegalArgumentException if the x and y array data are not compatible for the regression
- */
- public void newSampleData(Vector y, Matrix x) throws MathIllegalArgumentException {
- validateSampleData(x, y);
- newYSampleData(y);
- newXSampleData(x);
- }
-
- /**
- * {@inheritDoc}
- * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
- */
- @Override public void newSampleData(double[] data, int nobs, int nvars, Matrix like) {
- super.newSampleData(data, nobs, nvars, like);
- QRDecomposition qr = new QRDecomposition(getX(), threshold);
- solver = new QRDSolver(qr.getQ(), qr.getR());
- }
-
- /**
- * <p>Compute the "hat" matrix.
- * </p>
- * <p>The hat matrix is defined in terms of the design matrix X
- * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
- * </p>
- * <p>The implementation here uses the QR decomposition to compute the
- * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
- * p-dimensional identity matrix augmented by 0's. This computational
- * formula is from "The Hat Matrix in Regression and ANOVA",
- * David C. Hoaglin and Roy E. Welsch,
- * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
- * </p>
- * <p>Data for the model must have been successfully loaded using one of
- * the {@code newSampleData} methods before invoking this method; otherwise
- * a {@code NullPointerException} will be thrown.</p>
- *
- * @return the hat matrix
- * @throws NullPointerException unless method {@code newSampleData} has been called beforehand.
- */
- public Matrix calculateHat() {
- return solver.calculateHat();
- }
-
- /**
- * <p>Returns the sum of squared deviations of Y from its mean.</p>
- *
- * <p>If the model has no intercept term, <code>0</code> is used for the
- * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
- *
- * <p>The value returned by this method is the SSTO value used in
- * the {@link #calculateRSquared() R-squared} computation.</p>
- *
- * @return SSTO - the total sum of squares
- * @throws NullPointerException if the sample has not been set
- * @see #isNoIntercept()
- */
- public double calculateTotalSumOfSquares() {
- if (isNoIntercept())
- return getY().foldMap(Functions.PLUS, Functions.SQUARE, 0.0);
- else {
- // TODO: IGNITE-5826, think about incremental update formula.
- final double mean = getY().sum() / getY().size();
- return getY().foldMap(Functions.PLUS, x -> (mean - x) * (mean - x), 0.0);
- }
- }
-
- /**
- * Returns the sum of squared residuals.
- *
- * @return residual sum of squares
- * @throws SingularMatrixException if the design matrix is singular
- * @throws NullPointerException if the data for the model have not been loaded
- */
- public double calculateResidualSumOfSquares() {
- final Vector residuals = calculateResiduals();
- // No advertised DME, args are valid
- return residuals.dot(residuals);
- }
-
- /**
- * Returns the R-Squared statistic, defined by the formula <pre>
- * R<sup>2</sup> = 1 - SSR / SSTO
- * </pre>
- * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
- * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
- *
- * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
- *
- * @return R-square statistic
- * @throws NullPointerException if the sample has not been set
- * @throws SingularMatrixException if the design matrix is singular
- */
- public double calculateRSquared() {
- return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
- }
-
- /**
- * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
- * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
- * </pre>
- * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
- * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
- * of observations and p is the number of parameters estimated (including the intercept).</p>
- *
- * <p>If the regression is estimated without an intercept term, what is returned is <pre>
- * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
- * </pre></p>
- *
- * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
- *
- * @return adjusted R-Squared statistic
- * @throws NullPointerException if the sample has not been set
- * @throws SingularMatrixException if the design matrix is singular
- * @see #isNoIntercept()
- */
- public double calculateAdjustedRSquared() {
- final double n = getX().rowSize();
- if (isNoIntercept())
- return 1 - (1 - calculateRSquared()) * (n / (n - getX().columnSize()));
- else
- return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
- (calculateTotalSumOfSquares() * (n - getX().columnSize()));
- }
-
- /**
- * {@inheritDoc}
- * <p>This implementation computes and caches the QR decomposition of the X matrix
- * once it is successfully loaded.</p>
- */
- @Override protected void newXSampleData(Matrix x) {
- super.newXSampleData(x);
- QRDecomposition qr = new QRDecomposition(getX());
- solver = new QRDSolver(qr.getQ(), qr.getR());
- }
-
- /**
- * Calculates the regression coefficients using OLS.
- *
- * <p>Data for the model must have been successfully loaded using one of
- * the {@code newSampleData} methods before invoking this method; otherwise
- * a {@code NullPointerException} will be thrown.</p>
- *
- * @return beta
- * @throws SingularMatrixException if the design matrix is singular
- * @throws NullPointerException if the data for the model have not been loaded
- */
- @Override protected Vector calculateBeta() {
- return solver.solve(getY());
- }
-
- /**
- * <p>Calculates the variance-covariance matrix of the regression parameters.
- * </p>
- * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
- * </p>
- * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
- * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
- * R included, where p = the length of the beta vector.</p>
- *
- * <p>Data for the model must have been successfully loaded using one of
- * the {@code newSampleData} methods before invoking this method; otherwise
- * a {@code NullPointerException} will be thrown.</p>
- *
- * @return The beta variance-covariance matrix
- * @throws SingularMatrixException if the design matrix is singular
- * @throws NullPointerException if the data for the model have not been loaded
- */
- @Override protected Matrix calculateBetaVariance() {
- return solver.calculateBetaVariance(getX().columnSize());
- }
-
- /** */
- QRDSolver solver() {
- return solver;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
deleted file mode 100644
index b95cbf3..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions;
-
-import org.apache.ignite.ml.Exportable;
-import org.apache.ignite.ml.Exporter;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.decompositions.QRDSolver;
-import org.apache.ignite.ml.math.decompositions.QRDecomposition;
-
-/**
- * Model for linear regression.
- */
-public class OLSMultipleLinearRegressionModel implements Model<Vector, Vector>,
- Exportable<OLSMultipleLinearRegressionModelFormat> {
- /** */
- private final Matrix xMatrix;
- /** */
- private final QRDSolver solver;
-
- /**
- * Construct linear regression model.
- *
- * @param xMatrix See {@link QRDecomposition#QRDecomposition(Matrix)}.
- * @param solver Linear regression solver object.
- */
- public OLSMultipleLinearRegressionModel(Matrix xMatrix, QRDSolver solver) {
- this.xMatrix = xMatrix;
- this.solver = solver;
- }
-
- /** {@inheritDoc} */
- @Override public Vector apply(Vector val) {
- return xMatrix.times(solver.solve(val));
- }
-
- /** {@inheritDoc} */
- @Override public <P> void saveModel(Exporter<OLSMultipleLinearRegressionModelFormat, P> exporter, P path) {
- exporter.save(new OLSMultipleLinearRegressionModelFormat(xMatrix, solver), path);
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
-
- OLSMultipleLinearRegressionModel mdl = (OLSMultipleLinearRegressionModel)o;
-
- return xMatrix.equals(mdl.xMatrix) && solver.equals(mdl.solver);
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- int res = xMatrix.hashCode();
- res = 31 * res + solver.hashCode();
- return res;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java
deleted file mode 100644
index fc44968..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions;
-
-import java.io.Serializable;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.decompositions.QRDSolver;
-
-/**
- * Linear regression model representation.
- *
- * @see OLSMultipleLinearRegressionModel
- */
-public class OLSMultipleLinearRegressionModelFormat implements Serializable {
- /** X sample data. */
- private final Matrix xMatrix;
-
- /** Whether or not the regression model includes an intercept. True means no intercept. */
- private final QRDSolver solver;
-
- /** */
- public OLSMultipleLinearRegressionModelFormat(Matrix xMatrix, QRDSolver solver) {
- this.xMatrix = xMatrix;
- this.solver = solver;
- }
-
- /** */
- public OLSMultipleLinearRegressionModel getOLSMultipleLinearRegressionModel() {
- return new OLSMultipleLinearRegressionModel(xMatrix, solver);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java
deleted file mode 100644
index dde0aca..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions;
-
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.math.Matrix;
-
-/**
- * Trainer for linear regression.
- */
-public class OLSMultipleLinearRegressionTrainer implements Trainer<OLSMultipleLinearRegressionModel, double[]> {
- /** */
- private final double threshold;
-
- /** */
- private final int nobs;
-
- /** */
- private final int nvars;
-
- /** */
- private final Matrix like;
-
- /**
- * Construct linear regression trainer.
- *
- * @param threshold the singularity threshold for QR decomposition
- * @param nobs number of observations (rows)
- * @param nvars number of independent variables (columns, not counting y)
- * @param like matrix(maybe empty) indicating how data should be stored
- */
- public OLSMultipleLinearRegressionTrainer(double threshold, int nobs, int nvars, Matrix like) {
- this.threshold = threshold;
- this.nobs = nobs;
- this.nvars = nvars;
- this.like = like;
- }
-
- /** {@inheritDoc} */
- @Override public OLSMultipleLinearRegressionModel train(double[] data) {
- OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(threshold);
-
- regression.newSampleData(data, nobs, nvars, like);
-
- return new OLSMultipleLinearRegressionModel(regression.getX(), regression.solver());
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java
deleted file mode 100644
index 883adca..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions;
-
-/**
- * This class contains various messages used in regressions,
- */
-public class RegressionsErrorMessages {
- /** Constant for string indicating that sample has insufficient observed points. */
- static final String INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE = "Insufficient observed points in sample.";
- /** */
- static final String NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS = "Not enough data (%d rows) for this many predictors (%d predictors)";
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java
new file mode 100644
index 0000000..6586a81
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import java.io.Serializable;
+import java.util.Objects;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Simple linear regression model which predicts result value Y as a linear combination of input variables:
+ * Y = weights * X + intercept.
+ */
+public class LinearRegressionModel implements Model<Vector, Double>, Exportable<LinearRegressionModel>, Serializable {
+ /** */
+ private static final long serialVersionUID = -105984600091550226L;
+
+ /** Multiplier of the objects's vector required to make prediction. */
+ private final Vector weights;
+
+ /** Intercept of the linear regression model */
+ private final double intercept;
+
+ /** */
+ public LinearRegressionModel(Vector weights, double intercept) {
+ this.weights = weights;
+ this.intercept = intercept;
+ }
+
+ /** */
+ public Vector getWeights() {
+ return weights;
+ }
+
+ /** */
+ public double getIntercept() {
+ return intercept;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector input) {
+ return input.dot(weights) + intercept;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <P> void saveModel(Exporter<LinearRegressionModel, P> exporter, P path) {
+ exporter.save(this, path);
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+ LinearRegressionModel mdl = (LinearRegressionModel)o;
+ return Double.compare(mdl.intercept, intercept) == 0 &&
+ Objects.equals(weights, mdl.weights);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+
+ return Objects.hash(weights, intercept);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ if (weights.size() < 10) {
+ StringBuilder builder = new StringBuilder();
+
+ for (int i = 0; i < weights.size(); i++) {
+ double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1);
+
+ builder.append(String.format("%.4f", Math.abs(weights.get(i))))
+ .append("*x")
+ .append(i)
+ .append(nextItem > 0 ? " + " : " - ");
+ }
+
+ builder.append(String.format("%.4f", Math.abs(intercept)));
+ return builder.toString();
+ }
+
+ return "LinearRegressionModel{" +
+ "weights=" + weights +
+ ", intercept=" + intercept +
+ '}';
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java
new file mode 100644
index 0000000..5de3cda
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.decompositions.QRDSolver;
+import org.apache.ignite.ml.math.decompositions.QRDecomposition;
+import org.apache.ignite.ml.math.impls.vector.FunctionVector;
+
+/**
+ * Linear regression trainer based on least squares loss function and QR decomposition.
+ */
+public class LinearRegressionQRTrainer implements Trainer<LinearRegressionModel, Matrix> {
+ /**
+ * {@inheritDoc}
+ */
+ @Override public LinearRegressionModel train(Matrix data) {
+ Vector groundTruth = extractGroundTruth(data);
+ Matrix inputs = extractInputs(data);
+
+ QRDecomposition decomposition = new QRDecomposition(inputs);
+ QRDSolver solver = new QRDSolver(decomposition.getQ(), decomposition.getR());
+
+ Vector variables = solver.solve(groundTruth);
+ Vector weights = variables.viewPart(1, variables.size() - 1);
+
+ double intercept = variables.get(0);
+
+ return new LinearRegressionModel(weights, intercept);
+ }
+
+ /**
+ * Extracts first column with ground truth from the data set matrix.
+ *
+ * @param data data to build model
+ * @return Ground truth vector
+ */
+ private Vector extractGroundTruth(Matrix data) {
+ return data.getCol(0);
+ }
+
+ /**
+ * Extracts all inputs from data set matrix and updates matrix so that first column contains value 1.0.
+ *
+ * @param data data to build model
+ * @return Inputs matrix
+ */
+ private Matrix extractInputs(Matrix data) {
+ data = data.copy();
+
+ data.assignColumn(0, new FunctionVector(data.rowSize(), row -> 1.0));
+
+ return data;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
new file mode 100644
index 0000000..aad4c7a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.optimization.BarzilaiBorweinUpdater;
+import org.apache.ignite.ml.optimization.GradientDescent;
+import org.apache.ignite.ml.optimization.LeastSquaresGradientFunction;
+import org.apache.ignite.ml.optimization.SimpleUpdater;
+
+/**
+ * Linear regression trainer based on least squares loss function and gradient descent optimization algorithm.
+ */
+public class LinearRegressionSGDTrainer implements Trainer<LinearRegressionModel, Matrix> {
+ /**
+ * Gradient descent optimizer.
+ */
+ private final GradientDescent gradientDescent;
+
+ /** */
+ public LinearRegressionSGDTrainer(GradientDescent gradientDescent) {
+ this.gradientDescent = gradientDescent;
+ }
+
+ /** */
+ public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol) {
+ this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new BarzilaiBorweinUpdater())
+ .withMaxIterations(maxIterations)
+ .withConvergenceTol(convergenceTol);
+ }
+
+ /** */
+ public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol, double learningRate) {
+ this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new SimpleUpdater(learningRate))
+ .withMaxIterations(maxIterations)
+ .withConvergenceTol(convergenceTol);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override public LinearRegressionModel train(Matrix data) {
+ Vector variables = gradientDescent.optimize(data, data.likeVector(data.columnSize()));
+ Vector weights = variables.viewPart(1, variables.size() - 1);
+
+ double intercept = variables.get(0);
+
+ return new LinearRegressionModel(weights, intercept);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java
new file mode 100644
index 0000000..086a824
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains various linear regressions.
+ */
+package org.apache.ignite.ml.regressions.linear;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
index 37dec77..862a9c1 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
@@ -28,9 +28,8 @@ import org.apache.ignite.ml.knn.models.KNNModelFormat;
import org.apache.ignite.ml.knn.models.KNNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModel;
-import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModelFormat;
-import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionTrainer;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.junit.Assert;
import org.junit.Test;
@@ -63,21 +62,16 @@ public class LocalModelsTest {
/** */
@Test
- public void importExportOLSMultipleLinearRegressionModelTest() throws IOException {
+ public void importExportLinearRegressionModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
- OLSMultipleLinearRegressionModel mdl = getAbstractMultipleLinearRegressionModel();
+ LinearRegressionModel model = new LinearRegressionModel(new DenseLocalOnHeapVector(new double[]{1, 2}), 3);
+ Exporter<LinearRegressionModel, String> exporter = new FileExporter<>();
+ model.saveModel(exporter, mdlFilePath);
- Exporter<OLSMultipleLinearRegressionModelFormat, String> exporter = new FileExporter<>();
-
- mdl.saveModel(exporter, mdlFilePath);
-
- OLSMultipleLinearRegressionModelFormat load = exporter.load(mdlFilePath);
+ LinearRegressionModel load = exporter.load(mdlFilePath);
Assert.assertNotNull(load);
-
- OLSMultipleLinearRegressionModel importedMdl = load.getOLSMultipleLinearRegressionModel();
-
- Assert.assertTrue("", mdl.equals(importedMdl));
+ Assert.assertEquals("", model, load);
return null;
});
@@ -114,24 +108,6 @@ public class LocalModelsTest {
}
/** */
- private OLSMultipleLinearRegressionModel getAbstractMultipleLinearRegressionModel() {
- double[] data = new double[] {
- 0, 0, 0, 0, 0, 0, // IMPL NOTE values in this row are later replaced (with 1.0)
- 0, 2.0, 0, 0, 0, 0,
- 0, 0, 3.0, 0, 0, 0,
- 0, 0, 0, 4.0, 0, 0,
- 0, 0, 0, 0, 5.0, 0,
- 0, 0, 0, 0, 0, 6.0};
-
- final int nobs = 6, nvars = 5;
-
- OLSMultipleLinearRegressionTrainer trainer
- = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, new DenseLocalOnHeapMatrix(1, 1));
-
- return trainer.train(data);
- }
-
- /** */
@Test
public void importExportKNNModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java
new file mode 100644
index 0000000..f6f4775
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization;
+
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GradientDescent}.
+ */
+public class GradientDescentTest {
+ /** */
+ private static final double PRECISION = 1e-6;
+
+ /**
+ * Test gradient descent optimization on function y = x^2 with gradient function 2 * x.
+ */
+ @Test
+ public void testOptimize() {
+ GradientDescent gradientDescent = new GradientDescent(
+ (inputs, groundTruth, point) -> point.times(2),
+ new SimpleUpdater(0.01)
+ );
+
+ Vector res = gradientDescent.optimize(new DenseLocalOnHeapMatrix(new double[1][1]),
+ new DenseLocalOnHeapVector(new double[]{ 2.0 }));
+
+ TestUtils.assertEquals(0, res.get(0), PRECISION);
+ }
+
+ /**
+ * Test gradient descent optimization on function y = (x - 2)^2 with gradient function 2 * (x - 2).
+ */
+ @Test
+ public void testOptimizeWithOffset() {
+ GradientDescent gradientDescent = new GradientDescent(
+ (inputs, groundTruth, point) -> point.minus(new DenseLocalOnHeapVector(new double[]{ 2.0 })).times(2.0),
+ new SimpleUpdater(0.01)
+ );
+
+ Vector res = gradientDescent.optimize(new DenseLocalOnHeapMatrix(new double[1][1]),
+ new DenseLocalOnHeapVector(new double[]{ 2.0 }));
+
+ TestUtils.assertEquals(2, res.get(0), PRECISION);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java
new file mode 100644
index 0000000..9017c43
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.util;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Tests for {@link SparseDistributedMatrixMapReducer}.
+ */
+public class SparseDistributedMatrixMapReducerTest extends GridCommonAbstractTest {
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 2;
+
+ /** */
+ private Ignite ignite;
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ /* Grid instance. */
+ ignite = grid(NODE_COUNT);
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ }
+
+ /**
+ * Tests that matrix 100x100 filled by "1.0" and distributed across nodes successfully processed (calculate sum of
+ * all elements) via {@link SparseDistributedMatrixMapReducer}.
+ */
+ public void testMapReduce() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(100, 100);
+ for (int i = 0; i < 100; i++)
+ for (int j = 0; j < 100; j++)
+ distributedMatrix.set(i, j, 1);
+ SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix);
+ double total = mapReducer.mapReduce(
+ (matrix, args) -> {
+ double partialSum = 0.0;
+ for (int i = 0; i < matrix.rowSize(); i++)
+ for (int j = 0; j < matrix.columnSize(); j++)
+ partialSum += matrix.get(i, j);
+ return partialSum;
+ },
+ sums -> {
+ double totalSum = 0;
+ for (Double partialSum : sums)
+ if (partialSum != null)
+ totalSum += partialSum;
+ return totalSum;
+ }, 0.0);
+ assertEquals(100.0 * 100.0, total, 1e-18);
+ }
+
+ /**
+ * Tests that matrix 100x100 filled by "1.0" and distributed across nodes successfully processed via
+ * {@link SparseDistributedMatrixMapReducer} even when mapping function returns {@code null}.
+ */
+ public void testMapReduceWithNullValues() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(100, 100);
+ for (int i = 0; i < 100; i++)
+ for (int j = 0; j < 100; j++)
+ distributedMatrix.set(i, j, 1);
+ SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix);
+ double total = mapReducer.mapReduce(
+ (matrix, args) -> null,
+ sums -> {
+ double totalSum = 0;
+ for (Double partialSum : sums)
+ if (partialSum != null)
+ totalSum += partialSum;
+ return totalSum;
+ }, 0.0);
+ assertEquals(0, total, 1e-18);
+ }
+
+ /**
+ * Tests that matrix 1x100 filled by "1.0" and distributed across nodes successfully processed (calculate sum of
+ * all elements) via {@link SparseDistributedMatrixMapReducer} even when not all nodes contains data.
+ */
+ public void testMapReduceWithOneEmptyNode() {
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(1, 100);
+ for (int j = 0; j < 100; j++)
+ distributedMatrix.set(0, j, 1);
+ SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix);
+ double total = mapReducer.mapReduce(
+ (matrix, args) -> {
+ double partialSum = 0.0;
+ for (int i = 0; i < matrix.rowSize(); i++)
+ for (int j = 0; j < matrix.columnSize(); j++)
+ partialSum += matrix.get(i, j);
+ return partialSum;
+ },
+ sums -> {
+ double totalSum = 0;
+ for (Double partialSum : sums)
+ if (partialSum != null)
+ totalSum += partialSum;
+ return totalSum;
+ }, 0.0);
+ assertEquals(100.0, total, 1e-18);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java
deleted file mode 100644
index 6ad56a5..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions;
-
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
-import org.apache.ignite.ml.math.exceptions.NullArgumentException;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-/**
- * This class is based on the corresponding class from Apache Common Math lib.
- * Abstract base class for implementations of {@link MultipleLinearRegression}.
- */
-public abstract class AbstractMultipleLinearRegressionTest {
- /** */
- protected AbstractMultipleLinearRegression regression;
-
- /** */
- @Before
- public void setUp() {
- regression = createRegression();
- }
-
- /** */
- protected abstract AbstractMultipleLinearRegression createRegression();
-
- /** */
- protected abstract int getNumberOfRegressors();
-
- /** */
- protected abstract int getSampleSize();
-
- /** */
- @Test
- public void canEstimateRegressionParameters() {
- double[] beta = regression.estimateRegressionParameters();
- Assert.assertEquals(getNumberOfRegressors(), beta.length);
- }
-
- /** */
- @Test
- public void canEstimateResiduals() {
- double[] e = regression.estimateResiduals();
- Assert.assertEquals(getSampleSize(), e.length);
- }
-
- /** */
- @Test
- public void canEstimateRegressionParametersVariance() {
- Matrix var = regression.estimateRegressionParametersVariance();
- Assert.assertEquals(getNumberOfRegressors(), var.rowSize());
- }
-
- /** */
- @Test
- public void canEstimateRegressandVariance() {
- if (getSampleSize() > getNumberOfRegressors()) {
- double variance = regression.estimateRegressandVariance();
- Assert.assertTrue(variance > 0.0);
- }
- }
-
- /**
- * Verifies that newSampleData methods consistently insert unitary columns
- * in design matrix. Confirms the fix for MATH-411.
- */
- @Test
- public void testNewSample() {
- double[] design = new double[] {
- 1, 19, 22, 33,
- 2, 20, 30, 40,
- 3, 25, 35, 45,
- 4, 27, 37, 47
- };
-
- double[] y = new double[] {1, 2, 3, 4};
-
- double[][] x = new double[][] {
- {19, 22, 33},
- {20, 30, 40},
- {25, 35, 45},
- {27, 37, 47}
- };
-
- AbstractMultipleLinearRegression regression = createRegression();
- regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
-
- Matrix flatX = regression.getX().copy();
- Vector flatY = regression.getY().copy();
-
- regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
- regression.newYSampleData(new DenseLocalOnHeapVector(y));
-
- Assert.assertEquals(flatX, regression.getX());
- Assert.assertEquals(flatY, regression.getY());
-
- // No intercept
- regression.setNoIntercept(true);
- regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
-
- flatX = regression.getX().copy();
- flatY = regression.getY().copy();
-
- regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
- regression.newYSampleData(new DenseLocalOnHeapVector(y));
-
- Assert.assertEquals(flatX, regression.getX());
- Assert.assertEquals(flatY, regression.getY());
- }
-
- /** */
- @Test(expected = NullArgumentException.class)
- public void testNewSampleNullData() {
- double[] data = null;
- createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix());
- }
-
- /** */
- @Test(expected = MathIllegalArgumentException.class)
- public void testNewSampleInvalidData() {
- double[] data = new double[] {1, 2, 3, 4};
- createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix());
- }
-
- /** */
- @Test(expected = MathIllegalArgumentException.class)
- public void testNewSampleInsufficientData() {
- double[] data = new double[] {1, 2, 3, 4};
- createRegression().newSampleData(data, 1, 3, new DenseLocalOnHeapMatrix());
- }
-
- /** */
- @Test(expected = NullArgumentException.class)
- public void testXSampleDataNull() {
- createRegression().newXSampleData(null);
- }
-
- /** */
- @Test(expected = NullArgumentException.class)
- public void testYSampleDataNull() {
- createRegression().newYSampleData(null);
- }
-
-}