You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/12/14 17:49:58 UTC
systemml git commit: [SYSTEMML-540] Improved performance of
prediction via Keras2DML
Repository: systemml
Updated Branches:
refs/heads/master 3b87c2ba9 -> 341a1dc78
[SYSTEMML-540] Improved performance of prediction via Keras2DML
- Reduced the model loading time of VGG by 1.7x by supporting exchange of float32 matrices.
- Eliminated an additional mlcontext execution for converting probability to predicted labels. This improved the performance of VGG prediction by 15%.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/341a1dc7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/341a1dc7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/341a1dc7
Branch: refs/heads/master
Commit: 341a1dc789396ff3e46cf952a75bbe6958b77671
Parents: 3b87c2b
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Dec 14 09:49:48 2018 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Dec 14 09:49:48 2018 -0800
----------------------------------------------------------------------
.../spark/utils/RDDConverterUtilsExt.java | 35 ++++++++++-----
src/main/python/systemml/converters.py | 27 +++++++++---
src/main/python/tests/test_mlcontext.py | 25 +++++++++++
.../sysml/api/ml/BaseSystemMLClassifier.scala | 45 +++++++++++---------
4 files changed, 95 insertions(+), 37 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/341a1dc7/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index 4871aee..8db7558 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -126,13 +126,19 @@ public class RDDConverterUtilsExt
return df.select(columns.get(0), scala.collection.JavaConversions.asScalaBuffer(columnToSelect).toList());
}
- public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen) {
- return convertPy4JArrayToMB(data, (int)rlen, (int)clen, false);
+ // data_type: 0: int, 1: float and 2: double
+ public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen, long dataType) {
+ return convertPy4JArrayToMB(data, (int)rlen, (int)clen, false, dataType);
}
- public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen) {
- return convertPy4JArrayToMB(data, rlen, clen, false);
+ public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen, int dataType) {
+ return convertPy4JArrayToMB(data, rlen, clen, false, dataType);
}
+
+ public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen, boolean isSparse, long dataType) {
+ return convertPy4JArrayToMB(data, (int) rlen, (int) clen, isSparse, dataType);
+ }
+
public static MatrixBlock convertSciPyCOOToMB(byte [] data, byte [] row, byte [] col, long rlen, long clen, long nnz) {
return convertSciPyCOOToMB(data, row, col, (int)rlen, (int)clen, (int)nnz);
@@ -158,10 +164,6 @@ public class RDDConverterUtilsExt
return mb;
}
- public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen, boolean isSparse) {
- return convertPy4JArrayToMB(data, (int) rlen, (int) clen, isSparse);
- }
-
public static MatrixBlock allocateDenseOrSparse(int rlen, int clen, boolean isSparse) {
MatrixBlock ret = new MatrixBlock(rlen, clen, isSparse);
ret.allocateBlock();
@@ -195,7 +197,8 @@ public class RDDConverterUtilsExt
ret.examSparsity();
}
- public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen, boolean isSparse) {
+ // data_type: 0: int, 1: float and 2: double
+ public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen, boolean isSparse, long dataType) {
MatrixBlock mb = new MatrixBlock(rlen, clen, isSparse, -1);
if(isSparse) {
throw new DMLRuntimeException("Convertion to sparse format not supported");
@@ -207,9 +210,19 @@ public class RDDConverterUtilsExt
double [] denseBlock = new double[(int) limit];
ByteBuffer buf = ByteBuffer.wrap(data);
buf.order(ByteOrder.nativeOrder());
- for(int i = 0; i < rlen*clen; i++) {
- denseBlock[i] = buf.getDouble();
+ if(dataType == 0) {
+ for(int i = 0; i < rlen*clen; i++)
+ denseBlock[i] = (double)buf.getInt();
+ }
+ else if(dataType == 1) {
+ for(int i = 0; i < rlen*clen; i++)
+ denseBlock[i] = (double)buf.getFloat();
+ }
+ else if(dataType == 2) {
+ for(int i = 0; i < rlen*clen; i++)
+ denseBlock[i] = buf.getDouble();
}
+
mb.init( denseBlock, rlen, clen );
}
mb.recomputeNonZeros();
http://git-wip-us.apache.org/repos/asf/systemml/blob/341a1dc7/src/main/python/systemml/converters.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/converters.py b/src/main/python/systemml/converters.py
index 5954a30..1fc624a 100644
--- a/src/main/python/systemml/converters.py
+++ b/src/main/python/systemml/converters.py
@@ -221,11 +221,21 @@ def _convertSPMatrixToMB(sc, src):
def _convertDenseMatrixToMB(sc, src):
numCols = getNumCols(src)
numRows = src.shape[0]
- arr = src.ravel().astype(np.float64)
+ src = np.asarray(src, dtype=np.float64) if not isinstance(src, np.ndarray) else src
+ # data_type: 0: int, 1: float and 2: double
+ if src.dtype is np.dtype(np.int32):
+ arr = src.ravel().astype(np.int32)
+ dataType = 0
+ elif src.dtype is np.dtype(np.float32):
+ arr = src.ravel().astype(np.float32)
+ dataType = 1
+ else:
+ arr = src.ravel().astype(np.float64)
+ dataType = 2
buf = bytearray(arr.tostring())
createJavaObject(sc, 'dummy')
return sc._jvm.org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.convertPy4JArrayToMB(
- buf, numRows, numCols)
+ buf, numRows, numCols, dataType)
def _copyRowBlock(i, sc, ret, src, numRowsPerBlock, rlen, clen):
@@ -243,11 +253,14 @@ def _copyRowBlock(i, sc, ret, src, numRowsPerBlock, rlen, clen):
return i
-def convertToMatrixBlock(sc, src, maxSizeBlockInMB=8):
+def convertToMatrixBlock(sc, src, maxSizeBlockInMB=128):
if not isinstance(sc, SparkContext):
raise TypeError('sc needs to be of type SparkContext')
- isSparse = True if isinstance(src, spmatrix) else False
- src = np.asarray(src, dtype=np.float64) if not isSparse else src
+ if isinstance(src, spmatrix):
+ isSparse = True
+ else:
+ isSparse = False
+ src = np.asarray(src, dtype=np.float64) if not isinstance(src, np.ndarray) else src
if len(src.shape) != 2:
src_type = str(type(src).__name__)
raise TypeError('Expected 2-dimensional ' +
@@ -256,11 +269,11 @@ def convertToMatrixBlock(sc, src, maxSizeBlockInMB=8):
str(len(src.shape)) +
'-dimensional ' +
src_type)
+ worstCaseSizeInMB = (8*(src.getnnz()*3 if isSparse else src.shape[0]*src.shape[1])) / 1000000
# Ignoring sparsity for computing numRowsPerBlock for now
numRowsPerBlock = int(
math.ceil((maxSizeBlockInMB * 1000000) / (src.shape[1] * 8)))
- multiBlockTransfer = False if numRowsPerBlock >= src.shape[0] else True
- if not multiBlockTransfer:
+ if worstCaseSizeInMB <= maxSizeBlockInMB:
return _convertSPMatrixToMB(
sc, src) if isSparse else _convertDenseMatrixToMB(sc, src)
else:
http://git-wip-us.apache.org/repos/asf/systemml/blob/341a1dc7/src/main/python/tests/test_mlcontext.py
----------------------------------------------------------------------
diff --git a/src/main/python/tests/test_mlcontext.py b/src/main/python/tests/test_mlcontext.py
index e0db346..a144a15 100644
--- a/src/main/python/tests/test_mlcontext.py
+++ b/src/main/python/tests/test_mlcontext.py
@@ -28,6 +28,7 @@ import os
import sys
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
sys.path.insert(0, path)
+import numpy as np
import unittest
@@ -99,6 +100,30 @@ class TestAPI(unittest.TestCase):
script = dml(script).input(x1=5, x2=3).output("x3")
self.assertEqual(ml.execute(script).get("x3"), 8)
+ def test_numpy_float64(self):
+ script = """
+ x2 = x1 + 2.15
+ """
+ numpy_x1 = np.random.rand(5, 10).astype(np.float64)
+ script = dml(script).input(x1=numpy_x1).output("x2")
+ self.assertTrue(np.allclose(ml.execute(script).get("x2").toNumPy(), numpy_x1 + 2.15))
+
+ def test_numpy_float32(self):
+ script = """
+ x2 = x1 + 2.15
+ """
+ numpy_x1 = np.random.rand(5, 10).astype(np.float32)
+ script = dml(script).input(x1=numpy_x1).output("x2")
+ self.assertTrue(np.allclose(ml.execute(script).get("x2").toNumPy(), numpy_x1 + 2.15))
+
+ def test_numpy_int32(self):
+ script = """
+ x2 = x1 + 2
+ """
+ numpy_x1 = np.random.randint(1000, size=(5, 10)).astype(np.int32)
+ script = dml(script).input(x1=numpy_x1).output("x2")
+ self.assertTrue(np.allclose(ml.execute(script).get("x2").toNumPy(), numpy_x1 + 2))
+
def test_rdd(self):
sums = """
s1 = sum(m1)
http://git-wip-us.apache.org/repos/asf/systemml/blob/341a1dc7/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index 5d22c46..c1146d1 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -278,29 +278,39 @@ trait BaseSystemMLClassifierModel extends BaseSystemMLEstimatorModel {
val ml = new MLContext(sc)
updateML(ml)
val readScript = dml(dmlRead("X", X_file)).out("X")
- val res = ml.execute(readScript)
+ val res = ml.execute(readScript)
val script = getPredictionScript(isSingleNode)
val modelPredict = ml.execute(script._1.in(script._2, res.getMatrix("X")))
return modelPredict.getMatrix(probVar)
}
+
+ def replacePredictionWithProb(script: (Script, String), probVar: String, C: Int, H: Int, W: Int): Unit = {
+ // Append prediction code:
+ val newDML = "source(\"nn/util.dml\") as util;\n" +
+ script._1.getScriptString +
+ "\nPrediction = util::predict_class(" + probVar + ", " + C + ", " + H + ", " + W + ");"
+ script._1.setScriptString(newDML)
+
+ // Modify the output variables -> remove probability matrix and add Prediction
+ val outputVariables = new java.util.HashSet[String](script._1.getOutputVariables)
+ outputVariables.remove(probVar)
+ outputVariables.add("Prediction")
+ script._1.clearOutputs()
+ script._1.out(outputVariables.toList)
+ }
def baseTransform(X: MatrixBlock, sc: SparkContext, probVar: String, C: Int, H: Int, W: Int): MatrixBlock = {
- val Prob = baseTransformHelper(X, sc, probVar, C, H, W)
- val script1 = dml("source(\"nn/util.dml\") as util; Prediction = util::predict_class(Prob, C, H, W);")
- .out("Prediction")
- .in("Prob", Prob.toMatrixBlock, Prob.getMatrixMetadata)
- .in("C", C)
- .in("H", H)
- .in("W", W)
+ val isSingleNode = true
+ val ml = new MLContext(sc)
+ updateML(ml)
+ val script = getPredictionScript(isSingleNode)
- System.gc();
- val freeMem = Runtime.getRuntime().freeMemory();
- if(freeMem < OptimizerUtils.getLocalMemBudget()) {
- val LOG = LogFactory.getLog(classOf[BaseSystemMLClassifierModel].getName())
- LOG.warn("SystemML local memory budget:" + OptimizerUtils.toMB(OptimizerUtils.getLocalMemBudget()) + " mb. Approximate free memory available:" + OptimizerUtils.toMB(freeMem));
- }
- val ret = (new MLContext(sc)).execute(script1).getMatrix("Prediction").toMatrixBlock
-
+ replacePredictionWithProb(script, probVar, C, H, W)
+
+ // Now execute the prediction script directly
+ val ret = ml.execute(script._1.in(script._2, X, new MatrixMetadata(X.getNumRows, X.getNumColumns, X.getNonZeros)))
+ .getMatrix("Prediction").toMatrixBlock
+
if (ret.getNumColumns != 1 && H == 1 && W == 1) {
throw new RuntimeException("Expected predicted label to be a column vector")
}
@@ -312,9 +322,6 @@ trait BaseSystemMLClassifierModel extends BaseSystemMLEstimatorModel {
val ml = new MLContext(sc)
updateML(ml)
val script = getPredictionScript(isSingleNode)
- // Uncomment for debugging
- // ml.setExplainLevel(ExplainLevel.RECOMPILE_RUNTIME)
-
val modelPredict = ml.execute(script._1.in(script._2, X, new MatrixMetadata(X.getNumRows, X.getNumColumns, X.getNonZeros)))
return modelPredict.getMatrix(probVar)
}