You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/05 17:04:15 UTC
[1/3] incubator-systemml git commit: New sign builtin function
(parser/compiler/runtime), incl tests/rewrites
Repository: incubator-systemml
Updated Branches:
refs/heads/master edbab297b -> 86583584f
New sign builtin function (parser/compiler/runtime), incl tests/rewrites
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/87980ce2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/87980ce2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/87980ce2
Branch: refs/heads/master
Commit: 87980ce2dc0db74b6f1f28f3187ac931aa4a7c41
Parents: edbab29
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Mon Jan 4 13:39:31 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Mon Jan 4 21:08:45 2016 -0800
----------------------------------------------------------------------
src/main/java/org/apache/sysml/hops/Hop.java | 3 +-
.../java/org/apache/sysml/hops/UnaryOp.java | 1 -
.../hops/cost/CostEstimatorStaticRuntime.java | 4 +-
.../RewriteAlgebraicSimplificationStatic.java | 37 +++-
src/main/java/org/apache/sysml/lops/Unary.java | 7 +-
.../sysml/parser/BuiltinFunctionExpression.java | 4 +
.../org/apache/sysml/parser/DMLTranslator.java | 4 +
.../org/apache/sysml/parser/Expression.java | 1 +
.../sysml/runtime/functionobjects/Builtin.java | 105 +++++-----
.../instructions/CPInstructionParser.java | 1 +
.../instructions/MRInstructionParser.java | 1 +
.../instructions/SPInstructionParser.java | 1 +
.../runtime/matrix/operators/UnaryOperator.java | 3 +-
.../functions/unary/matrix/FullSignTest.java | 207 +++++++++++++++++++
src/test/scripts/functions/unary/matrix/Sign1.R | 33 +++
.../scripts/functions/unary/matrix/Sign1.dml | 26 +++
src/test/scripts/functions/unary/matrix/Sign2.R | 33 +++
.../scripts/functions/unary/matrix/Sign2.dml | 26 +++
18 files changed, 430 insertions(+), 67 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 4f520e4..ee815fe 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1024,7 +1024,7 @@ public abstract class Hop
}
public enum OpOp1 {
- NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SQRT, LOG, EXP,
+ NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SIGN, SQRT, LOG, EXP,
CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
PRINT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, FLOOR, MEDIAN, INVERSE,
//cumulative sums, products, extreme values
@@ -1231,6 +1231,7 @@ public abstract class Hop
HopsOpOp1LopsU.put(OpOp1.ASIN, org.apache.sysml.lops.Unary.OperationTypes.ASIN);
HopsOpOp1LopsU.put(OpOp1.ACOS, org.apache.sysml.lops.Unary.OperationTypes.ACOS);
HopsOpOp1LopsU.put(OpOp1.ATAN, org.apache.sysml.lops.Unary.OperationTypes.ATAN);
+ HopsOpOp1LopsU.put(OpOp1.SIGN, org.apache.sysml.lops.Unary.OperationTypes.SIGN);
HopsOpOp1LopsU.put(OpOp1.SQRT, org.apache.sysml.lops.Unary.OperationTypes.SQRT);
HopsOpOp1LopsU.put(OpOp1.EXP, org.apache.sysml.lops.Unary.OperationTypes.EXP);
HopsOpOp1LopsU.put(OpOp1.LOG, org.apache.sysml.lops.Unary.OperationTypes.LOG);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index 8654c46..a3ca530 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -93,7 +93,6 @@ public class UnaryOp extends Hop
public String getOpString() {
String s = new String("");
s += "u(" + _op.toString().toLowerCase() + ")";
- // s += HopsOpOp1String.get(_op) + ")";
return s;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
index 41fb506..15003c3 100644
--- a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
+++ b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
@@ -1002,7 +1002,7 @@ public class CostEstimatorStaticRuntime extends CostEstimator
return d3m * d3n;
- case BuiltinUnary: //opcodes: exp, abs, sin, cos, tan, sqrt, plogp, print, round, sprop, sigmoid
+ case BuiltinUnary: //opcodes: exp, abs, sin, cos, tan, sign, sqrt, plogp, print, round, sprop, sigmoid
if( optype.equals("print") ) //scalar only
return 1;
else
@@ -1013,7 +1013,7 @@ public class CostEstimatorStaticRuntime extends CostEstimator
if( optype.equals("sin") || optype.equals("tan") || optype.equals("round")
|| optype.equals("abs") || optype.equals("sqrt") || optype.equals("sprop")
- || optype.equals("sigmoid") ) //sparse-safe
+ || optype.equals("sigmoid") || optype.equals("sign") ) //sparse-safe
{
if( leftSparse ) //sparse
return xbu * d1m * d1n * d1s;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5483bed..20e94a0 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -133,7 +133,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = removeUnnecessaryBinaryOperation(hop, hi, i); //e.g., X*1 -> X (dep: should come after rm unnecessary vectorize)
hi = fuseDatagenAndBinaryOperation(hop, hi, i); //e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7)
hi = fuseDatagenAndMinusOperation(hop, hi, i); //e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2)
- hi = simplifyBinaryToUnaryOperation(hi); //e.g., X*X -> X^2 (pow2)
+ hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X)
hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
@@ -470,15 +470,19 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
* handle simplification of binary operations
* (relies on previous common subexpression elimination)
*
- * X+X -> X*2 or X*X -> X^2
+ * X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X)
+ * @throws HopsException
*/
- private Hop simplifyBinaryToUnaryOperation( Hop hi )
+ private Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi, int pos )
+ throws HopsException
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
+
+ //patterns: X+X -> X*2, X*X -> X^2,
if( left == right && left.getDataType()==DataType.MATRIX )
{
//note: we simplify this to unary operations first (less mem and better MR plan),
@@ -504,6 +508,30 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
LOG.debug("Applied simplifyBinaryToUnaryOperation2");
}
}
+ //patterns: (X>0)-(X<0) -> sign(X)
+ else if( bop.getOp() == OpOp2.MINUS
+ && left instanceof BinaryOp && right instanceof BinaryOp
+ && ((BinaryOp)left).getOp()==OpOp2.GREATER && ((BinaryOp)right).getOp()==OpOp2.LESS
+ && left.getInput().get(0) == right.getInput().get(0)
+ && left.getInput().get(1) instanceof LiteralOp
+ && HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0
+ && right.getInput().get(1) instanceof LiteralOp
+ && HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 )
+ {
+ UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN);
+
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.removeAllChildReferences(hi);
+ HopRewriteUtils.addChildReference(parent, uop, pos);
+ if( left.getParent().isEmpty() )
+ HopRewriteUtils.removeAllChildReferences(left);
+ if( right.getParent().isEmpty() )
+ HopRewriteUtils.removeAllChildReferences(right);
+
+ hi = uop;
+
+ LOG.debug("Applied simplifyBinaryToUnaryOperation3");
+ }
}
return hi;
@@ -752,7 +780,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
&& ((BinaryOp)hi.getInput().get(0)).isPPredOperation() )
{
UnaryOp uop = (UnaryOp) hi; //valid unary op
- if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.CEIL
+ if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN
+ || uop.getOp()==OpOp1.SELP || uop.getOp()==OpOp1.CEIL
|| uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND )
{
//clear link unary-binary
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/lops/Unary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Unary.java b/src/main/java/org/apache/sysml/lops/Unary.java
index 544e255..8375202 100644
--- a/src/main/java/org/apache/sysml/lops/Unary.java
+++ b/src/main/java/org/apache/sysml/lops/Unary.java
@@ -37,7 +37,10 @@ public class Unary extends Lop
{
public enum OperationTypes {
- ADD, SUBTRACT, SUBTRACTRIGHT, MULTIPLY, MULTIPLY2, DIVIDE, MODULUS, INTDIV, MINUS1_MULTIPLY, POW, POW2, LOG, MAX, MIN, NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SQRT, EXP, Over, LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS, ROUND, CEIL, FLOOR, MR_IQM, INVERSE,
+ ADD, SUBTRACT, SUBTRACTRIGHT, MULTIPLY, MULTIPLY2, DIVIDE, MODULUS, INTDIV, MINUS1_MULTIPLY,
+ POW, POW2, LOG, MAX, MIN, NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SIGN, SQRT, EXP, Over,
+ LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
+ ROUND, CEIL, FLOOR, MR_IQM, INVERSE,
CUMSUM, CUMPROD, CUMMIN, CUMMAX,
SPROP, SIGMOID, SELP, SUBTRACT_NZ, LOG_NZ,
NOTSUPPORTED
@@ -199,6 +202,8 @@ public class Unary extends Lop
return "acos";
case ATAN:
return "atan";
+ case SIGN:
+ return "sign";
case SQRT:
return "sqrt";
case EXP:
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 260e05d..a0fd56e 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -1031,6 +1031,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
case ACOS:
case ASIN:
case ATAN:
+ case SIGN:
case SQRT:
case ABS:
case LOG:
@@ -1053,6 +1054,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
case ACOS:
case ASIN:
case ATAN:
+ case SIGN:
case SQRT:
case ABS:
case EXP:
@@ -1299,6 +1301,8 @@ public class BuiltinFunctionExpression extends DataIdentifier
bifop = Expression.BuiltinFunctionOp.NCOL;
else if (functionName.equals("nrow"))
bifop = Expression.BuiltinFunctionOp.NROW;
+ else if (functionName.equals("sign"))
+ bifop = Expression.BuiltinFunctionOp.SIGN;
else if (functionName.equals("sqrt"))
bifop = Expression.BuiltinFunctionOp.SQRT;
else if (functionName.equals("sum"))
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 42ada29..5b6ee70 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2426,6 +2426,7 @@ public class DMLTranslator
case ASIN:
case ACOS:
case ATAN:
+ case SIGN:
case SQRT:
case EXP:
case ROUND:
@@ -2458,6 +2459,9 @@ public class DMLTranslator
case ATAN:
mathOp1 = Hop.OpOp1.ATAN;
break;
+ case SIGN:
+ mathOp1 = Hop.OpOp1.SIGN;
+ break;
case SQRT:
mathOp1 = Hop.OpOp1.SQRT;
break;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index e74e256..ad288de 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -94,6 +94,7 @@ public abstract class Expression
ROWSUM,
SEQ,
SIN,
+ SIGN,
SQRT,
SUM,
TABLE,
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
index 533c057..beea376 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
@@ -50,7 +50,7 @@ public class Builtin extends ValueFunction
private static final long serialVersionUID = 3836744687789840574L;
- public enum BuiltinFunctionCode { INVALID, SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SQRT, EXP, PLOGP, PRINT, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP };
+ public enum BuiltinFunctionCode { INVALID, SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP };
public BuiltinFunctionCode bFunc;
private static final boolean FASTMATH = true;
@@ -72,6 +72,7 @@ public class Builtin extends ValueFunction
String2BuiltinFunctionCode.put( "maxindex" , BuiltinFunctionCode.MAXINDEX);
String2BuiltinFunctionCode.put( "minindex" , BuiltinFunctionCode.MININDEX);
String2BuiltinFunctionCode.put( "abs" , BuiltinFunctionCode.ABS);
+ String2BuiltinFunctionCode.put( "sign" , BuiltinFunctionCode.SIGN);
String2BuiltinFunctionCode.put( "sqrt" , BuiltinFunctionCode.SQRT);
String2BuiltinFunctionCode.put( "exp" , BuiltinFunctionCode.EXP);
String2BuiltinFunctionCode.put( "plogp" , BuiltinFunctionCode.PLOGP);
@@ -96,7 +97,7 @@ public class Builtin extends ValueFunction
// We should create one object for every builtin function that we support
private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null, atanObj = null;
private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null;
- private static Builtin absObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null;
+ private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null;
private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null;
private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null;
private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null;
@@ -185,6 +186,10 @@ public class Builtin extends ValueFunction
if ( absObj == null )
absObj = new Builtin(BuiltinFunctionCode.ABS);
return absObj;
+ case SIGN:
+ if ( signObj == null )
+ signObj = new Builtin(BuiltinFunctionCode.SIGN);
+ return signObj;
case SQRT:
if ( sqrtObj == null )
sqrtObj = new Builtin(BuiltinFunctionCode.SQRT);
@@ -285,6 +290,7 @@ public class Builtin extends ValueFunction
case ASIN:
case ACOS:
case ATAN:
+ case SIGN:
case SQRT:
case EXP:
case PLOGP:
@@ -317,71 +323,58 @@ public class Builtin extends ValueFunction
}
}
- public double execute (double in) throws DMLRuntimeException {
+ public double execute (double in)
+ throws DMLRuntimeException
+ {
switch(bFunc) {
- case SIN: return FASTMATH ? FastMath.sin(in) : Math.sin(in);
- case COS: return FASTMATH ? FastMath.cos(in) : Math.cos(in);
- case TAN: return FASTMATH ? FastMath.tan(in) : Math.tan(in);
- case ASIN: return FASTMATH ? FastMath.asin(in) : Math.asin(in);
- case ACOS: return FASTMATH ? FastMath.acos(in) : Math.acos(in);
- case ATAN: return Math.atan(in); //faster in Math
- case CEIL: return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
- case FLOOR: return FASTMATH ? FastMath.floor(in) : Math.floor(in);
- case LOG:
- //if ( in <= 0 )
- // throw new DMLRuntimeException("Builtin.execute(): logarithm can only be computed for non-negative numbers (input = " + in + ").");
- // for negative numbers, Math.log will return NaN
- return FASTMATH ? FastMath.log(in) : Math.log(in);
- case LOG_NZ:
- return (in==0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in);
-
- case ABS:
- return Math.abs(in); //no need for FastMath
-
- case SQRT:
- //if ( in < 0 )
- // throw new DMLRuntimeException("Builtin.execute(): squareroot can only be computed for non-negative numbers (input = " + in + ").");
- return Math.sqrt(in); //faster in Math
-
- case PLOGP:
- if (in == 0.0)
- return 0.0;
- else if (in < 0)
- return Double.NaN;
- else
- return (in * (FASTMATH ? FastMath.log(in) : Math.log(in)));
+ case SIN: return FASTMATH ? FastMath.sin(in) : Math.sin(in);
+ case COS: return FASTMATH ? FastMath.cos(in) : Math.cos(in);
+ case TAN: return FASTMATH ? FastMath.tan(in) : Math.tan(in);
+ case ASIN: return FASTMATH ? FastMath.asin(in) : Math.asin(in);
+ case ACOS: return FASTMATH ? FastMath.acos(in) : Math.acos(in);
+ case ATAN: return Math.atan(in); //faster in Math
+ case CEIL: return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
+ case FLOOR: return FASTMATH ? FastMath.floor(in) : Math.floor(in);
+ case LOG: return FASTMATH ? FastMath.log(in) : Math.log(in);
+ case LOG_NZ: return (in==0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in);
+ case ABS: return Math.abs(in); //no need for FastMath
+ case SIGN: return FASTMATH ? FastMath.signum(in) : Math.signum(in);
+ case SQRT: return Math.sqrt(in); //faster in Math
+ case EXP: return FASTMATH ? FastMath.exp(in) : Math.exp(in);
+ case ROUND: return Math.round(in); //no need for FastMath
- case EXP:
- return FASTMATH ? FastMath.exp(in) : Math.exp(in);
-
- case ROUND:
- return Math.round(in); //no need for FastMath
+ case PLOGP:
+ if (in == 0.0)
+ return 0.0;
+ else if (in < 0)
+ return Double.NaN;
+ else
+ return (in * (FASTMATH ? FastMath.log(in) : Math.log(in)));
- case SPROP:
- //sample proportion: P*(1-P)
- return in * (1 - in);
-
- case SIGMOID:
- //sigmoid: 1/(1+exp(-x))
- return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in));
-
- case SELP:
- //select positive: x*(x>0)
- return (in > 0) ? in : 0;
+ case SPROP:
+ //sample proportion: P*(1-P)
+ return in * (1 - in);
+
+ case SIGMOID:
+ //sigmoid: 1/(1+exp(-x))
+ return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in));
- default:
- throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
+ case SELP:
+ //select positive: x*(x>0)
+ return (in > 0) ? in : 0;
+
+ default:
+ throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
public double execute (long in) throws DMLRuntimeException {
- return this.execute((double)in);
+ return execute((double)in);
}
/*
* Builtin functions with two inputs
- */
-
+ */
public double execute (double in1, double in2) throws DMLRuntimeException {
switch(bFunc) {
@@ -525,6 +518,4 @@ public class Builtin extends ValueFunction
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
-
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 5b84ea2..b34dbc9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -151,6 +151,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "asin" , CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "acos" , CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "atan" , CPINSTRUCTION_TYPE.BuiltinUnary);
+ String2CPInstructionType.put( "sign" , CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "sqrt" , CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "plogp" , CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "print" , CPINSTRUCTION_TYPE.BuiltinUnary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
index 5aee907..57051c2 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
@@ -136,6 +136,7 @@ public class MRInstructionParser extends InstructionParser
String2MRInstructionType.put( "asin" , MRINSTRUCTION_TYPE.Unary);
String2MRInstructionType.put( "acos" , MRINSTRUCTION_TYPE.Unary);
String2MRInstructionType.put( "atan" , MRINSTRUCTION_TYPE.Unary);
+ String2MRInstructionType.put( "sign" , MRINSTRUCTION_TYPE.Unary);
String2MRInstructionType.put( "sqrt" , MRINSTRUCTION_TYPE.Unary);
String2MRInstructionType.put( "exp" , MRINSTRUCTION_TYPE.Unary);
String2MRInstructionType.put( "log" , MRINSTRUCTION_TYPE.Unary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index 2683c43..5d5450c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -191,6 +191,7 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "asin" , SPINSTRUCTION_TYPE.BuiltinUnary);
String2SPInstructionType.put( "acos" , SPINSTRUCTION_TYPE.BuiltinUnary);
String2SPInstructionType.put( "atan" , SPINSTRUCTION_TYPE.BuiltinUnary);
+ String2SPInstructionType.put( "sign" , SPINSTRUCTION_TYPE.BuiltinUnary);
String2SPInstructionType.put( "sqrt" , SPINSTRUCTION_TYPE.BuiltinUnary);
String2SPInstructionType.put( "plogp" , SPINSTRUCTION_TYPE.BuiltinUnary);
String2SPInstructionType.put( "round" , SPINSTRUCTION_TYPE.BuiltinUnary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
index e560003..1a45794 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
@@ -40,7 +40,8 @@ public class UnaryOperator extends Operator
if(f.bFunc==Builtin.BuiltinFunctionCode.SIN || f.bFunc==Builtin.BuiltinFunctionCode.TAN
|| f.bFunc==Builtin.BuiltinFunctionCode.ROUND || f.bFunc==Builtin.BuiltinFunctionCode.ABS
|| f.bFunc==Builtin.BuiltinFunctionCode.SQRT || f.bFunc==Builtin.BuiltinFunctionCode.SPROP
- || f.bFunc==Builtin.BuiltinFunctionCode.SELP || f.bFunc==Builtin.BuiltinFunctionCode.LOG_NZ )
+ || f.bFunc==Builtin.BuiltinFunctionCode.SELP || f.bFunc==Builtin.BuiltinFunctionCode.LOG_NZ
+ || f.bFunc==Builtin.BuiltinFunctionCode.SIGN )
{
sparseSafe = true;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java
new file mode 100644
index 0000000..30106ab
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSignTest.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.unary.matrix;
+
+import java.util.HashMap;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ *
+ *
+ */
+public class FullSignTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "Sign1";
+ private final static String TEST_NAME2 = "Sign2";
+ private final static String TEST_DIR = "functions/unary/matrix/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FullSignTest.class.getSimpleName() + "/";
+
+ private final static int rows = 1108;
+ private final static int cols = 1001;
+ private final static double spSparse = 0.05;
+ private final static double spDense = 0.7;
+
+ @Override
+ public void setUp()
+ {
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"B"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"B"}));
+
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @BeforeClass
+ public static void init()
+ {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp()
+ {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+ }
+
+ @Test
+ public void testSignDenseCP() {
+ runSignTest(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testSignSparseCP() {
+ runSignTest(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testSignDenseMR() {
+ runSignTest(TEST_NAME1, false, ExecType.MR);
+ }
+
+ @Test
+ public void testSignSparseMR() {
+ runSignTest(TEST_NAME1, true, ExecType.MR);
+ }
+
+ @Test
+ public void testSignDenseSP() {
+ runSignTest(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testSignSparseSP() {
+ runSignTest(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testRewriteSignDenseCP() {
+ runSignTest(TEST_NAME2, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteSignSparseCP() {
+ runSignTest(TEST_NAME2, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteSignDenseMR() {
+ runSignTest(TEST_NAME2, false, ExecType.MR);
+ }
+
+ @Test
+ public void testRewriteSignSparseMR() {
+ runSignTest(TEST_NAME2, true, ExecType.MR);
+ }
+
+ @Test
+ public void testRewriteSignDenseSP() {
+ runSignTest(TEST_NAME2, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testRewriteSignSparseSP() {
+ runSignTest(TEST_NAME2, true, ExecType.SPARK);
+ }
+
+ /**
+ *
+ * @param sparseM1
+ * @param sparseM2
+ * @param instType
+ */
+ private void runSignTest( String testname, boolean sparse, ExecType instType)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( instType ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ try
+ {
+ String TEST_NAME = testname;
+ double sparsity = (sparse) ? spSparse : spDense;
+
+ String TEST_CACHE_DIR = "";
+ if (TEST_CACHE_ENABLED)
+ {
+ TEST_CACHE_DIR = sparsity + "/";
+ }
+
+ TestConfiguration config = getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config, TEST_CACHE_DIR);
+
+ // This is for running the junit test the new way, i.e., construct the arguments directly
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+ //stats parameter required for opcode check
+ programArgs = new String[]{"-stats", "-args", input("A"), output("B") };
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
+
+ //generate actual dataset
+ double[][] A = getRandomMatrix(rows, cols, -1, 1, sparsity, 7);
+ writeInputMatrixWithMTD("A", A, true);
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("B");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("B");
+ TestUtils.compareMatrices(dmlfile, rfile, 0, "Stat-DML", "Stat-R");
+
+ //check generated opcode
+ if( instType == ExecType.CP )
+ Assert.assertTrue("Missing opcode: sign", Statistics.getCPHeavyHitterOpCodes().contains("sign"));
+ else if ( instType == ExecType.SPARK )
+ Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+"sel+", Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+"sign"));
+ }
+ finally
+ {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/test/scripts/functions/unary/matrix/Sign1.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/unary/matrix/Sign1.R b/src/test/scripts/functions/unary/matrix/Sign1.R
new file mode 100644
index 0000000..837ba80
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/Sign1.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = sign(A);
+
+writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep=""));
+
+
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/test/scripts/functions/unary/matrix/Sign1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/unary/matrix/Sign1.dml b/src/test/scripts/functions/unary/matrix/Sign1.dml
new file mode 100644
index 0000000..93c8534
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/Sign1.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+
+A = read($1);
+B = sign(A);
+write(B, $2);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/test/scripts/functions/unary/matrix/Sign2.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/unary/matrix/Sign2.R b/src/test/scripts/functions/unary/matrix/Sign2.R
new file mode 100644
index 0000000..837ba80
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/Sign2.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = sign(A);
+
+writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep=""));
+
+
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/87980ce2/src/test/scripts/functions/unary/matrix/Sign2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/unary/matrix/Sign2.dml b/src/test/scripts/functions/unary/matrix/Sign2.dml
new file mode 100644
index 0000000..40ed84c
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/Sign2.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+
+A = read($1);
+B = ppred(A, 0, ">") - ppred(A, 0, "<");
+write(B, $2);
[2/3] incubator-systemml git commit: Performance dense sparse-safe
unary block operations (single loop, nnz)
Posted by mb...@apache.org.
Performance dense sparse-safe unary block operations (single loop, nnz)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/e80f94b2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/e80f94b2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/e80f94b2
Branch: refs/heads/master
Commit: e80f94b22819634fd96c9e1dcd7da143244d9b23
Parents: 87980ce
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Mon Jan 4 13:40:58 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Mon Jan 4 21:08:53 2016 -0800
----------------------------------------------------------------------
.../apache/sysml/runtime/matrix/data/MatrixBlock.java | 13 +++++--------
1 file changed, 5 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e80f94b2/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index afc4788..7b20502 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -3085,14 +3085,11 @@ public class MatrixBlock extends MatrixValue implements Externalizable
double[] c = ret.denseBlock;
//unary op, incl nnz maintenance
- for( int i=0, ix=0; i<m; i++ ) {
- for( int j=0; j<n; j++, ix++ ) {
- c[ix] = op.fn.execute(a[ix]);
- if( c[ix] != 0 )
- ret.nonZeros++;
- }
- }
-
+ int len = m*n;
+ for( int i=0; i<len; i++ ) {
+ c[i] = op.fn.execute(a[i]);
+ ret.nonZeros += (c[i] != 0) ? 1 : 0;
+ }
}
}
[3/3] incubator-systemml git commit: Improved lasso script (sign
computation, cleanup indentation)
Posted by mb...@apache.org.
Improved lasso script (sign computation, cleanup indentation)
Incl cleanup test suite packages (missing tests, ordering)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/86583584
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/86583584
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/86583584
Branch: refs/heads/master
Commit: 86583584fccf09da320c54cf3cdacd9134bc3465
Parents: e80f94b
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Tue Jan 5 00:07:17 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Tue Jan 5 00:07:17 2016 -0800
----------------------------------------------------------------------
scripts/staging/regression/lasso/lasso.dml | 38 ++++++++++----------
.../functions/unary/matrix/ZPackageSuite.java | 20 +++++------
2 files changed, 29 insertions(+), 29 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/86583584/scripts/staging/regression/lasso/lasso.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/regression/lasso/lasso.dml b/scripts/staging/regression/lasso/lasso.dml
index 16761f3..fb520df 100644
--- a/scripts/staging/regression/lasso/lasso.dml
+++ b/scripts/staging/regression/lasso/lasso.dml
@@ -23,6 +23,8 @@
X = read($X)
y = read($Y)
+n = nrow(X)
+m = ncol(X)
#params
tol = 10^(-15)
@@ -30,9 +32,6 @@ M = 5
tau = 1
maxiter = 1000
-n = nrow(X)
-m = ncol(X)
-
#constants
eta = 2
sigma = 0.01
@@ -55,7 +54,7 @@ history[M,1] = obj
inactive_set = matrix(1, rows=m, cols=1)
iter = 0
continue = TRUE
-while(iter < maxiter & continue){
+while(iter < maxiter & continue) {
dw = matrix(0, rows=m, cols=1)
dg = matrix(0, rows=m, cols=1)
relChangeObj = -1.0
@@ -63,22 +62,20 @@ while(iter < maxiter & continue){
inner_iter = 0
inner_continue = TRUE
inner_maxiter = 100
- while(inner_iter < inner_maxiter & inner_continue){
+ while(inner_iter < inner_maxiter & inner_continue) {
u = w - g/alpha
lambda = tau/alpha
- signum_u = ppred(u, 0, ">") - ppred(u, 0, "<")
- wnew = signum_u * (abs(u) - lambda) * ppred(abs(u) - lambda, 0, ">")
-
+ wnew = sign(u) * (abs(u) - lambda) * ppred(abs(u) - lambda, 0, ">")
dw = wnew - w
dw2 = sum(dw*dw)
r = X %*% wnew - y
gnew = t(X) %*% r
- objnew = 0.5 * sum(r*r) + tau*sum(abs(wnew))
-
+ objnew = 0.5 * sum(r*r) + tau*sum(abs(wnew))
obj_threshold = max(history) - 0.5*sigma*alpha*dw2
- if(objnew <= obj_threshold){
+
+ if(objnew <= obj_threshold) {
w = wnew
dg = gnew - g
g = gnew
@@ -88,25 +85,28 @@ while(iter < maxiter & continue){
history[M,1] = objnew
relChangeObj = abs(objnew - obj)/obj
obj = objnew
- }else alpha = eta*alpha
+ }
+ else
+ alpha = eta*alpha
inner_iter = inner_iter + 1
}
- if(inner_continue) print("Inner loop did not converge")
+ if(inner_continue)
+ print("Inner loop did not converge")
alphanew = sum(dw*dg)/sum(dw*dw)
alpha = max(alpha_min, min(alpha_max, alphanew))
old_inactive_set = inactive_set
- inactive_set = ppred(w, 0, "!=")
- diff = sum(abs(old_inactive_set - inactive_set))
-
- if(diff == 0 & relChangeObj < tol) continue = FALSE
+ inactive_set = ppred(w, 0, "!=")
+ diff = sum(abs(old_inactive_set - inactive_set))
- num_inactive = sum(ppred(w, 0, "!="))
- print("ITER=" + iter + " OBJ=" + obj + " relative change=" + relChangeObj + " num_inactive=" + num_inactive)
+ if(diff == 0 & relChangeObj < tol)
+ continue = FALSE
+ num_inactive = sum(ppred(w, 0, "!="))
+ print("ITER=" + iter + " OBJ=" + obj + " relative change=" + relChangeObj + " num_inactive=" + num_inactive)
iter = iter + 1
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/86583584/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java
index 137eba3..dac067d 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java
@@ -33,28 +33,28 @@ import org.junit.runners.Suite;
CastAsScalarTest.class,
CosTest.class,
DiagTest.class,
+ EigenFactorizeTest.class,
+ FullCummaxTest.class,
+ FullCumminTest.class,
+ FullCumprodTest.class,
+ FullCumsumTest.class,
+ FullSelectPosTest.class,
+ FullSignTest.class,
IQMTest.class,
+ LUFactorizeTest.class,
MatrixInverseTest.class,
MinusTest.class,
+ MLUnaryBuiltinTest.class,
NegationTest.class,
PrintTest.class,
QRSolverTest.class,
+ RemoveEmptyTest.class,
ReplaceTest.class,
RoundTest.class,
SinTest.class,
SqrtTest.class,
TanTest.class,
TransposeTest.class,
-
- EigenFactorizeTest.class,
- FullCumminTest.class,
- FullCummaxTest.class,
- FullCumprodTest.class,
- FullCumsumTest.class,
- FullSelectPosTest.class,
- LUFactorizeTest.class,
- RemoveEmptyTest.class,
- MLUnaryBuiltinTest.class
})