You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ac...@apache.org on 2017/07/26 20:04:20 UTC

systemml git commit: [SYSTEMML-1703] Fix input data processing for Caffe VGG-19 model

Repository: systemml
Updated Branches:
  refs/heads/master 8f412ac5c -> 5fa84ccfa


[SYSTEMML-1703] Fix input data processing for Caffe VGG-19 model

Closes 588.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5fa84ccf
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5fa84ccf
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5fa84ccf

Branch: refs/heads/master
Commit: 5fa84ccfab89b6e2e43b1f4ad7a571c4d0e46cf1
Parents: 8f412ac
Author: Arvind Surve <ac...@yahoo.com>
Authored: Wed Jul 26 13:03:36 2017 -0700
Committer: Arvind Surve <ac...@yahoo.com>
Committed: Wed Jul 26 13:03:36 2017 -0700

----------------------------------------------------------------------
 src/main/python/systemml/converters.py | 90 +++++++++++++++++++++++------
 1 file changed, 71 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/5fa84ccf/src/main/python/systemml/converters.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/converters.py b/src/main/python/systemml/converters.py
index 87a9a45..1309cb9 100644
--- a/src/main/python/systemml/converters.py
+++ b/src/main/python/systemml/converters.py
@@ -19,18 +19,21 @@
 #
 #-------------------------------------------------------------
 
-__all__ = [ 'getNumCols', 'convertToMatrixBlock', 'convert_caffemodel', 'convert_lmdb_to_jpeg', 'convertToNumPyArr', 'convertToPandasDF', 'SUPPORTED_TYPES' , 'convertToLabeledDF', 'convertImageToNumPyArr']
+__all__ = [ 'getNumCols', 'convertToMatrixBlock', 'convert_caffemodel', 'convert_lmdb_to_jpeg', 'convertToNumPyArr', 'convertToPandasDF', 'SUPPORTED_TYPES' , 'convertToLabeledDF', 'convertImageToNumPyArr', 'getDatasetMean']
 
 import numpy as np
 import pandas as pd
 import os
 import math
+
 from pyspark.context import SparkContext
 from scipy.sparse import coo_matrix, spmatrix, csr_matrix
 from .classloader import *
 
 SUPPORTED_TYPES = (np.ndarray, pd.DataFrame, spmatrix)
 
+DATASET_MEAN = {'VGG_ILSVRC_19_2014':[103.939, 116.779, 123.68]}
+
 def getNumCols(numPyArr):
     if numPyArr.ndim == 1:
         return 1
@@ -39,7 +42,7 @@ def getNumCols(numPyArr):
 
 def get_pretty_str(key, value):
     return '\t"' + key + '": ' + str(value) + ',\n'
-        
+
 def save_tensor_csv(tensor, file_path, shouldTranspose):
     w = w.reshape(w.shape[0], -1)
     if shouldTranspose:
@@ -51,29 +54,29 @@ def save_tensor_csv(tensor, file_path, shouldTranspose):
         file.write(get_pretty_str('cols', w.shape[1]))
         file.write(get_pretty_str('nnz', np.count_nonzero(w)))
         file.write('\t"format": "csv",\n\t"description": {\n\t\t"author": "SystemML"\n\t}\n}\n')
-    
+
 def convert_caffemodel(sc, deploy_file, caffemodel_file, output_dir, format="binary", is_caffe_installed=False):
     """
-    Saves the weights and bias in the caffemodel file to output_dir in the specified format. 
+    Saves the weights and bias in the caffemodel file to output_dir in the specified format.
     This method does not requires caffe to be installed.
-    
+
     Parameters
     ----------
     sc: SparkContext
         SparkContext
-    
+
     deploy_file: string
         Path to the input network file
-        
+
     caffemodel_file: string
         Path to the input caffemodel file
-    
+
     output_dir: string
         Path to the output directory
-    
+
     format: string
         Format of the weights and bias (can be binary, csv or text)
-    
+
     is_caffe_installed: bool
         True if caffe is installed
     """
