You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/02/01 22:52:43 UTC

incubator-systemml git commit: [MINOR] Added external builtin functions for performing cumsumprod and rowclassmeet

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 4f8648593 -> 6fad65d1d


[MINOR] Added external builtin functions for performing cumsumprod and
rowclassmeet


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6fad65d1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6fad65d1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6fad65d1

Branch: refs/heads/master
Commit: 6fad65d1d5ae4f1e65bdf99a68faf8396f280331
Parents: 4f86485
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Wed Feb 1 14:48:33 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed Feb 1 14:50:12 2017 -0800

----------------------------------------------------------------------
 .../org/apache/sysml/udf/lib/CumSumProd.java    | 243 +++++++++++++++++++
 .../org/apache/sysml/udf/lib/RowClassMeet.java  | 230 ++++++++++++++++++
 2 files changed, 473 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6fad65d1/src/main/java/org/apache/sysml/udf/lib/CumSumProd.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/udf/lib/CumSumProd.java b/src/main/java/org/apache/sysml/udf/lib/CumSumProd.java
new file mode 100644
index 0000000..bdb2231
--- /dev/null
+++ b/src/main/java/org/apache/sysml/udf/lib/CumSumProd.java
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.udf.lib;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheException;
+import org.apache.sysml.runtime.matrix.data.IJV;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.udf.FunctionParameter;
+import org.apache.sysml.udf.Matrix;
+import org.apache.sysml.udf.PackageFunction;
+import org.apache.sysml.udf.Scalar;
+import org.apache.sysml.udf.Matrix.ValueType;
+
+/**
+ * Variant of cumsum:
+ * Computes following two functions:
+ * 
+ * cumsum_prod = function (Matrix[double] X, Matrix[double] C, double start)  return (Matrix[double] Y)
+ * # Computes the following recurrence in log-number of steps:
+ * # Y [1, ] = X [1, ] + C [1, ] * start;
+ * # Y [i+1, ] = X [i+1, ] + C [i+1, ] * Y [i, ]
+ * {
+ * 		Y = X; P = C; m = nrow(X); k = 1;
+ * 		Y [1, ] = Y [1, ] + C [1, ] * start;
+ * 		while (k < m) {
+ * 			Y [k+1 : m, ] = Y [k+1 : m, ] + Y [1 : m-k, ] * P [k+1 : m, ];
+ * 			P [k+1 : m, ] = P [1 : m-k, ] * P [k+1 : m, ];
+ * 			k = 2 * k;
+ * 		} 
+ * }
+ * 
+ * cumsum_prod_reverse = function (Matrix[double] X, Matrix[double] C, double start) return (Matrix[double] Y)
+ * # Computes the reverse recurrence in log-number of steps:
+ * # Y [m, ] = X [m, ] + C [m, ] * start;
+ * # Y [i-1, ] = X [i-1, ] + C [i-1, ] * Y [i, ]
+ * {
+ * 		Y = X; P = C; m = nrow(X); k = 1;
+ * 		Y [m, ] = Y [m, ] + C [m, ] * start;
+ * 		while (k < m) {
+ * 			Y [1 : m-k, ] = Y [1 : m-k, ] + Y [k+1 : m, ] * P [1 : m-k, ];
+ * 			P [1 : m-k, ] = P [k+1 : m, ] * P [1 : m-k, ];
+ * 			k = 2 * k;
+ * 		} 
+ * }
+ * 
+ * The API of this external built-in function is as follows:
+ * 
+ * func = externalFunction(matrix[double] X, matrix[double] C,  double start, boolean isReverse) return (matrix[double] Y) 
+ * implemented in (classname="org.apache.sysml.udf.lib.CumSumProd",exectype="mem");
+ */
+public class CumSumProd extends PackageFunction {
+
+	private static final long serialVersionUID = -7883258699548686065L;
+	private Matrix ret;
+	private MatrixBlock retMB, X, C;
+	private double start;
+	private boolean isReverse;
+
+	@Override
+	public int getNumFunctionOutputs() {
+		return 1;
+	}
+
+	@Override
+	public FunctionParameter getFunctionOutput(int pos) {
+		if(pos == 0)
+			return ret;
+		else
+			throw new RuntimeException("CumSumProd produces only one output");
+	}
+
+	@Override
+	public void execute() {
+		try {
+			X = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
+			C = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
+			if(X.getNumRows() != C.getNumRows())
+				throw new RuntimeException("Number of rows of X and C should match");
+			if( X.getNumColumns() != C.getNumColumns() && C.getNumColumns() != 1 )
+				throw new RuntimeException("Incorrect Number of columns of X and C (Expected C to be of same dimension or a vector)");
+			start = Double.parseDouble(((Scalar)getFunctionInput(2)).getValue());
+			isReverse = Boolean.parseBoolean(((Scalar)getFunctionInput(3)).getValue()); 
+			
+			numRetRows = X.getNumRows();
+			numRetCols = X.getNumColumns();
+			allocateOutput();
+			
+			// Copy X to Y
+			denseBlock = retMB.getDenseBlock();
+			if(X.isInSparseFormat()) {
+				Iterator<IJV> iter = X.getSparseBlockIterator();
+				while(iter.hasNext()) {
+					IJV ijv = iter.next();
+					denseBlock[ijv.getI()*numRetCols + ijv.getJ()] = ijv.getV();
+				}
+			}
+			else {
+				if(X.getDenseBlock() != null)
+					System.arraycopy(X.getDenseBlock(), 0, denseBlock, 0, denseBlock.length);
+			}
+			
+			if(!isReverse) {
+				// Y [1, ] = X [1, ] + C [1, ] * start;
+				// Y [i+1, ] = X [i+1, ] + C [i+1, ] * Y [i, ]
+				addCNConstant(0, start);
+				for(int i = 1; i < numRetRows; i++) {
+					addC(i, true);
+				}
+			}
+			else {
+				// Y [m, ] = X [m, ] + C [m, ] * start;
+				// Y [i-1, ] = X [i-1, ] + C [i-1, ] * Y [i, ]
+				addCNConstant(numRetRows-1, start);
+				for(int i = numRetRows - 2; i >= 0; i--) {
+					addC(i, false);
+				}
+			}
+			
+			((Matrix) getFunctionInput(1)).getMatrixObject().release();
+			((Matrix) getFunctionInput(0)).getMatrixObject().release();
+		} catch (CacheException e) {
+			throw new RuntimeException("Error while executing CumSumProd", e);
+		}
+		
+		retMB.recomputeNonZeros();
+		try {
+			retMB.examSparsity();
+			ret.setMatrixDoubleArray(retMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
+		} catch (DMLRuntimeException e) {
+			throw new RuntimeException("Error while executing CumSumProd", e);
+		} catch (IOException e) {
+			throw new RuntimeException("Error while executing CumSumProd", e);
+		}
+	}
+	
+	int numRetRows; int numRetCols;
+	double [] denseBlock; 
+	
+	private void addCNConstant(int i, double constant) {
+		boolean isCVector = C.getNumColumns() != ret.getNumCols();
+		if(C.isInSparseFormat()) {
+			Iterator<IJV> iter = C.getSparseBlockIterator(i, i+1);
+			while(iter.hasNext()) {
+				IJV ijv = iter.next();
+				if(!isCVector)
+					denseBlock[ijv.getI()*numRetCols + ijv.getJ()] += ijv.getV() * constant;
+				else {
+					double val = ijv.getV();
+					for(int j = ijv.getI()*numRetCols; j < (ijv.getI()+1)*numRetCols; j++) {
+						denseBlock[j] += val*constant;
+					}
+				}
+			}
+		}
+		else {
+			double [] CBlk = C.getDenseBlock();
+			if(CBlk != null) {
+				if(!isCVector) {
+					for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
+						denseBlock[j] += CBlk[j]*constant;
+					}
+				}
+				else {
+					for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
+						denseBlock[j] += CBlk[i]*constant;
+					}
+				}
+			}
+		}
+	}
+	
+	private void addC(int i, boolean addPrevRow) {
+		boolean isCVector = C.getNumColumns() != ret.getNumCols();
+		if(C.isInSparseFormat()) {
+			Iterator<IJV> iter = C.getSparseBlockIterator(i, i+1);
+			while(iter.hasNext()) {
+				IJV ijv = iter.next();
+				if(!isCVector) {
+					if(addPrevRow)
+						denseBlock[ijv.getI()*numRetCols + ijv.getJ()] += ijv.getV() * denseBlock[(ijv.getI()-1)*numRetCols + ijv.getJ()];
+					else
+						denseBlock[ijv.getI()*numRetCols + ijv.getJ()] += ijv.getV() * denseBlock[(ijv.getI()+1)*numRetCols + ijv.getJ()];
+				}
+				else {
+					double val = ijv.getV();
+					for(int j = ijv.getI()*numRetCols; j < (ijv.getI()+1)*numRetCols; j++) {
+						double val1 = addPrevRow ? denseBlock[(ijv.getI()-1)*numRetCols + ijv.getJ()] : denseBlock[(ijv.getI()+1)*numRetCols + ijv.getJ()];
+						denseBlock[j] += val*val1;
+					}
+				}
+			}
+		}
+		else {
+			double [] CBlk = C.getDenseBlock();
+			if(CBlk != null) {
+				if(!isCVector) {
+					for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
+						double val1 = addPrevRow ? denseBlock[j-numRetCols] : denseBlock[j+numRetCols];
+						denseBlock[j] += CBlk[j]*val1;
+					}
+				}
+				else {
+					for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
+						double val1 = addPrevRow ? denseBlock[j-numRetCols] : denseBlock[j+numRetCols];
+						denseBlock[j] += CBlk[i]*val1;
+					}
+				}
+			}
+		}
+	}
+	
+	private void allocateOutput() {
+		String dir = createOutputFilePathAndName( "TMP" );
+		ret = new Matrix( dir, numRetRows, numRetCols, ValueType.Double );
+		retMB = new MatrixBlock((int) numRetRows, (int) numRetCols, false);
+		retMB.allocateDenseBlock();
+	}
+
+	
+	
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6fad65d1/src/main/java/org/apache/sysml/udf/lib/RowClassMeet.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/udf/lib/RowClassMeet.java b/src/main/java/org/apache/sysml/udf/lib/RowClassMeet.java
new file mode 100644
index 0000000..d24d0e8
--- /dev/null
+++ b/src/main/java/org/apache/sysml/udf/lib/RowClassMeet.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.udf.lib;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.Map.Entry;
+import java.util.TreeMap;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheException;
+import org.apache.sysml.runtime.matrix.data.IJV;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.udf.FunctionParameter;
+import org.apache.sysml.udf.Matrix;
+import org.apache.sysml.udf.PackageFunction;
+import org.apache.sysml.udf.Matrix.ValueType;
+
+/**
+ * Performs following operation:
+ * # Computes the intersection ("meet") of equivalence classes for
+ * # each row of A and B, excluding 0-valued cells.
+ * # INPUT:
+ * #   A, B = matrices whose rows contain that row's class labels;
+ * #          for each i, rows A [i, ] and B [i, ] define two
+ * #          equivalence relations on some of the columns, which
+ * #          we want to intersect
+ * #   A [i, j] == A [i, k] != 0 if and only if (j ~ k) as defined
+ * #          by row A [i, ];
+ * #   A [i, j] == 0 means that j is excluded by A [i, ]
+ * #   B [i, j] is analogous
+ * #   NOTE 1: Either nrow(A) == nrow(B), or exactly one of A or B
+ * #   has one row that "applies" to each row of the other matrix.
+ * #   NOTE 2: If ncol(A) != ncol(B), we pad extra 0-columns up to
+ * #   max (ncol(A), ncol(B)).
+ * # OUTPUT:
+ * #   Both C and N have the same size as (the max of) A and B.
+ * #   C = matrix whose rows contain class labels that represent
+ * #       the intersection (coarsest common refinement) of the
+ * #       corresponding rows of A and B.
+ * #   C [i, j] == C [i, k] != 0 if and only if (j ~ k) as defined
+ * #       by both A [i, ] and B [j, ]
+ * #   C [i, j] == 0 if and only if A [i, j] == 0 or B [i, j] == 0
+ * #       Additionally, we guarantee that non-0 labels in C [i, ]
+ * #       will be integers from 1 to max (C [i, ]) without gaps.
+ * #       For A and B the labels can be arbitrary.
+ * #   N = matrix with class-size information for C-cells
+ * #   N [i, j] = count of {C [i, k] | C [i, j] == C [i, k] != 0}
+ *
+ */
+public class RowClassMeet extends PackageFunction {
+
+	private static final long serialVersionUID = 1L;
+	private Matrix CMat, NMat;
+	private MatrixBlock A, B, C, N;
+	private int nr, nc;
+
+	@Override
+	public int getNumFunctionOutputs() {
+		return 2;
+	}
+
+	@Override
+	public FunctionParameter getFunctionOutput(int pos) {
+		if(pos == 0)
+			return CMat;
+		else if(pos == 1)
+			return NMat;
+		else
+			throw new RuntimeException("RowClassMeet produces only one output");
+	}
+	
+	
+	public class ClassLabels {
+		public double aVal;
+		public double bVal;
+		public ClassLabels(double aVal, double bVal) {
+			this.aVal = aVal;
+			this.bVal = bVal;
+		}
+	}
+	
+	public class ClassLabelComparator implements Comparator<ClassLabels> {
+		Integer tmp1, tmp2;
+		@Override
+		public int compare(ClassLabels o1, ClassLabels o2) {
+			if(o1.aVal != o2.aVal) {
+				tmp1 = (int) o1.aVal;
+				tmp2 = (int) o2.aVal;
+			}
+			else {
+				tmp1 = (int) o1.bVal;
+				tmp2 = (int) o2.bVal;
+			}
+			return tmp1.compareTo(tmp2);
+		}
+	}
+	
+	double [] getRow(MatrixBlock B, double [] bRow, int i) {
+		if(B.getNumRows() == 1) 
+			i = 0;
+		Arrays.fill(bRow, 0);
+		if(B.isInSparseFormat()) {
+			Iterator<IJV> iter = B.getSparseBlockIterator(i, i+1);
+			while(iter.hasNext()) {
+				IJV ijv = iter.next();
+				bRow[ijv.getJ()] = ijv.getV();
+			}
+		}
+		else {
+			double [] denseBlk = B.getDenseBlock();
+			if(denseBlk != null)
+				System.arraycopy(denseBlk, i*B.getNumColumns(), bRow, 0, B.getNumColumns());
+		}
+		return bRow;
+	}
+	
+	@Override
+	public void execute() {
+		try {
+			A = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
+			B = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
+			nr = Math.max(A.getNumRows(), B.getNumRows());
+			nc = Math.max(A.getNumColumns(), B.getNumColumns());
+			
+			double [] bRow = new double[B.getNumColumns()];
+			CMat = new Matrix( createOutputFilePathAndName( "TMP" ), nr, nc, ValueType.Double );
+			C = new MatrixBlock(nr, nc, false);
+			C.allocateDenseBlock();
+			NMat = new Matrix( createOutputFilePathAndName( "TMP" ), nr, nc, ValueType.Double );
+			N = new MatrixBlock(nr, nc, false);
+			N.allocateDenseBlock();
+			
+			double [] cBlk = C.getDenseBlock();
+			double [] nBlk = N.getDenseBlock();
+			
+			if(B.getNumRows() == 1)
+				getRow(B, bRow, 0);
+			
+			for(int i = 0; i < A.getNumRows(); i++) {
+				if(B.getNumRows() != 1)
+					getRow(B, bRow, i);
+				
+				// Create class labels
+				TreeMap<ClassLabels, ArrayList<Integer>> classLabelMapping = new TreeMap<ClassLabels, ArrayList<Integer>>(new ClassLabelComparator());
+				if(A.isInSparseFormat()) {
+					Iterator<IJV> iter = A.getSparseBlockIterator(i, i+1);
+					while(iter.hasNext()) {
+						IJV ijv = iter.next();
+						int j = ijv.getJ();
+						double aVal = ijv.getV();
+						if(aVal != 0 && bRow[j] != 0) {
+							ClassLabels key = new ClassLabels(aVal, bRow[j]);
+							if(!classLabelMapping.containsKey(key))
+								classLabelMapping.put(key, new ArrayList<Integer>());
+							classLabelMapping.get(key).add(j);
+						}
+					}
+				}
+				else {
+					double [] denseBlk = A.getDenseBlock();
+					if(denseBlk != null) {
+						int offset = i*A.getNumColumns();
+						for(int j = 0; j < A.getNumColumns(); j++) {
+							double aVal = denseBlk[offset + j];
+							if(aVal != 0 && bRow[j] != 0) {
+								ClassLabels key = new ClassLabels(aVal, bRow[j]);
+								if(!classLabelMapping.containsKey(key))
+									classLabelMapping.put(key, new ArrayList<Integer>());
+								classLabelMapping.get(key).add(j);
+							}
+						}
+					}
+				}
+				
+				
+				int labelID = 1;
+				for(Entry<ClassLabels, ArrayList<Integer>> entry : classLabelMapping.entrySet()) {
+					double nVal = entry.getValue().size();
+					for(Integer j : entry.getValue()) {
+						nBlk[i*nc + j] = nVal;
+						cBlk[i*nc + j] = labelID;
+					}
+					labelID++;
+				}
+			}
+			
+			((Matrix) getFunctionInput(0)).getMatrixObject().release();
+			((Matrix) getFunctionInput(1)).getMatrixObject().release();
+		} catch (CacheException e) {
+			throw new RuntimeException("Error while executing RowClassMeet", e);
+		} 
+		
+		try {
+			C.recomputeNonZeros();
+			C.examSparsity();
+			CMat.setMatrixDoubleArray(C, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
+			N.recomputeNonZeros();
+			N.examSparsity();
+			NMat.setMatrixDoubleArray(N, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
+		} catch (DMLRuntimeException e) {
+			throw new RuntimeException("Error while executing RowClassMeet", e);
+		} catch (IOException e) {
+			throw new RuntimeException("Error while executing RowClassMeet", e);
+		}
+	}
+	
+	
+}