You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/05/30 20:40:13 UTC
systemml git commit: [SYSTEMML-2350] Fix missing support for lists in
as.scalar casts
Repository: systemml
Updated Branches:
refs/heads/master 97018d4e6 -> a62b65c8f
[SYSTEMML-2350] Fix missing support for lists in as.scalar casts
This patch fixes the missing support for list inputs in as.scalar casts,
which is necessary as a means to index scalars out of unnamed or named
lists because the list indexing itself still returns a list of one
element. Furthermore, this also improves the error handling of the
related runtime instruction.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a62b65c8
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a62b65c8
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a62b65c8
Branch: refs/heads/master
Commit: a62b65c8f61ed8cf0b009732f8cbdb7c8eda95e9
Parents: 97018d4
Author: Matthias Boehm <mb...@gmail.com>
Authored: Wed May 30 13:40:09 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Wed May 30 13:40:09 2018 -0700
----------------------------------------------------------------------
.../sysml/parser/BuiltinFunctionExpression.java | 14 ++++++++------
.../java/org/apache/sysml/parser/Expression.java | 2 +-
.../instructions/cp/VariableCPInstruction.java | 19 ++++++++++++++-----
.../mnist_lenet_paramserv_minimum_version.dml | 4 ++--
4 files changed, 25 insertions(+), 14 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 0e949d0..ea51bd1 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -496,18 +496,20 @@ public class BuiltinFunctionExpression extends DataIdentifier
output.setValueType(id.getValueType());
break;
-
case CAST_AS_SCALAR:
checkNumParameters(1);
- checkMatrixFrameParam(getFirstExpr());
- if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1) || ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) {
- raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() + " dim2 " + getFirstExpr().getOutput().getDim2(),
- conditional, LanguageErrorCodes.INVALID_PARAMETERS);
+ checkDataTypeParam(getFirstExpr(),
+ DataType.MATRIX, DataType.FRAME, DataType.LIST);
+ if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1)
+ || ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) {
+ raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1()
+ + " dim2 " + getFirstExpr().getOutput().getDim2(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
- output.setValueType(id.getValueType());
+ output.setValueType((id.getValueType()!=ValueType.UNKNOWN) ?
+ id.getValueType() : ValueType.DOUBLE);
break;
case CAST_AS_MATRIX:
checkNumParameters(1);
http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index fd3f855..9a6ea64 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -194,7 +194,7 @@ public abstract class Expression implements ParseInfo
public boolean isScalar() {
return (this == SCALAR);
}
- public boolean isComposite() {
+ public boolean isList() {
return (this == LIST);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
index 5786e87..b46f4df 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
@@ -553,7 +553,7 @@ public class VariableCPInstruction extends CPInstruction {
break;
case CastAsScalarVariable: //castAsScalarVariable
- if( getInput1().getDataType()==DataType.FRAME ) {
+ if( getInput1().getDataType().isFrame() ) {
FrameBlock fBlock = ec.getFrameInput(getInput1().getName());
if( fBlock.getNumRows()!=1 || fBlock.getNumColumns()!=1 )
throw new DMLRuntimeException("Dimension mismatch - unable to cast frame '"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x "+fBlock.getNumColumns()+") to scalar.");
@@ -562,7 +562,7 @@ public class VariableCPInstruction extends CPInstruction {
ec.setScalarOutput(output.getName(),
ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value));
}
- else { //assume DataType.MATRIX otherwise
+ else if( getInput1().getDataType().isMatrix() ) {
MatrixBlock mBlock = ec.getMatrixInput(getInput1().getName(), getExtendedOpcode());
if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 )
throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix '"+getInput1().getName()+"' of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar.");
@@ -570,21 +570,30 @@ public class VariableCPInstruction extends CPInstruction {
ec.releaseMatrixInput(getInput1().getName(), getExtendedOpcode());
ec.setScalarOutput(output.getName(), new DoubleObject(value));
}
+ else if( getInput1().getDataType().isList() ) {
+ //TODO handling of cleanup status, potentially new object
+ ListObject list = (ListObject)ec.getVariable(getInput1().getName());
+ ec.setVariable(output.getName(), list.slice(0));
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported data type "
+ + "in as.scalar(): "+getInput1().getDataType().name());
+ }
break;
case CastAsMatrixVariable:{
- if( getInput1().getDataType()==DataType.FRAME ) {
+ if( getInput1().getDataType().isFrame() ) {
FrameBlock fin = ec.getFrameInput(getInput1().getName());
MatrixBlock out = DataConverter.convertToMatrixBlock(fin);
ec.releaseFrameInput(getInput1().getName());
ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
}
- else if( getInput1().getDataType()==DataType.SCALAR ) {
+ else if( getInput1().getDataType().isScalar() ) {
ScalarObject scalarInput = ec.getScalarInput(
getInput1().getName(), getInput1().getValueType(), getInput1().isLiteral());
MatrixBlock out = new MatrixBlock(scalarInput.getDoubleValue());
ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
}
- else if( getInput1().getDataType()==DataType.LIST ) {
+ else if( getInput1().getDataType().isList() ) {
//TODO handling of cleanup status, potentially new object
ListObject list = (ListObject)ec.getVariable(getInput1().getName());
ec.setVariable(output.getName(), list.slice(0));
http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index 2ef7411..8811c36 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -234,8 +234,8 @@ aggregation = function(list[unknown] model,
vb2 = as.matrix(model["vb2"])
vb3 = as.matrix(model["vb3"])
vb4 = as.matrix(model["vb4"])
- lr = 0.01
- mu = 0.9
+ lr = as.scalar(hyperparams['lr']);
+ mu = as.scalar(hyperparams['mu']);
# Optimize with SGD w/ Nesterov momentum
[W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)