You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/20 18:54:56 UTC

[systemds] branch master updated: [SYSTEMDS-3068] Fix robustness of Eigen decomposition (fallback)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 9387156  [SYSTEMDS-3068] Fix robustness of Eigen decomposition (fallback)
9387156 is described below

commit 9387156823b28632f640ae820e3fc08ba8af99bf
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Jul 20 20:52:16 2021 +0200

    [SYSTEMDS-3068] Fix robustness of Eigen decomposition (fallback)
    
    This patch adds a fallback strategy for when the internally used commons
    math library for Eigen decomposition runs into failed convergence
    exceptions (due to hitting the maximum number of iterations). Specially
    we now add a small regularization term (either positive or negative
    according to the input values).
---
 .../sysds/runtime/matrix/data/LibCommonsMath.java  | 45 +++++++++--
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 16 ++--
 .../test/functions/builtin/BuiltinPCATest.java     | 94 ++++++++++++++++++++++
 .../BuiltinTopkCleaningRegressionTest.java         |  3 -
 src/test/scripts/functions/builtin/pca.dml         | 39 +++++++++
 5 files changed, 182 insertions(+), 15 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index 8283352..a49d1ce 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.matrix.data;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.commons.math3.exception.MaxCountExceededException;
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.BlockRealMatrix;
 import org.apache.commons.math3.linear.CholeskyDecomposition;
