You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by jm...@apache.org on 2011/10/25 04:00:00 UTC
svn commit: r1188491 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/math/hadoop/solver/
core/src/test/java/org/apache/mahout/math/hadoop/solver/
math/src/main/java/org/apache/mahout/math/solver/
math/src/test/java/org/apache/mahout/math/solver/
Author: jmannix
Date: Tue Oct 25 01:59:58 2011
New Revision: 1188491
URL: http://svn.apache.org/viewvc?rev=1188491&view=rev
Log:
MAHOUT-672 on behalf of jtraupman
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/
mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/
mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,164 @@
+package org.apache.mahout.math.hadoop.solver;
+
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.solver.ConjugateGradientSolver;
+import org.apache.mahout.math.solver.Preconditioner;
+
+/**
+ *
+ * Distributed implementation of the conjugate gradient solver. More or less, this is just the standard solver
+ * but wrapped with some methods that make it easy to run it on a DistributedRowMatrix.
+ *
+ */
+
+public class DistributedConjugateGradientSolver extends ConjugateGradientSolver implements Tool
+{
+ private Configuration conf;
+ private Map<String, String> parsedArgs;
+
+ /**
+ *
+ * Runs the distributed conjugate gradient solver programmatically to solve the system (A + lambda*I)x = b.
+ *
+ * @param inputPath Path to the matrix A
+ * @param tempPath Path to scratch output path, deleted after the solver completes
+ * @param numRows Number of rows in A
+ * @param numCols Number of columns in A
+ * @param b Vector b
+ * @param preconditioner Optional preconditioner for the system
+ * @param maxIterations Maximum number of iterations to run, defaults to numCols
+ * @param maxError Maximum error tolerated in the result. If the norm of the residual falls below this, then the
+ * algorithm stops and returns.
+
+ * @return The vector that solves the system.
+ */
+ public Vector runJob(Path inputPath,
+ Path tempPath,
+ int numRows,
+ int numCols,
+ Vector b,
+ Preconditioner preconditioner,
+ int maxIterations,
+ double maxError) {
+ DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, tempPath, numRows, numCols);
+ matrix.setConf(conf);
+
+ return solve(matrix, b, preconditioner, maxIterations, maxError);
+ }
+
+ @Override
+ public Configuration getConf()
+ {
+ return conf;
+ }
+
+ @Override
+ public void setConf(Configuration conf)
+ {
+ this.conf = conf;
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception
+ {
+ Path inputPath = new Path(parsedArgs.get("--input"));
+ Path outputPath = new Path(parsedArgs.get("--output"));
+ Path tempPath = new Path(parsedArgs.get("--tempDir"));
+ Path vectorPath = new Path(parsedArgs.get("--vector"));
+ int numRows = Integer.parseInt(parsedArgs.get("--numRows"));
+ int numCols = Integer.parseInt(parsedArgs.get("--numCols"));
+ int maxIterations = parsedArgs.containsKey("--maxIter") ? Integer.parseInt(parsedArgs.get("--maxIter")) : numCols;
+ double maxError = parsedArgs.containsKey("--maxError")
+ ? Double.parseDouble(parsedArgs.get("--maxError"))
+ : ConjugateGradientSolver.DEFAULT_MAX_ERROR;
+
+ Vector b = loadInputVector(vectorPath);
+ Vector x = runJob(inputPath, tempPath, numRows, numCols, b, null, maxIterations, maxError);
+ saveOutputVector(outputPath, x);
+ tempPath.getFileSystem(conf).delete(tempPath, true);
+
+ return 0;
+ }
+
+ public DistributedConjugateGradientSolverJob job() {
+ return new DistributedConjugateGradientSolverJob();
+ }
+
+ private Vector loadInputVector(Path path) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+ IntWritable key = new IntWritable();
+ VectorWritable value = new VectorWritable();
+
+ try {
+ if (!reader.next(key, value)) {
+ throw new IOException("Input vector file is empty.");
+ }
+ return value.get();
+ } finally {
+ reader.close();
+ }
+ }
+
+ private void saveOutputVector(Path path, Vector v) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
+
+ try {
+ writer.append(new IntWritable(0), new VectorWritable(v));
+ } finally {
+ writer.close();
+ }
+ }
+
+ public class DistributedConjugateGradientSolverJob extends AbstractJob {
+ @Override
+ public void setConf(Configuration conf) {
+ DistributedConjugateGradientSolver.this.setConf(conf);
+ }
+
+ @Override
+ public Configuration getConf() {
+ return DistributedConjugateGradientSolver.this.getConf();
+ }
+
+ @Override
+ public int run(String[] args) throws Exception
+ {
+ addInputOption();
+ addOutputOption();
+ addOption("numRows", "nr", "Number of rows in the input matrix", true);
+ addOption("numCols", "nc", "Number of columns in the input matrix", true);
+ addOption("vector", "b", "Vector to solve against", true);
+ addOption("lambda", "l", "Scalar in A + lambda * I [default = 0]", "0.0");
+ addOption("symmetric", "sym", "Is the input matrix square and symmetric?", "true");
+ addOption("maxIter", "x", "Maximum number of iterations to run");
+ addOption("maxError", "err", "Maximum residual error to allow before stopping");
+
+ DistributedConjugateGradientSolver.this.parsedArgs = parseArguments(args);
+ if (DistributedConjugateGradientSolver.this.parsedArgs == null) {
+ return -1;
+ } else {
+ DistributedConjugateGradientSolver.this.setConf(new Configuration());
+ return DistributedConjugateGradientSolver.this.run(args);
+ }
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new DistributedConjugateGradientSolver().job(), args);
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,43 @@
+package org.apache.mahout.math.hadoop.solver;
+
+import java.io.File;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.TestDistributedRowMatrix;
+import org.junit.Test;
+
+
+public class TestDistributedConjugateGradientSolver extends MahoutTestCase
+{
+ private Vector randomVector(int size, double entryMean) {
+ DenseVector v = new DenseVector(size);
+ Random r = new Random(1234L);
+
+ for (int i = 0; i < size; ++i) {
+ v.setQuick(i, r.nextGaussian() * entryMean);
+ }
+
+ return v;
+ }
+
+ @Test
+ public void testSolver() throws Exception {
+ File testData = getTestTempDir("testdata");
+ DistributedRowMatrix matrix = new TestDistributedRowMatrix().randomDistributedMatrix(
+ 10, 10, 10, 10, 10.0, true, testData.getAbsolutePath());
+ matrix.setConf(new Configuration());
+ Vector vector = randomVector(matrix.numCols(), 10.0);
+
+ DistributedConjugateGradientSolver solver = new DistributedConjugateGradientSolver();
+ Vector x = solver.solve(matrix, vector);
+
+ Vector solvedVector = matrix.times(x);
+ double distance = Math.sqrt(vector.getDistanceSquared(solvedVector));
+ assertEquals(0.0, distance, EPSILON);
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,94 @@
+package org.apache.mahout.math.hadoop.solver;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.TestDistributedRowMatrix;
+import org.junit.Test;
+
+public class TestDistributedConjugateGradientSolverCLI extends MahoutTestCase
+{
+ private Vector randomVector(int size, double entryMean) {
+ DenseVector v = new DenseVector(size);
+ Random r = new Random(1234L);
+
+ for (int i = 0; i < size; ++i) {
+ v.setQuick(i, r.nextGaussian() * entryMean);
+ }
+
+ return v;
+ }
+
+ private Path saveVector(Configuration conf, Path path, Vector v) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
+
+ try {
+ writer.append(new IntWritable(0), new VectorWritable(v));
+ } finally {
+ writer.close();
+ }
+ return path;
+ }
+
+ private Vector loadVector(Configuration conf, Path path) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+ IntWritable key = new IntWritable();
+ VectorWritable value = new VectorWritable();
+
+ try {
+ if (!reader.next(key, value)) {
+ throw new IOException("Input vector file is empty.");
+ }
+ return value.get();
+ } finally {
+ reader.close();
+ }
+ }
+
+ @Test
+ public void testSolver() throws Exception {
+ Configuration conf = new Configuration();
+ Path testData = getTestTempDirPath("testdata");
+ DistributedRowMatrix matrix = new TestDistributedRowMatrix().randomDistributedMatrix(
+ 10, 10, 10, 10, 10.0, true, testData.toString());
+ matrix.setConf(conf);
+ Path output = getTestTempFilePath("output");
+ Path vectorPath = getTestTempFilePath("vector");
+ Path tempPath = getTestTempDirPath("tmp");
+
+ Vector vector = randomVector(matrix.numCols(), 10.0);
+ saveVector(conf, vectorPath, vector);
+
+ String[] args = {
+ "-i", matrix.getRowPath().toString(),
+ "-o", output.toString(),
+ "--tempDir", tempPath.toString(),
+ "--vector", vectorPath.toString(),
+ "--numRows", "10",
+ "--numCols", "10",
+ "--symmetric", "true"
+ };
+
+ DistributedConjugateGradientSolver solver = new DistributedConjugateGradientSolver();
+ solver.job().run(args);
+
+ Vector x = loadVector(conf, output);
+
+ Vector solvedVector = matrix.times(x);
+ double distance = Math.sqrt(vector.getDistanceSquared(solvedVector));
+ assertEquals(0.0, distance, EPSILON);
+ }
+}
Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,199 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.PlusMult;
+import org.apache.mahout.math.function.TimesFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>Implementation of a conjugate gradient iterative solver for linear systems. Implements both
+ * standard conjugate gradient and pre-conditioned conjugate gradient.
+ *
+ * <p>Conjugate gradient requires the matrix A in the linear system Ax = b to be symmetric and positive
+ * definite. For convenience, this implementation allows the input matrix to be be non-symmetric, in
+ * which case the system A'Ax = b is solved. Because this requires only one pass through the matrix A, it
+ * is faster than explictly computing A'A, then passing the results to the solver.
+ *
+ * <p>For inputs that may be ill conditioned (often the case for highly sparse input), this solver
+ * also accepts a parameter, lambda, which adds a scaled identity to the matrix A, solving the system
+ * (A + lambda*I)x = b. This obviously changes the solution, but it will guarantee solvability. The
+ * ridge regression approach to linear regression is a common use of this feature.
+ *
+ * <p>If only an approximate solution is required, the maximum number of iterations or the error threshold
+ * may be specified to end the algorithm early at the expense of accuracy. When the matrix A is ill conditioned,
+ * it may sometimes be necessary to increase the maximum number of iterations above the default of A.numCols()
+ * due to numerical issues.
+ *
+ * <p>By default the solver will run a.numCols() iterations or until the residual falls below 1E-9.
+ *
+ * <p>For more information on the conjugate gradient algorithm, see Golub & van Loan, "Matrix Computations",
+ * sections 10.2 and 10.3 or the <a href="http://en.wikipedia.org/wiki/Conjugate_gradient">conjugate gradient
+ * wikipedia article</a>.
+ */
+
+public class ConjugateGradientSolver
+{
+ public static final double DEFAULT_MAX_ERROR = 1e-9;
+
+ private static final Logger log = LoggerFactory.getLogger(ConjugateGradientSolver.class);
+ private static final PlusMult plusMult = new PlusMult(1.0);
+
+ private int iterations;
+ private double residualNormSquared;
+
+ public ConjugateGradientSolver() {
+ this.iterations = 0;
+ this.residualNormSquared = Double.NaN;
+ }
+
+ /**
+ * Solves the system Ax = b with default termination criteria. A must be symmetric, square, and positive definite.
+ * Only the squareness of a is checked, since testing for symmetry and positive definiteness are too expensive. If
+ * an invalid matrix is specified, then the algorithm may not yield a valid result.
+ *
+ * @param a The linear operator A.
+ * @param b The vector b.
+ * @return The result x of solving the system.
+ * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
+ *
+ */
+ public Vector solve(VectorIterable a, Vector b) {
+ return solve(a, b, null, b.size(), DEFAULT_MAX_ERROR);
+ }
+
+ /**
+ * Solves the system Ax = b with default termination criteria using the specified preconditioner. A must be
+ * symmetric, square, and positive definite. Only the squareness of a is checked, since testing for symmetry
+ * and positive definiteness are too expensive. If an invalid matrix is specified, then the algorithm may not
+ * yield a valid result.
+ *
+ * @param a The linear operator A.
+ * @param b The vector b.
+ * @param precond A preconditioner to use on A during the solution process.
+ * @return The result x of solving the system.
+ * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
+ *
+ */
+ public Vector solve(VectorIterable a, Vector b, Preconditioner precond) {
+ return solve(a, b, precond, b.size(), DEFAULT_MAX_ERROR);
+ }
+
+
+ /**
+ * Solves the system Ax = b, where A is a linear operator and b is a vector. Uses the specified preconditioner
+ * to improve numeric stability and possibly speed convergence. This version of solve() allows control over the
+ * termination and iteration parameters.
+ *
+ * @param a The matrix A.
+ * @param b The vector b.
+ * @param preconditioner The preconditioner to apply.
+ * @param maxIterations The maximum number of iterations to run.
+ * @param maxError The maximum amount of residual error to tolerate. The algorithm will run until the residual falls
+ * below this value or until maxIterations are completed.
+ * @return The result x of solving the system.
+ * @throws IllegalArgumentException if the matrix is not square, if the size of b is not equal to the number of
+ * columns of A, if maxError is less than zero, or if maxIterations is not positive.
+ */
+
+ public Vector solve(VectorIterable a,
+ Vector b,
+ Preconditioner preconditioner,
+ int maxIterations,
+ double maxError) {
+
+ if (a.numRows() != a.numCols()) {
+ throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite.");
+ }
+
+ if (a.numCols() != b.size()) {
+ throw new CardinalityException(a.numCols(), b.size());
+ }
+
+ if (maxIterations <= 0) {
+ throw new IllegalArgumentException("Max iterations must be positive.");
+ }
+
+ if (maxError < 0.0) {
+ throw new IllegalArgumentException("Max error must be non-negative.");
+ }
+
+ Vector x = new DenseVector(b.size());
+
+ iterations = 0;
+ Vector residual = b.minus(a.times(x));
+ residualNormSquared = residual.dot(residual);
+
+ double conditionedNormSqr;
+ double previousConditionedNormSqr = 0.0;
+
+ Vector updateDirection = null;
+
+ log.info("Conjugate gradient initial residual norm = " + Math.sqrt(residualNormSquared));
+ while (Math.sqrt(residualNormSquared) > maxError && iterations < maxIterations) {
+ Vector conditionedResidual;
+ if (preconditioner == null) {
+ conditionedResidual = residual;
+ conditionedNormSqr = residualNormSquared;
+ } else {
+ conditionedResidual = preconditioner.precondition(residual);
+ conditionedNormSqr = residual.dot(conditionedResidual);
+ }
+
+ ++iterations;
+
+ if (iterations == 1) {
+ updateDirection = new DenseVector(conditionedResidual);
+ } else {
+ double beta = conditionedNormSqr / previousConditionedNormSqr;
+
+ // updateDirection = residual + beta * updateDirection
+ updateDirection.assign(Functions.MULT, beta);
+ updateDirection.assign(conditionedResidual, Functions.PLUS);
+ }
+
+ Vector aTimesUpdate = a.times(updateDirection);
+
+ double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate);
+
+ // x = x + alpha * updateDirection
+ plusMult.setMultiplicator(alpha);
+ x.assign(updateDirection, plusMult);
+
+ // residual = residual - alpha * A * updateDirection
+ plusMult.setMultiplicator(-alpha);
+ residual.assign(aTimesUpdate, plusMult);
+
+ previousConditionedNormSqr = conditionedNormSqr;
+ residualNormSquared = residual.dot(residual);
+
+ log.info("Conjugate gradient iteration " + iterations + " residual norm = " + Math.sqrt(residualNormSquared));
+ }
+ return x;
+ }
+
+ /**
+ * Returns the number of iterations run once the solver is complete.
+ *
+ * @return The number of iterations run.
+ */
+ public int getIterations() {
+ return iterations;
+ }
+
+ /**
+ * Returns the norm of the residual at the completion of the solver. Usually this should be close to zero except in
+ * the case of a non positive definite matrix A, which results in an unsolvable system, or for ill conditioned A, in
+ * which case more iterations than the default may be needed.
+ *
+ * @return The norm of the residual in the solution.
+ */
+ public double getResidualNorm() {
+ return Math.sqrt(residualNormSquared);
+ }
+}
Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,33 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ *
+ * Implements the Jacobi preconditioner for a matrix A. This is defined as inv(diag(A)).
+ *
+ */
+public class JacobiConditioner implements Preconditioner
+{
+ private DenseVector inverseDiagonal;
+
+ public JacobiConditioner(Matrix a) {
+ if (a.numCols() != a.numRows()) {
+ throw new IllegalArgumentException("Matrix must be square.");
+ }
+
+ inverseDiagonal = new DenseVector(a.numCols());
+ for (int i = 0; i < a.numCols(); ++i) {
+ inverseDiagonal.setQuick(i, 1.0 / a.getQuick(i, i));
+ }
+ }
+
+ @Override
+ public Vector precondition(Vector v)
+ {
+ return v.times(inverseDiagonal);
+ }
+
+}
Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,582 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Solves sparse least-squares using the LSMR algorithm.
+ * <p/>
+ * LSMR solves the system of linear equations A * X = B. If the system is inconsistent, it solves
+ * the least-squares problem min ||b - Ax||_2. A is a rectangular matrix of dimension m-by-n, where
+ * all cases are allowed: m=n, m>n, or m<n. B is a vector of length m. The matrix A may be dense
+ * or sparse (usually sparse).
+ * <p/>
+ * Some additional configurable properties adjust the behavior of the algorithm.
+ * <p/>
+ * If you set lambda to a non-zero value then LSMR solves the regularized least-squares problem min
+ * ||(B) - ( A )X|| ||(0) (lambda*I) ||_2 where LAMBDA is a scalar. If LAMBDA is not set,
+ * the system is solved without regularization.
+ * <p/>
+ * You can also set aTolerance and bTolerance. These cause LSMR to iterate until a certain backward
+ * error estimate is smaller than some quantity depending on ATOL and BTOL. Let RES = B - A*X be
+ * the residual vector for the current approximate solution X. If A*X = B seems to be consistent,
+ * LSMR terminates when NORM(RES) <= ATOL*NORM(A)*NORM(X) + BTOL*NORM(B). Otherwise, LSMR terminates
+ * when NORM(A'*RES) <= ATOL*NORM(A)*NORM(RES). If both tolerances are 1.0e-6 (say), the final
+ * NORM(RES) should be accurate to about 6 digits. (The final X will usually have fewer correct
+ * digits, depending on cond(A) and the size of LAMBDA.)
+ * <p/>
+ * The default value for ATOL and BTOL is 1e-6.
+ * <p/>
+ * Ideally, they should be estimates of the relative error in the entries of A and B respectively.
+ * For example, if the entries of A have 7 correct digits, set ATOL = 1e-7. This prevents the
+ * algorithm from doing unnecessary work beyond the uncertainty of the input data.
+ * <p/>
+ * You can also set conditionLimit. In that case, LSMR terminates if an estimate of cond(A) exceeds
+ * conditionLimit. For compatible systems Ax = b, conditionLimit could be as large as 1.0e+12 (say).
+ * For least-squares problems, conditionLimit should be less than 1.0e+8. If conditionLimit is not
+ * set, the default value is 1e+8. Maximum precision can be obtained by setting aTolerance =
+ * bTolerance = conditionLimit = 0, but the number of iterations may then be excessive.
+ * <p/>
+ * Setting iterationLimit causes LSMR to terminate if the number of iterations reaches
+ * iterationLimit. The default is iterationLimit = min(m,n). For ill-conditioned systems, a
+ * larger value of ITNLIM may be needed.
+ * <p/>
+ * Setting localSize causes LSMR to run with rerorthogonalization on the last localSize v_k's.
+ * (v-vectors generated by Golub-Kahan bidiagonalization) If localSize is not set, LSMR runs without
+ * reorthogonalization. A localSize > max(n,m) performs reorthogonalization on all v_k's.
+ * Reorthgonalizing only u_k or both u_k and v_k are not an option here. Details are discussed in
+ * the SIAM paper.
+ * <p/>
+ * getTerminationReason() gives the reason for termination. ISTOP = 0 means X=0 is a solution. = 1
+ * means X is an approximate solution to A*X = B, according to ATOL and BTOL. = 2 means X
+ * approximately solves the least-squares problem according to ATOL. = 3 means COND(A) seems to be
+ * greater than CONLIM. = 4 is the same as 1 with ATOL = BTOL = EPS. = 5 is the same as 2 with ATOL
+ * = EPS. = 6 is the same as 3 with CONLIM = 1/EPS. = 7 means ITN reached ITNLIM before the other
+ * stopping conditions were satisfied.
+ * <p/>
+ * getIterationCount() gives ITN = the number of LSMR iterations.
+ * <p/>
+ * getResidualNorm() gives an estimate of the residual norm: NORMR = norm(B-A*X).
+ * <p/>
+ * getNormalEquationResidual() gives an estimate of the residual for the normal equation: NORMAR =
+ * NORM(A'*(B-A*X)).
+ * <p/>
+ * getANorm() gives an estimate of the Frobenius norm of A.
+ * <p/>
+ * getCondition() gives an estimate of the condition number of A.
+ * <p/>
+ * getXNorm() gives an estimate of NORM(X).
+ * <p/>
+ * LSMR uses an iterative method. For further information, see D. C.-L. Fong and M. A. Saunders
+ * LSMR: An iterative algorithm for least-square problems Draft of 03 Apr 2010, to be submitted to
+ * SISC.
+ * <p/>
+ * David Chin-lung Fong clfong@stanford.edu Institute for Computational and Mathematical
+ * Engineering Stanford University
+ * <p/>
+ * Michael Saunders saunders@stanford.edu Systems Optimization Laboratory Dept of
+ * MS&E, Stanford University. -----------------------------------------------------------------------
+ */
+public class LSMR {
+ private Logger log = LoggerFactory.getLogger(LSMR.class);
+ private double lambda;
+ private int localSize;
+ private int iterationLimit;
+ private double conditionLimit;
+ private double bTolerance;
+ private double aTolerance;
+ private int localPointer;
+ private Vector v;
+ private boolean localVQueueFull;
+ private Vector[] localV;
+ private double residualNorm;
+ private double normalEquationResidual;
+ private double aNorm;
+ private double xNorm;
+ private int iteration;
+ private double normA;
+ private double condA;
+
+ public int getIterationCount() {
+ return iteration;
+ }
+
+ public double getResidualNorm() {
+ return residualNorm;
+ }
+
+ public double getNormalEquationResidual() {
+ return normalEquationResidual;
+ }
+
+ public double getANorm() {
+ return normA;
+ }
+
+ public double getCondition() {
+ return condA;
+ }
+
+ public double getXNorm() {
+ return xNorm;
+ }
+
+ /**
+ * LSMR uses an iterative method to solve a linear system. For further information, see D. C.-L.
+ * Fong and M. A. Saunders LSMR: An iterative algorithm for least-square problems Draft of 03 Apr
+ * 2010, to be submitted to SISC.
+ * <p/>
+ * 08 Dec 2009: First release version of LSMR. 09 Apr 2010: Updated documentation and default
+ * parameters. 14 Apr 2010: Updated documentation. 03 Jun 2010: LSMR with local
+ * reorthogonalization (full reorthogonalization is also implemented)
+ * <p/>
+ * David Chin-lung Fong clfong@stanford.edu Institute for Computational and
+ * Mathematical Engineering Stanford University
+ * <p/>
+ * Michael Saunders saunders@stanford.edu Systems Optimization Laboratory Dept of
+ * MS&E, Stanford University. -----------------------------------------------------------------------
+ */
+
+ public LSMR() {
+ // Set default parameters.
+ setLambda(0);
+ setAtolerance(1e-6);
+ setBtolerance(1e-6);
+ setConditionLimit(1e8);
+ setIterationLimit(-1);
+ setLocalSize(0);
+ }
+
+ public Vector solve(Matrix A, Vector b) {
+ /*
+ % Initialize.
+
+
+ hdg1 = ' itn x(1) norm r norm A''r';
+ hdg2 = ' compatible LS norm A cond A';
+ pfreq = 20; % print frequency (for repeating the heading)
+ pcount = 0; % print counter
+
+ % Determine dimensions m and n, and
+ % form the first vectors u and v.
+ % These satisfy beta*u = b, alpha*v = A'u.
+ */
+ log.debug(" itn x(1) norm r norm A'r");
+ log.debug(" compatible LS norm A cond A");
+
+ Matrix transposedA = A.transpose();
+ Vector u = b;
+
+ double beta = u.norm(2);
+ if (beta > 0) {
+ u = u.divide(beta);
+ }
+
+ v = transposedA.times(u);
+ int m = A.numRows();
+ int n = A.numCols();
+
+ int minDim = Math.min(m, n);
+ if (iterationLimit == -1) {
+ setIterationLimit(minDim);
+ }
+
+ if (log.isDebugEnabled()) {
+ log.debug("LSMR - Least-squares solution of Ax = b, based on Matlab Version 1.02, 14 Apr 2010, Mahout version {}",
+ this.getClass().getPackage().getImplementationVersion());
+ log.debug(String.format("The matrix A has %d rows and %d cols, lambda = %.4g, atol = %g, btol = %g",
+ m, n, getLambda(), getAtolerance(), getBtolerance()));
+ }
+
+ double alpha = v.norm(2);
+ if (alpha > 0) {
+ v.assign(Functions.div(alpha));
+ }
+
+
+ // Initialization for local reorthogonalization
+ boolean localOrtho = false;
+ localPointer = 0;
+ localVQueueFull = false;
+
+ // Preallocate storage for storing the last few v_k. Since with
+ // orthogonal v_k's, Krylov subspace method would converge in not
+ // more iterations than the number of singular values, more
+ // space is not necessary.
+ localV = new Vector[Math.min(localSize, minDim)];
+ if (localSize > 0) {
+ localOrtho = true;
+ localV[0] = v;
+ }
+
+
+ // Initialize variables for 1st iteration.
+
+ iteration = 0;
+ double zetabar = alpha * beta;
+ double alphabar = alpha;
+ double rho = 1;
+ double rhobar = 1;
+ double cbar = 1;
+ double sbar = 0;
+
+ Vector h = v;
+ Vector hbar = zeros(n);
+ Vector x = zeros(n);
+
+ // Initialize variables for estimation of ||r||.
+
+ double betadd = beta;
+ double betad = 0;
+ double rhodold = 1;
+ double tautildeold = 0;
+ double thetatilde = 0;
+ double zeta = 0;
+ double d = 0;
+
+ // Initialize variables for estimation of ||A|| and cond(A)
+
+ aNorm = alpha * alpha;
+ double maxrbar = 0;
+ double minrbar = 1e+100;
+
+ // Items for use in stopping rules.
+ double normb = beta;
+
+ int istop = 0;
+ StopCode stop = StopCode.CONTINUE;
+
+ double ctol = 0;
+ if (conditionLimit > 0) {
+ ctol = 1 / conditionLimit;
+ }
+ residualNorm = beta;
+
+ // Exit if b=0 or A'b = 0.
+
+ normalEquationResidual = alpha * beta;
+ if (normalEquationResidual == 0) {
+ return x;
+ }
+
+ // Heading for iteration log.
+
+
+ if (log.isDebugEnabled()) {
+ double test1 = 1;
+ double test2 = alpha / beta;
+// log.debug('{} {}', hdg1, hdg2);
+ log.debug("{} {}", iteration, x.get(0));
+ log.debug("{} {}", residualNorm, normalEquationResidual);
+ log.debug("{} {}", test1, test2);
+ }
+
+
+ //------------------------------------------------------------------
+ // Main iteration loop.
+ //------------------------------------------------------------------
+ while (iteration <= iterationLimit && stop == StopCode.CONTINUE) {
+
+ iteration = iteration + 1;
+
+ // Perform the next step of the bidiagonalization to obtain the
+ // next beta, u, alpha, v. These satisfy the relations
+ // beta*u = A*v - alpha*u,
+ // alpha*v = A'*u - beta*v.
+
+ u = A.times(v).minus(u.times(alpha));
+ beta = u.norm(2);
+ if (beta > 0) {
+ u.assign(Functions.div(beta));
+
+ // store data for local-reorthogonalization of V
+ if (localOrtho) {
+ localVEnqueue(v);
+ }
+ v = transposedA.times(u).minus(v.times(beta));
+ // local-reorthogonalization of V
+ if (localOrtho) {
+ v = localVOrtho(v);
+ }
+ alpha = v.norm(2);
+ if (alpha > 0) {
+ v.assign(Functions.div(alpha));
+ }
+ }
+
+ // At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
+
+ // Construct rotation Qhat_{k,2k+1}.
+
+ double alphahat = Math.hypot(alphabar, lambda);
+ double chat = alphabar / alphahat;
+ double shat = lambda / alphahat;
+
+ // Use a plane rotation (Q_i) to turn B_i to R_i
+
+ double rhoold = rho;
+ rho = Math.hypot(alphahat, beta);
+ double c = alphahat / rho;
+ double s = beta / rho;
+ double thetanew = s * alpha;
+ alphabar = c * alpha;
+
+ // Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
+
+ double rhobarold = rhobar;
+ double zetaold = zeta;
+ double thetabar = sbar * rho;
+ double rhotemp = cbar * rho;
+ rhobar = Math.hypot(cbar * rho, thetanew);
+ cbar = cbar * rho / rhobar;
+ sbar = thetanew / rhobar;
+ zeta = cbar * zetabar;
+ zetabar = -sbar * zetabar;
+
+
+ // Update h, h_hat, x.
+
+ hbar = h.minus(hbar.times(thetabar * rho / (rhoold * rhobarold)));
+
+ x.assign(hbar.times(zeta / (rho * rhobar)), Functions.PLUS);
+ h = v.minus(h.times(thetanew / rho));
+
+ // Estimate of ||r||.
+
+ // Apply rotation Qhat_{k,2k+1}.
+ double betaacute = chat * betadd;
+ double betacheck = -shat * betadd;
+
+ // Apply rotation Q_{k,k+1}.
+ double betahat = c * betaacute;
+ betadd = -s * betaacute;
+
+ // Apply rotation Qtilde_{k-1}.
+ // betad = betad_{k-1} here.
+
+ double thetatildeold = thetatilde;
+ double rhotildeold = Math.hypot(rhodold, thetabar);
+ double ctildeold = rhodold / rhotildeold;
+ double stildeold = thetabar / rhotildeold;
+ thetatilde = stildeold * rhobar;
+ rhodold = ctildeold * rhobar;
+ betad = -stildeold * betad + ctildeold * betahat;
+
+ // betad = betad_k here.
+ // rhodold = rhod_k here.
+
+ tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold;
+ double taud = (zeta - thetatilde * tautildeold) / rhodold;
+ d = d + betacheck * betacheck;
+ residualNorm = Math.sqrt(d + (betad - taud) * (betad - taud) + betadd * betadd);
+
+ // Estimate ||A||.
+ aNorm = aNorm + beta * beta;
+ normA = Math.sqrt(aNorm);
+ aNorm = aNorm + alpha * alpha;
+
+ // Estimate cond(A).
+ maxrbar = Math.max(maxrbar, rhobarold);
+ if (iteration > 1) {
+ minrbar = Math.min(minrbar, rhobarold);
+ }
+ condA = Math.max(maxrbar, rhotemp) / Math.min(minrbar, rhotemp);
+
+ // Test for convergence.
+
+ // Compute norms for convergence testing.
+ normalEquationResidual = Math.abs(zetabar);
+ xNorm = x.norm(2);
+
+ // Now use these norms to estimate certain other quantities,
+ // some of which will be small near a solution.
+
+ double test1 = residualNorm / normb;
+ double test2 = normalEquationResidual / (normA * residualNorm);
+ double test3 = 1 / condA;
+ double t1 = test1 / (1 + normA * xNorm / normb);
+ double rtol = bTolerance + aTolerance * normA * xNorm / normb;
+
+ // The following tests guard against extremely small values of
+ // atol, btol or ctol. (The user may have set any or all of
+ // the parameters atol, btol, conlim to 0.)
+ // The effect is equivalent to the normAl tests using
+ // atol = eps, btol = eps, conlim = 1/eps.
+
+ if (iteration > iterationLimit) {
+ istop = 7;
+ stop = StopCode.ITERATION_LIMIT;
+ }
+ if (1 + test3 <= 1) {
+ istop = 6;
+ stop = StopCode.CONDITION_MACHINE_TOLERANCE;
+ }
+ if (1 + test2 <= 1) {
+ istop = 5;
+ stop = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE;
+ }
+ if (1 + t1 <= 1) {
+ istop = 4;
+ stop = StopCode.CONVERGED_MACHINE_TOLERANCE;
+ }
+
+ // Allow for tolerances set by the user.
+
+ if (test3 <= ctol) {
+ istop = 3;
+ stop = StopCode.CONDITION;
+ }
+ if (test2 <= aTolerance) {
+ istop = 2;
+ stop = StopCode.CONVERGED;
+ }
+ if (test1 <= rtol) {
+ istop = 1;
+ stop = StopCode.TRIVIAL;
+ }
+
+ if (stop != StopCode.CONTINUE && stop.ordinal() != istop + 1) {
+ throw new IllegalStateException(String.format("bad code match %d vs %d", istop, stop.ordinal()));
+ }
+
+ // See if it is time to print something.
+
+ if (log.isDebugEnabled()) {
+ if ((n <= 40) || (iteration <= 10) || (iteration >= iterationLimit - 10) || ((iteration % 10) == 0) || (test3 <= 1.1 * ctol) || (test2 <= 1.1 * aTolerance) || (test1 <= 1.1 * rtol) || (istop != 0)) {
+ statusDump(x, normA, condA, test1, test2);
+ }
+ }
+ } // iteration loop
+
+ // Print the stopping condition.
+ log.debug("Finished: {}", stop.getMessage());
+
+ return x;
+ /*
+
+
+ if show
+ fprintf('\n\nLSMR finished')
+ fprintf('\n%s', msg(istop+1,:))
+ fprintf('\nistop =%8g normr =%8.1e' , istop, normr )
+ fprintf(' normA =%8.1e normAr =%8.1e', normA, normAr)
+ fprintf('\nitn =%8g condA =%8.1e' , itn , condA )
+ fprintf(' normx =%8.1e\n', normx)
+ end
+ */
+ }
+
+ private void statusDump(Vector x, double normA, double condA, double test1, double test2) {
+ log.debug("{} {}", residualNorm, normalEquationResidual);
+ log.debug("{} {}", iteration, x.get(0));
+ log.debug("{} {}", test1, test2);
+ log.debug("{} {}", normA, condA);
+ }
+
+ private Vector zeros(int n) {
+ return new DenseVector(n);
+ }
+
+ //-----------------------------------------------------------------------
+ // stores v into the circular buffer localV
+ //-----------------------------------------------------------------------
+
+ private void localVEnqueue(Vector v) {
+ if (localV.length > 0) {
+ localV[localPointer] = v;
+ localPointer = (localPointer + 1) % localV.length;
+ }
+ }
+
+ //-----------------------------------------------------------------------
+ // Perform local reorthogonalization of V
+ //-----------------------------------------------------------------------
+
+ private Vector localVOrtho(Vector v) {
+ for (Vector old : localV) {
+ if (old != null) {
+ double x = v.dot(old);
+ v = v.minus(old.times(x));
+ }
+ }
+ return v;
+ }
+
+ private enum StopCode {
+ CONTINUE("Not done"),
+ TRIVIAL("The exact solution is x = 0"),
+ CONVERGED("Ax - b is small enough, given atol, btol"), LEAST_SQUARE_CONVERGED("The least-squares solution is good enough, given atol"),
+ CONDITION("The estimate of cond(Abar) has exceeded condition limit"),
+ CONVERGED_MACHINE_TOLERANCE("Ax - b is small enough for this machine"),
+ LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE("The least-squares solution is good enough for this machine"),
+ CONDITION_MACHINE_TOLERANCE("Cond(Abar) seems to be too large for this machine"),
+ ITERATION_LIMIT("The iteration limit has been reached");
+
+ private String message;
+
+ private StopCode(String message) {
+ this.message = message;
+ }
+
+ public String getMessage() {
+ return message;
+ }
+ }
+
+ public void setAtolerance(double aTolerance) {
+ this.aTolerance = aTolerance;
+ }
+
+ public void setBtolerance(double bTolerance) {
+ this.bTolerance = bTolerance;
+ }
+
+ public void setConditionLimit(double conditionLimit) {
+ this.conditionLimit = conditionLimit;
+ }
+
+ public void setIterationLimit(int iterationLimit) {
+ this.iterationLimit = iterationLimit;
+ }
+
+ public void setLocalSize(int localSize) {
+ this.localSize = localSize;
+ }
+
+ private void setLambda(double lambda) {
+ this.lambda = lambda;
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public double getAtolerance() {
+ return aTolerance;
+ }
+
+ public double getBtolerance() {
+ return bTolerance;
+ }
+}
Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,20 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ *
+ * <p>Interface for defining preconditioners used for improving the performance and/or stability of linear
+ * system solvers.
+ *
+ */
+public interface Preconditioner
+{
+ /**
+ * Preconditions the specified vector.
+ *
+ * @param v The vector to precondition.
+ * @return The preconditioned vector.
+ */
+ public Vector precondition(Vector v);
+}
Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.solver.LSMR;
+import org.junit.Test;
+
+import java.util.Random;
+
+/**
+ * Created by IntelliJ IDEA. User: tdunning Date: Sep 15, 2010 Time: 7:32:27 PM To change this
+ * template use File | Settings | File Templates.
+ */
+public class LSMRTest extends MahoutTestCase {
+ @Test
+ public void basics() {
+ Matrix m = hilbert(5);
+
+ // make sure it is the hilbert matrix we know and love
+ assertEquals(1, m.get(0, 0), 0);
+ assertEquals(0.5, m.get(0, 1), 0);
+ assertEquals(1 / 6.0, m.get(2, 3), 1e-9);
+
+ Vector x = new DenseVector(new double[]{5, -120, 630, -1120, 630});
+
+ Vector b = new DenseVector(5);
+ b.assign(1);
+
+ assertEquals(0, m.times(x).minus(b).norm(2), 1e-9);
+
+ LSMR r = new LSMR();
+ Vector x1 = r.solve(m, b);
+
+ // the ideal solution is [5 -120 630 -1120 630] but the 5x5 hilbert matrix
+ // has a condition number of almost 500,000 and the normal equation condition
+ // number is that squared. This means that we don't get the exact answer with
+ // a fast iterative solution.
+ // Thus, we have to check the residuals rather than testing that the answer matched
+ // the ideal.
+ assertEquals(m.times(x1).minus(b).norm(2), 0, 1e-2);
+ assertEquals(0, m.transpose().times(m).times(x1).minus(m.transpose().times(b)).norm(2), 1e-7);
+
+ // and we need to check that the error estimates are pretty good.
+ assertEquals(m.times(x1).minus(b).norm(2), r.getResidualNorm(), 1e-5);
+ assertEquals(m.transpose().times(m).times(x1).minus(m.transpose().times(b)).norm(2), r.getNormalEquationResidual(), 1e-9);
+ }
+
+ private Matrix hilbert(int n) {
+ Matrix r = new DenseMatrix(n, n);
+ for (int i = 0; i < n; i++) {
+ for (int j = 0; j < n; j++) {
+ r.set(i, j, 1.0 / (i + j + 1));
+ }
+ }
+ return r;
+ }
+
+ private Matrix overDetermined(int n) {
+ Random rand = RandomUtils.getRandom();
+ Matrix r = new DenseMatrix(2 * n, n);
+ for (int i = 0; i < 2 * n; i++) {
+ for (int j = 0; j < n; j++) {
+ r.set(i, j, rand.nextGaussian());
+ }
+ }
+ return r;
+ }
+}
Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,212 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class TestConjugateGradientSolver extends MahoutTestCase
+{
+ @Test
+ public void testConjugateGradientSolver() {
+ Matrix a = getA();
+ Vector b = getB();
+
+ ConjugateGradientSolver solver = new ConjugateGradientSolver();
+ Vector x = solver.solve(a, b);
+
+ assertEquals(0.0, Math.sqrt(a.times(x).getDistanceSquared(b)), EPSILON);
+ assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+ assertEquals(10, solver.getIterations());
+ }
+
+ @Test
+ public void testConditionedConjugateGradientSolver() {
+ Matrix a = getIllConditionedMatrix();
+ Vector b = getB();
+ Preconditioner conditioner = new JacobiConditioner(a);
+ ConjugateGradientSolver solver = new ConjugateGradientSolver();
+
+ Vector x = solver.solve(a, b, null, 100, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+
+ double distance = Math.sqrt(a.times(x).getDistanceSquared(b));
+ assertEquals(0.0, distance, EPSILON);
+ assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+ assertEquals(16, solver.getIterations());
+
+ Vector x2 = solver.solve(a, b, conditioner, 100, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+
+ // the Jacobi preconditioner isn't very good, but it does result in one less iteration to converge
+ distance = Math.sqrt(a.times(x2).getDistanceSquared(b));
+ assertEquals(0.0, distance, EPSILON);
+ assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+ assertEquals(15, solver.getIterations());
+ }
+
+ @Test
+ public void testEarlyStop() {
+ Matrix a = getA();
+ Vector b = getB();
+ ConjugateGradientSolver solver = new ConjugateGradientSolver();
+
+ // specifying a looser max error will result in few iterations but less accurate results
+ Vector x = solver.solve(a, b, null, 10, 0.1);
+ double distance = Math.sqrt(a.times(x).getDistanceSquared(b));
+ assertTrue(distance > EPSILON);
+ assertEquals(0.0, distance, 0.1); // should be equal to within the error specified
+ assertEquals(7, solver.getIterations()); // should have taken fewer iterations
+
+ // can get a similar effect by bounding the number of iterations
+ x = solver.solve(a, b, null, 7, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+ distance = Math.sqrt(a.times(x).getDistanceSquared(b));
+ assertTrue(distance > EPSILON);
+ assertEquals(0.0, distance, 0.1);
+ assertEquals(7, solver.getIterations());
+ }
+
+ private static Matrix getA() {
+ return reshape(new double[] {
+ 11.7155649822793997, -0.7125253363083646, 4.6473613961860183, 1.6020939468348456, -4.6789817799137134,
+ -0.8140416763434970, -4.5995617505618345, -1.1749070042775340, -1.6747995811678336, 3.1922255171058342,
+ -0.7125253363083646, 12.3400579683994867, -2.6498099427000645, 0.5264507222630669, 0.3783428369189767,
+ -2.1170186159188811, 2.3695134252190528, 3.8182131490333013, 6.5285942298270347, 2.8564814419366353,
+ 4.6473613961860183, -2.6498099427000645, 16.1317933921668484, -0.0409475448061225, 1.4805687075608227,
+ -2.9958076484628950, -2.5288893025027264, -0.9614557539842487, -2.2974738351519077, -1.5516184284572598,
+ 1.6020939468348456, 0.5264507222630669, -0.0409475448061225, 4.1946802122694482, -2.5210038046912198,
+ 0.6634899962909317, 0.4036187419205338, -0.2829211393003727, -0.2283091172980954, 1.1253516563552464,
+ -4.6789817799137134, 0.3783428369189767, 1.4805687075608227, -2.5210038046912198, 19.4307361862733430,
+ -2.5200132222091787, 2.3748511971444510, 11.6426598443305522, -0.1508136510863874, 4.3471343888063512,
+ -0.8140416763434970, -2.1170186159188811, -2.9958076484628950, 0.6634899962909317, -2.5200132222091787,
+ 7.6712334419700747, -3.8687773629502851, -3.0453418711591529, -0.1155580876143619, -2.4025459467422121,
+ -4.5995617505618345, 2.3695134252190528, -2.5288893025027264, 0.4036187419205338, 2.3748511971444510,
+ -3.8687773629502851, 10.4681666057470082, 1.6527180866171229, 2.9341795819365384, -2.1708176372763099,
+ -1.1749070042775340, 3.8182131490333013, -0.9614557539842487, -0.2829211393003727, 11.6426598443305522,
+ -3.0453418711591529, 1.6527180866171229, 16.0050616934176233, 1.1689747208793086, 1.6665090945954870,
+ -1.6747995811678336, 6.5285942298270347, -2.2974738351519077, -0.2283091172980954, -0.1508136510863874,
+ -0.1155580876143619, 2.9341795819365384, 1.1689747208793086, 6.4794329751637481, -1.9197339981871877,
+ 3.1922255171058342, 2.8564814419366353, -1.5516184284572598, 1.1253516563552464, 4.3471343888063512,
+ -2.4025459467422121, -2.1708176372763099, 1.6665090945954870, -1.9197339981871877, 18.9149021356344598
+ }, 10, 10);
+ }
+
+ private static Vector getB() {
+ return new DenseVector(new double[] {
+ -0.552252, 0.038430, 0.058392, -1.234496, 1.240369, 0.373649, 0.505113, 0.503723, 1.215340, -0.391908
+ });
+ }
+
+ private static Matrix getIllConditionedMatrix() {
+ return reshape(new double[] {
+ 0.00695278043678842, 0.09911830022078683, 0.01309584636255063, 0.00652917453032394, 0.04337631487735064,
+ 0.14232165273321387, 0.05808722912361313, -0.06591965049732287, 0.06055771542862332, 0.00577423310349649,
+ 0.09911830022078683, 1.50071402418061428, 0.14988743575884242, 0.07195514527480981, 0.63747362341752722,
+ 1.30711819020414688, 0.82151609385115953, -0.72616125524587938, 1.03490136002022948, 0.12800239664439328,
+ 0.01309584636255063, 0.14988743575884242, 0.04068462583124965, 0.02147022047006482, 0.07388113580146650,
+ 0.58070223915076002, 0.11280336266257514, -0.21690068430020618, 0.04065087561300068, -0.00876895259593769,
+ 0.00652917453032394, 0.07195514527480981, 0.02147022047006482, 0.01140105250542524, 0.03624164348693958,
+ 0.31291554581393255, 0.05648457235205666, -0.11507583016077780, 0.01475756130709823, -0.00584453679519805,
+ 0.04337631487735064, 0.63747362341752722, 0.07388113580146649, 0.03624164348693959, 0.27491543200760571,
+ 0.73410543168748121, 0.36120630002843257, -0.36583546331208316, 0.41472509341940017, 0.04581458758255480,
+ 0.14232165273321387, 1.30711819020414666, 0.58070223915076002, 0.31291554581393255, 0.73410543168748121,
+ 9.02536073121807014, 1.25426385582883104, -3.16186335125594642, -0.19740140818905436, -0.26613760880058035,
+ 0.05808722912361314, 0.82151609385115953, 0.11280336266257514, 0.05648457235205667, 0.36120630002843257,
+ 1.25426385582883126, 0.48661058451606820, -0.57030511336562195, 0.49151280464818098, 0.04428280690189127,
+ -0.06591965049732286, -0.72616125524587938, -0.21690068430020618, -0.11507583016077781, -0.36583546331208316,
+ -3.16186335125594642, -0.57030511336562195, 1.16270815038078945, -0.14837898963724327, 0.05917203395002889,
+ 0.06055771542862331, 1.03490136002022926, 0.04065087561300068, 0.01475756130709823, 0.41472509341940023,
+ -0.19740140818905436, 0.49151280464818103, -0.14837898963724327, 0.86693820682049716, 0.14089688752570340,
+ 0.00577423310349649, 0.12800239664439328, -0.00876895259593769, -0.00584453679519805, 0.04581458758255480,
+ -0.26613760880058035, 0.04428280690189126, 0.05917203395002889, 0.14089688752570340, 0.02901858439788401
+ }, 10, 10);
+ }
+
+ private static Matrix getAsymmetricMatrix() {
+ return reshape(new double[] {
+ 0.1586493402398226, -0.8668244036239467, 0.4335233711065471, -1.1025223577469705, 1.1344100191664601,
+ -0.1399944083742454, 0.8879750333144295, -1.2139664527957903, 0.7154591081557057, -0.6320890356949669,
+ -2.4546945723009581, 0.6354748667295935, -0.1931993736354496, -0.1210449542073575, -1.0668745874463414,
+ 0.6539061600017384, 2.4045520271091063,-0.3387572116155693, 0.1575188740437142, 1.1791073500243496,
+ -0.6418745429181755, 0.6836410530720005, -1.2447493564334062, -1.8840081252627843, 0.5663864914859502,
+ 0.0819203791124956, 0.2004407540793239, 0.7350145066687849, 1.6525377683305262, -0.3156915229969668,
+ -0.1866701463141060, -0.3929673444397022, -0.4440946700501859, 0.1366803303987421, -0.2138101381625466,
+ 0.5399874351478779, -1.0088091882703056, 0.0978023083150833, 1.8795777615527958, 0.3782417618354363,
+ -0.4564752186043173, 0.4014814252832269, 1.9691150950571501, 0.2424686682362568, 1.0965758964799504,
+ 0.2751725463132324, -0.6652756564294597, -0.6256564536463288, 1.0332457212107204, -0.0330851504958215,
+ -1.0402096493279287, -0.6850389655533707, -1.8896839974451625, 1.1533231017445102, -0.5387306882127710,
+ 0.0181850207098213, -0.2416652193929706, -0.9868171673047287, -1.5872573189377035, -0.8492253650362955,
+ 1.1949977792951225, 0.7901168665120927, 0.9832676055718492, -0.0752834029327588, 1.0555006468941126,
+ 0.6842531633106009, 0.2589700378872499, 0.3565253337268334, 0.1869608474650344, -0.1696524825242293,
+ 0.6919898638809949, -1.4937187919435133, 1.0039151841775080, -0.2580993333173019, 0.1243386429912411,
+ 1.3945380460721688, 0.3078165489952902, 1.1248734111054359, 0.5613308856003306, -0.9013329415656699,
+ -0.9197179846787753, 0.1167372728291174, -0.7807620712716467, 0.2210918047063067, -0.4813869727362010,
+ 0.3870067788770671, 1.1974416632199159, 2.4676804711420330, 1.8492990765211168, -1.3089887830472471,
+ -0.7587845769668021, -1.0354138253278353, -0.3907902473275445, -2.1292895670916168, -0.7544686049709807,
+ -0.3431317172534703, 1.4959721683724390, 0.6004852467523584, 1.2140230344223786, 0.1279148299232956
+ }, 20, 5);
+ }
+
+ private static Vector getSmallB() {
+ return new DenseVector(new double[] {
+ 0.114065955249272,
+ 0.953981568944476,
+ -2.611106316607759,
+ 0.652190962446307,
+ 1.298055218126384,
+ });
+ }
+
+ private static Matrix getLowrankSymmetricMatrix() {
+ Matrix m = new DenseMatrix(5,5);
+ Vector u = new DenseVector(new double[] {
+ -0.0364638798936962,
+ 1.0219291133418171,
+ -0.5649933120375343,
+ -1.0050553315595800,
+ -0.5264178580727512
+ });
+ Vector v = new DenseVector(new double[] {
+ -1.345847117891187,
+ 0.553386426498032,
+ 1.912020072696648,
+ -0.820959934779948,
+ 1.223358044171859
+ });
+
+ return m.plus(u.cross(u)).plus(v.cross(v));
+ }
+
+ private static Matrix getLowrankAsymmetricMatrix() {
+ Matrix m = new DenseMatrix(20,5);
+ Vector u = new DenseVector(new double[] {
+ -0.0364638798936962,
+ 1.0219291133418171,
+ -0.5649933120375343,
+ -1.0050553315595800,
+ -0.5264178580727512
+ });
+ Vector v = new DenseVector(new double[] {
+ -1.345847117891187,
+ 0.553386426498032,
+ 1.912020072696648,
+ -0.820959934779948,
+ 1.223358044171859
+ });
+
+ m.assignRow(0, u);
+ m.assignRow(0, v);
+
+ return m;
+ }
+
+ private static Matrix reshape(double[] values, int rows, int columns) {
+ Matrix m = new DenseMatrix(rows, columns);
+ int i = 0;
+ for (double v : values) {
+ m.set(i % rows, i / rows, v);
+ i++;
+ }
+ return m;
+ }
+}