You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/07 18:59:53 UTC
[1/2] incubator-systemml git commit: [SYSTEMML-268] New rev builtin
function, incl rewrites/tests
Repository: incubator-systemml
Updated Branches:
refs/heads/master 53221df4c -> 895610547
[SYSTEMML-268] New rev builtin function, incl rewrites/tests
This change adds a new rev builtin function for column-wise matrix
reverse. Beside the required parser and compiler/runtime changes for all
backends, this also includes extended rewrites like 'rev(rev(X)) -> X'
and new rewrites like 'table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X
-> rev(X)'.
Further, it also includes a minor fix of sign unary operation tests
(wrong opcode checking error message).
https://issues.apache.org/jira/browse/SYSTEMML-268
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/1c9fef12
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/1c9fef12
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/1c9fef12
Branch: refs/heads/master
Commit: 1c9fef12d3b13cb8dc44359ba0fefa453f8c699f
Parents: 53221df
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Wed Jan 6 20:05:10 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Thu Jan 7 09:58:44 2016 -0800
----------------------------------------------------------------------
src/main/java/org/apache/sysml/hops/Hop.java | 3 +-
.../java/org/apache/sysml/hops/ReorgOp.java | 49 ++++-
.../sysml/hops/rewrite/HopRewriteUtils.java | 45 ++++-
.../RewriteAlgebraicSimplificationDynamic.java | 5 +-
.../RewriteAlgebraicSimplificationStatic.java | 70 ++++++-
.../java/org/apache/sysml/lops/Transform.java | 7 +-
.../sysml/parser/BuiltinFunctionExpression.java | 12 ++
.../org/apache/sysml/parser/DMLTranslator.java | 5 +
.../org/apache/sysml/parser/Expression.java | 1 +
.../sysml/runtime/functionobjects/RevIndex.java | 65 +++++++
.../instructions/CPInstructionParser.java | 1 +
.../instructions/MRInstructionParser.java | 1 +
.../instructions/SPInstructionParser.java | 1 +
.../instructions/cp/ReorgCPInstruction.java | 16 +-
.../instructions/mr/ReorgInstruction.java | 22 ++-
.../instructions/spark/ReorgSPInstruction.java | 49 ++++-
.../runtime/matrix/data/LibMatrixReorg.java | 153 +++++++++++++++
.../sysml/runtime/matrix/data/MatrixBlock.java | 5 +-
.../functions/reorg/FullReverseTest.java | 190 +++++++++++++++++++
.../functions/unary/matrix/FullSignTest.java | 2 +-
src/test/scripts/functions/reorg/Reverse1.R | 41 ++++
src/test/scripts/functions/reorg/Reverse1.dml | 25 +++
src/test/scripts/functions/reorg/Reverse2.R | 41 ++++
src/test/scripts/functions/reorg/Reverse2.dml | 25 +++
24 files changed, 799 insertions(+), 35 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index ee815fe..b3bed99 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1068,7 +1068,7 @@ public abstract class Hop
};
public enum ReOrgOp {
- TRANSPOSE, DIAG, RESHAPE, SORT
+ TRANSPOSE, DIAG, RESHAPE, SORT, REV
//Note: Diag types are invalid because for unknown sizes this would
//create incorrect plans (now we try to infer it for memory estimates
//and rewrites but the final choice is made during runtime)
@@ -1130,6 +1130,7 @@ public abstract class Hop
static {
HopsTransf2Lops = new HashMap<ReOrgOp, org.apache.sysml.lops.Transform.OperationTypes>();
HopsTransf2Lops.put(ReOrgOp.TRANSPOSE, org.apache.sysml.lops.Transform.OperationTypes.Transpose);
+ HopsTransf2Lops.put(ReOrgOp.REV, org.apache.sysml.lops.Transform.OperationTypes.Rev);
HopsTransf2Lops.put(ReOrgOp.DIAG, org.apache.sysml.lops.Transform.OperationTypes.Diag);
HopsTransf2Lops.put(ReOrgOp.RESHAPE, org.apache.sysml.lops.Transform.OperationTypes.Reshape);
HopsTransf2Lops.put(ReOrgOp.SORT, org.apache.sysml.lops.Transform.OperationTypes.Sort);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/hops/ReorgOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index c7df98b..cfd2f65 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -120,6 +120,36 @@ public class ReorgOp extends Hop
break;
}
+ case REV:
+ {
+ Lop rev = null;
+
+ if( et == ExecType.MR ) {
+ Lop tmp = new Transform( getInput().get(0).constructLops(),
+ HopsTransf2Lops.get(op), getDataType(), getValueType(), et);
+ setOutputDimensions(tmp);
+ setLineNumbers(tmp);
+
+ Group group1 = new Group(tmp, Group.OperationTypes.Sort,
+ DataType.MATRIX, getValueType());
+ setOutputDimensions(group1);
+ setLineNumbers(group1);
+
+ rev = new Aggregate(group1, Aggregate.OperationTypes.Sum,
+ DataType.MATRIX, getValueType(), et);
+ }
+ else { //CP/SPARK
+
+ rev = new Transform( getInput().get(0).constructLops(),
+ HopsTransf2Lops.get(op), getDataType(), getValueType(), et);
+ }
+
+ setOutputDimensions(rev);
+ setLineNumbers(rev);
+ setLops(rev);
+
+ break;
+ }
case RESHAPE:
{
if( et==ExecType.MR )
@@ -356,7 +386,14 @@ public class ReorgOp extends Hop
if( mc.dimsKnown() )
ret = new long[]{ mc.getCols(), mc.getRows(), mc.getNonZeros() };
break;
- }
+ }
+ case REV:
+ {
+ // dims and nnz are exactly the same as in input
+ if( mc.dimsKnown() )
+ ret = new long[]{ mc.getRows(), mc.getCols(), mc.getNonZeros() };
+ break;
+ }
case DIAG:
{
// NOTE: diag is overloaded according to the number of columns of the input
@@ -470,7 +507,15 @@ public class ReorgOp extends Hop
setDim2(input1.getDim1());
setNnz(input1.getNnz());
break;
- }
+ }
+ case REV:
+ {
+ // dims and nnz are exactly the same as in input
+ setDim1(input1.getDim1());
+ setDim2(input1.getDim2());
+ setNnz(input1.getNnz());
+ break;
+ }
case DIAG:
{
// NOTE: diag is overloaded according to the number of columns of the input
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 1817c1f..95ddf0f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -478,9 +478,19 @@ public class HopRewriteUtils
* @param input
* @return
*/
- public static ReorgOp createTranspose(Hop input)
+ public static ReorgOp createTranspose(Hop input) {
+ return createReorg(input, ReOrgOp.TRANSPOSE);
+ }
+
+ /**
+ *
+ * @param input
+ * @param rop
+ * @return
+ */
+ public static ReorgOp createReorg(Hop input, ReOrgOp rop)
{
- ReorgOp transpose = new ReorgOp(input.getName(), input.getDataType(), input.getValueType(), ReOrgOp.TRANSPOSE, input);
+ ReorgOp transpose = new ReorgOp(input.getName(), input.getDataType(), input.getValueType(), rop, input);
HopRewriteUtils.setOutputBlocksizes(transpose, input.getRowsInBlock(), input.getColsInBlock());
HopRewriteUtils.copyLineNumbers(input, transpose);
transpose.refreshSizeInformation();
@@ -860,6 +870,29 @@ public class HopRewriteUtils
*
* @param hop
* @return
+ */
+ public static boolean isBasicN1Sequence(Hop hop)
+ {
+ boolean ret = false;
+
+ if( hop instanceof DataGenOp )
+ {
+ DataGenOp dgop = (DataGenOp) hop;
+ if( dgop.getOp() == DataGenMethod.SEQ ){
+ Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO));
+ Hop incr = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR));
+ ret = (to instanceof LiteralOp && getDoubleValueSafe((LiteralOp)to)==1)
+ &&(incr instanceof LiteralOp && getDoubleValueSafe((LiteralOp)incr)==-1);
+ }
+ }
+
+ return ret;
+ }
+
+ /**
+ *
+ * @param hop
+ * @return
* @throws HopsException
*/
public static double getBasic1NSequenceMax(Hop hop)
@@ -1060,6 +1093,14 @@ public class HopRewriteUtils
return false;
}
+ public static boolean isValidOp( ReOrgOp input, ReOrgOp[] validTab )
+ {
+ for( ReOrgOp valid : validTab )
+ if( valid == input )
+ return true;
+ return false;
+ }
+
public static int getValidOpPos( OpOp2 input, OpOp2[] validTab )
{
for( int i=0; i<validTab.length; i++ ) {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 3deb4c6..7c4a67a 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -24,7 +24,6 @@ import java.util.HashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
@@ -880,7 +879,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
Hop hnew = null;
if( rhi.getOp() == ReOrgOp.TRANSPOSE )
hnew = HopRewriteUtils.createDataGenOp(input, true, input, true, 0);
- else if( rhi.getOp() == ReOrgOp.DIAG ){
+ else if( rhi.getOp() == ReOrgOp.REV )
+ hnew = HopRewriteUtils.createDataGenOp(input, 0);
+ else if( rhi.getOp() == ReOrgOp.DIAG ) {
if( HopRewriteUtils.isDimsKnown(input) ){
if( input.getDim1()==1 ) //diagv2m
hnew = HopRewriteUtils.createDataGenOp(input, false, input, true, 0);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 20e94a0..3aea9f0 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -24,7 +24,6 @@ import java.util.HashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
@@ -37,6 +36,7 @@ import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
import org.apache.sysml.hops.Hop.ReOrgOp;
import org.apache.sysml.hops.HopsException;
@@ -134,7 +134,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = fuseDatagenAndBinaryOperation(hop, hi, i); //e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7)
hi = fuseDatagenAndMinusOperation(hop, hi, i); //e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2)
hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X)
- hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
+ hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
+ hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X)
@@ -145,7 +146,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq;
hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq;
- hi = removeUnnecessaryTranspose(hop, hi, i); //e.g., t(t(X))->X; potentially introduced by diag/trace_MM
+ hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites
hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean
@@ -537,6 +538,50 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
return hi;
}
+
+ /**
+ * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static
+ * rewrite in order to apply it before splitting dags which would hide the table information
+ * if dimensions are not specified.
+ *
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop simplifyReverseOperation( Hop parent, Hop hi, int pos )
+ throws HopsException
+ {
+ if( hi instanceof AggBinaryOp
+ && hi.getInput().get(0) instanceof TernaryOp )
+ {
+ TernaryOp top = (TernaryOp) hi.getInput().get(0);
+
+ if( top.getOp()==OpOp3.CTABLE
+ && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0))
+ && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1))
+ && top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1())
+ {
+ ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV);
+
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.removeAllChildReferences(hi);
+ HopRewriteUtils.addChildReference(parent, rop, pos);
+ if( top.getParent().isEmpty() )
+ HopRewriteUtils.removeAllChildReferences(top);
+
+ hi = rop;
+
+ LOG.debug("Applied simplifyReverseOperation.");
+ }
+ }
+
+ return hi;
+ }
+
+
/**
*
* @param hi
@@ -751,8 +796,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
&& hi.getInput().get(0) instanceof ReorgOp ) //reorg operation
{
ReorgOp rop = (ReorgOp)hi.getInput().get(0);
- if( (rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE) //valid reorg
- && rop.getParent().size()==1 ) //uagg only reorg consumer
+ if( (rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE
+ || rop.getOp() == ReOrgOp.REV ) //valid reorg
+ && rop.getParent().size()==1 ) //uagg only reorg consumer
{
Hop input = rop.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi);
@@ -1271,17 +1317,20 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
}
/**
+ * Pattners: t(t(X)) -> X, rev(rev(X)) -> X
*
* @param parent
* @param hi
* @param pos
*/
- private Hop removeUnnecessaryTranspose(Hop parent, Hop hi, int pos)
+ private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
{
- if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE ) //first transpose
+ ReOrgOp[] lookup = new ReOrgOp[]{ReOrgOp.TRANSPOSE, ReOrgOp.REV};
+
+ if( hi instanceof ReorgOp && HopRewriteUtils.isValidOp(((ReorgOp)hi).getOp(), lookup) ) //first reorg
{
Hop hi2 = hi.getInput().get(0);
- if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.TRANSPOSE ) //second transpose
+ if( hi2 instanceof ReorgOp && HopRewriteUtils.isValidOp(((ReorgOp)hi2).getOp(), lookup) ) //second reorg
{
Hop hi3 = hi2.getInput().get(0);
//remove unnecessary chain of t(t())
@@ -1295,7 +1344,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
if( hi2.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi2 );
- LOG.debug("Applied removeUnecessaryTranspose");
+ LOG.debug("Applied removeUnecessaryReorgOperation.");
}
}
@@ -1376,7 +1425,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
return hi;
}
-
+
+
/**
*
* @param parent
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/lops/Transform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Transform.java b/src/main/java/org/apache/sysml/lops/Transform.java
index 543ad6b..e666aa8 100644
--- a/src/main/java/org/apache/sysml/lops/Transform.java
+++ b/src/main/java/org/apache/sysml/lops/Transform.java
@@ -38,7 +38,8 @@ public class Transform extends Lop
Transpose,
Diag,
Reshape,
- Sort
+ Sort,
+ Rev
};
private boolean _bSortIndInMem = false;
@@ -129,6 +130,10 @@ public class Transform extends Lop
// Transpose a matrix
return "r'";
+ case Rev:
+ // Transpose a matrix
+ return "rev";
+
case Diag:
// Transform a vector into a diagonal matrix
return "rdiag";
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index a0fd56e..da5bcd9 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -480,6 +480,16 @@ public class BuiltinFunctionExpression extends DataIdentifier
output.setBlockDimensions (id.getColumnsInBlock(), id.getRowsInBlock());
output.setValueType(id.getValueType());
break;
+
+ case REV:
+ checkNumParameters(1);
+ checkMatrixParam(getFirstExpr());
+ output.setDataType(DataType.MATRIX);
+ output.setDimensions(id.getDim1(), id.getDim2());
+ output.setBlockDimensions (id.getColumnsInBlock(), id.getRowsInBlock());
+ output.setValueType(id.getValueType());
+ break;
+
case DIAG:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
@@ -1313,6 +1323,8 @@ public class BuiltinFunctionExpression extends DataIdentifier
bifop = Expression.BuiltinFunctionOp.TRACE;
else if (functionName.equals("t"))
bifop = Expression.BuiltinFunctionOp.TRANS;
+ else if (functionName.equals("rev"))
+ bifop = Expression.BuiltinFunctionOp.REV;
else if (functionName.equals("cbind") || functionName.equals("append"))
bifop = Expression.BuiltinFunctionOp.CBIND;
else if (functionName.equals("rbind"))
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 5b6ee70..e10f3a8 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2342,6 +2342,11 @@ public class DMLTranslator
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.ReOrgOp.TRANSPOSE, expr);
break;
+
+ case REV:
+ currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(),
+ Hop.ReOrgOp.REV, expr);
+ break;
case CBIND:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index ad288de..7ec06aa 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -85,6 +85,7 @@ public abstract class Expression
PROD,
QUANTILE,
RANGE,
+ REV,
ROUND,
ROWINDEXMAX,
ROWMAX,
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/functionobjects/RevIndex.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/RevIndex.java b/src/main/java/org/apache/sysml/runtime/functionobjects/RevIndex.java
new file mode 100644
index 0000000..4308497
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/RevIndex.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.functionobjects;
+
+import java.io.Serializable;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+
+
+public class RevIndex extends IndexFunction implements Serializable
+{
+ private static final long serialVersionUID = -1002715543022547788L;
+
+ private static RevIndex singleObj = null;
+
+ private RevIndex() {
+ // nothing to do here
+ }
+
+ public static RevIndex getRevIndexFnObject() {
+ if ( singleObj == null )
+ singleObj = new RevIndex();
+ return singleObj;
+ }
+
+ public Object clone() throws CloneNotSupportedException {
+ // cloning is not supported for singleton classes
+ throw new CloneNotSupportedException();
+ }
+
+ @Override // for cp block operations
+ public boolean computeDimension(int row, int col, CellIndex retDim)
+ throws DMLRuntimeException
+ {
+ retDim.set(row, col);
+ return false;
+ }
+
+ @Override //for mr block operations
+ public boolean computeDimension(MatrixCharacteristics in, MatrixCharacteristics out)
+ throws DMLRuntimeException
+ {
+ out.set(in.getRows(), in.getCols(), in.getColsPerBlock(), in.getRowsPerBlock(), in.getNonZeros());
+ return false;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index b34dbc9..81b6550 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -196,6 +196,7 @@ public class CPInstructionParser extends InstructionParser
// Reorg Instruction Opcodes (repositioning of existing values)
String2CPInstructionType.put( "r'" , CPINSTRUCTION_TYPE.Reorg);
+ String2CPInstructionType.put( "rev" , CPINSTRUCTION_TYPE.Reorg);
String2CPInstructionType.put( "rdiag" , CPINSTRUCTION_TYPE.Reorg);
String2CPInstructionType.put( "rshape" , CPINSTRUCTION_TYPE.MatrixReshape);
String2CPInstructionType.put( "rsort" , CPINSTRUCTION_TYPE.Reorg);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
index 57051c2..60555f5 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
@@ -200,6 +200,7 @@ public class MRInstructionParser extends InstructionParser
// REORG Instruction Opcodes
String2MRInstructionType.put( "r'" , MRINSTRUCTION_TYPE.Reorg);
+ String2MRInstructionType.put( "rev" , MRINSTRUCTION_TYPE.Reorg);
String2MRInstructionType.put( "rdiag" , MRINSTRUCTION_TYPE.Reorg);
// REPLICATE Instruction Opcodes
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index 5d5450c..a64fc8f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -130,6 +130,7 @@ public class SPInstructionParser extends InstructionParser
// Reorg Instruction Opcodes (repositioning of existing values)
String2SPInstructionType.put( "r'" , SPINSTRUCTION_TYPE.Reorg);
+ String2SPInstructionType.put( "rev" , SPINSTRUCTION_TYPE.Reorg);
String2SPInstructionType.put( "rdiag" , SPINSTRUCTION_TYPE.Reorg);
String2SPInstructionType.put( "rshape" , SPINSTRUCTION_TYPE.MatrixReshape);
String2SPInstructionType.put( "rsort" , SPINSTRUCTION_TYPE.Reorg);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/instructions/cp/ReorgCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ReorgCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ReorgCPInstruction.java
index 69ded28..ba79858 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ReorgCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ReorgCPInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.functionobjects.DiagIndex;
+import org.apache.sysml.runtime.functionobjects.RevIndex;
import org.apache.sysml.runtime.functionobjects.SortIndex;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
@@ -35,7 +36,6 @@ import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
public class ReorgCPInstruction extends UnaryCPInstruction
{
-
//sort-specific attributes (to enable variable attributes)
private CPOperand _col = null;
private CPOperand _desc = null;
@@ -87,6 +87,10 @@ public class ReorgCPInstruction extends UnaryCPInstruction
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, str);
}
+ else if ( opcode.equalsIgnoreCase("rev") ) {
+ parseUnaryInstruction(str, in, out); //max 2 operands
+ return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
+ }
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -95,13 +99,9 @@ public class ReorgCPInstruction extends UnaryCPInstruction
InstructionUtils.checkNumFields(parts, 5);
in.split(parts[1]);
out.split(parts[5]);
- CPOperand col = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
- CPOperand desc = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
- CPOperand ixret = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
- col.split(parts[2]);
- desc.split(parts[3]);
- ixret.split(parts[4]);
-
+ CPOperand col = new CPOperand(parts[2]);
+ CPOperand desc = new CPOperand(parts[3]);
+ CPOperand ixret = new CPOperand(parts[4]);
return new ReorgCPInstruction(new ReorgOperator(SortIndex.getSortIndexFnObject(1,false,false)),
in, col, desc, ixret, out, opcode, str);
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/instructions/mr/ReorgInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/ReorgInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/ReorgInstruction.java
index eb67564..a1d7cdc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/mr/ReorgInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/ReorgInstruction.java
@@ -24,9 +24,11 @@ import java.util.ArrayList;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.DiagIndex;
+import org.apache.sysml.runtime.functionobjects.RevIndex;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
@@ -41,14 +43,12 @@ public class ReorgInstruction extends UnaryMRInstructionBase
//required for diag (size-based type, load-balance-aware output of empty blocks)
private MatrixCharacteristics _mcIn = null;
private boolean _outputEmptyBlocks = true;
- private boolean _isDiag = false;
public ReorgInstruction(ReorgOperator op, byte in, byte out, String istr)
{
super(op, in, out);
mrtype = MRINSTRUCTION_TYPE.Reorg;
instString = istr;
- _isDiag = (op.fn==DiagIndex.getDiagIndexFnObject());
}
public void setInputMatrixCharacteristics( MatrixCharacteristics in )
@@ -75,11 +75,12 @@ public class ReorgInstruction extends UnaryMRInstructionBase
if ( opcode.equalsIgnoreCase("r'") ) {
return new ReorgInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, str);
}
-
+ else if ( opcode.equalsIgnoreCase("rev") ) {
+ return new ReorgInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, str);
+ }
else if ( opcode.equalsIgnoreCase("rdiag") ) {
return new ReorgInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, str);
}
-
else {
throw new DMLRuntimeException("Unknown opcode while parsing a ReorgInstruction: " + str);
}
@@ -102,8 +103,9 @@ public class ReorgInstruction extends UnaryMRInstructionBase
int startRow=0, startColumn=0, length=0;
//process instruction
- if( _isDiag ) //special diag handling (overloaded, size-dependent operation; hence decided during runtime)
+ if( ((ReorgOperator)optr).fn instanceof DiagIndex )
{
+ //special diag handling (overloaded, size-dependent operation; hence decided during runtime)
boolean V2M = (_mcIn.getRows()==1 || _mcIn.getCols()==1);
long rlen = Math.max(_mcIn.getRows(), _mcIn.getCols()); //input can be row/column vector
@@ -149,6 +151,16 @@ public class ReorgInstruction extends UnaryMRInstructionBase
}
}
}
+ else if( ((ReorgOperator)optr).fn instanceof RevIndex )
+ {
+ //execute reverse operation
+ ArrayList<IndexedMatrixValue> out = new ArrayList<IndexedMatrixValue>();
+ LibMatrixReorg.rev(in, _mcIn.getRows(), _mcIn.getRowsPerBlock(), out);
+
+ //output indexed matrix values
+ for( IndexedMatrixValue outblk : out )
+ cachedValues.add(output, outblk);
+ }
else //general case (e.g., transpose)
{
//allocate space for the output value
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
index aeea61b..cdfdc54 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
@@ -34,17 +34,22 @@ import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.DiagIndex;
+import org.apache.sysml.runtime.functionobjects.RevIndex;
import org.apache.sysml.runtime.functionobjects.SortIndex;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.functions.FilterDiagBlocksFunction;
import org.apache.sysml.runtime.instructions.spark.functions.IsBlockInRange;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDSortUtils;
+import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.instructions.spark.functions.ReorgMapFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.UtilFunctions;
@@ -82,7 +87,11 @@ public class ReorgSPInstruction extends UnarySPInstruction
if ( opcode.equalsIgnoreCase("r'") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgSPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, str);
- }
+ }
+ else if ( opcode.equalsIgnoreCase("rev") ) {
+ parseUnaryInstruction(str, in, out); //max 2 operands
+ return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
+ }
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -125,6 +134,13 @@ public class ReorgSPInstruction extends UnarySPInstruction
//execute transpose reorg operation
out = in1.mapToPair(new ReorgMapFunction(opcode));
}
+ else if( opcode.equalsIgnoreCase("rev") ) //REVERSE
+ {
+ //execute reverse reorg operation
+ out = in1.flatMapToPair(new RDDRevFunction(mcIn));
+ if( mcIn.getRows() % mcIn.getRowsPerBlock() != 0 )
+ out = RDDAggregateUtils.mergeByKey(out);
+ }
else if ( opcode.equalsIgnoreCase("rdiag") ) // DIAG
{
if(mcIn.getCols() == 1) { // diagV2M
@@ -260,6 +276,37 @@ public class ReorgSPInstruction extends UnarySPInstruction
return ret;
}
}
+
+ /**
+ *
+ */
+ private static class RDDRevFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
+ {
+ private static final long serialVersionUID = 1183373828539843938L;
+
+ private MatrixCharacteristics _mcIn = null;
+
+ public RDDRevFunction(MatrixCharacteristics mcIn)
+ throws DMLRuntimeException
+ {
+ _mcIn = mcIn;
+ }
+
+ @Override
+ public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 )
+ throws Exception
+ {
+ //construct input
+ IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0);
+
+ //execute reverse operation
+ ArrayList<IndexedMatrixValue> out = new ArrayList<IndexedMatrixValue>();
+ LibMatrixReorg.rev(in, _mcIn.getRows(), _mcIn.getRowsPerBlock(), out);
+
+ //construct output
+ return SparkUtils.fromIndexedMatrixBlock(out);
+ }
+ }
/**
*
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
index b10844f..545d425 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
@@ -31,6 +31,7 @@ import java.util.Map.Entry;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.DiagIndex;
+import org.apache.sysml.runtime.functionobjects.RevIndex;
import org.apache.sysml.runtime.functionobjects.SortIndex;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
@@ -61,6 +62,7 @@ public class LibMatrixReorg
private enum ReorgType {
TRANSPOSE,
+ REV,
DIAG,
RESHAPE,
SORT,
@@ -101,6 +103,8 @@ public class LibMatrixReorg
{
case TRANSPOSE:
return transpose(in, out);
+ case REV:
+ return rev(in, out);
case DIAG:
return diag(in, out);
case SORT:
@@ -141,7 +145,93 @@ public class LibMatrixReorg
return out;
}
+
+ /**
+ *
+ * @param in
+ * @param out
+ * @return
+ * @throws DMLRuntimeException
+ */
+ public static MatrixBlock rev( MatrixBlock in, MatrixBlock out )
+ throws DMLRuntimeException
+ {
+ //Timing time = new Timing(true);
+
+ //sparse-safe operation
+ if( in.isEmptyBlock(false) )
+ return out;
+
+ //special case: row vector
+ if( in.rlen == 1 ) {
+ out.copy(in);
+ return out;
+ }
+
+ if( in.sparse )
+ reverseSparse( in, out );
+ else
+ reverseDense( in, out );
+
+ //System.out.println("rev ("+in.rlen+", "+in.clen+", "+in.sparse+") in "+time.stop()+" ms.");
+ return out;
+ }
+
+ /**
+ *
+ * @param in
+ * @param rows1
+ * @param brlen
+ * @param out
+ * @throws DMLRuntimeException
+ * @throws DMLUnsupportedOperationException
+ */
+ public static void rev( IndexedMatrixValue in, long rlen, int brlen, ArrayList<IndexedMatrixValue> out )
+ throws DMLRuntimeException, DMLUnsupportedOperationException
+ {
+ //input block reverse
+ MatrixIndexes inix = in.getIndexes();
+ MatrixBlock inblk = (MatrixBlock) in.getValue();
+ MatrixBlock tmpblk = rev(inblk, new MatrixBlock(inblk.getNumRows(), inblk.getNumColumns(), inblk.isInSparseFormat()));
+
+ //split and expand block if necessary (at most 2 blocks)
+ if( rlen % brlen == 0 ) //special case: aligned blocks
+ {
+ int nrblks = (int)Math.ceil((double)rlen/brlen);
+ out.add(new IndexedMatrixValue(
+ new MatrixIndexes(nrblks-inix.getRowIndex()+1, inix.getColumnIndex()), tmpblk));
+ }
+ else //general case: unaligned blocks
+ {
+ //compute target positions and sizes
+ long pos1 = rlen - UtilFunctions.computeCellIndex(inix.getRowIndex(), brlen, tmpblk.getNumRows()-1) + 1;
+ long pos2 = pos1 + tmpblk.getNumRows() - 1;
+ int ipos1 = UtilFunctions.computeCellInBlock(pos1, brlen);
+ int iposCut = tmpblk.getNumRows() - ipos1 - 1;
+ int blkix1 = (int)UtilFunctions.computeBlockIndex(pos1, brlen);
+ int blkix2 = (int)UtilFunctions.computeBlockIndex(pos2, brlen);
+ int blklen1 = (int)UtilFunctions.computeBlockSize(rlen, blkix1, brlen);
+ int blklen2 = (int)UtilFunctions.computeBlockSize(rlen, blkix2, brlen);
+
+ //slice first block
+ MatrixIndexes outix1 = new MatrixIndexes(blkix1, inix.getColumnIndex());
+ MatrixBlock outblk1 = new MatrixBlock(blklen1, inblk.getNumColumns(), inblk.isInSparseFormat());
+ MatrixBlock tmp1 = tmpblk.sliceOperations(0, iposCut, 0, tmpblk.getNumColumns()-1, new MatrixBlock());
+ outblk1.leftIndexingOperations(tmp1, ipos1, outblk1.getNumRows()-1, 0, tmpblk.getNumColumns()-1, outblk1, true);
+ out.add(new IndexedMatrixValue(outix1, outblk1));
+
+ //slice second block (if necessary)
+ if( blkix1 != blkix2 ) {
+ MatrixIndexes outix2 = new MatrixIndexes(blkix2, inix.getColumnIndex());
+ MatrixBlock outblk2 = new MatrixBlock(blklen2, inblk.getNumColumns(), inblk.isInSparseFormat());
+ MatrixBlock tmp2 = tmpblk.sliceOperations(iposCut+1, tmpblk.getNumRows()-1, 0, tmpblk.getNumColumns()-1, new MatrixBlock());
+ outblk2.leftIndexingOperations(tmp2, 0, tmp2.getNumRows()-1, 0, tmpblk.getNumColumns()-1, outblk2, true);
+ out.add(new IndexedMatrixValue(outix2, outblk2));
+ }
+ }
+ }
+
/**
*
* @param in
@@ -619,6 +709,9 @@ public class LibMatrixReorg
if( op.fn instanceof SwapIndex ) //transpose
return ReorgType.TRANSPOSE;
+ else if( op.fn instanceof RevIndex ) //rev
+ return ReorgType.REV;
+
else if( op.fn instanceof DiagIndex ) //diag
return ReorgType.DIAG;
@@ -910,6 +1003,66 @@ public class LibMatrixReorg
}
/**
+ *
+ * @param in
+ * @param out
+ * @throws DMLRuntimeException
+ */
+ private static void reverseDense(MatrixBlock in, MatrixBlock out)
+ throws DMLRuntimeException
+ {
+ final int m = in.rlen;
+ final int n = in.clen;
+ final int len = m * n;
+
+ //set basic meta data and allocate output
+ out.sparse = false;
+ out.nonZeros = in.nonZeros;
+ out.allocateDenseBlock(false);
+
+ double[] a = in.getDenseArray();
+ double[] c = out.getDenseArray();
+
+ //copy all rows into target positions
+ if( n == 1 ) { //column vector
+ for( int i=0; i<m; i++ )
+ c[m-1-i] = a[i];
+ }
+ else { //general matrix case
+ for( int i=0, aix=0; i<m; i++, aix+=n )
+ System.arraycopy(a, aix, c, len-aix-n, n);
+ }
+ }
+
+ /**
+ *
+ * @param in
+ * @param out
+ * @throws DMLRuntimeException
+ */
+ private static void reverseSparse(MatrixBlock in, MatrixBlock out)
+ throws DMLRuntimeException
+ {
+ final int m = in.rlen;
+
+ //set basic meta data and allocate output
+ out.sparse = true;
+ out.nonZeros = in.nonZeros;
+
+ out.allocateSparseRowsBlock(false);
+
+ SparseRow[] a = in.getSparseRows();
+ SparseRow[] c = out.getSparseRows();
+
+ //copy all rows into target positions
+ for( int i=0; i<m; i++ ) {
+ if( a[i] != null && !a[i].isEmpty() ) {
+ c[m-1-i] = new SparseRow(a[i]);
+ }
+ }
+ }
+
+ /**
* Generic implementation diagV2M (non-performance critical)
* (in most-likely DENSE, out most likely SPARSE)
*
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 7b20502..45b666b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -35,7 +35,6 @@ import java.util.Map;
import org.apache.commons.math3.random.Well1024a;
import org.apache.hadoop.io.DataInputBuffer;
-
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.Hop.OpOp2;
@@ -58,6 +57,7 @@ import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
+import org.apache.sysml.runtime.functionobjects.RevIndex;
import org.apache.sysml.runtime.functionobjects.SortIndex;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
@@ -3597,7 +3597,8 @@ public class MatrixBlock extends MatrixValue implements Externalizable
public MatrixValue reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length)
throws DMLRuntimeException
{
- if ( !( op.fn instanceof SwapIndex || op.fn instanceof DiagIndex || op.fn instanceof SortIndex) )
+ if ( !( op.fn instanceof SwapIndex || op.fn instanceof DiagIndex
+ || op.fn instanceof SortIndex || op.fn instanceof RevIndex ) )
throw new DMLRuntimeException("the current reorgOperations cannot support: "+op.fn.getClass()+".");
MatrixBlock result = checkType(ret);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/test/java/org/apache/sysml/test/integration/functions/reorg/FullReverseTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/reorg/FullReverseTest.java b/src/test/java/org/apache/sysml/test/integration/functions/reorg/FullReverseTest.java
new file mode 100644
index 0000000..35b7391
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/reorg/FullReverseTest.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.reorg;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+
+public class FullReverseTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "Reverse1";
+ private final static String TEST_NAME2 = "Reverse2";
+
+ private final static String TEST_DIR = "functions/reorg/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest.class.getSimpleName() + "/";
+
+ private final static int rows1 = 2017;
+ private final static int cols1 = 1001;
+ private final static double sparsity1 = 0.7;
+ private final static double sparsity2 = 0.1;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"B"}));
+ }
+
+ @Test
+ public void testReverseVectorDenseCP() {
+ runReverseTest(TEST_NAME1, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testReverseVectorSparseCP() {
+ runReverseTest(TEST_NAME1, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testReverseVectorDenseMR() {
+ runReverseTest(TEST_NAME1, false, false, ExecType.MR);
+ }
+
+ @Test
+ public void testReverseVectorSparseMR() {
+ runReverseTest(TEST_NAME1, false, true, ExecType.MR);
+ }
+
+ @Test
+ public void testReverseVectorDenseSP() {
+ runReverseTest(TEST_NAME1, false, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testReverseVectorSparseSP() {
+ runReverseTest(TEST_NAME1, false, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testReverseMatrixDenseCP() {
+ runReverseTest(TEST_NAME1, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testReverseMatrixSparseCP() {
+ runReverseTest(TEST_NAME1, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testReverseMatrixDenseMR() {
+ runReverseTest(TEST_NAME1, true, false, ExecType.MR);
+ }
+
+ @Test
+ public void testReverseMatrixSparseMR() {
+ runReverseTest(TEST_NAME1, true, true, ExecType.MR);
+ }
+
+ @Test
+ public void testReverseMatrixDenseSP() {
+ runReverseTest(TEST_NAME1, true, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testReverseMatrixSparseSP() {
+ runReverseTest(TEST_NAME1, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testReverseVectorDenseRewriteCP() {
+ runReverseTest(TEST_NAME2, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testReverseMatrixDenseRewriteCP() {
+ runReverseTest(TEST_NAME2, true, false, ExecType.CP);
+ }
+
+
+ /**
+ *
+ * @param sparseM1
+ * @param sparseM2
+ * @param instType
+ */
+ private void runReverseTest(String testname, boolean matrix, boolean sparse, ExecType instType)
+ {
+ //rtplatform for MR
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( instType ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+ }
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ String TEST_NAME = testname;
+
+ try
+ {
+ int cols = matrix ? cols1 : 1;
+ double sparsity = sparse ? sparsity2 : sparsity1;
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ /* This is for running the junit test the new way, i.e., construct the arguments directly */
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-stats","-explain","-args", input("A"), output("B") };
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
+
+ //generate actual dataset
+ double[][] A = getRandomMatrix(rows1, cols, -1, 1, sparsity, 7);
+ writeInputMatrixWithMTD("A", A, true);
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("B");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("B");
+ TestUtils.compareMatrices(dmlfile, rfile, 0, "Stat-DML", "Stat-R");
+
+ //check generated opcode
+ if( instType == ExecType.CP )
+ Assert.assertTrue("Missing opcode: rev", Statistics.getCPHeavyHitterOpCodes().contains("rev"));
+ else if ( instType == ExecType.SPARK )
+ Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+"rev", Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+"rev"));
+ }
+ finally
+ {
+ //reset flags
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java
index 30106ab..db4598f 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java
@@ -196,7 +196,7 @@ public class FullSignTest extends AutomatedTestBase
if( instType == ExecType.CP )
Assert.assertTrue("Missing opcode: sign", Statistics.getCPHeavyHitterOpCodes().contains("sign"));
else if ( instType == ExecType.SPARK )
- Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+"sel+", Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+"sign"));
+ Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+"sign", Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+"sign"));
}
finally
{
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/test/scripts/functions/reorg/Reverse1.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/reorg/Reverse1.R b/src/test/scripts/functions/reorg/Reverse1.R
new file mode 100644
index 0000000..7537fe9
--- /dev/null
+++ b/src/test/scripts/functions/reorg/Reverse1.R
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+B = matrix(0, nrow(A), ncol(A));
+for( i in 1:ncol(A) )
+{
+ col = as.vector(A[,i])
+ col = rev(col);
+ B[,i] = col;
+}
+
+writeMM(as(B,"CsparseMatrix"), paste(args[2], "B", sep=""))
+
+
+
+
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/test/scripts/functions/reorg/Reverse1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/reorg/Reverse1.dml b/src/test/scripts/functions/reorg/Reverse1.dml
new file mode 100644
index 0000000..586d05a
--- /dev/null
+++ b/src/test/scripts/functions/reorg/Reverse1.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1);
+B = rev(A);
+write(B, $2);
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/test/scripts/functions/reorg/Reverse2.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/reorg/Reverse2.R b/src/test/scripts/functions/reorg/Reverse2.R
new file mode 100644
index 0000000..7537fe9
--- /dev/null
+++ b/src/test/scripts/functions/reorg/Reverse2.R
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+B = matrix(0, nrow(A), ncol(A));
+for( i in 1:ncol(A) )
+{
+ col = as.vector(A[,i])
+ col = rev(col);
+ B[,i] = col;
+}
+
+writeMM(as(B,"CsparseMatrix"), paste(args[2], "B", sep=""))
+
+
+
+
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1c9fef12/src/test/scripts/functions/reorg/Reverse2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/reorg/Reverse2.dml b/src/test/scripts/functions/reorg/Reverse2.dml
new file mode 100644
index 0000000..b1d796c
--- /dev/null
+++ b/src/test/scripts/functions/reorg/Reverse2.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1);
+B = table(seq(1,nrow(A),1),seq(nrow(A),1,-1)) %*% A;
+write(B, $2);
\ No newline at end of file
[2/2] incubator-systemml git commit: [SYSTEMML-268] Improved cox
script (exploit rev builtin, cse prep)
Posted by mb...@apache.org.
[SYSTEMML-268] Improved cox script (exploit rev builtin, cse prep)
https://issues.apache.org/jira/browse/SYSTEMML-268
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/89561054
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/89561054
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/89561054
Branch: refs/heads/master
Commit: 89561054704df5bacf894a9a34605d3d877e4425
Parents: 1c9fef1
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Wed Jan 6 20:05:28 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Thu Jan 7 09:58:53 2016 -0800
----------------------------------------------------------------------
scripts/algorithms/Cox.dml | 36 +++++++++++++++++-------------------
1 file changed, 17 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/89561054/scripts/algorithms/Cox.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/Cox.dml b/scripts/algorithms/Cox.dml
index 30cfd63..3cab74d 100644
--- a/scripts/algorithms/Cox.dml
+++ b/scripts/algorithms/Cox.dml
@@ -202,10 +202,8 @@ e_r = aggregate (target = RT, groups = RT, fn = "count");
# computing initial loss function value o
num_distinct = nrow (d_r); # no. of distinct timestamps
-I_rev = table (seq (1, num_distinct, 1), seq (num_distinct, 1, -1));
-I_rev_all = table (seq (1, N, 1), seq (N, 1, -1));
-e_r_rev_agg = cumsum (I_rev %*% e_r);
-d_r_rev = I_rev %*% d_r;
+e_r_rev_agg = cumsum (rev(e_r));
+d_r_rev = rev(d_r);
o = sum (d_r_rev * log (e_r_rev_agg));
o_init = o;
if (ncol (X_orig) < 3) {
@@ -223,7 +221,7 @@ if (ncol (X_orig) < 3) {
# part 1 g0_1
g0_1 = - t (colSums (X * E)); # g_1
# part 2 g0_2
-X_rev_agg = cumsum (I_rev_all %*% X);
+X_rev_agg = cumsum (rev(X));
select = table (seq (1, num_distinct), e_r_rev_agg);
X_agg = select %*% X_rev_agg;
g0_2 = t (colSums ((X_agg * d_r_rev)/ e_r_rev_agg));
@@ -251,9 +249,9 @@ while (sum_g2 > exit_g2 & i < maxiter) {
exp_Xb = exp (X %*% b);
exp_Xb_agg = aggregate (target = exp_Xb, groups = RT, fn = "sum");
- D_r_rev = cumsum (I_rev %*% exp_Xb_agg); # denominator
+ D_r_rev = cumsum (rev(exp_Xb_agg)); # denominator
X_exp_Xb = X * exp_Xb;
- X_exp_Xb_rev_agg = cumsum (I_rev_all %*% X_exp_Xb);
+ X_exp_Xb_rev_agg = cumsum (rev(X_exp_Xb));
X_exp_Xb_rev_agg = select %*% X_exp_Xb_rev_agg;
while (r2 > exit_r2 & (! trust_bound_reached) & j < maxinneriter) {
@@ -261,15 +259,15 @@ while (sum_g2 > exit_g2 & i < maxiter) {
# computing Hessian times d (Hd)
# part 1 Hd_1
Xd = X %*% d;
- X_Xd_exp_Xb = X * (Xd) * exp_Xb;
- X_Xd_exp_Xb_rev_agg = cumsum (I_rev_all %*% X_Xd_exp_Xb);
+ X_Xd_exp_Xb = X * (Xd * exp_Xb);
+ X_Xd_exp_Xb_rev_agg = cumsum (rev(X_Xd_exp_Xb));
X_Xd_exp_Xb_rev_agg = select %*% X_Xd_exp_Xb_rev_agg;
Hd_1 = X_Xd_exp_Xb_rev_agg / D_r_rev;
# part 2 Hd_2
Xd_exp_Xb = Xd * exp_Xb;
- Xd_exp_Xb_rev_agg = cumsum (I_rev_all %*% Xd_exp_Xb);
+ Xd_exp_Xb_rev_agg = cumsum (rev(Xd_exp_Xb));
Xd_exp_Xb_rev_agg = select %*% Xd_exp_Xb_rev_agg;
Hd_2_num = X_exp_Xb_rev_agg * Xd_exp_Xb_rev_agg; # numerator
@@ -292,7 +290,7 @@ while (sum_g2 > exit_g2 & i < maxiter) {
# part 2 so_2
exp_Xbsb = exp (X %*% (b + sb));
exp_Xbsb_agg = aggregate (target = exp_Xbsb, groups = RT, fn = "sum");
- so_2 = sum (d_r_rev * log (cumsum (I_rev %*% exp_Xbsb_agg)));
+ so_2 = sum (d_r_rev * log (cumsum (rev(exp_Xbsb_agg))));
#
so = so_1 + so_2;
so = so - o;
@@ -305,10 +303,10 @@ while (sum_g2 > exit_g2 & i < maxiter) {
exp_Xb = exp (X %*% b);
exp_Xb_agg = aggregate (target = exp_Xb, groups = RT, fn = "sum");
X_exp_Xb = X * exp_Xb;
- X_exp_Xb_rev_agg = cumsum (I_rev_all %*% X_exp_Xb);
+ X_exp_Xb_rev_agg = cumsum (rev(X_exp_Xb));
X_exp_Xb_rev_agg = select %*% X_exp_Xb_rev_agg;
- D_r_rev = cumsum (I_rev %*% exp_Xb_agg); # denominator
+ D_r_rev = cumsum (rev(exp_Xb_agg)); # denominator
g_2 = t (colSums ((X_exp_Xb_rev_agg / D_r_rev) * d_r_rev));
g = g0_1 + g_2;
sum_g2 = sum (g ^ 2);
@@ -325,15 +323,15 @@ print ("COMPUTING HESSIAN...");
H0 = matrix (0, rows = D, cols = D);
H = matrix (0, rows = D, cols = D);
-X_exp_Xb_rev_2 = I_rev_all %*% X_exp_Xb;
-X_rev_2 = I_rev_all %*% X;
+X_exp_Xb_rev_2 = rev(X_exp_Xb);
+X_rev_2 = rev(X);
-X_exp_Xb_rev_agg = cumsum (I_rev_all %*% X_exp_Xb);
+X_exp_Xb_rev_agg = cumsum (rev(X_exp_Xb));
X_exp_Xb_rev_agg = select %*% X_exp_Xb_rev_agg;
parfor (i in 1:D, check = 0) {
Xi = X[,i];
- Xi_rev = I_rev_all %*% Xi;
+ Xi_rev = rev(Xi);
## ----------Start calculating H0--------------
# part 1 H0_1
@@ -344,7 +342,7 @@ parfor (i in 1:D, check = 0) {
# part 2 H0_2
Xi_agg = aggregate (target = Xi, groups = RT, fn = "sum");
- Xi_agg_rev_agg = cumsum (I_rev %*% Xi_agg);
+ Xi_agg_rev_agg = cumsum (rev(Xi_agg));
H0_2_num = X_agg[,i:D] * Xi_agg_rev_agg; # numerator
H0_2 = H0_2_num / (e_r_rev_agg ^ 2);
@@ -362,7 +360,7 @@ parfor (i in 1:D, check = 0) {
Xi_exp_Xb = exp_Xb * Xi;
Xi_exp_Xb_agg = aggregate (target = Xi_exp_Xb, groups = RT, fn = "sum");
- Xi_exp_Xb_agg_rev_agg = cumsum (I_rev %*% Xi_exp_Xb_agg);
+ Xi_exp_Xb_agg_rev_agg = cumsum (rev(Xi_exp_Xb_agg));
H_2_num = X_exp_Xb_rev_agg[,i:D] * Xi_exp_Xb_agg_rev_agg; # numerator
H_2 = H_2_num / (D_r_rev ^ 2);
H[i,i:D] = colSums ((H_1 - H_2) * d_r_rev);