@@ -29,6 +32,7 @@ import org.apache.commons.math3.linear.QRDecomposition;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.SingularValueDecomposition;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.util.DataConverter;
 
 /**
@@ -40,9 +44,9 @@ import org.apache.sysds.runtime.util.DataConverter;
  */
 public class LibCommonsMath 
 {
-	// private static final Log LOG = LogFactory.getLog(LibCommonsMath.class.getName());
-
-	static final double RELATIVE_SYMMETRY_THRESHOLD = 1e-6;
+	private static final Log LOG = LogFactory.getLog(LibCommonsMath.class.getName());
+	private static final double RELATIVE_SYMMETRY_THRESHOLD = 1e-6;
+	private static final double EIGEN_LAMBDA = 1e-8;
 
 	private LibCommonsMath() {
 		//prevent instantiation via private constructor
@@ -173,12 +177,20 @@ public class LibCommonsMath
 	 */
 	private static MatrixBlock[] computeEigen(MatrixBlock in) {
 		if ( in.getNumRows() != in.getNumColumns() ) {
-			throw new DMLRuntimeException("Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols="+ in.getNumColumns() +")");
+			throw new DMLRuntimeException("Eigen Decomposition can only be done on a square matrix. "
+				+ "Input matrix is rectangular (rows=" + in.getNumRows() + ", cols="+ in.getNumColumns() +")");
 		}
 		
-		Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
+		EigenDecomposition eigendecompose = null;
+		try {
+			Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
+			eigendecompose = new EigenDecomposition(matrixInput);
+		}
+		catch(MaxCountExceededException ex) {
+			LOG.warn("Eigen: "+ ex.getMessage()+". Falling back to regularized eigen factorization.");
+			eigendecompose = computeEigenRegularized(in);
+		}
 		
-		EigenDecomposition eigendecompose = new EigenDecomposition(matrixInput);
 		RealMatrix eVectorsMatrix = eigendecompose.getV();
 		double[][] eVectors = eVectorsMatrix.getData();
 		double[] eValues = eigendecompose.getRealEigenvalues();
@@ -210,6 +222,27 @@ public class LibCommonsMath
 
 		return new MatrixBlock[] { mbValues, mbVectors };
 	}
+	
+	private static EigenDecomposition computeEigenRegularized(MatrixBlock in) {
+		if( in == null || in.isEmptyBlock(false) )
+			throw new DMLRuntimeException("Invalid empty block");
+		
+		//slightly modify input for regularization (pos/neg)
+		MatrixBlock in2 = new MatrixBlock(in, false);
+		DenseBlock a = in2.getDenseBlock();
+		for( int i=0; i<in2.rlen; i++ ) {
+			double[] avals = a.values(i);
+			int apos = a.pos(i);
+			for( int j=0; j<in2.clen; j++ ) {
+				double v = avals[apos+j];
+				avals[apos+j] += Math.signum(v) * EIGEN_LAMBDA;
+			}
+		}
+		
+		//run eigen decomposition
+		return new EigenDecomposition(
+			DataConverter.convertToArray2DRowRealMatrix(in2));
+	}
 
 
 	/**
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 2c71d53..22b7e5c 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -172,6 +172,10 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		copy(that);
 	}
 	
+	public MatrixBlock(MatrixBlock that, boolean sp) {
+		copy(that, sp);
+	}
+	
 	public MatrixBlock(double val) {
 		reset(1, 1, false, 1, val);
 	}
@@ -1390,15 +1394,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		if( this == that ) //prevent data loss (e.g., on sparse-dense conversion)
 			throw new RuntimeException( "Copy must not overwrite itself!" );
 		
-		this.rlen=that.rlen;
-		this.clen=that.clen;
-		this.sparse=sp;
+		rlen=that.rlen;
+		clen=that.clen;
+		sparse=sp;
 		estimatedNNzsPerRow=(int)Math.ceil((double)thatValue.getNonZeros()/(double)rlen);
-		if(this.sparse && that.sparse)
+		if(sparse && that.sparse)
 			copySparseToSparse(that);
-		else if(this.sparse && !that.sparse)
+		else if(sparse && !that.sparse)
 			copyDenseToSparse(that);
-		else if(!this.sparse && that.sparse)
+		else if(!sparse && that.sparse)
 			copySparseToDense(that);
 		else
 			copyDenseToDense(that);
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinPCATest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinPCATest.java
new file mode 100644
index 0000000..a4aaf03
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinPCATest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class BuiltinPCATest extends AutomatedTestBase {
+	private final static String TEST_NAME = "pca";
+	private final static String TEST_DIR = "functions/builtin/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinPCATest.class.getSimpleName() + "/";
+
+	//Note: for <110 fine, but failing for more columns w/ eigen
+	// org.apache.commons.math3.exception.MaxCountExceededException: illegal state: convergence failed
+	private final static int rows = 3000;
+	private final static int cols = 110;
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"PC","V"}));
+	}
+
+	@Test
+	public void testPca4Hybrid() {
+		runPCATest(4, ExecMode.HYBRID);
+	}
+	
+	@Test
+	public void testPca16Hybrid() {
+		runPCATest(16, ExecMode.HYBRID);
+	}
+	
+	@Test
+	public void testPca4Spark() {
+		runPCATest(4, ExecMode.SPARK);
+	}
+	
+	@Test
+	public void testPca16Spark() {
+		runPCATest(16, ExecMode.SPARK);
+	}
+
+	private void runPCATest(int k, ExecMode mode) {
+		ExecMode modeOld = setExecMode(mode);
+		try {
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			List<String> proArgs = new ArrayList<>();
+	
+			proArgs.add("-args");
+			proArgs.add(input("X"));
+			proArgs.add(String.valueOf(k));
+			proArgs.add(output("PC"));
+			proArgs.add(output("V"));
+			programArgs = proArgs.toArray(new String[proArgs.size()]);
+			double[][] X = TestUtils.round(getRandomMatrix(rows, cols, 1, 5, 1.0, 7));
+			writeInputMatrixWithMTD("X", X, true);
+	
+			runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+			MatrixCharacteristics mc = readDMLMetaDataFile("PC");
+			Assert.assertEquals(rows, mc.getRows());
+			Assert.assertEquals(k, mc.getCols());
+		}
+		finally {
+			resetExecMode(modeOld);
+		}
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningRegressionTest.java b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningRegressionTest.java
index 996fa63..45ab4c3 100644
--- a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningRegressionTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningRegressionTest.java
@@ -23,7 +23,6 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 
 public class BuiltinTopkCleaningRegressionTest extends AutomatedTestBase{
@@ -55,7 +54,6 @@ public class BuiltinTopkCleaningRegressionTest extends AutomatedTestBase{
 			"lm", Types.ExecMode.HYBRID);
 	}
 
-
 	private void runFindPipelineTest(Double sample, int topk, int resources, int crossfold,
 		String target, Types.ExecMode et) {
 
@@ -78,5 +76,4 @@ public class BuiltinTopkCleaningRegressionTest extends AutomatedTestBase{
 			resetExecMode(modeOld);
 		}
 	}
-
 }
diff --git a/src/test/scripts/functions/builtin/pca.dml b/src/test/scripts/functions/builtin/pca.dml
new file mode 100644
index 0000000..6a6ec1a
--- /dev/null
+++ b/src/test/scripts/functions/builtin/pca.dml
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1)
+k = $2;
+
+# one hot encoding
+m = nrow(X)
+n = ncol(X)
+fdom = colMaxs(X);
+foffb = t(cumsum(t(fdom))) - fdom;
+foffe = t(cumsum(t(fdom)))
+rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1)
+cix = matrix(X + foffb, m*n, 1);
+X2 = table(rix, cix); #one-hot encoded
+
+X2 = scale(X=X2, scale=TRUE, center=TRUE);
+[PC, V] = pca(X=X2, K=k)
+
+write(PC, $3)
+write(V, $4)