@@ -104,17 +107,17 @@ def convert_caffemodel(sc, deploy_file, caffemodel_file, output_dir, format="bin
         utilObj = sc._jvm.org.apache.sysml.api.dl.Utils()
         utilObj.saveCaffeModelFile(sc._jsc, deploy_file, caffemodel_file, output_dir, format)
 
-    
+
 def convert_lmdb_to_jpeg(lmdb_img_file, output_dir):
     """
     Saves the images in the lmdb file as jpeg in the output_dir. This method requires caffe to be installed along with lmdb and cv2 package.
     To install cv2 package, do `pip install opencv-python`.
-    
+
     Parameters
     ----------
     lmdb_img_file: string
         Path to the input lmdb file
-    
+
     output_dir: string
         Output directory for images (local filesystem)
     """
@@ -163,7 +166,7 @@ def _convertSPMatrixToMB(sc, src):
     buf3 = bytearray(col.tostring())
     createJavaObject(sc, 'dummy')
     return sc._jvm.org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.convertSciPyCOOToMB(buf1, buf2, buf3, numRows, numCols, nnz)
-            
+
 def _convertDenseMatrixToMB(sc, src):
     numCols = getNumCols(src)
     numRows = src.shape[0]
@@ -178,7 +181,7 @@ def _copyRowBlock(i, sc, ret, src, numRowsPerBlock,  rlen, clen):
     mb = _convertSPMatrixToMB(sc, tmp) if isinstance(src, spmatrix) else _convertDenseMatrixToMB(sc, tmp)
     sc._jvm.org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.copyRowBlocks(mb, rowIndex, ret, numRowsPerBlock, rlen, clen)
     return i
-    
+
 def convertToMatrixBlock(sc, src, maxSizeBlockInMB=8):
     if not isinstance(sc, SparkContext):
         raise TypeError('sc needs to be of type SparkContext')
@@ -212,10 +215,38 @@ def convertToNumPyArr(sc, mb):
     else:
         raise TypeError('sc needs to be of type SparkContext') # TODO: We can generalize this by creating py4j gateway ourselves
 
+# Returns the mean of a model if defined otherwise None
+def getDatasetMean(dataset_name):
+    """
+    Input Parameters
+    ----------------
+    dataset_name: Name of the dataset used to train model. This name is artificial name based on dataset used to train the model.
+
+    Returns
+    -------
+    mean: Mean value of model if its defined in the list DATASET_MEAN else None.
+
+    """
+
+    try:
+        mean = DATASET_MEAN[dataset_name.upper()]
+    except:
+        mean = None
+    return mean
+
+
 # Example usage: convertImageToNumPyArr(im, img_shape=(3, 224, 224), add_rotated_images=True, add_mirrored_images=True)
 # The above call returns a numpy array of shape (6, 50176) in NCHW format
-def convertImageToNumPyArr(im, img_shape=None, add_rotated_images=False, add_mirrored_images=False):
-    from PIL import Image
+def convertImageToNumPyArr(im, img_shape=None, add_rotated_images=False, add_mirrored_images=False,
+    color_mode = 'RGB', mean=None):
+
+    ## Input Parameters
+
+    # color_mode: In case of VGG models which expect image data in BGR format instead of RGB for other most models,
+    # color_mode parameter is used to process image data in BGR format.
+
+    # mean: mean value is used to subtract from input data from every pixel value. By default value is None, so mean value not subtracted.
+
     if img_shape is not None:
         num_channels = img_shape[0]
         size = (img_shape[1], img_shape[2])
@@ -224,24 +255,45 @@ def convertImageToNumPyArr(im, img_shape=None, add_rotated_images=False, add_mir
         size = None
     if num_channels != 1 and num_channels != 3:
         raise ValueError('Expected the number of channels to be either 1 or 3')
+
+    from PIL import Image
+
     if size is not None:
         im = im.resize(size, Image.LANCZOS)
     expected_mode = 'L' if num_channels == 1 else 'RGB'
     if expected_mode is not im.mode:
         im = im.convert(expected_mode)
+
     def _im2NumPy(im):
         if expected_mode == 'L':
             return np.asarray(im.getdata()).reshape((1, -1))
         else:
-            # (H,W,C) --> (C,H,W) --> (1, C*H*W)
-            return np.asarray(im).transpose(2, 0, 1).reshape((1, -1))
+            im = (np.array(im).astype(np.float))
+
+            # (H,W,C) -> (C,H,W)
+            im = im.transpose(2, 0, 1)
+
+            # RGB -> BGR
+            if color_mode == 'BGR':
+                im = im[...,::-1]
+
+            # Subtract Mean
+            if mean is not None:
+                for c in range(3):
+                    im[:, :, c] = im[:, :, c] - mean[c]
+
+            # (C,H,W) --> (1, C*H*W)
+            return im.reshape((1, -1))
+
     ret = _im2NumPy(im)
+
     if add_rotated_images:
         ret = np.vstack((ret, _im2NumPy(im.rotate(90)), _im2NumPy(im.rotate(180)), _im2NumPy(im.rotate(270)) ))
     if add_mirrored_images:
         ret = np.vstack((ret, _im2NumPy(im.transpose(Image.FLIP_LEFT_RIGHT)), _im2NumPy(im.transpose(Image.FLIP_TOP_BOTTOM))))
     return ret
 
+
 def convertToPandasDF(X):
     if not isinstance(X, pd.DataFrame):
         return pd.DataFrame(X, columns=['C' + str(i) for i in range(getNumCols(X))])