You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/05/30 20:40:13 UTC

systemml git commit: [SYSTEMML-2350] Fix missing support for lists in as.scalar casts

Repository: systemml
Updated Branches:
  refs/heads/master 97018d4e6 -> a62b65c8f


[SYSTEMML-2350] Fix missing support for lists in as.scalar casts

This patch fixes the missing support for list inputs in as.scalar casts,
which is necessary as a means to index scalars out of unnamed or named
lists because the list indexing itself still returns a list of one
element. Furthermore, this also improves the error handling of the
related runtime instruction.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a62b65c8
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a62b65c8
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a62b65c8

Branch: refs/heads/master
Commit: a62b65c8f61ed8cf0b009732f8cbdb7c8eda95e9
Parents: 97018d4
Author: Matthias Boehm <mb...@gmail.com>
Authored: Wed May 30 13:40:09 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Wed May 30 13:40:09 2018 -0700

----------------------------------------------------------------------
 .../sysml/parser/BuiltinFunctionExpression.java  | 14 ++++++++------
 .../java/org/apache/sysml/parser/Expression.java |  2 +-
 .../instructions/cp/VariableCPInstruction.java   | 19 ++++++++++++++-----
 .../mnist_lenet_paramserv_minimum_version.dml    |  4 ++--
 4 files changed, 25 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 0e949d0..ea51bd1 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -496,18 +496,20 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			output.setValueType(id.getValueType());
 			
 			break;
-			
 		case CAST_AS_SCALAR:
 			checkNumParameters(1);
-			checkMatrixFrameParam(getFirstExpr());
-			if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1) || ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) {
-				raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() +  " dim2 " + getFirstExpr().getOutput().getDim2(), 
-				          conditional, LanguageErrorCodes.INVALID_PARAMETERS);
+			checkDataTypeParam(getFirstExpr(),
+				DataType.MATRIX, DataType.FRAME, DataType.LIST);
+			if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1)
+				|| ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) {
+				raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() 
+					+  " dim2 " + getFirstExpr().getOutput().getDim2(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
 			}
 			output.setDataType(DataType.SCALAR);
 			output.setDimensions(0, 0);
 			output.setBlockDimensions (0, 0);
-			output.setValueType(id.getValueType());
+			output.setValueType((id.getValueType()!=ValueType.UNKNOWN) ?
+				id.getValueType() : ValueType.DOUBLE);
 			break;
 		case CAST_AS_MATRIX:
 			checkNumParameters(1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index fd3f855..9a6ea64 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -194,7 +194,7 @@ public abstract class Expression implements ParseInfo
 		public boolean isScalar() {
 			return (this == SCALAR);
 		}
-		public boolean isComposite() {
+		public boolean isList() {
 			return (this == LIST);
 		}
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
index 5786e87..b46f4df 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
@@ -553,7 +553,7 @@ public class VariableCPInstruction extends CPInstruction {
 			break;
 			
 		case CastAsScalarVariable: //castAsScalarVariable
-			if( getInput1().getDataType()==DataType.FRAME ) {
+			if( getInput1().getDataType().isFrame() ) {
 				FrameBlock fBlock = ec.getFrameInput(getInput1().getName());
 				if( fBlock.getNumRows()!=1 || fBlock.getNumColumns()!=1 )
 					throw new DMLRuntimeException("Dimension mismatch - unable to cast frame '"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x "+fBlock.getNumColumns()+") to scalar.");
@@ -562,7 +562,7 @@ public class VariableCPInstruction extends CPInstruction {
 				ec.setScalarOutput(output.getName(), 
 						ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value));
 			}
-			else { //assume DataType.MATRIX otherwise
+			else if( getInput1().getDataType().isMatrix() ) {
 				MatrixBlock mBlock = ec.getMatrixInput(getInput1().getName(), getExtendedOpcode());
 				if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 )
 					throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix '"+getInput1().getName()+"' of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar.");
@@ -570,21 +570,30 @@ public class VariableCPInstruction extends CPInstruction {
 				ec.releaseMatrixInput(getInput1().getName(), getExtendedOpcode());
 				ec.setScalarOutput(output.getName(), new DoubleObject(value));
 			}
+			else if( getInput1().getDataType().isList() ) {
+				//TODO handling of cleanup status, potentially new object
+				ListObject list = (ListObject)ec.getVariable(getInput1().getName());
+				ec.setVariable(output.getName(), list.slice(0));
+			}
+			else {
+				throw new DMLRuntimeException("Unsupported data type "
+					+ "in as.scalar(): "+getInput1().getDataType().name());
+			}
 			break;
 		case CastAsMatrixVariable:{
-			if( getInput1().getDataType()==DataType.FRAME ) {
+			if( getInput1().getDataType().isFrame() ) {
 				FrameBlock fin = ec.getFrameInput(getInput1().getName());
 				MatrixBlock out = DataConverter.convertToMatrixBlock(fin);
 				ec.releaseFrameInput(getInput1().getName());
 				ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
 			}
-			else if( getInput1().getDataType()==DataType.SCALAR ) {
+			else if( getInput1().getDataType().isScalar() ) {
 				ScalarObject scalarInput = ec.getScalarInput(
 					getInput1().getName(), getInput1().getValueType(), getInput1().isLiteral());
 				MatrixBlock out = new MatrixBlock(scalarInput.getDoubleValue());
 				ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
 			}
-			else if( getInput1().getDataType()==DataType.LIST ) {
+			else if( getInput1().getDataType().isList() ) {
 				//TODO handling of cleanup status, potentially new object
 				ListObject list = (ListObject)ec.getVariable(getInput1().getName());
 				ec.setVariable(output.getName(), list.slice(0));

http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index 2ef7411..8811c36 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -234,8 +234,8 @@ aggregation = function(list[unknown] model,
      vb2 = as.matrix(model["vb2"])
      vb3 = as.matrix(model["vb3"])
      vb4 = as.matrix(model["vb4"])
-     lr = 0.01
-     mu = 0.9
+     lr = as.scalar(hyperparams['lr']);
+     mu = as.scalar(hyperparams['mu']);
 
      # Optimize with SGD w/ Nesterov momentum
      [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)