You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2016/12/03 00:27:36 UTC

[1/2] incubator-systemml git commit: [SYSTEMML-1116] Make SystemML Python DSL NumPy-friendly

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 398490e3e -> 23ccab85c


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/systemml/defmatrix.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/defmatrix.py b/src/main/python/systemml/defmatrix.py
index be9bc5f..6a56690 100644
--- a/src/main/python/systemml/defmatrix.py
+++ b/src/main/python/systemml/defmatrix.py
@@ -19,19 +19,26 @@
 #
 #-------------------------------------------------------------
 
-trigFn = [ 'exp', 'log', 'abs', 'sqrt', 'round', 'floor', 'ceil', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'sign' ]
-__all__ = [ 'setSparkContext', 'matrix', 'eval', 'solve', 'DMLOp' ] + trigFn
-
-
-from pyspark import SparkContext
-from pyspark.sql import DataFrame, SQLContext
-
-from . import MLContext, pydml
+__all__ = [ 'setSparkContext', 'matrix', 'eval', 'solve', 'DMLOp', 'set_lazy', 'debug_array_conversion', 'load', 'full', 'seq' ]
+
+import numpy as np
+import pandas as pd
+from scipy.sparse import coo_matrix, spmatrix
+try:
+    import py4j.java_gateway
+    from py4j.java_gateway import JavaObject
+    from pyspark import SparkContext
+    from pyspark.sql import DataFrame, SQLContext
+    import pyspark.mllib.common
+except ImportError:
+    raise ImportError('Unable to import `pyspark`. Hint: Make sure you are running with PySpark.')
+
+from . import MLContext, pydml, _java2py, Matrix
 from .converters import *
 
 def setSparkContext(sc):
     """
-    Before using the matrix, the user needs to invoke this function.
+    Before using the matrix, the user needs to invoke this function if SparkContext is not previously created in the session.
 
     Parameters
     ----------
@@ -43,7 +50,7 @@ def setSparkContext(sc):
     matrix.ml = MLContext(matrix.sc)
 
 
-def checkIfMLContextIsSet():
+def check_MLContext():
     if matrix.ml is None:
         if SparkContext._active_spark_context is not None:
             setSparkContext(SparkContext._active_spark_context)
@@ -60,24 +67,41 @@ class DMLOp(object):
         self.inputs = inputs
         self.dml = dml
         self.ID = None
+        self.depth = 1
         for m in self.inputs:
             m.referenced = m.referenced + [ self ]
+            if isinstance(m, matrix) and m.op is not None:
+                self.depth = max(self.depth, m.op.depth + 1)
 
+    MAX_DEPTH = 0
+    
     def _visit(self, execute=True):
         matrix.dml = matrix.dml + self.dml
 
-    # Don't use this method instead use matrix's printAST()
-    def printAST(self, numSpaces):
+    def _print_ast(self, numSpaces):
         ret = []
         for m in self.inputs:
-            ret = [ m.printAST(numSpaces+2) ]
+            ret = [ m._print_ast(numSpaces+2) ]
         return ''.join(ret)
 
 # Special object used internally to specify the placeholder which will be replaced by output ID
-# This helps to provide dml containing output ID in constructIntermediateNode
+# This helps to provide dml containing output ID in construct_intermediate_node
 OUTPUT_ID = '$$OutputID$$'
 
-def constructIntermediateNode(inputs, dml):
+def set_lazy(isLazy):
+    """
+    This method allows users to set whether the matrix operations should be executed in lazy manner.
+    
+    Parameters
+    ----------
+    isLazy: True if matrix operations should be evaluated in lazy manner.
+    """
+    if isLazy:
+        DMLOp.MAX_DEPTH = 0
+    else:
+        DMLOp.MAX_DEPTH = 1
+    
+def construct_intermediate_node(inputs, dml):
     """
     Convenient utility to create an intermediate node of AST.
 
@@ -89,8 +113,32 @@ def constructIntermediateNode(inputs, dml):
     dmlOp = DMLOp(inputs)
     out = matrix(None, op=dmlOp)
     dmlOp.dml = [out.ID if x==OUTPUT_ID else x for x in dml]
+    if DMLOp.MAX_DEPTH > 0 and out.op.depth >= DMLOp.MAX_DEPTH:
+        out.eval()
     return out
 
+def load(file, format='csv'):
+    """
+    Allows user to load a matrix from filesystem
+
+    Parameters
+    ----------
+    file: filepath
+    format: can be csv, text or binary or mm
+    """
+    return construct_intermediate_node([], [OUTPUT_ID, ' = load(\"', file, '\", format=\"', format, '\")\n'])
+
+def full(shape, fill_value):
+    """
+    Return a new array of given shape filled with fill_value.
+
+    Parameters
+    ----------
+    shape: tuple of length 2
+    fill_value: float or int
+    """
+    return construct_intermediate_node([], [OUTPUT_ID, ' = full(', str(fill_value), ', rows=', str(shape[0]), ', cols=', str(shape[1]), ')\n'])    
+
 def reset():
     """
     Resets the visited status of matrix and the operators in the generated AST.
@@ -102,7 +150,7 @@ def reset():
     matrix.dml = []
     matrix.script = pydml('')
 
-def performDFS(outputs, execute):
+def perform_dfs(outputs, execute):
     """
     Traverses the forest of nodes rooted at outputs nodes and returns the DML script to execute
     """
@@ -116,50 +164,74 @@ def performDFS(outputs, execute):
 
 ########################## Utility functions ##################################
 
-
-def binaryOp(lhs, rhs, opStr):
+def _log_base(val, base):
+    if not isinstance(val, str):
+        raise ValueError('The val to _log_base should be of type string')
+    return '(log(' + val + ')/log(' + str(base) + '))' 
+    
+def _matricize(lhs, inputs):
     """
-    Common function called by all the binary operators in matrix class
+    Utility fn to convert the supported types to matrix class or to string (if float or int)
+    and return the string to be passed to DML as well as inputs
     """
-    inputs = []
+    if isinstance(lhs, SUPPORTED_TYPES):
+        lhs = matrix(lhs)
     if isinstance(lhs, matrix):
         lhsStr = lhs.ID
-        inputs = [lhs]
+        inputs = inputs + [lhs]
     elif isinstance(lhs, float) or isinstance(lhs, int):
         lhsStr = str(lhs)
     else:
         raise TypeError('Incorrect type')
-    if isinstance(rhs, matrix):
-        rhsStr = rhs.ID
-        inputs = inputs + [rhs]
-    elif isinstance(rhs, float) or isinstance(rhs, int):
-        rhsStr = str(rhs)
-    else:
-        raise TypeError('Incorrect type')
-    return constructIntermediateNode(inputs, [OUTPUT_ID, ' = ', lhsStr, opStr, rhsStr, '\n'])
-
-def getValue(obj):
-    if isinstance(obj, matrix):
-        return obj.ID
-    elif isinstance(obj, float) or isinstance(obj, int):
-        return str(obj)
-    else:
-        raise TypeError('Unsupported type for ' + s)
+    return lhsStr, inputs
+    
+def binary_op(lhs, rhs, opStr):
+    """
+    Common function called by all the binary operators in matrix class
+    """
+    inputs = []
+    lhsStr, inputs = _matricize(lhs, inputs)
+    rhsStr, inputs = _matricize(rhs, inputs)
+    return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', lhsStr, opStr, rhsStr, '\n'])
 
 def binaryMatrixFunction(X, Y, fnName):
     """
     Common function called by supported PyDML built-in function that has two arguments.
     """
-    return constructIntermediateNode([X, Y], [OUTPUT_ID, ' = ', fnName,'(', getValue(X), ', ', getValue(Y), ')\n'])
+    inputs = []
+    lhsStr, inputs = _matricize(X, inputs)
+    rhsStr, inputs = _matricize(Y, inputs)
+    return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', fnName,'(', lhsStr, ', ', rhsStr, ')\n'])
 
 def unaryMatrixFunction(X, fnName):
     """
     Common function called by supported PyDML built-in function that has one argument.
     """
-    return constructIntermediateNode([X], [OUTPUT_ID, ' = ', fnName,'(', getValue(X), ')\n'])
+    inputs = []
+    lhsStr, inputs = _matricize(X, inputs)
+    return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', fnName,'(', lhsStr, ')\n'])
 
+def seq(start=None, stop=None, step=1):
+    """
+    Creates a single column vector with values starting from <start>, to <stop>, in increments of <step>.
+    Note: Unlike Numpy's arange which returns a row-vector, this returns a column vector.
+    Also, Unlike Numpy's arange which doesnot include stop, this method includes stop in the interval.
+    
+    Parameters
+    ----------
+    start: int or float [Optional: default = 0]
+    stop: int or float
+    step : int float [Optional: default = 1]
+    """
+    if start is None and stop is None:
+        raise ValueError('Both start and stop cannot be None')
+    elif start is not None and stop is None:
+        stop = start
+        start = 0
+    return construct_intermediate_node([], [OUTPUT_ID, ' = seq(', str(start), ',', str(stop), ',',  str(step), ')\n'])
+    
 # utility function that converts 1:3 into DML string
-def convertSeqToDML(s):
+def convert_seq_to_dml(s):
     ret = []
     if s is None:
         return ''
@@ -182,14 +254,14 @@ def convertSeqToDML(s):
 def getIndexingDML(index):
     ret = [ '[' ]
     if isinstance(index, tuple) and len(index) == 1:
-        ret = ret + [ convertSeqToDML(index[0]), ',' ]
+        ret = ret + [ convert_seq_to_dml(index[0]), ',' ]
     elif isinstance(index, tuple) and len(index) == 2:
-        ret = ret + [ convertSeqToDML(index[0]), ',', convertSeqToDML(index[1]) ]
+        ret = ret + [ convert_seq_to_dml(index[0]), ',', convert_seq_to_dml(index[1]) ]
     else:
         raise TypeError('matrix indexes can only be tuple of length 2. For example: m[1,1], m[0:1,], m[:, 0:1]')
     return ret + [ ']' ]
 
-def convertOutputsToList(outputs):
+def convert_outputs_to_list(outputs):
     if isinstance(outputs, matrix):
         return [ outputs ]
     elif isinstance(outputs, list):
@@ -200,69 +272,15 @@ def convertOutputsToList(outputs):
     else:
         raise TypeError('Only matrix or list of matrix allowed')
 
-def resetOutputFlag(outputs):
+def reset_output_flag(outputs):
     for m in outputs:
         m.output = False
-
-def populateOutputs(outputs, results, outputDF):
-    """
-    Set the attribute 'data' of the matrix by fetching it from MLResults class
-    """
-    for m in outputs:
-        if outputDF:
-            m.data = results.get(m.ID).toDF()
-        else:
-            m.data = results.get(m.ID).toNumPy()
+    
 
 ###############################################################################
 
 ########################## Global user-facing functions #######################
 
-def exp(X):
-    return unaryMatrixFunction(X, 'exp')
-
-def log(X, y=None):
-    if y is None:
-        return unaryMatrixFunction(X, 'log')
-    else:
-        return binaryMatrixFunction(X, y, 'log')
-
-def abs(X):
-    return unaryMatrixFunction(X, 'abs')
-
-def sqrt(X):
-    return unaryMatrixFunction(X, 'sqrt')
-
-def round(X):
-    return unaryMatrixFunction(X, 'round')
-
-def floor(X):
-    return unaryMatrixFunction(X, 'floor')
-
-def ceil(X):
-    return unaryMatrixFunction(X, 'ceil')
-
-def sin(X):
-    return unaryMatrixFunction(X, 'sin')
-
-def cos(X):
-    return unaryMatrixFunction(X, 'cos')
-
-def tan(X):
-    return unaryMatrixFunction(X, 'tan')
-
-def asin(X):
-    return unaryMatrixFunction(X, 'asin')
-
-def acos(X):
-    return unaryMatrixFunction(X, 'acos')
-
-def atan(X):
-    return unaryMatrixFunction(X, 'atan')
-
-def sign(X):
-    return unaryMatrixFunction(X, 'sign')
-
 def solve(A, b):
     """
     Computes the least squares solution for system of linear equations A %*% x = b
@@ -291,25 +309,34 @@ def solve(A, b):
     """
     return binaryMatrixFunction(A, b, 'solve')
 
-def eval(outputs, outputDF=False, execute=True):
+def eval(outputs, execute=True):
     """
     Executes the unevaluated DML script and computes the matrices specified by outputs.
 
     Parameters
     ----------
     outputs: list of matrices or a matrix object
-    outputDF: back the data of matrix as PySpark DataFrame
+    execute: specified whether to execute the unevaluated operation or just return the script.
     """
-    checkIfMLContextIsSet()
+    check_MLContext()
     reset()
-    outputs = convertOutputsToList(outputs)
-    matrix.script.scriptString = performDFS(outputs, execute)
+    outputs = convert_outputs_to_list(outputs)
+    matrix.script.scriptString = perform_dfs(outputs, execute)
     if not execute:
-        resetOutputFlag(outputs)
+        reset_output_flag(outputs)
         return matrix.script.scriptString
     results = matrix.ml.execute(matrix.script)
-    populateOutputs(outputs, results, outputDF)
-    resetOutputFlag(outputs)
+    for m in outputs:
+        m.eval_data = results._java_results.get(m.ID)
+    reset_output_flag(outputs)
+
+
+def debug_array_conversion(throwError):
+    matrix.THROW_ARRAY_CONVERSION_ERROR = throwError
+    
+def _get_new_var_id():
+    matrix.systemmlVarID += 1
+    return 'mVar' + str(matrix.systemmlVarID)
 
 ###############################################################################
 
@@ -331,6 +358,10 @@ class matrix(object):
     2. Aggregation functions: sum, mean, var, sd, max, min, argmin, argmax, cumsum
     3. Global statistical built-In functions: exp, log, abs, sqrt, round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve
     
+    For all the above functions, we always return a two dimensional matrix, especially for aggregation functions with axis. 
+    For example: Assuming m1 is a matrix of (3, n), NumPy returns a 1d vector of dimension (3,) for operation m1.sum(axis=1)
+    whereas SystemML returns a 2d matrix of dimension (3, 1).
+    
     Note: an evaluated matrix contains a data field computed by eval method as DataFrame or NumPy array.
 
     Examples
@@ -346,7 +377,7 @@ class matrix(object):
     >>> m2 = m1 * (m2 + m1)
     >>> m4 = 1.0 - m2
     >>> m4
-    # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPyArray() or toDataFrame() or toPandas() methods.
+    # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
     mVar1 = load(" ", format="csv")
     mVar2 = load(" ", format="csv")
     mVar3 = mVar2 + mVar1
@@ -355,9 +386,9 @@ class matrix(object):
     save(mVar5, " ")
     >>> m2.eval()
     >>> m2
-    # This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPyArray() method.
+    # This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.
     >>> m4
-    # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPyArray() or toDataFrame() or toPandas() methods.
+    # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
     mVar4 = load(" ", format="csv")
     mVar5 = 1.0 - mVar4
     save(mVar5, " ")
@@ -385,7 +416,7 @@ class matrix(object):
        Then the left-indexed matrix is set to be backed by DMLOp consisting of following pydml:
        left-indexed-matrix = new-deep-copied-matrix
        left-indexed-matrix[index] = value
-    8. Please use m.printAST() and/or  type `m` for debugging. Here is a sample session:
+    8. Please use m.print_ast() and/or  type `m` for debugging. Here is a sample session:
     
        >>> npm = np.ones((3,3))
        >>> m1 = sml.matrix(npm + 3)
@@ -396,7 +427,7 @@ class matrix(object):
        mVar1 = load(" ", format="csv")
        mVar3 = mVar1 + mVar2
        save(mVar3, " ")
-       >>> m3.printAST()
+       >>> m3.print_ast()
        - [mVar3] (op).
          - [mVar1] (data).
          - [mVar2] (data).    
@@ -418,7 +449,7 @@ class matrix(object):
     # Contains list of nodes visited in Abstract Syntax Tree. This helps to avoid computation of matrix objects
     # that have been previously evaluated.
     visited = []
-
+    
     def __init__(self, data, op=None):
         """
         Constructs a lazy matrix
@@ -427,74 +458,107 @@ class matrix(object):
         ----------
         data: NumPy ndarray, Pandas DataFrame, scipy sparse matrix or PySpark DataFrame. (data cannot be None for external users, 'data=None' is used internally for lazy evaluation).
         """
-        checkIfMLContextIsSet()
+        self.dtype = np.double
+        check_MLContext()
         self.visited = False
-        matrix.systemmlVarID += 1
         self.output = False
-        self.ID = 'mVar' + str(matrix.systemmlVarID)
+        self.ID = _get_new_var_id()
         self.referenced = []
         # op refers to the node of Abstract Syntax Tree created internally for lazy evaluation
         self.op = op
-        self.data = data
+        self.eval_data = data
+        self._shape = None
+        if isinstance(data, SUPPORTED_TYPES):
+            self._shape = data.shape
         if not (isinstance(data, SUPPORTED_TYPES) or hasattr(data, '_jdf') or (data is None and op is not None)):
             raise TypeError('Unsupported input type')
 
-    def eval(self, outputDF=False):
+    def eval(self):
         """
         This is a convenience function that calls the global eval method
         """
-        eval([self], outputDF=False)
-
+        eval([self])
+        
     def toPandas(self):
         """
         This is a convenience function that calls the global eval method and then converts the matrix object into Pandas DataFrame.
         """
-        if self.data is None:
-            self.eval()
-        return convertToPandasDF(self.data)
-
-    def toNumPyArray(self):
+        self.eval()
+        if isinstance(self.eval_data, py4j.java_gateway.JavaObject):
+            self.eval_data = _java2py(SparkContext._active_spark_context, self.eval_data)
+        if isinstance(self.eval_data, Matrix):
+            self.eval_data = self.eval_data.toNumPy()
+        self.eval_data = convertToPandasDF(self.eval_data)
+        return self.eval_data
+
+    def toNumPy(self):
         """
         This is a convenience function that calls the global eval method and then converts the matrix object into NumPy array.
         """
-        if self.data is None:
-            self.eval()
-        if isinstance(self.data, DataFrame):
-            self.data = self.data.toPandas().as_matrix()
+        self.eval()
+        if isinstance(self.eval_data, py4j.java_gateway.JavaObject):
+            self.eval_data = _java2py(SparkContext._active_spark_context, self.eval_data)
+        if isinstance(self.eval_data, Matrix):
+            self.eval_data = self.eval_data.toNumPy()
+            return self.eval_data
+        if isinstance(self.eval_data, pd.DataFrame):
+            self.eval_data = self.eval_data.as_matrix()
+        elif isinstance(self.eval_data, DataFrame):
+            self.eval_data = self.eval_data.toPandas().as_matrix()
+        elif isinstance(self.eval_data, spmatrix):
+            self.eval_data = self.eval_data.toarray()
+        elif isinstance(self.eval_data, Matrix):
+            self.eval_data = self.eval_data.toNumPy()
         # Always keep default format as NumPy array if possible
-        return self.data
+        return self.eval_data
 
-    def toDataFrame(self):
+    def toDF(self):
         """
         This is a convenience function that calls the global eval method and then converts the matrix object into DataFrame.
         """
-        if self.data is None:
-            self.eval(outputDF=True)
-        if not isinstance(self.data, DataFrame):
-            self.data = matrix.sqlContext.createDataFrame(self.toPandas())
-        return self.data
-
-    def _markAsVisited(self):
+        if isinstance(self.eval_data, DataFrame):
+            return self.eval_data
+        if isinstance(self.eval_data, py4j.java_gateway.JavaObject):
+            self.eval_data = _java2py(SparkContext._active_spark_context, self.eval_data)
+        if isinstance(self.eval_data, Matrix):
+            self.eval_data = self.eval_data.toDF()
+            return self.eval_data
+        self.eval_data = matrix.sqlContext.createDataFrame(self.toPandas())
+        return self.eval_data
+
+    def save(self, file, format='csv'):
+        """
+        Allows user to save a matrix to filesystem
+    
+        Parameters
+        ----------
+        file: filepath
+        format: can be csv, text or binary or mm
+        """
+        tmp = construct_intermediate_node([self], ['save(', self.ID , ',\"', file, '\", format=\"', format, '\")\n'])
+        construct_intermediate_node([tmp], [OUTPUT_ID, ' = full(0, rows=1, cols=1)\n']).eval()
+    
+    def _mark_as_visited(self):
         self.visited = True
         # for cleanup
         matrix.visited = matrix.visited + [ self ]
         return self
 
-    def _registerAsInput(self, execute):
+    def _register_as_input(self, execute):
         # TODO: Remove this when automatic registration of frame is resolved
         matrix.dml = [ self.ID,  ' = load(\" \", format=\"csv\")\n'] + matrix.dml
-        if isinstance(self.data, DataFrame) and execute:
-            matrix.script.input(self.ID, self.data)
+        if isinstance(self.eval_data, SUPPORTED_TYPES) and execute:
+            matrix.script.input(self.ID, convertToMatrixBlock(matrix.sc, self.eval_data))
         elif execute:
-            matrix.script.input(self.ID, convertToMatrixBlock(matrix.sc, self.data))
+            matrix.script.input(self.ID, self.toDF())
         return self
 
-    def _registerAsOutput(self, execute):
+    def _register_as_output(self, execute):
         # TODO: Remove this when automatic registration of frame is resolved
         matrix.dml = matrix.dml + ['save(',  self.ID, ', \" \")\n']
         if execute:
             matrix.script.output(self.ID)
-
+        
     def _visit(self, execute=True):
         """
         This function is called for two scenarios:
@@ -504,9 +568,9 @@ class matrix(object):
         """
         if self.visited:
             return self
-        self._markAsVisited()
-        if self.data is not None:
-            self._registerAsInput(execute)
+        self._mark_as_visited()
+        if self.eval_data is not None:
+            self._register_as_input(execute)
         elif self.op is not None:
             # Traverse the AST
             for m in self.op.inputs:
@@ -514,13 +578,13 @@ class matrix(object):
             self.op._visit(execute=execute)
         else:
             raise Exception('Expected either op or data to be set')
-        if self.data is None and self.output:
-            self._registerAsOutput(execute)
+        if self.eval_data is None and self.output:
+            self._register_as_output(execute)
         return self
 
-    def printAST(self, numSpaces = 0):
+    def print_ast(self):
         """
-        Please use m.printAST() and/or  type `m` for debugging. Here is a sample session:
+        Please use m.print_ast() and/or  type `m` for debugging. Here is a sample session:
         
         >>> npm = np.ones((3,3))
         >>> m1 = sml.matrix(npm + 3)
@@ -531,18 +595,21 @@ class matrix(object):
         mVar1 = load(" ", format="csv")
         mVar3 = mVar1 + mVar2
         save(mVar3, " ")
-        >>> m3.printAST()
+        >>> m3.print_ast()
         - [mVar3] (op).
           - [mVar1] (data).
           - [mVar2] (data).
         """
+        return self._print_ast(0)
+    
+    def _print_ast(self, numSpaces):
         head = ''.join([ ' ' ]*numSpaces + [ '- [', self.ID, '] ' ])
-        if self.data is not None:
+        if self.eval_data is not None:
             out = head + '(data).\n'
         elif self.op is not None:
             ret = [ head, '(op).\n' ]
             for m in self.op.inputs:
-                ret = ret + [ m.printAST(numSpaces + 2) ]
+                ret = ret + [ m._print_ast(numSpaces + 2) ]
             out = ''.join(ret)
         else:
             raise ValueError('Either op or data needs to be set')
@@ -555,66 +622,304 @@ class matrix(object):
         """
         This function helps to debug matrix class and also examine the generated PyDML script
         """
-        if self.data is None:
-            print('# This matrix (' + self.ID + ') is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPyArray() or toDataFrame() or toPandas() methods.\n' + eval([self], execute=False))
-        elif isinstance(self.data, DataFrame):
-            print('# This matrix (' + self.ID + ') is backed by PySpark DataFrame. To fetch the DataFrame, invoke toDataFrame() method.')
+        if self.eval_data is None:
+            print('# This matrix (' + self.ID + ') is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.\n' + eval([self], execute=False))
         else:
-            print('# This matrix (' + self.ID + ') is backed by NumPy array. To fetch the NumPy array, invoke toNumPyArray() method.')
+            print('# This matrix (' + self.ID + ') is backed by ' + str(type(self.eval_data)) + '. To fetch the DataFrame or NumPy array, invoke toDF() or toNumPy() method respectively.')
         return ''
+    
+    ######################### NumPy related methods ######################################
+    
+    __array_priority__ = 10.2
+    ndim = 2
+    
+    THROW_ARRAY_CONVERSION_ERROR = False
+    
+    def __array__(self, dtype=np.double):
+        """
+        As per NumPy from Python,
+        This method is called to obtain an ndarray object when needed. You should always guarantee this returns an actual ndarray object.
+        
+        Using this method, you get back a ndarray object, and subsequent operations on the returned ndarray object will be singlenode.
+        """
+        if not isinstance(self.eval_data, SUPPORTED_TYPES):
+            # Only warn if there is an unevaluated operation (which could potentially generate large matrix or if data is non-supported singlenode formats)
+            import inspect
+            frame,filename,line_number,function_name,lines,index = inspect.stack()[1]
+            msg = 'Conversion from SystemML matrix to NumPy array (occurs in ' + str(filename) + ':' + str(line_number) + ' ' + function_name + ")"
+            if matrix.THROW_ARRAY_CONVERSION_ERROR:
+                raise Exception('[ERROR]:' + msg)
+            else:
+                print('[WARN]:' + msg)
+        return np.array(self.toNumPy(), dtype)
+    
+    def astype(self, t):
+        # TODO: Throw error if incorrect type
+        return self
+    
+    def asfptype(self):
+        return self
+        
+    def set_shape(self,shape):
+        raise NotImplementedError('Reshaping is not implemented')
+    
+    def get_shape(self):
+        if self._shape is None:
+            lhsStr, inputs = _matricize(self, [])
+            rlen_ID = _get_new_var_id()
+            clen_ID = _get_new_var_id()
+            multiline_dml = [rlen_ID, ' = ', lhsStr, '.shape(0)\n']
+            multiline_dml = multiline_dml + [clen_ID, ' = ', lhsStr, '.shape(1)\n']
+            multiline_dml = multiline_dml + [OUTPUT_ID, ' = full(0, rows=2, cols=1)\n']
+            multiline_dml = multiline_dml + [ OUTPUT_ID, '[0,0] = ', rlen_ID, '\n' ]
+            multiline_dml = multiline_dml + [ OUTPUT_ID, '[1,0] = ', clen_ID, '\n' ]
+            ret = construct_intermediate_node(inputs, multiline_dml).toNumPy()
+            self._shape = tuple(np.array(ret, dtype=int).flatten())
+        return self._shape 
+    
+    shape = property(fget=get_shape, fset=set_shape)
+    
+    def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
+        """
+        This function enables systemml matrix to be compatible with NumPy's ufuncs.
+        
+        Parameters
+        ----------
+        func:  ufunc object that was called.
+        method: string indicating which Ufunc method was called (one of "__call__", "reduce", "reduceat", "accumulate", "outer", "inner").
+        pos: index of self in inputs.
+        inputs:  tuple of the input arguments to the ufunc
+        kwargs: dictionary containing the optional input arguments of the ufunc.
+        """
+        if method != '__call__' or kwargs:
+            return NotImplemented
+        if func in matrix._numpy_to_systeml_mapping:
+            fn = matrix._numpy_to_systeml_mapping[func]
+        else:
+            return NotImplemented
+        if len(inputs) == 2:
+            return fn(inputs[0], inputs[1])
+        elif  len(inputs) == 1:
+            return fn(inputs[0])
+        else:
+            raise ValueError('Unsupported number of inputs')
 
+    def hstack(self, other):
+        """
+        Stack matrices horizontally (column wise). Invokes cbind internally.
+        """
+        return binaryMatrixFunction(self, other, 'cbind')
+    
+    def vstack(self, other):
+        """
+        Stack matrices vertically (row wise). Invokes rbind internally.
+        """
+        return binaryMatrixFunction(self, other, 'rbind')
+            
     ######################### Arithmetic operators ######################################
 
+    def negative(self):
+        lhsStr, inputs = _matricize(self, [])
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = -', lhsStr, '\n'])
+                
+    def remainder(self, other):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        rhsStr, inputs = _matricize(other, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = floor(', lhsStr, '/', rhsStr, ') * ', rhsStr, '\n'])
+    
+    def ldexp(self, other):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        rhsStr, inputs = _matricize(other, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', lhsStr, '* (2**', rhsStr, ')\n'])
+        
+    def mod(self, other):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        rhsStr, inputs = _matricize(other, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', lhsStr, ' - floor(', lhsStr, '/', rhsStr, ') * ', rhsStr, '\n'])
+    
+    def logaddexp(self, other):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        rhsStr, inputs = _matricize(other, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = log(exp(', lhsStr, ') + exp(', rhsStr, '))\n'])
+    
+    def logaddexp2(self, other):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        rhsStr, inputs = _matricize(other, inputs)
+        opStr =  _log_base('2**' + lhsStr + '2**' + rhsStr, 2)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', opStr, '\n'])
+
+    def log1p(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = log(1 + ', lhsStr, ')\n'])
+        
+    def exp(self):
+        return unaryMatrixFunction(self, 'exp')
+
+    def exp2(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = 2**', lhsStr, '\n'])
+    
+    def square(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', lhsStr, '**2\n'])    
+    
+    def reciprocal(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = 1/', lhsStr, '\n'])
+        
+    def expm1(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = exp(', lhsStr, ') - 1\n'])
+    
+    def ones_like(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        rlen = lhsStr + '.shape(axis=0)'
+        clen = lhsStr + '.shape(axis=1)'
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = full(1, rows=', rlen, ', cols=', clen, ')\n'])
+    
+    def zeros_like(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        rlen = lhsStr + '.shape(axis=0)'
+        clen = lhsStr + '.shape(axis=1)'
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = full(0, rows=', rlen, ', cols=', clen, ')\n'])    
+    
+    def log2(self):
+        return self.log(2)
+    
+    def log10(self):
+        return self.log(10)
+        
+    def log(self, y=None):
+        if y is None:
+            return unaryMatrixFunction(self, 'log')
+        else:
+            return binaryMatrixFunction(self, y, 'log')
+
+    def abs(self):
+        return unaryMatrixFunction(self, 'abs')
+
+    def sqrt(self):
+        return unaryMatrixFunction(self, 'sqrt')
+
+    def round(self):
+        return unaryMatrixFunction(self, 'round')
+
+    def floor(self):
+        return unaryMatrixFunction(self, 'floor')
+
+    def ceil(self):
+        return unaryMatrixFunction(self, 'ceil')
+
+    def sin(self):
+        return unaryMatrixFunction(self, 'sin')
+
+    def cos(self):
+        return unaryMatrixFunction(self, 'cos')
+
+    def tan(self):
+        return unaryMatrixFunction(self, 'tan')
+
+    def arcsin(self):
+        return self.asin()
+
+    def arccos(self):
+        return self.acos()
+
+    def arctan(self):
+        return self.atan()
+    
+    def asin(self):
+        return unaryMatrixFunction(self, 'asin')
+
+    def acos(self):
+        return unaryMatrixFunction(self, 'acos')
+
+    def atan(self):
+        return unaryMatrixFunction(self, 'atan')
+
+    def rad2deg(self):
+        """
+        Convert angles from radians to degrees.
+        """
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        # 180/pi = 57.2957795131
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', lhsStr, '*57.2957795131\n'])
+    
+    def deg2rad(self):
+        """
+        Convert angles from degrees to radians.
+        """
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        # pi/180 = 0.01745329251
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = ', lhsStr, '*0.01745329251\n'])    
+    
+    def sign(self):
+        return unaryMatrixFunction(self, 'sign')    
+
     def __add__(self, other):
-        return binaryOp(self, other, ' + ')
+        return binary_op(self, other, ' + ')
 
     def __sub__(self, other):
-        return binaryOp(self, other, ' - ')
+        return binary_op(self, other, ' - ')
 
     def __mul__(self, other):
-        return binaryOp(self, other, ' * ')
+        return binary_op(self, other, ' * ')
 
     def __floordiv__(self, other):
-        return binaryOp(self, other, ' // ')
+        return binary_op(self, other, ' // ')
 
     def __div__(self, other):
         """
         Performs division (Python 2 way).
         """
-        return binaryOp(self, other, ' / ')
+        return binary_op(self, other, ' / ')
 
     def __truediv__(self, other):
         """
         Performs division (Python 3 way).
         """
-        return binaryOp(self, other, ' / ')
+        return binary_op(self, other, ' / ')
 
     def __mod__(self, other):
-        return binaryOp(self, other, ' % ')
+        return binary_op(self, other, ' % ')
 
     def __pow__(self, other):
-        return binaryOp(self, other, ' ** ')
+        return binary_op(self, other, ' ** ')
 
     def __radd__(self, other):
-        return binaryOp(other, self, ' + ')
+        return binary_op(other, self, ' + ')
 
     def __rsub__(self, other):
-        return binaryOp(other, self, ' - ')
+        return binary_op(other, self, ' - ')
 
     def __rmul__(self, other):
-        return binaryOp(other, self, ' * ')
+        return binary_op(other, self, ' * ')
 
     def __rfloordiv__(self, other):
-        return binaryOp(other, self, ' // ')
+        return binary_op(other, self, ' // ')
 
     def __rdiv__(self, other):
-        return binaryOp(other, self, ' / ')
+        return binary_op(other, self, ' / ')
 
     def __rmod__(self, other):
-        return binaryOp(other, self, ' % ')
+        return binary_op(other, self, ' % ')
 
     def __rpow__(self, other):
-        return binaryOp(other, self, ' ** ')
+        return binary_op(other, self, ' ** ')
 
     def dot(self, other):
         """
@@ -632,32 +937,98 @@ class matrix(object):
     ######################### Relational/Boolean operators ######################################
 
     def __lt__(self, other):
-        return binaryOp(other, self, ' < ')
+        return binary_op(self, other, ' < ')
 
     def __le__(self, other):
-        return binaryOp(other, self, ' <= ')
+        return binary_op(self, other, ' <= ')
 
     def __gt__(self, other):
-        return binaryOp(other, self, ' > ')
+        return binary_op(self, other, ' > ')
 
     def __ge__(self, other):
-        return binaryOp(other, self, ' >= ')
+        return binary_op(self, other,' >= ')
 
     def __eq__(self, other):
-        return binaryOp(other, self, ' == ')
+        return binary_op(self, other, ' == ')
 
     def __ne__(self, other):
-        return binaryOp(other, self, ' != ')
-
+        return binary_op(self, other, ' != ')
+    
     # TODO: Cast the output back into scalar and return boolean results
     def __and__(self, other):
-        return binaryOp(other, self, ' & ')
+        return binary_op(other, self, ' & ')
 
     def __or__(self, other):
-        return binaryOp(other, self, ' | ')
+        return binary_op(other, self, ' | ')
 
+    def logical_not(self):
+        inputs = []
+        lhsStr, inputs = _matricize(self, inputs)
+        return construct_intermediate_node(inputs, [OUTPUT_ID, ' = !', lhsStr, '\n'])
+    
+    def remove_empty(self, axis=None):
+        """
+        Removes all empty rows or columns from the input matrix target X according to specified axis.
+        
+        Parameters
+        ----------
+        axis : int (0 or 1)
+        """
+        if axis is None:
+            raise ValueError('axis is a mandatory argument for remove_empty')
+        if axis == 0:
+            return self._parameterized_helper_fn(self, 'removeEmpty',  { 'target':self, 'margin':'rows' })
+        elif axis == 1:
+            return self._parameterized_helper_fn(self, 'removeEmpty',  { 'target':self, 'margin':'cols' })
+        else:
+            raise ValueError('axis for remove_empty needs to be either 0 or 1.')
+    
+    def replace(self, pattern=None, replacement=None):
+        """
+        Removes all empty rows or columns from the input matrix target X according to specified axis.
+        
+        Parameters
+        ----------
+        pattern : float or int
+        replacement : float or int
+        """
+        if pattern is None or not isinstance(pattern, (float, int)):
+            raise ValueError('pattern should be of type float or int')
+        if replacement is None or not isinstance(replacement, (float, int)):
+            raise ValueError('replacement should be of type float or int')
+        return self._parameterized_helper_fn(self, 'replace',  { 'target':self, 'pattern':pattern, 'replacement':replacement })
+    
+    def _parameterized_helper_fn(self, fnName, **kwargs):
+        """
+        Helper to invoke parameterized builtin function
+        """
+        dml_script = ''
+        lhsStr, inputs = _matricize(self, [])
+        dml_script = [OUTPUT_ID, ' = ', fnName, '(', lhsStr ]
+        first_arg = True
+        for key in kwargs:
+            if first_arg:
+                first_arg = False
+            else:
+                dml_script = dml_script + [ ', ' ]
+            v = kwargs[key]
+            if isinstance(v, str):
+                dml_script = dml_script + [key, '=\"', v, '\"' ]
+            elif isinstance(v, matrix):
+                dml_script = dml_script + [key, '=', v.ID]
+            else:
+                dml_script = dml_script + [key, '=', str(v) ]
+        dml_script = dml_script + [ ')\n' ]
+        return construct_intermediate_node(inputs, dml_script)
+            
     ######################### Aggregation functions ######################################
 
+    def prod(self):
+        """
+        Return the product of all cells in matrix
+        """
+        return self._aggFn('prod', None)
+        
     def sum(self, axis=None):
         """
         Compute the sum along the specified axis
@@ -680,14 +1051,53 @@ class matrix(object):
 
     def var(self, axis=None):
         """
-        Compute the variance along the specified axis
+        Compute the variance along the specified axis.
+        We assume that delta degree of freedom is 1 (unlike NumPy which assumes ddof=0).
         
         Parameters
         ----------
         axis : int, optional
         """
         return self._aggFn('var', axis)
-
+        
+    def moment(self, moment=1, axis=None):
+        """
+        Calculates the nth moment about the mean
+        
+        Parameters
+        ----------
+        moment : int
+            can be 1, 2, 3 or 4
+        axis : int, optional
+        """
+        if moment == 1:
+            return self.mean(axis)
+        elif moment == 2:
+            return self.var(axis)
+        elif moment == 3 or moment == 4:
+            return self._moment_helper(moment, axis)
+        else:
+            raise ValueError('The specified moment is not supported:' + str(moment))
+        
+    def _moment_helper(self, k, axis=0):
+        dml_script = ''
+        lhsStr, inputs = _matricize(self, [])
+        dml_script = [OUTPUT_ID, ' = moment(', lhsStr, ', ', str(k), ')\n' ]
+        dml_script = [OUTPUT_ID, ' = moment(', lhsStr, ', ', str(k), ')\n' ]
+        if axis is None:
+            dml_script = [OUTPUT_ID, ' = moment(full(', lhsStr, ', rows=length(', lhsStr, '), cols=1), ', str(k), ')\n' ]
+        elif axis == 0:
+            dml_script = [OUTPUT_ID, ' = full(0, rows=nrow(', lhsStr, '), cols=1)\n' ]
+            dml_script = dml_script + [ 'parfor(i in 1:nrow(', lhsStr, '), check=0):\n' ]
+            dml_script = dml_script + [ '\t', OUTPUT_ID, '[i-1, 0] = moment(full(', lhsStr, '[i-1,], rows=ncol(', lhsStr, '), cols=1), ', str(k), ')\n\n' ]
+        elif axis == 1:
+            dml_script = [OUTPUT_ID, ' = full(0, rows=1, cols=ncol(', lhsStr, '))\n' ]
+            dml_script = dml_script + [ 'parfor(i in 1:ncol(', lhsStr, '), check=0):\n' ]
+            dml_script = dml_script + [ '\t', OUTPUT_ID, '[0, i-1] = moment(', lhsStr, '[,i-1], ', str(k), ')\n\n' ]
+        else:
+            raise ValueError('Incorrect axis:' + axis)
+        return construct_intermediate_node(inputs, dml_script)
+        
     def sd(self, axis=None):
         """
         Compute the standard deviation along the specified axis
@@ -698,25 +1108,37 @@ class matrix(object):
         """
         return self._aggFn('sd', axis)
 
-    def max(self, axis=None):
+    def max(self, other=None, axis=None):
         """
         Compute the maximum value along the specified axis
         
         Parameters
         ----------
+        other: matrix or numpy array (& other supported types) or scalar
         axis : int, optional
         """
-        return self._aggFn('max', axis)
+        if other is not None and axis is not None:
+            raise ValueError('Both axis and other cannot be not None')
+        elif other is None and axis is not None:
+            return self._aggFn('max', axis)
+        else:
+            return binaryMatrixFunction(self, other, 'max')
 
-    def min(self, axis=None):
+    def min(self, other=None, axis=None):
         """
         Compute the minimum value along the specified axis
         
         Parameters
         ----------
+        other: matrix or numpy array (& other supported types) or scalar
         axis : int, optional
         """
-        return self._aggFn('min', axis)
+        if other is not None and axis is not None:
+            raise ValueError('Both axis and other cannot be not None')
+        elif other is None and axis is not None:
+            return self._aggFn('min', axis)
+        else:
+            return binaryMatrixFunction(self, other, 'min')
 
     def argmin(self, axis=None):
         """
@@ -764,13 +1186,13 @@ class matrix(object):
         """
         Common function that is called for functions that have axis as parameter.
         """
-        dmlOp = DMLOp([self])
-        out = matrix(None, op=dmlOp)
+        dml_script = ''
+        lhsStr, inputs = _matricize(self, [])
         if axis is None:
-            dmlOp.dml = [out.ID, ' = ', fnName, '(', self.ID, ')\n']
+            dml_script = [OUTPUT_ID, ' = ', fnName, '(', lhsStr, ')\n']
         else:
-            dmlOp.dml = [out.ID, ' = ', fnName, '(', self.ID, ', axis=', str(axis) ,')\n']
-        return out
+            dml_script = [OUTPUT_ID, ' = ', fnName, '(', lhsStr, ', axis=', str(axis) ,')\n']
+        return construct_intermediate_node(inputs, dml_script)
 
     ######################### Indexing operators ######################################
 
@@ -778,19 +1200,16 @@ class matrix(object):
         """
         Implements evaluation of right indexing operations such as m[1,1], m[0:1,], m[:, 0:1]
         """
-        dmlOp = DMLOp([self])
-        out = matrix(None, op=dmlOp)
-        dmlOp.dml = [out.ID, ' = ', self.ID ] + getIndexingDML(index) + [ '\n' ]
-        return out
+        return construct_intermediate_node([self], [OUTPUT_ID, ' = ', self.ID ] + getIndexingDML(index) + [ '\n' ])
 
     # Performs deep copy if the matrix is backed by data
     def _prepareForInPlaceUpdate(self):
-        temp = matrix(self.data, op=self.op)
+        temp = matrix(self.eval_data, op=self.op)
         for op in self.referenced:
             op.inputs = [temp if x.ID==self.ID else x for x in op.inputs]
         self.ID, temp.ID = temp.ID, self.ID # Copy even the IDs as the IDs might be used to create DML
         self.op = DMLOp([temp], dml=[self.ID, " = ", temp.ID])
-        self.data = None
+        self.eval_data = None
         temp.referenced = self.referenced + [ self.op ]
         self.referenced = []
 
@@ -804,3 +1223,6 @@ class matrix(object):
         if isinstance(value, matrix):
             value.referenced = value.referenced + [ self.op ]
         self.op.dml = self.op.dml + [ '\n', self.ID ] + getIndexingDML(index) + [ ' = ',  getValue(value), '\n']
+
+    # Not implemented: conj, hyperbolic/inverse-hyperbolic functions(i.e. sinh, arcsinh, cosh, ...), bitwise operator, xor operator, isreal, iscomplex, isfinite, isinf, isnan, copysign, nextafter, modf, frexp, trunc  
+    _numpy_to_systeml_mapping = {np.add: __add__, np.subtract: __sub__, np.multiply: __mul__, np.divide: __div__, np.logaddexp: logaddexp, np.true_divide: __truediv__, np.floor_divide: __floordiv__, np.negative: negative, np.power: __pow__, np.remainder: remainder, np.mod: mod, np.fmod: __mod__, np.absolute: abs, np.rint: round, np.sign: sign, np.exp: exp, np.exp2: exp2, np.log: log, np.log2: log2, np.log10: log10, np.expm1: expm1, np.log1p: log1p, np.sqrt: sqrt, np.square: square, np.reciprocal: reciprocal, np.ones_like: ones_like, np.zeros_like: zeros_like, np.sin: sin, np.cos: cos, np.tan: tan, np.arcsin: arcsin, np.arccos: arccos, np.arctan: arctan, np.deg2rad: deg2rad, np.rad2deg: rad2deg, np.greater: __gt__, np.greater_equal: __ge__, np.less: __lt__, np.less_equal: __le__, np.not_equal: __ne__, np.equal: __eq__, np.logical_not: logical_not, np.logical_and: __and__, np.logical_or: __or__, np.maximum: max, np.minimum: min, np.signbit: sign, np.ldexp: ldexp, np.dot:dot}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/systemml/mlcontext.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mlcontext.py b/src/main/python/systemml/mlcontext.py
index 63631aa..4f769d5 100644
--- a/src/main/python/systemml/mlcontext.py
+++ b/src/main/python/systemml/mlcontext.py
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-__all__ = ['MLResults', 'MLContext', 'Script', 'dml', 'pydml']
+__all__ = ['MLResults', 'MLContext', 'Script', 'dml', 'pydml', '_java2py', 'Matrix']
 
 import os
 
@@ -288,7 +288,10 @@ class MLContext(object):
             # representing `script_java.in`, and then call it with the arguments.  This is in
             # lieu of adding a new `input` method on the JVM side, as that would complicate use
             # from Scala/Java.
-            py4j.java_gateway.get_method(script_java, "in")(key, _py2java(self._sc, val))
+            if isinstance(val, py4j.java_gateway.JavaObject):
+                py4j.java_gateway.get_method(script_java, "in")(key, val)
+            else:
+                py4j.java_gateway.get_method(script_java, "in")(key, _py2java(self._sc, val))
         for val in script._output:
             script_java.out(val)
         return MLResults(self._ml.execute(script_java), self._sc)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/systemml/random/sampling.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/random/sampling.py b/src/main/python/systemml/random/sampling.py
index 02408e5..d320536 100644
--- a/src/main/python/systemml/random/sampling.py
+++ b/src/main/python/systemml/random/sampling.py
@@ -71,7 +71,7 @@ def normal(loc=0.0, scale=1.0, size=(1,1), sparsity=1.0):
     >>> sml.setSparkContext(sc)
     >>> from systemml import random
     >>> m1 = sml.random.normal(loc=3, scale=2, size=(3,3))
-    >>> m1.toNumPyArray()
+    >>> m1.toNumPy()
     array([[ 3.48857226,  6.17261819,  2.51167259],
            [ 3.60506708, -1.90266305,  3.97601633],
            [ 3.62245706,  5.9430881 ,  2.53070413]])
@@ -107,7 +107,7 @@ def uniform(low=0.0, high=1.0, size=(1,1), sparsity=1.0):
     >>> sml.setSparkContext(sc)
     >>> from systemml import random
     >>> m1 = sml.random.uniform(size=(3,3))
-    >>> m1.toNumPyArray()
+    >>> m1.toNumPy()
     array([[ 0.54511396,  0.11937437,  0.72975775],
            [ 0.14135946,  0.01944448,  0.52544478],
            [ 0.67582422,  0.87068849,  0.02766852]])
@@ -142,7 +142,7 @@ def poisson(lam=1.0, size=(1,1), sparsity=1.0):
     >>> sml.setSparkContext(sc)
     >>> from systemml import random
     >>> m1 = sml.random.poisson(lam=1, size=(3,3))
-    >>> m1.toNumPyArray()
+    >>> m1.toNumPy()
     array([[ 1.,  0.,  2.],
            [ 1.,  0.,  0.],
            [ 0.,  0.,  0.]])

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/tests/test_matrix_agg_fn.py
----------------------------------------------------------------------
diff --git a/src/main/python/tests/test_matrix_agg_fn.py b/src/main/python/tests/test_matrix_agg_fn.py
new file mode 100644
index 0000000..be3df14
--- /dev/null
+++ b/src/main/python/tests/test_matrix_agg_fn.py
@@ -0,0 +1,95 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# To run:
+#   - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-class-path SystemML.jar test_matrix_agg_fn.py`
+#   - Python 3: `PYSPARK_PYTHON=python3 spark-submit --master local[*] --driver-class-path SystemML.jar test_matrix_agg_fn.py`
+
+# Make the `systemml` package importable
+import os
+import sys
+path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
+sys.path.insert(0, path)
+
+import unittest
+import systemml as sml
+import numpy as np
+from scipy.stats import kurtosis, skew, moment
+from pyspark.context import SparkContext
+sc = SparkContext()
+
+dim = 5
+m1 = np.array(np.random.randint(100, size=dim*dim) + 1.01, dtype=np.double)
+m1.shape = (dim, dim)
+m2 = np.array(np.random.randint(5, size=dim*dim) + 1, dtype=np.double)
+m2.shape = (dim, dim)
+s = 3.02
+
+class TestMatrixAggFn(unittest.TestCase):
+
+    def test_sum1(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).sum(), m1.sum()))
+
+    def test_sum2(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).sum(axis=0), m1.sum(axis=0)))
+    
+    def test_sum3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).sum(axis=1), m1.sum(axis=1).reshape(dim, 1)))
+
+    def test_mean1(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).mean(), m1.mean()))
+
+    def test_mean2(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).mean(axis=0), m1.mean(axis=0).reshape(1, dim)))
+    
+    def test_mean3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).mean(axis=1), m1.mean(axis=1).reshape(dim, 1)))
+    
+    def test_hstack(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).hstack(sml.matrix(m1)), np.hstack((m1, m1))))    
+    
+    def test_vstack(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).vstack(sml.matrix(m1)), np.vstack((m1, m1))))
+        
+    def test_full(self):
+        self.assertTrue(np.allclose(sml.full((2, 3), 10.1), np.full((2, 3), 10.1)))
+    
+    def test_seq(self):
+        self.assertTrue(np.allclose(sml.seq(3), np.arange(3+1).reshape(4, 1)))
+        
+    def test_var1(self):
+        print(str(np.array(sml.matrix(m1).var())) + " " + str(np.array(m1.var(ddof=1))))
+        self.assertTrue(np.allclose(sml.matrix(m1).var(), m1.var(ddof=1)))
+
+    def test_var2(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).var(axis=0), m1.var(axis=0, ddof=1).reshape(1, dim)))
+    
+    def test_var3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).var(axis=1), m1.var(axis=1, ddof=1).reshape(dim, 1)))
+    
+    def test_moment3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).moment(moment=3, axis=None), moment(m1, moment=3, axis=None)))
+        
+    def test_moment4(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).moment(moment=4, axis=None), moment(m1, moment=4, axis=None)))
+
+if __name__ == "__main__":
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/tests/test_matrix_binary_op.py
----------------------------------------------------------------------
diff --git a/src/main/python/tests/test_matrix_binary_op.py b/src/main/python/tests/test_matrix_binary_op.py
new file mode 100644
index 0000000..6bba3e9
--- /dev/null
+++ b/src/main/python/tests/test_matrix_binary_op.py
@@ -0,0 +1,138 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# To run:
+#   - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-class-path SystemML.jar test_matrix_binary_op.py`
+#   - Python 3: `PYSPARK_PYTHON=python3 spark-submit --master local[*] --driver-class-path SystemML.jar test_matrix_binary_op.py`
+
+# Make the `systemml` package importable
+import os
+import sys
+path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
+sys.path.insert(0, path)
+
+import unittest
+import systemml as sml
+import numpy as np
+from pyspark.context import SparkContext
+sc = SparkContext()
+
+dim = 5
+m1 = np.array(np.random.randint(100, size=dim*dim) + 1.01, dtype=np.double)
+m1.shape = (dim, dim)
+m2 = np.array(np.random.randint(5, size=dim*dim) + 1, dtype=np.double)
+m2.shape = (dim, dim)
+s = 3.02
+
+class TestBinaryOp(unittest.TestCase):
+
+    def test_plus(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) + sml.matrix(m2), m1 + m2))
+        
+    def test_minus(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) - sml.matrix(m2), m1 - m2))
+        
+    def test_mul(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) * sml.matrix(m2), m1 * m2))
+    
+    def test_div(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) / sml.matrix(m2), m1 / m2))
+    
+    #def test_power(self):
+    #    self.assertTrue(np.allclose(sml.matrix(m1) ** sml.matrix(m2), m1 ** m2))
+    
+    def test_plus1(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) + m2, m1 + m2))
+        
+    def test_minus1(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) - m2, m1 - m2))
+        
+    def test_mul1(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) * m2, m1 * m2))
+    
+    def test_div1(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) / m2, m1 / m2))
+    
+    def test_power1(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) ** m2, m1 ** m2))
+        
+    def test_plus2(self):
+        self.assertTrue(np.allclose(m1 + sml.matrix(m2), m1 + m2))
+        
+    def test_minus2(self):
+        self.assertTrue(np.allclose(m1 - sml.matrix(m2), m1 - m2))
+        
+    def test_mul2(self):
+        self.assertTrue(np.allclose(m1 * sml.matrix(m2), m1 * m2))
+    
+    def test_div2(self):
+        self.assertTrue(np.allclose(m1 / sml.matrix(m2), m1 / m2))
+    
+    def test_power2(self):
+        self.assertTrue(np.allclose(m1 ** sml.matrix(m2), m1 ** m2))
+    
+    def test_plus3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) + s, m1 + s))
+        
+    def test_minus3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) - s, m1 - s))
+        
+    def test_mul3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) * s, m1 * s))
+    
+    def test_div3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) / s, m1 / s))
+    
+    def test_power3(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) ** s, m1 ** s))
+    
+    def test_plus4(self):
+        self.assertTrue(np.allclose(s + sml.matrix(m2), s + m2))
+        
+    def test_minus4(self):
+        self.assertTrue(np.allclose(s - sml.matrix(m2), s - m2))
+        
+    def test_mul4(self):
+        self.assertTrue(np.allclose(s * sml.matrix(m2), s * m2))
+    
+    def test_div4(self):
+        self.assertTrue(np.allclose(s / sml.matrix(m2), s / m2))
+    
+    def test_power4(self):
+        self.assertTrue(np.allclose(s ** sml.matrix(m2), s ** m2))
+
+    def test_lt(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) < sml.matrix(m2), m1 < m2))
+        
+    def test_gt(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) > sml.matrix(m2), m1 > m2))
+        
+    def test_le(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) <= sml.matrix(m2), m1 <= m2))
+    
+    def test_ge(self):
+        self.assertTrue(np.allclose(sml.matrix(m1) >= sml.matrix(m2), m1 >= m2))
+        
+    def test_abs(self):
+        self.assertTrue(np.allclose(sml.matrix(m1).abs(), np.abs(m1)))
+
+if __name__ == "__main__":
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/tests/test_mllearn.py
----------------------------------------------------------------------
diff --git a/src/main/python/tests/test_mllearn.py b/src/main/python/tests/test_mllearn.py
deleted file mode 100644
index 532d450..0000000
--- a/src/main/python/tests/test_mllearn.py
+++ /dev/null
@@ -1,190 +0,0 @@
-#!/usr/bin/python
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-# To run:
-#   - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-class-path SystemML.jar test_mllearn.py`
-#   - Python 3: `PYSPARK_PYTHON=python3 spark-submit --master local[*] --driver-class-path SystemML.jar test_mllearn.py`
-
-# Make the `systemml` package importable
-import os
-import sys
-path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
-sys.path.insert(0, path)
-
-import unittest
-
-import numpy as np
-from pyspark.context import SparkContext
-from pyspark.ml import Pipeline
-from pyspark.ml.feature import HashingTF, Tokenizer
-from pyspark.sql import SQLContext
-from sklearn import datasets, metrics, neighbors
-from sklearn.datasets import fetch_20newsgroups
-from sklearn.feature_extraction.text import TfidfVectorizer
-
-from systemml.mllearn import LinearRegression, LogisticRegression, NaiveBayes, SVM
-
-sc = SparkContext()
-sqlCtx = SQLContext(sc)
-
-# Currently not integrated with JUnit test
-# ~/spark-1.6.1-scala-2.11/bin/spark-submit --master local[*] --driver-class-path SystemML.jar test.py
-class TestMLLearn(unittest.TestCase):
-    def testLogisticSK1(self):
-        digits = datasets.load_digits()
-        X_digits = digits.data
-        y_digits = digits.target
-        n_samples = len(X_digits)
-        X_train = X_digits[:.9 * n_samples]
-        y_train = y_digits[:.9 * n_samples]
-        X_test = X_digits[.9 * n_samples:]
-        y_test = y_digits[.9 * n_samples:]
-        logistic = LogisticRegression(sqlCtx)
-        score = logistic.fit(X_train, y_train).score(X_test, y_test)
-        self.failUnless(score > 0.9)
-
-    def testLogisticSK2(self):
-        digits = datasets.load_digits()
-        X_digits = digits.data
-        y_digits = digits.target
-        n_samples = len(X_digits)
-        X_train = X_digits[:.9 * n_samples]
-        y_train = y_digits[:.9 * n_samples]
-        X_test = X_digits[.9 * n_samples:]
-        y_test = y_digits[.9 * n_samples:]
-        # Convert to DataFrame for i/o: current way to transfer data
-        logistic = LogisticRegression(sqlCtx, transferUsingDF=True)
-        score = logistic.fit(X_train, y_train).score(X_test, y_test)
-        self.failUnless(score > 0.9)
-
-    def testLogisticMLPipeline1(self):
-        training = sqlCtx.createDataFrame([
-            ("a b c d e spark", 1.0),
-            ("b d", 2.0),
-            ("spark f g h", 1.0),
-            ("hadoop mapreduce", 2.0),
-            ("b spark who", 1.0),
-            ("g d a y", 2.0),
-            ("spark fly", 1.0),
-            ("was mapreduce", 2.0),
-            ("e spark program", 1.0),
-            ("a e c l", 2.0),
-            ("spark compile", 1.0),
-            ("hadoop software", 2.0)
-            ], ["text", "label"])
-        tokenizer = Tokenizer(inputCol="text", outputCol="words")
-        hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
-        lr = LogisticRegression(sqlCtx)
-        pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
-        model = pipeline.fit(training)
-        test = sqlCtx.createDataFrame([
-            ("spark i j k", 1.0),
-            ("l m n", 2.0),
-            ("mapreduce spark", 1.0),
-            ("apache hadoop", 2.0)], ["text", "label"])
-        result = model.transform(test)
-        predictionAndLabels = result.select("prediction", "label")
-        from pyspark.ml.evaluation import MulticlassClassificationEvaluator
-        evaluator = MulticlassClassificationEvaluator()
-        score = evaluator.evaluate(predictionAndLabels)
-        self.failUnless(score == 1.0)
-
-    def testLinearRegressionSK1(self):
-        diabetes = datasets.load_diabetes()
-        diabetes_X = diabetes.data[:, np.newaxis, 2]
-        diabetes_X_train = diabetes_X[:-20]
-        diabetes_X_test = diabetes_X[-20:]
-        diabetes_y_train = diabetes.target[:-20]
-        diabetes_y_test = diabetes.target[-20:]
-        regr = LinearRegression(sqlCtx)
-        regr.fit(diabetes_X_train, diabetes_y_train)
-        score = regr.score(diabetes_X_test, diabetes_y_test)
-        self.failUnless(score > 0.4) # TODO: Improve r2-score (may be I am using it incorrectly)
-
-    def testLinearRegressionSK2(self):
-        diabetes = datasets.load_diabetes()
-        diabetes_X = diabetes.data[:, np.newaxis, 2]
-        diabetes_X_train = diabetes_X[:-20]
-        diabetes_X_test = diabetes_X[-20:]
-        diabetes_y_train = diabetes.target[:-20]
-        diabetes_y_test = diabetes.target[-20:]
-        regr = LinearRegression(sqlCtx, transferUsingDF=True)
-        regr.fit(diabetes_X_train, diabetes_y_train)
-        score = regr.score(diabetes_X_test, diabetes_y_test)
-        self.failUnless(score > 0.4) # TODO: Improve r2-score (may be I am using it incorrectly)
-
-    def testSVMSK1(self):
-        digits = datasets.load_digits()
-        X_digits = digits.data
-        y_digits = digits.target
-        n_samples = len(X_digits)
-        X_train = X_digits[:.9 * n_samples]
-        y_train = y_digits[:.9 * n_samples]
-        X_test = X_digits[.9 * n_samples:]
-        y_test = y_digits[.9 * n_samples:]
-        svm = SVM(sqlCtx, is_multi_class=True)
-        score = svm.fit(X_train, y_train).score(X_test, y_test)
-        self.failUnless(score > 0.9)
-
-    def testSVMSK2(self):
-        digits = datasets.load_digits()
-        X_digits = digits.data
-        y_digits = digits.target
-        n_samples = len(X_digits)
-        X_train = X_digits[:.9 * n_samples]
-        y_train = y_digits[:.9 * n_samples]
-        X_test = X_digits[.9 * n_samples:]
-        y_test = y_digits[.9 * n_samples:]
-        svm = SVM(sqlCtx, is_multi_class=True, transferUsingDF=True)
-        score = svm.fit(X_train, y_train).score(X_test, y_test)
-        self.failUnless(score > 0.9)
-
-    def testNaiveBayesSK1(self):
-        digits = datasets.load_digits()
-        X_digits = digits.data
-        y_digits = digits.target
-        n_samples = len(X_digits)
-        X_train = X_digits[:.9 * n_samples]
-        y_train = y_digits[:.9 * n_samples]
-        X_test = X_digits[.9 * n_samples:]
-        y_test = y_digits[.9 * n_samples:]
-        nb = NaiveBayes(sqlCtx)
-        score = nb.fit(X_train, y_train).score(X_test, y_test)
-        self.failUnless(score > 0.85)
-
-    def testNaiveBayesSK2(self):
-        categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
-        newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
-        newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
-        vectorizer = TfidfVectorizer()
-        # Both vectors and vectors_test are SciPy CSR matrix
-        vectors = vectorizer.fit_transform(newsgroups_train.data)
-        vectors_test = vectorizer.transform(newsgroups_test.data)
-        nb = NaiveBayes(sqlCtx)
-        nb.fit(vectors, newsgroups_train.target)
-        pred = nb.predict(vectors_test)
-        score = metrics.f1_score(newsgroups_test.target, pred, average='weighted')
-        self.failUnless(score > 0.8)
-
-
-if __name__ == '__main__':
-    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/tests/test_mllearn_df.py
----------------------------------------------------------------------
diff --git a/src/main/python/tests/test_mllearn_df.py b/src/main/python/tests/test_mllearn_df.py
new file mode 100644
index 0000000..0d6a4b4
--- /dev/null
+++ b/src/main/python/tests/test_mllearn_df.py
@@ -0,0 +1,108 @@
+#!/usr/bin/python
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# To run:
+#   - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-class-path SystemML.jar test_mllearn_df.py`
+#   - Python 3: `PYSPARK_PYTHON=python3 spark-submit --master local[*] --driver-class-path SystemML.jar test_mllearn_df.py`
+
+# Make the `systemml` package importable
+import os
+import sys
+path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
+sys.path.insert(0, path)
+
+import unittest
+
+import numpy as np
+from pyspark.context import SparkContext
+from pyspark.ml import Pipeline
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.sql import SQLContext
+from sklearn import datasets, metrics, neighbors
+from sklearn.datasets import fetch_20newsgroups
+from sklearn.feature_extraction.text import TfidfVectorizer
+
+from systemml.mllearn import LinearRegression, LogisticRegression, NaiveBayes, SVM
+
+sc = SparkContext()
+sqlCtx = SQLContext(sc)
+
+# Currently not integrated with JUnit test
+# ~/spark-1.6.1-scala-2.11/bin/spark-submit --master local[*] --driver-class-path SystemML.jar test.py
+class TestMLLearn(unittest.TestCase):
+
+    def test_logistic_sk2(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:int(.9 * n_samples)]
+        y_train = y_digits[:int(.9 * n_samples)]
+        X_test = X_digits[int(.9 * n_samples):]
+        y_test = y_digits[int(.9 * n_samples):]
+        # Convert to DataFrame for i/o: current way to transfer data
+        logistic = LogisticRegression(sqlCtx, transferUsingDF=True)
+        score = logistic.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+
+    def test_linear_regression_sk2(self):
+        diabetes = datasets.load_diabetes()
+        diabetes_X = diabetes.data[:, np.newaxis, 2]
+        diabetes_X_train = diabetes_X[:-20]
+        diabetes_X_test = diabetes_X[-20:]
+        diabetes_y_train = diabetes.target[:-20]
+        diabetes_y_test = diabetes.target[-20:]
+        regr = LinearRegression(sqlCtx, transferUsingDF=True)
+        regr.fit(diabetes_X_train, diabetes_y_train)
+        score = regr.score(diabetes_X_test, diabetes_y_test)
+        self.failUnless(score > 0.4) # TODO: Improve r2-score (may be I am using it incorrectly)
+
+    def test_svm_sk2(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:int(.9 * n_samples)]
+        y_train = y_digits[:int(.9 * n_samples)]
+        X_test = X_digits[int(.9 * n_samples):]
+        y_test = y_digits[int(.9 * n_samples):]
+        svm = SVM(sqlCtx, is_multi_class=True, transferUsingDF=True)
+        score = svm.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+
+    #def test_naive_bayes_sk2(self):
+    #    categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
+    #    newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
+    #    newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
+    #    vectorizer = TfidfVectorizer()
+    #    # Both vectors and vectors_test are SciPy CSR matrix
+    #    vectors = vectorizer.fit_transform(newsgroups_train.data)
+    #    vectors_test = vectorizer.transform(newsgroups_test.data)
+    #    nb = NaiveBayes(sqlCtx)
+    #    nb.fit(vectors, newsgroups_train.target)
+    #    pred = nb.predict(vectors_test)
+    #    score = metrics.f1_score(newsgroups_test.target, pred, average='weighted')
+    #    self.failUnless(score > 0.8)
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/tests/test_mllearn_numpy.py
----------------------------------------------------------------------
diff --git a/src/main/python/tests/test_mllearn_numpy.py b/src/main/python/tests/test_mllearn_numpy.py
new file mode 100644
index 0000000..d030837
--- /dev/null
+++ b/src/main/python/tests/test_mllearn_numpy.py
@@ -0,0 +1,151 @@
+#!/usr/bin/python
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# To run:
+#   - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-class-path SystemML.jar test_mllearn_numpy.py`
+#   - Python 3: `PYSPARK_PYTHON=python3 spark-submit --master local[*] --driver-class-path SystemML.jar test_mllearn_numpy.py`
+
+# Make the `systemml` package importable
+import os
+import sys
+path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
+sys.path.insert(0, path)
+
+import unittest
+
+import numpy as np
+from pyspark.context import SparkContext
+from pyspark.ml import Pipeline
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.sql import SQLContext
+from sklearn import datasets, metrics, neighbors
+from sklearn.datasets import fetch_20newsgroups
+from sklearn.feature_extraction.text import TfidfVectorizer
+
+from systemml.mllearn import LinearRegression, LogisticRegression, NaiveBayes, SVM
+
+sc = SparkContext()
+sqlCtx = SQLContext(sc)
+
+# Currently not integrated with JUnit test
+# ~/spark-1.6.1-scala-2.11/bin/spark-submit --master local[*] --driver-class-path SystemML.jar test.py
+class TestMLLearn(unittest.TestCase):
+    def test_logistic(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:int(.9 * n_samples)]
+        y_train = y_digits[:int(.9 * n_samples)]
+        X_test = X_digits[int(.9 * n_samples):]
+        y_test = y_digits[int(.9 * n_samples):]
+        logistic = LogisticRegression(sqlCtx)
+        score = logistic.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+    
+    def test_logistic_mlpipeline(self):
+        training = sqlCtx.createDataFrame([
+            ("a b c d e spark", 1.0),
+            ("b d", 2.0),
+            ("spark f g h", 1.0),
+            ("hadoop mapreduce", 2.0),
+            ("b spark who", 1.0),
+            ("g d a y", 2.0),
+            ("spark fly", 1.0),
+            ("was mapreduce", 2.0),
+            ("e spark program", 1.0),
+            ("a e c l", 2.0),
+            ("spark compile", 1.0),
+            ("hadoop software", 2.0)
+            ], ["text", "label"])
+        tokenizer = Tokenizer(inputCol="text", outputCol="words")
+        hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
+        lr = LogisticRegression(sqlCtx)
+        pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+        model = pipeline.fit(training)
+        test = sqlCtx.createDataFrame([
+            ("spark i j k", 1.0),
+            ("l m n", 2.0),
+            ("mapreduce spark", 1.0),
+            ("apache hadoop", 2.0)], ["text", "label"])
+        result = model.transform(test)
+        predictionAndLabels = result.select("prediction", "label")
+        from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+        evaluator = MulticlassClassificationEvaluator()
+        score = evaluator.evaluate(predictionAndLabels)
+        self.failUnless(score == 1.0)
+
+    def test_linear_regression(self):
+        diabetes = datasets.load_diabetes()
+        diabetes_X = diabetes.data[:, np.newaxis, 2]
+        diabetes_X_train = diabetes_X[:-20]
+        diabetes_X_test = diabetes_X[-20:]
+        diabetes_y_train = diabetes.target[:-20]
+        diabetes_y_test = diabetes.target[-20:]
+        regr = LinearRegression(sqlCtx)
+        regr.fit(diabetes_X_train, diabetes_y_train)
+        score = regr.score(diabetes_X_test, diabetes_y_test)
+        self.failUnless(score > 0.4) # TODO: Improve r2-score (may be I am using it incorrectly)
+
+    def test_svm(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:int(.9 * n_samples)]
+        y_train = y_digits[:int(.9 * n_samples)]
+        X_test = X_digits[int(.9 * n_samples):]
+        y_test = y_digits[int(.9 * n_samples):]
+        svm = SVM(sqlCtx, is_multi_class=True)
+        score = svm.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+
+    def test_naive_bayes(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:int(.9 * n_samples)]
+        y_train = y_digits[:int(.9 * n_samples)]
+        X_test = X_digits[int(.9 * n_samples):]
+        y_test = y_digits[int(.9 * n_samples):]
+        nb = NaiveBayes(sqlCtx)
+        score = nb.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.8)
+        
+    #def test_naive_bayes1(self):
+    #    categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
+    #    newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
+    #    newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
+    #    vectorizer = TfidfVectorizer()
+    #    # Both vectors and vectors_test are SciPy CSR matrix
+    #    vectors = vectorizer.fit_transform(newsgroups_train.data)
+    #    vectors_test = vectorizer.transform(newsgroups_test.data)
+    #    nb = NaiveBayes(sqlCtx)
+    #    nb.fit(vectors, newsgroups_train.target)
+    #    pred = nb.predict(vectors_test)
+    #    score = metrics.f1_score(newsgroups_test.target, pred, average='weighted')
+    #    self.failUnless(score > 0.8)
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/test/config/hadoop_bin_windows/bin/.gitignore
----------------------------------------------------------------------
diff --git a/src/test/config/hadoop_bin_windows/bin/.gitignore b/src/test/config/hadoop_bin_windows/bin/.gitignore
new file mode 100644
index 0000000..e9d2125
--- /dev/null
+++ b/src/test/config/hadoop_bin_windows/bin/.gitignore
@@ -0,0 +1,2 @@
+/libiomp5md.dll
+/systemml.dll

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/test/java/org/apache/sysml/test/integration/functions/python/PythonTestRunner.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/python/PythonTestRunner.java b/src/test/java/org/apache/sysml/test/integration/functions/python/PythonTestRunner.java
new file mode 100644
index 0000000..afe6f5f
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/python/PythonTestRunner.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.test.integration.functions.python;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.lang.ProcessBuilder.Redirect;
+import java.util.Map;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.junit.Test;
+
+/**
+ * To run Python tests, please:
+ * 1. Set the RUN_PYTHON_TEST flag to true.
+ * 2. Set SPARK_HOME environment variable
+ * 3. Compile SystemML so that there is SystemML.jar in the target directory
+ */
+public class PythonTestRunner extends AutomatedTestBase
+{
+	
+	private static boolean RUN_PYTHON_TEST = false;
+	
+	private final static String TEST_NAME = "PythonTestRunner";
+	private final static String TEST_DIR = "functions/python/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + PythonTestRunner.class.getSimpleName() + "/";
+	
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, 
+				new String[] {"B"}));
+	}
+	
+	
+	@Test
+	public void testMLContext() throws DMLRuntimeException, IOException, InterruptedException  {
+		runPythonTest("test_mlcontext.py");
+	}
+	
+	@Test
+	public void testMatrixBinaryOp() throws DMLRuntimeException, IOException, InterruptedException  {
+		runPythonTest("test_matrix_binary_op.py");
+	}
+	
+	@Test
+	public void testMatrixAggFn() throws DMLRuntimeException, IOException, InterruptedException  {
+		runPythonTest("test_matrix_agg_fn.py");
+	}
+	
+	@Test
+	public void testMLLearn_df() throws DMLRuntimeException, IOException, InterruptedException  {
+		runPythonTest("test_mllearn_df.py");
+	}
+	
+	@Test
+	public void testMLLearn_numpy() throws DMLRuntimeException, IOException, InterruptedException  {
+		runPythonTest("test_mllearn_numpy.py");
+	}
+	
+	public void runPythonTest(String pythonFileName) throws IOException, DMLRuntimeException, InterruptedException {
+		if(!RUN_PYTHON_TEST)
+			return;
+			
+		if(!new File("target/SystemML.jar").exists()) {
+			throw new DMLRuntimeException("Please build the project before running PythonTestRunner");
+		}
+//		String [] args = { "--master", "local[*]", "--driver-class-path", "target/SystemML.jar", "src/main/python/tests/test_mlcontext.py"};
+//		org.apache.spark.deploy.SparkSubmit.main(args);
+		Map<String, String> env = System.getenv();
+		if(!env.containsKey("SPARK_HOME")) {
+			throw new DMLRuntimeException("Please set the SPARK_HOME environment variable");
+		}
+		String spark_submit = env.get("SPARK_HOME") + File.separator + "bin" + File.separator + "spark-submit";
+		if (System.getProperty("os.name").contains("Windows")) {
+			spark_submit += ".cmd";
+		}
+		Process p = new ProcessBuilder(spark_submit, "--master", "local[*]", 
+				"--driver-class-path", "target/SystemML.jar", "src/main/python/tests/" + pythonFileName)
+				.redirectError(Redirect.INHERIT)
+				.start();
+		
+		BufferedReader in = new BufferedReader(new InputStreamReader(p.getInputStream()));
+	    String line;
+	    boolean passed = false;
+	    while ((line = in.readLine()) != null) {
+	    	if(line.trim().equals("OK")) {
+	    		passed = true;
+	    	}
+	    	System.out.println(line);
+	    }
+	    
+		// System.out.println( IOUtils.toString(p.getInputStream(), Charset.defaultCharset())); 
+
+		p.waitFor();
+		
+		if(!passed) {
+			throw new DMLRuntimeException("The python test failed:" + pythonFileName);
+		}
+	}
+}


[2/2] incubator-systemml git commit: [SYSTEMML-1116] Make SystemML Python DSL NumPy-friendly

Posted by ni...@apache.org.
[SYSTEMML-1116] Make SystemML Python DSL NumPy-friendly

1. Added python test cases for matrix.
2. Added web documentation for all the Python APIs.
3. Added set_lazy method to enable and disable lazy evaluation.
4. matrix class itself has almost all basic linear algebra operators
supported by DML.
4. Updated SystemML.jar to *-incubating.jar
5. Added maven cleanup logic for python artifacts.
6. Integrated python testcases with maven (See
org.apache.sysml.test.integration.functions.python.PythonTestRunner). This
requires SPARK_HOME to be set.

Closes #290.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/23ccab85
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/23ccab85
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/23ccab85

Branch: refs/heads/master
Commit: 23ccab85c6639dc5d8ce40a3f2352c691246e6b9
Parents: 398490e
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Dec 2 16:21:13 2016 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Dec 2 16:25:43 2016 -0800

----------------------------------------------------------------------
 docs/_layouts/global.html                       |   1 +
 docs/beginners-guide-python.md                  |  28 +-
 docs/devdocs/python_api.html                    |  40 +-
 docs/index.md                                   |   2 +
 docs/python-reference.md                        | 953 +++++++++++++++++++
 pom.xml                                         |   9 +
 .../spark/utils/RDDConverterUtilsExt.java       |  33 +-
 .../sysml/runtime/matrix/data/MatrixBlock.java  |   2 +-
 src/main/python/LICENSE                         |  46 +
 src/main/python/MANIFEST.in                     |   2 +-
 src/main/python/pre_setup.py                    |   5 +-
 src/main/python/setup.py                        |   2 +-
 src/main/python/systemml/classloader.py         |   8 +-
 src/main/python/systemml/defmatrix.py           | 822 ++++++++++++----
 src/main/python/systemml/mlcontext.py           |   7 +-
 src/main/python/systemml/random/sampling.py     |   6 +-
 src/main/python/tests/test_matrix_agg_fn.py     |  95 ++
 src/main/python/tests/test_matrix_binary_op.py  | 138 +++
 src/main/python/tests/test_mllearn.py           | 190 ----
 src/main/python/tests/test_mllearn_df.py        | 108 +++
 src/main/python/tests/test_mllearn_numpy.py     | 151 +++
 .../config/hadoop_bin_windows/bin/.gitignore    |   2 +
 .../functions/python/PythonTestRunner.java      | 119 +++
 23 files changed, 2320 insertions(+), 449 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/docs/_layouts/global.html
----------------------------------------------------------------------
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 516c7b4..f7cb969 100644
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -57,6 +57,7 @@
                                 <li><a href="dml-language-reference.html">DML Language Reference</a></li>
                                 <li><a href="beginners-guide-to-dml-and-pydml.html">Beginner's Guide to DML and PyDML</a></li>
                                 <li><a href="beginners-guide-python.html">Beginner's Guide for Python users</a></li>
+                                <li><a href="python-reference.html">Reference Guide for Python users</a></li>
                                 <li class="divider"></li>
                                 <li><b>ML Algorithms:</b></li>
                                 <li><a href="algorithms-reference.html">Algorithms Reference</a></li>

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/docs/beginners-guide-python.md
----------------------------------------------------------------------
diff --git a/docs/beginners-guide-python.md b/docs/beginners-guide-python.md
index 8d597bf..d0598aa 100644
--- a/docs/beginners-guide-python.md
+++ b/docs/beginners-guide-python.md
@@ -46,7 +46,7 @@ Before you get started on SystemML, make sure that your environment is set up an
 
 ### Install Java (need Java 8) and Apache Spark
 
-If you already have a Apache Spark installation, you can skip this step.
+If you already have an Apache Spark installation, you can skip this step.
   
 <div class="codetabs">
 <div data-lang="OSX" markdown="1">
@@ -70,19 +70,18 @@ brew install apache-spark16
 
 ### Install SystemML
 
-#### Step 1: Install SystemML Python package 
-
 We are working towards uploading the python package on pypi. Until then, please use following commands: 
 
 ```bash
 git checkout https://github.com/apache/incubator-systemml.git
 cd incubator-systemml
 mvn post-integration-test -P distribution -DskipTests
-pip install src/main/python/dist/systemml-incubating-0.11.0.dev1.tar.gz
+pip install src/main/python/dist/systemml-incubating-0.12.0.dev1.tar.gz
 ```
 
 The above commands will install Python package and place the corresponding Java binaries (along with algorithms) into the installed location.
 To find the location of the downloaded Java binaries, use the following command:
+
 ```bash
 python -c 'import imp; import os; print os.path.join(imp.find_module("systemml")[1], "systemml-java")'
 ```
@@ -92,24 +91,16 @@ or download them from [SystemML website](http://systemml.apache.org/download.htm
 or build them from the [source](https://github.com/apache/incubator-systemml).
 
 To uninstall SystemML, please use following command:
+
 ```bash
 pip uninstall systemml-incubating
 ```
 
 ### Start Pyspark shell
 
-<div class="codetabs">
-<div data-lang="OSX" markdown="1">
-```bash
-pyspark --master local[*]
-```
-</div>
-<div data-lang="Linux" markdown="1">
 ```bash
 pyspark --master local[*]
 ```
-</div>
-</div>
 
 ## Matrix operations
 
@@ -122,7 +113,7 @@ m1 = sml.matrix(np.ones((3,3)) + 2)
 m2 = sml.matrix(np.ones((3,3)) + 3)
 m2 = m1 * (m2 + m1)
 m4 = 1.0 - m2
-m4.sum(axis=1).toNumPyArray()
+m4.sum(axis=1).toNumPy()
 ```
 
 Output:
@@ -156,7 +147,7 @@ X = sml.matrix(X_train)
 y = sml.matrix(y_train)
 A = X.transpose().dot(X)
 b = X.transpose().dot(y)
-beta = sml.solve(A, b).toNumPyArray()
+beta = sml.solve(A, b).toNumPy()
 y_predicted = X_test.dot(beta)
 print('Residual sum of squares: %.2f' % np.mean((y_predicted - y_test) ** 2)) 
 ```
@@ -333,7 +324,7 @@ from sklearn import datasets, neighbors
 from pyspark.sql import DataFrame, SQLContext
 import systemml as sml
 import pandas as pd
-import os
+import os, imp
 sqlCtx = SQLContext(sc)
 digits = datasets.load_digits()
 X_digits = digits.data
@@ -343,7 +334,8 @@ n_samples = len(X_digits)
 X_df = sqlCtx.createDataFrame(pd.DataFrame(X_digits[:.9 * n_samples]))
 y_df = sqlCtx.createDataFrame(pd.DataFrame(y_digits[:.9 * n_samples]))
 ml = sml.MLContext(sc)
-script = os.path.join(os.environ['SYSTEMML_HOME'], 'scripts', 'algorithms', 'MultiLogReg.dml')
-script = sml.dml(script).input(X=X_df, Y_vec=y_df).output("B_out")
+# Get the path of MultiLogReg.dml
+scriptPath = os.path.join(imp.find_module("systemml")[1], 'systemml-java', 'scripts', 'algorithms', 'MultiLogReg.dml')
+script = sml.dml(scriptPath).input(X=X_df, Y_vec=y_df).output("B_out")
 beta = ml.execute(script).get('B_out').toNumPy()
 ```

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/docs/devdocs/python_api.html
----------------------------------------------------------------------
diff --git a/docs/devdocs/python_api.html b/docs/devdocs/python_api.html
index 41a8e3e..93ec624 100644
--- a/docs/devdocs/python_api.html
+++ b/docs/devdocs/python_api.html
@@ -391,7 +391,7 @@ sparsity: Sparsity (between 0.0 and 1.0).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">sml</span><span class="o">.</span><span class="n">setSparkContext</span><span class="p">(</span><span class="n">sc</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">systemml</span> <span class="k">import</span> <span class="n">random</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span> <span class="o">=</span> <span class="n">sml</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
-<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPyArray</span><span class="p">()</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPy</span><span class="p">()</span>
 <span class="go">array([[ 3.48857226,  6.17261819,  2.51167259],</span>
 <span class="go">       [ 3.60506708, -1.90266305,  3.97601633],</span>
 <span class="go">       [ 3.62245706,  5.9430881 ,  2.53070413]])</span>
@@ -412,7 +412,7 @@ sparsity: Sparsity (between 0.0 and 1.0).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">sml</span><span class="o">.</span><span class="n">setSparkContext</span><span class="p">(</span><span class="n">sc</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">systemml</span> <span class="k">import</span> <span class="n">random</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span> <span class="o">=</span> <span class="n">sml</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
-<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPyArray</span><span class="p">()</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPy</span><span class="p">()</span>
 <span class="go">array([[ 0.54511396,  0.11937437,  0.72975775],</span>
 <span class="go">       [ 0.14135946,  0.01944448,  0.52544478],</span>
 <span class="go">       [ 0.67582422,  0.87068849,  0.02766852]])</span>
@@ -432,7 +432,7 @@ sparsity: Sparsity (between 0.0 and 1.0).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">sml</span><span class="o">.</span><span class="n">setSparkContext</span><span class="p">(</span><span class="n">sc</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">systemml</span> <span class="k">import</span> <span class="n">random</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span> <span class="o">=</span> <span class="n">sml</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">poisson</span><span class="p">(</span><span class="n">lam</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
-<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPyArray</span><span class="p">()</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPy</span><span class="p">()</span>
 <span class="go">array([[ 1.,  0.,  2.],</span>
 <span class="go">       [ 1.,  0.,  0.],</span>
 <span class="go">       [ 0.,  0.,  0.]])</span>
@@ -479,7 +479,7 @@ sparsity: Sparsity (between 0.0 and 1.0).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">sml</span><span class="o">.</span><span class="n">setSparkContext</span><span class="p">(</span><span class="n">sc</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">systemml</span> <span class="k">import</span> <span class="n">random</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span> <span class="o">=</span> <span class="n">sml</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
-<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPyArray</span><span class="p">()</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPy</span><span class="p">()</span>
 <span class="go">array([[ 3.48857226,  6.17261819,  2.51167259],</span>
 <span class="go">       [ 3.60506708, -1.90266305,  3.97601633],</span>
 <span class="go">       [ 3.62245706,  5.9430881 ,  2.53070413]])</span>
@@ -500,7 +500,7 @@ sparsity: Sparsity (between 0.0 and 1.0).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">sml</span><span class="o">.</span><span class="n">setSparkContext</span><span class="p">(</span><span class="n">sc</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">systemml</span> <span class="k">import</span> <span class="n">random</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span> <span class="o">=</span> <span class="n">sml</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
-<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPyArray</span><span class="p">()</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPy</span><span class="p">()</span>
 <span class="go">array([[ 0.54511396,  0.11937437,  0.72975775],</span>
 <span class="go">       [ 0.14135946,  0.01944448,  0.52544478],</span>
 <span class="go">       [ 0.67582422,  0.87068849,  0.02766852]])</span>
@@ -520,7 +520,7 @@ sparsity: Sparsity (between 0.0 and 1.0).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">sml</span><span class="o">.</span><span class="n">setSparkContext</span><span class="p">(</span><span class="n">sc</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">systemml</span> <span class="k">import</span> <span class="n">random</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span> <span class="o">=</span> <span class="n">sml</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">poisson</span><span class="p">(</span><span class="n">lam</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
-<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPyArray</span><span class="p">()</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">m1</span><span class="o">.</span><span class="n">toNumPy</span><span class="p">()</span>
 <span class="go">array([[ 1.,  0.,  2.],</span>
 <span class="go">       [ 1.,  0.,  0.],</span>
 <span class="go">       [ 0.,  0.,  0.]])</span>
@@ -607,7 +607,7 @@ and Pandas DataFrame).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m2</span> <span class="o">=</span> <span class="n">m1</span> <span class="o">*</span> <span class="p">(</span><span class="n">m2</span> <span class="o">+</span> <span class="n">m1</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m4</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">m2</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m4</span>
-<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPyArray() or toDataFrame() or toPandas() methods.</span>
+<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.</span>
 <span class="go">mVar1 = load(&quot; &quot;, format=&quot;csv&quot;)</span>
 <span class="go">mVar2 = load(&quot; &quot;, format=&quot;csv&quot;)</span>
 <span class="go">mVar3 = mVar2 + mVar1</span>
@@ -616,9 +616,9 @@ and Pandas DataFrame).</p>
 <span class="go">save(mVar5, &quot; &quot;)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m2</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m2</span>
-<span class="go"># This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPyArray() method.</span>
+<span class="go"># This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m4</span>
-<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPyArray() or toDataFrame() or toPandas() methods.</span>
+<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.</span>
 <span class="go">mVar4 = load(&quot; &quot;, format=&quot;csv&quot;)</span>
 <span class="go">mVar5 = 1.0 - mVar4</span>
 <span class="go">save(mVar5, &quot; &quot;)</span>
@@ -780,14 +780,14 @@ left-indexed-matrix[index] = value</p>
 <dd></dd></dl>
 
 <dl class="method">
-<dt id="systemml.defmatrix.matrix.toDataFrame">
-<code class="descname">toDataFrame</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toDataFrame"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.defmatrix.matrix.toDataFrame" title="Permalink to this definition">�</a></dt>
+<dt id="systemml.defmatrix.matrix.toDF">
+<code class="descname">toDF</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toDF"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.defmatrix.matrix.toDF" title="Permalink to this definition">�</a></dt>
 <dd><p>This is a convenience function that calls the global eval method and then converts the matrix object into DataFrame.</p>
 </dd></dl>
 
 <dl class="method">
-<dt id="systemml.defmatrix.matrix.toNumPyArray">
-<code class="descname">toNumPyArray</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toNumPyArray"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.defmatrix.matrix.toNumPyArray" title="Permalink to this definition">�</a></dt>
+<dt id="systemml.defmatrix.matrix.toNumPy">
+<code class="descname">toNumPy</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toNumPy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.defmatrix.matrix.toNumPy" title="Permalink to this definition">�</a></dt>
 <dd><p>This is a convenience function that calls the global eval method and then converts the matrix object into NumPy array.</p>
 </dd></dl>
 
@@ -1282,7 +1282,7 @@ and Pandas DataFrame).</p>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m2</span> <span class="o">=</span> <span class="n">m1</span> <span class="o">*</span> <span class="p">(</span><span class="n">m2</span> <span class="o">+</span> <span class="n">m1</span><span class="p">)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m4</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">m2</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m4</span>
-<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPyArray() or toDataFrame() or toPandas() methods.</span>
+<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.</span>
 <span class="go">mVar1 = load(&quot; &quot;, format=&quot;csv&quot;)</span>
 <span class="go">mVar2 = load(&quot; &quot;, format=&quot;csv&quot;)</span>
 <span class="go">mVar3 = mVar2 + mVar1</span>
@@ -1291,9 +1291,9 @@ and Pandas DataFrame).</p>
 <span class="go">save(mVar5, &quot; &quot;)</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m2</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m2</span>
-<span class="go"># This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPyArray() method.</span>
+<span class="go"># This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.</span>
 <span class="gp">&gt;&gt;&gt; </span><span class="n">m4</span>
-<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPyArray() or toDataFrame() or toPandas() methods.</span>
+<span class="go"># This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.</span>
 <span class="go">mVar4 = load(&quot; &quot;, format=&quot;csv&quot;)</span>
 <span class="go">mVar5 = 1.0 - mVar4</span>
 <span class="go">save(mVar5, &quot; &quot;)</span>
@@ -1455,14 +1455,14 @@ left-indexed-matrix[index] = value</p>
 <dd></dd></dl>
 
 <dl class="method">
-<dt id="systemml.matrix.toDataFrame">
-<code class="descname">toDataFrame</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toDataFrame"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.matrix.toDataFrame" title="Permalink to this definition">�</a></dt>
+<dt id="systemml.matrix.toDF">
+<code class="descname">toDF</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toDF"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.matrix.toDF" title="Permalink to this definition">�</a></dt>
 <dd><p>This is a convenience function that calls the global eval method and then converts the matrix object into DataFrame.</p>
 </dd></dl>
 
 <dl class="method">
-<dt id="systemml.matrix.toNumPyArray">
-<code class="descname">toNumPyArray</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toNumPyArray"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.matrix.toNumPyArray" title="Permalink to this definition">�</a></dt>
+<dt id="systemml.matrix.toNumPy">
+<code class="descname">toNumPy</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/systemml/defmatrix.html#matrix.toNumPy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#systemml.matrix.toNumPy" title="Permalink to this definition">�</a></dt>
 <dd><p>This is a convenience function that calls the global eval method and then converts the matrix object into NumPy array.</p>
 </dd></dl>
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/docs/index.md
----------------------------------------------------------------------
diff --git a/docs/index.md b/docs/index.md
index 3fcece6..6b91654 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -70,6 +70,8 @@ PyDML is a high-level Python-like declarative language for machine learning.
 An introduction to the basics of DML and PyDML.
 * [Beginner's Guide for Python users](beginners-guide-python) -
 Beginner's Guide for Python users.
+* [Reference Guide for Python users](python-reference) -
+Reference Guide for Python users.
 
 ## ML Algorithms
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/docs/python-reference.md
----------------------------------------------------------------------
diff --git a/docs/python-reference.md b/docs/python-reference.md
new file mode 100644
index 0000000..3c2bbc3
--- /dev/null
+++ b/docs/python-reference.md
@@ -0,0 +1,953 @@
+---
+layout: global
+title: Reference Guide for Python users
+description: Reference Guide for Python users
+---
+<!--
+{% comment %}
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+{% endcomment %}
+-->
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+<br/>
+
+## Introduction
+
+SystemML enables flexible, scalable machine learning. This flexibility is achieved through the specification of a high-level declarative machine learning language that comes in two flavors, 
+one with an R-like syntax (DML) and one with a Python-like syntax (PyDML).
+
+Algorithm scripts written in DML and PyDML can be run on Hadoop, on Spark, or in Standalone mode. 
+No script modifications are required to change between modes. SystemML automatically performs advanced optimizations 
+based on data and cluster characteristics, so much of the need to manually tweak algorithms is largely reduced or eliminated.
+To understand more about DML and PyDML, we recommend that you read [Beginner's Guide to DML and PyDML](https://apache.github.io/incubator-systemml/beginners-guide-to-dml-and-pydml.html).
+
+For convenience of Python users, SystemML exposes several language-level APIs that allow Python users to use SystemML
+and its algorithms without the need to know DML or PyDML. We explain these APIs in the below sections.
+
+## matrix API
+
+The matrix class allows users to perform linear algebra operations in SystemML using a NumPy-like interface.
+This class supports several arithmetic operators (such as +, -, *, /, ^, etc) and also supports most of NumPy's universal functions (i.e. ufuncs).
+
+The current version of NumPy explicitly disables overriding ufunc, but this should be enabled in next release. 
+Until then to test above code, please use:
+
+```bash
+git clone https://github.com/niketanpansare/numpy.git
+cd numpy
+python setup.py install
+```
+
+This will enable NumPy's functions to invoke matrix class:
+
+```python
+import systemml as sml
+import numpy as np
+m1 = sml.matrix(np.ones((3,3)) + 2)
+m2 = sml.matrix(np.ones((3,3)) + 3)
+np.add(m1, m2)
+``` 
+
+The matrix class doesnot support following ufuncs:
+
+- Complex number related ufunc (for example: `conj`)
+- Hyperbolic/inverse-hyperbolic functions (for example: sinh, arcsinh, cosh, ...)
+- Bitwise operators
+- Xor operator
+- Infinite/Nan-checking (for example: isreal, iscomplex, isfinite, isinf, isnan)
+- Other ufuncs: copysign, nextafter, modf, frexp, trunc.
+
+This class also supports several input/output formats such as NumPy arrays, Pandas DataFrame, SciPy sparse matrix and PySpark DataFrame.
+
+By default, the operations are evaluated lazily to avoid conversion overhead and also to maximize optimization scope.
+To disable lazy evaluation, please us `set_lazy` method:
+
+```python
+>>> import systemml as sml
+>>> import numpy as np
+>>> m1 = sml.matrix(np.ones((3,3)) + 2)
+
+Welcome to Apache SystemML!
+
+>>> m2 = sml.matrix(np.ones((3,3)) + 3)
+>>> np.add(m1, m2) + m1
+# This matrix (mVar4) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
+mVar2 = load(" ", format="csv")
+mVar1 = load(" ", format="csv")
+mVar3 = mVar1 + mVar2
+mVar4 = mVar3 + mVar1
+save(mVar4, " ")
+
+
+>>> sml.set_lazy(False)
+>>> m1 = sml.matrix(np.ones((3,3)) + 2)
+>>> m2 = sml.matrix(np.ones((3,3)) + 3)
+>>> np.add(m1, m2) + m1
+# This matrix (mVar8) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.
+``` 
+
+### Usage:
+
+```python
+import systemml as sml
+import numpy as np
+m1 = sml.matrix(np.ones((3,3)) + 2)
+m2 = sml.matrix(np.ones((3,3)) + 3)
+m2 = m1 * (m2 + m1)
+m4 = 1.0 - m2
+m4.sum(axis=1).toNumPy()
+```
+
+Output:
+
+```bash
+array([[-60.],
+       [-60.],
+       [-60.]])
+```
+
+
+### Reference Documentation:
+
+ *class*`systemml.defmatrix.matrix`(*data*, *op=None*)
+:   Bases: `object`
+
+    matrix class is a python wrapper that implements basic matrix
+    operators, matrix functions as well as converters to common Python
+    types (for example: Numpy arrays, PySpark DataFrame and Pandas
+    DataFrame).
+
+    The operators supported are:
+
+    1.  Arithmetic operators: +, -, *, /, //, %, \** as well as dot
+        (i.e. matrix multiplication)
+    2.  Indexing in the matrix
+    3.  Relational/Boolean operators: \<, \<=, \>, \>=, ==, !=, &, \|
+
+    In addition, following functions are supported for matrix:
+
+    1.  transpose
+    2.  Aggregation functions: sum, mean, var, sd, max, min, argmin,
+        argmax, cumsum
+    3.  Global statistical built-In functions: exp, log, abs, sqrt,
+        round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve
+
+    For all the above functions, we always return a two dimensional matrix, especially for aggregation functions with axis. 
+    For example: Assuming m1 is a matrix of (3, n), NumPy returns a 1d vector of dimension (3,) for operation m1.sum(axis=1)
+    whereas SystemML returns a 2d matrix of dimension (3, 1).
+    
+    Note: an evaluated matrix contains a data field computed by eval
+    method as DataFrame or NumPy array.
+
+        >>> import SystemML as sml
+        >>> import numpy as np
+        >>> sml.setSparkContext(sc)
+
+    Welcome to Apache SystemML!
+
+        >>> m1 = sml.matrix(np.ones((3,3)) + 2)
+        >>> m2 = sml.matrix(np.ones((3,3)) + 3)
+        >>> m2 = m1 * (m2 + m1)
+        >>> m4 = 1.0 - m2
+        >>> m4
+        # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
+        mVar1 = load(" ", format="csv")
+        mVar2 = load(" ", format="csv")
+        mVar3 = mVar2 + mVar1
+        mVar4 = mVar1 * mVar3
+        mVar5 = 1.0 - mVar4
+        save(mVar5, " ")
+        >>> m2.eval()
+        >>> m2
+        # This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.
+        >>> m4
+        # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
+        mVar4 = load(" ", format="csv")
+        mVar5 = 1.0 - mVar4
+        save(mVar5, " ")
+        >>> m4.sum(axis=1).toNumPy()
+        array([[-60.],
+               [-60.],
+               [-60.]])
+
+    Design Decisions:
+
+    1.  Until eval() method is invoked, we create an AST (not exposed to
+        the user) that consist of unevaluated operations and data
+        required by those operations. As an anology, a spark user can
+        treat eval() method similar to calling RDD.persist() followed by
+        RDD.count().
+    2.  The AST consist of two kinds of nodes: either of type matrix or
+        of type DMLOp. Both these classes expose \_visit method, that
+        helps in traversing the AST in DFS manner.
+    3.  A matrix object can either be evaluated or not. If evaluated,
+        the attribute 'data' is set to one of the supported types (for
+        example: NumPy array or DataFrame). In this case, the attribute
+        'op' is set to None. If not evaluated, the attribute 'op' which
+        refers to one of the intermediate node of AST and if of type
+        DMLOp. In this case, the attribute 'data' is set to None.
+
+    5.  DMLOp has an attribute 'inputs' which contains list of matrix
+        objects or DMLOp.
+
+    6.  To simplify the traversal, every matrix object is considered
+        immutable and an matrix operations creates a new matrix object.
+        As an example: m1 = sml.matrix(np.ones((3,3))) creates a matrix
+        object backed by 'data=(np.ones((3,3))'. m1 = m1 \* 2 will
+        create a new matrix object which is now backed by 'op=DMLOp( ...
+        )' whose input is earlier created matrix object.
+
+    7.  Left indexing (implemented in \_\_setitem\_\_ method) is a
+        special case, where Python expects the existing object to be
+        mutated. To ensure the above property, we make deep copy of
+        existing object and point any references to the left-indexed
+        matrix to the newly created object. Then the left-indexed matrix
+        is set to be backed by DMLOp consisting of following pydml:
+        left-indexed-matrix = new-deep-copied-matrix
+        left-indexed-matrix[index] = value
+
+    8.  Please use m.print\_ast() and/or type m for debugging. Here is a
+        sample session:
+
+            >>> npm = np.ones((3,3))
+            >>> m1 = sml.matrix(npm + 3)
+            >>> m2 = sml.matrix(npm + 5)
+            >>> m3 = m1 + m2
+            >>> m3
+            mVar2 = load(" ", format="csv")
+            mVar1 = load(" ", format="csv")
+            mVar3 = mVar1 + mVar2
+            save(mVar3, " ")
+            >>> m3.print_ast()
+            - [mVar3] (op).
+              - [mVar1] (data).
+              - [mVar2] (data).    
+
+ `abs`()
+:   
+
+ `acos`()
+:   
+
+ `arccos`()
+:   
+
+ `arcsin`()
+:   
+
+ `arctan`()
+:   
+
+ `argmax`(*axis=None*)
+:   Returns the indices of the maximum values along an axis.
+
+    axis : int, optional (only axis=1, i.e. rowIndexMax is supported
+    in this version)
+
+ `argmin`(*axis=None*)
+:   Returns the indices of the minimum values along an axis.
+
+    axis : int, optional (only axis=1, i.e. rowIndexMax is supported
+    in this version)
+
+ `asfptype`()
+:   
+
+ `asin`()
+:   
+
+ `astype`(*t*)
+:   
+
+ `atan`()
+:   
+
+ `ceil`()
+:   
+
+ `cos`()
+:   
+
+ `cumsum`(*axis=None*)
+:   Returns the indices of the maximum values along an axis.
+
+    axis : int, optional (only axis=0, i.e. cumsum along the rows is
+    supported in this version)
+
+ `deg2rad`()
+:   Convert angles from degrees to radians.
+
+ `dot`(*other*)[](#systemml.defmatrix.matrix.dot "Permalink to this definition")
+:   Numpy way of performing matrix multiplication
+
+ `eval`(*outputDF=False*)[](#systemml.defmatrix.matrix.eval "Permalink to this definition")
+:   This is a convenience function that calls the global eval method
+
+ `exp`()[](#systemml.defmatrix.matrix.exp "Permalink to this definition")
+:   
+
+ `exp2`()[](#systemml.defmatrix.matrix.exp2 "Permalink to this definition")
+:   
+
+ `expm1`()[](#systemml.defmatrix.matrix.expm1 "Permalink to this definition")
+:   
+
+ `floor`()[](#systemml.defmatrix.matrix.floor "Permalink to this definition")
+:   
+
+ `get_shape`()[](#systemml.defmatrix.matrix.get_shape "Permalink to this definition")
+:   
+
+ `ldexp`(*other*)[](#systemml.defmatrix.matrix.ldexp "Permalink to this definition")
+:   
+
+ `log`(*y=None*)[](#systemml.defmatrix.matrix.log "Permalink to this definition")
+:   
+
+ `log10`()[](#systemml.defmatrix.matrix.log10 "Permalink to this definition")
+:   
+
+ `log1p`()[](#systemml.defmatrix.matrix.log1p "Permalink to this definition")
+:   
+
+ `log2`()[](#systemml.defmatrix.matrix.log2 "Permalink to this definition")
+:   
+
+ `logaddexp`(*other*)[](#systemml.defmatrix.matrix.logaddexp "Permalink to this definition")
+:   
+
+ `logaddexp2`(*other*)[](#systemml.defmatrix.matrix.logaddexp2 "Permalink to this definition")
+:   
+
+ `logical_not`()[](#systemml.defmatrix.matrix.logical_not "Permalink to this definition")
+:   
+
+ `max`(*other=None*, *axis=None*)[](#systemml.defmatrix.matrix.max "Permalink to this definition")
+:   Compute the maximum value along the specified axis
+
+    other: matrix or numpy array (& other supported types) or scalar
+    axis : int, optional
+
+ `mean`(*axis=None*)[](#systemml.defmatrix.matrix.mean "Permalink to this definition")
+:   Compute the arithmetic mean along the specified axis
+
+    axis : int, optional
+
+ `min`(*other=None*, *axis=None*)[](#systemml.defmatrix.matrix.min "Permalink to this definition")
+:   Compute the minimum value along the specified axis
+
+    other: matrix or numpy array (& other supported types) or scalar
+    axis : int, optional
+
+ `mod`(*other*)[](#systemml.defmatrix.matrix.mod "Permalink to this definition")
+:   
+
+ `ndim`*= 2*[](#systemml.defmatrix.matrix.ndim "Permalink to this definition")
+:   
+
+ `negative`()[](#systemml.defmatrix.matrix.negative "Permalink to this definition")
+:   
+
+ `ones_like`()[](#systemml.defmatrix.matrix.ones_like "Permalink to this definition")
+:   
+
+ `print_ast`()[](#systemml.defmatrix.matrix.print_ast "Permalink to this definition")
+:   Please use m.print\_ast() and/or type m for debugging. Here is a
+    sample session:
+
+        >>> npm = np.ones((3,3))
+        >>> m1 = sml.matrix(npm + 3)
+        >>> m2 = sml.matrix(npm + 5)
+        >>> m3 = m1 + m2
+        >>> m3
+        mVar2 = load(" ", format="csv")
+        mVar1 = load(" ", format="csv")
+        mVar3 = mVar1 + mVar2
+        save(mVar3, " ")
+        >>> m3.print_ast()
+        - [mVar3] (op).
+          - [mVar1] (data).
+          - [mVar2] (data).
+
+ `rad2deg`()[](#systemml.defmatrix.matrix.rad2deg "Permalink to this definition")
+:   Convert angles from radians to degrees.
+
+ `reciprocal`()[](#systemml.defmatrix.matrix.reciprocal "Permalink to this definition")
+:   
+
+ `remainder`(*other*)[](#systemml.defmatrix.matrix.remainder "Permalink to this definition")
+:   
+
+ `round`()[](#systemml.defmatrix.matrix.round "Permalink to this definition")
+:   
+
+ `script`*= None*[](#systemml.defmatrix.matrix.script "Permalink to this definition")
+:   
+
+ `sd`(*axis=None*)[](#systemml.defmatrix.matrix.sd "Permalink to this definition")
+:   Compute the standard deviation along the specified axis
+
+    axis : int, optional
+
+ `set_shape`(*shape*)[](#systemml.defmatrix.matrix.set_shape "Permalink to this definition")
+:   
+
+ `shape`[](#systemml.defmatrix.matrix.shape "Permalink to this definition")
+:   
+
+ `sign`()[](#systemml.defmatrix.matrix.sign "Permalink to this definition")
+:   
+
+ `sin`()[](#systemml.defmatrix.matrix.sin "Permalink to this definition")
+:   
+
+ `sqrt`()[](#systemml.defmatrix.matrix.sqrt "Permalink to this definition")
+:   
+
+ `square`()[](#systemml.defmatrix.matrix.square "Permalink to this definition")
+:   
+
+ `sum`(*axis=None*)[](#systemml.defmatrix.matrix.sum "Permalink to this definition")
+:   Compute the sum along the specified axis. 
+
+    axis : int, optional
+
+ `systemmlVarID`*= 0*[](#systemml.defmatrix.matrix.systemmlVarID "Permalink to this definition")
+:   
+
+ `tan`()[](#systemml.defmatrix.matrix.tan "Permalink to this definition")
+:   
+
+ `toDF`()[](#systemml.defmatrix.matrix.toDF "Permalink to this definition")
+:   This is a convenience function that calls the global eval method
+    and then converts the matrix object into DataFrame.
+
+ `toNumPy`()[](#systemml.defmatrix.matrix.toNumPy "Permalink to this definition")
+:   This is a convenience function that calls the global eval method
+    and then converts the matrix object into NumPy array.
+
+ `toPandas`()[](#systemml.defmatrix.matrix.toPandas "Permalink to this definition")
+:   This is a convenience function that calls the global eval method
+    and then converts the matrix object into Pandas DataFrame.
+
+ `trace`()[](#systemml.defmatrix.matrix.trace "Permalink to this definition")
+:   Return the sum of the cells of the main diagonal square matrix
+
+ `transpose`()[](#systemml.defmatrix.matrix.transpose "Permalink to this definition")
+:   Transposes the matrix.
+
+ `var`(*axis=None*)[](#systemml.defmatrix.matrix.var "Permalink to this definition")
+:   Compute the variance along the specified axis
+
+    axis : int, optional
+
+ `zeros_like`()[](#systemml.defmatrix.matrix.zeros_like "Permalink to this definition")
+:   
+
+ `systemml.defmatrix.eval`(*outputs*, *outputDF=False*, *execute=True*)[](#systemml.defmatrix.eval "Permalink to this definition")
+:   Executes the unevaluated DML script and computes the matrices
+    specified by outputs.
+
+    outputs: list of matrices or a matrix object outputDF: back the data
+    of matrix as PySpark DataFrame
+
+ `systemml.defmatrix.solve`(*A*, *b*)[](#systemml.defmatrix.solve "Permalink to this definition")
+:   Computes the least squares solution for system of linear equations A
+    %\*% x = b
+
+        >>> import numpy as np
+        >>> from sklearn import datasets
+        >>> import SystemML as sml
+        >>> from pyspark.sql import SQLContext
+        >>> diabetes = datasets.load_diabetes()
+        >>> diabetes_X = diabetes.data[:, np.newaxis, 2]
+        >>> X_train = diabetes_X[:-20]
+        >>> X_test = diabetes_X[-20:]
+        >>> y_train = diabetes.target[:-20]
+        >>> y_test = diabetes.target[-20:]
+        >>> sml.setSparkContext(sc)
+        >>> X = sml.matrix(X_train)
+        >>> y = sml.matrix(y_train)
+        >>> A = X.transpose().dot(X)
+        >>> b = X.transpose().dot(y)
+        >>> beta = sml.solve(A, b).toNumPy()
+        >>> y_predicted = X_test.dot(beta)
+        >>> print('Residual sum of squares: %.2f' % np.mean((y_predicted - y_test) ** 2))
+        Residual sum of squares: 25282.12
+
+ `systemml.defmatrix.set_lazy`(*isLazy*)[](#systemml.defmatrix.set_max_depth "Permalink to this definition")
+:   This method allows users to set whether the matrix operations should be executed in lazy manner.
+
+    isLazy: True if matrix operations should be evaluated in lazy manner.
+
+ `systemml.defmatrix.debug_array_conversion`(*throwError*)[](#systemml.defmatrix.debug_array_conversion "Permalink to this definition")
+:   
+
+ `systemml.random.sampling.normal`(*loc=0.0*, *scale=1.0*, *size=(1*, *1)*, *sparsity=1.0*)(#systemml.random.sampling.normal "Permalink to this definition")
+:   Draw random samples from a normal (Gaussian) distribution.
+
+    loc: Mean ('centre') of the distribution. scale: Standard deviation
+    (spread or 'width') of the distribution. size: Output shape (only
+    tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity
+    (between 0.0 and 1.0).
+
+        >>> import systemml as sml
+        >>> import numpy as np
+        >>> sml.setSparkContext(sc)
+        >>> from systemml import random
+        >>> m1 = sml.random.normal(loc=3, scale=2, size=(3,3))
+        >>> m1.toNumPy()
+        array([[ 3.48857226,  6.17261819,  2.51167259],
+               [ 3.60506708, -1.90266305,  3.97601633],
+               [ 3.62245706,  5.9430881 ,  2.53070413]])
+
+ `systemml.random.sampling.uniform`(*low=0.0*, *high=1.0*, *size=(1*, *1)*, *sparsity=1.0*)(#systemml.random.sampling.uniform "Permalink to this definition")
+:   Draw samples from a uniform distribution.
+
+    low: Lower boundary of the output interval. high: Upper boundary of
+    the output interval. size: Output shape (only tuple of length 2,
+    i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).
+
+        >>> import systemml as sml
+        >>> import numpy as np
+        >>> sml.setSparkContext(sc)
+        >>> from systemml import random
+        >>> m1 = sml.random.uniform(size=(3,3))
+        >>> m1.toNumPy()
+        array([[ 0.54511396,  0.11937437,  0.72975775],
+               [ 0.14135946,  0.01944448,  0.52544478],
+               [ 0.67582422,  0.87068849,  0.02766852]])
+
+ `systemml.random.sampling.poisson`(*lam=1.0*, *size=(1*, *1)*, *sparsity=1.0*)(#systemml.random.sampling.poisson "Permalink to this definition")
+:   Draw samples from a Poisson distribution.
+
+    lam: Expectation of interval, should be \> 0. size: Output shape
+    (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity
+    (between 0.0 and 1.0).
+
+        >>> import systemml as sml
+        >>> import numpy as np
+        >>> sml.setSparkContext(sc)
+        >>> from systemml import random
+        >>> m1 = sml.random.poisson(lam=1, size=(3,3))
+        >>> m1.toNumPy()
+        array([[ 1.,  0.,  2.],
+               [ 1.,  0.,  0.],
+               [ 0.,  0.,  0.]])
+
+
+
+## MLContext API
+
+The Spark MLContext API offers a programmatic interface for interacting with SystemML from Spark using languages such as Scala, Java, and Python. 
+As a result, it offers a convenient way to interact with SystemML from the Spark Shell and from Notebooks such as Jupyter and Zeppelin.
+
+### Usage
+
+The below example demonstrates how to invoke the algorithm [scripts/algorithms/MultiLogReg.dml](https://github.com/apache/incubator-systemml/blob/master/scripts/algorithms/MultiLogReg.dml)
+using Python [MLContext API](https://apache.github.io/incubator-systemml/spark-mlcontext-programming-guide).
+
+```python
+from sklearn import datasets, neighbors
+from pyspark.sql import DataFrame, SQLContext
+import systemml as sml
+import pandas as pd
+import os, imp
+sqlCtx = SQLContext(sc)
+digits = datasets.load_digits()
+X_digits = digits.data
+y_digits = digits.target + 1
+n_samples = len(X_digits)
+# Split the data into training/testing sets and convert to PySpark DataFrame
+X_df = sqlCtx.createDataFrame(pd.DataFrame(X_digits[:.9 * n_samples]))
+y_df = sqlCtx.createDataFrame(pd.DataFrame(y_digits[:.9 * n_samples]))
+ml = sml.MLContext(sc)
+# Get the path of MultiLogReg.dml
+scriptPath = os.path.join(imp.find_module("systemml")[1], 'systemml-java', 'scripts', 'algorithms', 'MultiLogReg.dml')
+script = sml.dml(scriptPath).input(X=X_df, Y_vec=y_df).output("B_out")
+beta = ml.execute(script).get('B_out').toNumPy()
+```
+
+### Reference documentation
+
+ *class*`systemml.mlcontext.MLResults`(*results*, *sc*)[](#systemml.mlcontext.MLResults "Permalink to this definition")
+:   Bases: `object`{.xref .py .py-class .docutils .literal}
+
+    Wrapper around a Java ML Results object.
+
+    results: JavaObject
+    :   A Java MLResults object as returned by calling ml.execute().
+    sc: SparkContext
+    :   SparkContext
+
+     `get`(*\*outputs*)[](#systemml.mlcontext.MLResults.get "Permalink to this definition")
+    :   outputs: string, list of strings
+        :   Output variables as defined inside the DML script.
+
+ *class*`systemml.mlcontext.MLContext`(*sc*)[](#systemml.mlcontext.MLContext "Permalink to this definition")
+:   Bases: `object`{.xref .py .py-class .docutils .literal}
+
+    Wrapper around the new SystemML MLContext.
+
+    sc: SparkContext
+    :   SparkContext
+
+ `execute`(*script*)[](#systemml.mlcontext.MLContext.execute "Permalink to this definition")
+:   Execute a DML / PyDML script.
+
+    script: Script instance
+    :   Script instance defined with the appropriate input and
+        output variables.
+
+    ml\_results: MLResults
+    :   MLResults instance.
+
+ `setExplain`(*explain*)[](#systemml.mlcontext.MLContext.setExplain "Permalink to this definition")
+:   Explanation about the program. Mainly intended for developers.
+
+    explain: boolean
+
+ `setExplainLevel`(*explainLevel*)[](#systemml.mlcontext.MLContext.setExplainLevel "Permalink to this definition")
+:   Set explain level.
+
+    explainLevel: string
+    :   Can be one of 'hops', 'runtime', 'recompile\_hops',
+        'recompile\_runtime' or in the above in upper case.
+
+ `setStatistics`(*statistics*)[](#systemml.mlcontext.MLContext.setStatistics "Permalink to this definition")
+:   Whether or not to output statistics (such as execution time,
+    elapsed time) about script executions.
+
+    statistics: boolean
+
+ `setStatisticsMaxHeavyHitters`(*maxHeavyHitters*)[](#systemml.mlcontext.MLContext.setStatisticsMaxHeavyHitters "Permalink to this definition")
+:   The maximum number of heavy hitters that are printed as part of
+    the statistics.
+
+    maxHeavyHitters: int
+
+ *class*`systemml.mlcontext.Script`(*scriptString*, *scriptType='dml'*)[](#systemml.mlcontext.Script "Permalink to this definition")
+:   Bases: `object`{.xref .py .py-class .docutils .literal}
+
+    Instance of a DML/PyDML Script.
+
+    scriptString: string
+    :   Can be either a file path to a DML script or a DML script
+        itself.
+    scriptType: string
+    :   Script language, either 'dml' for DML (R-like) or 'pydml' for
+        PyDML (Python-like).
+
+ `input`(*\*args*, *\*\*kwargs*)[](#systemml.mlcontext.Script.input "Permalink to this definition")
+:   args: name, value tuple
+    :   where name is a string, and currently supported value
+        formats are double, string, dataframe, rdd, and list of such
+        object.
+    kwargs: dict of name, value pairs
+    :   To know what formats are supported for name and value, look
+        above.
+
+ `output`(*\*names*)[](#systemml.mlcontext.Script.output "Permalink to this definition")
+:   names: string, list of strings
+    :   Output variables as defined inside the DML script.
+
+ `systemml.mlcontext.dml`(*scriptString*)[](#systemml.mlcontext.dml "Permalink to this definition")
+:   Create a dml script object based on a string.
+
+    scriptString: string
+    :   Can be a path to a dml script or a dml script itself.
+
+    script: Script instance
+    :   Instance of a script object.
+
+ `systemml.mlcontext.pydml`(*scriptString*)[](#systemml.mlcontext.pydml "Permalink to this definition")
+:   Create a pydml script object based on a string.
+
+    scriptString: string
+    :   Can be a path to a pydml script or a pydml script itself.
+
+    script: Script instance
+    :   Instance of a script object.
+
+ `systemml.mlcontext.getNumCols`(*numPyArr*)[](#systemml.mlcontext.getNumCols "Permalink to this definition")
+:   
+
+ `systemml.mlcontext.convertToMatrixBlock`(*sc*, *src*)[](#systemml.mlcontext.convertToMatrixBlock "Permalink to this definition")
+:   
+
+ `systemml.mlcontext.convertToNumPyArr`(*sc*, *mb*)[](#systemml.mlcontext.convertToNumPyArr "Permalink to this definition")
+:   
+
+ `systemml.mlcontext.convertToPandasDF`(*X*)[](#systemml.mlcontext.convertToPandasDF "Permalink to this definition")
+:   
+
+ `systemml.mlcontext.convertToLabeledDF`(*sqlCtx*, *X*, *y=None*)[](#systemml.mlcontext.convertToLabeledDF "Permalink to this definition")
+:   
+
+
+## mllearn API
+
+### Usage
+
+```python
+# Scikit-learn way
+from sklearn import datasets, neighbors
+from systemml.mllearn import LogisticRegression
+from pyspark.sql import SQLContext
+sqlCtx = SQLContext(sc)
+digits = datasets.load_digits()
+X_digits = digits.data
+y_digits = digits.target 
+n_samples = len(X_digits)
+X_train = X_digits[:.9 * n_samples]
+y_train = y_digits[:.9 * n_samples]
+X_test = X_digits[.9 * n_samples:]
+y_test = y_digits[.9 * n_samples:]
+logistic = LogisticRegression(sqlCtx)
+print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))
+```
+
+Output:
+
+```bash
+LogisticRegression score: 0.922222
+```
+
+### Reference documentation
+
+ *class*`systemml.mllearn.estimators.LinearRegression`(*sqlCtx*, *fit\_intercept=True*, *max\_iter=100*, *tol=1e-06*, *C=1.0*, *solver='newton-cg'*, *transferUsingDF=False*)(#systemml.mllearn.estimators.LinearRegression "Permalink to this definition")
+:   Bases: `systemml.mllearn.estimators.BaseSystemMLRegressor`{.xref .py
+    .py-class .docutils .literal}
+
+    Performs linear regression to model the relationship between one
+    numerical response variable and one or more explanatory (feature)
+    variables.
+
+        >>> import numpy as np
+        >>> from sklearn import datasets
+        >>> from systemml.mllearn import LinearRegression
+        >>> from pyspark.sql import SQLContext
+        >>> # Load the diabetes dataset
+        >>> diabetes = datasets.load_diabetes()
+        >>> # Use only one feature
+        >>> diabetes_X = diabetes.data[:, np.newaxis, 2]
+        >>> # Split the data into training/testing sets
+        >>> diabetes_X_train = diabetes_X[:-20]
+        >>> diabetes_X_test = diabetes_X[-20:]
+        >>> # Split the targets into training/testing sets
+        >>> diabetes_y_train = diabetes.target[:-20]
+        >>> diabetes_y_test = diabetes.target[-20:]
+        >>> # Create linear regression object
+        >>> regr = LinearRegression(sqlCtx, solver='newton-cg')
+        >>> # Train the model using the training sets
+        >>> regr.fit(diabetes_X_train, diabetes_y_train)
+        >>> # The mean square error
+        >>> print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))
+
+ *class*`systemml.mllearn.estimators.LogisticRegression`(*sqlCtx*, *penalty='l2'*, *fit\_intercept=True*, *max\_iter=100*, *max\_inner\_iter=0*, *tol=1e-06*, *C=1.0*, *solver='newton-cg'*, *transferUsingDF=False*)(#systemml.mllearn.estimators.LogisticRegression "Permalink to this definition")
+:   Bases: `systemml.mllearn.estimators.BaseSystemMLClassifier`{.xref
+    .py .py-class .docutils .literal}
+
+    Performs both binomial and multinomial logistic regression.
+
+    Scikit-learn way
+
+        >>> from sklearn import datasets, neighbors
+        >>> from systemml.mllearn import LogisticRegression
+        >>> from pyspark.sql import SQLContext
+        >>> sqlCtx = SQLContext(sc)
+        >>> digits = datasets.load_digits()
+        >>> X_digits = digits.data
+        >>> y_digits = digits.target + 1
+        >>> n_samples = len(X_digits)
+        >>> X_train = X_digits[:.9 * n_samples]
+        >>> y_train = y_digits[:.9 * n_samples]
+        >>> X_test = X_digits[.9 * n_samples:]
+        >>> y_test = y_digits[.9 * n_samples:]
+        >>> logistic = LogisticRegression(sqlCtx)
+        >>> print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))
+
+    MLPipeline way
+
+        >>> from pyspark.ml import Pipeline
+        >>> from systemml.mllearn import LogisticRegression
+        >>> from pyspark.ml.feature import HashingTF, Tokenizer
+        >>> from pyspark.sql import SQLContext
+        >>> sqlCtx = SQLContext(sc)
+        >>> training = sqlCtx.createDataFrame([
+        >>>     (0L, "a b c d e spark", 1.0),
+        >>>     (1L, "b d", 2.0),
+        >>>     (2L, "spark f g h", 1.0),
+        >>>     (3L, "hadoop mapreduce", 2.0),
+        >>>     (4L, "b spark who", 1.0),
+        >>>     (5L, "g d a y", 2.0),
+        >>>     (6L, "spark fly", 1.0),
+        >>>     (7L, "was mapreduce", 2.0),
+        >>>     (8L, "e spark program", 1.0),
+        >>>     (9L, "a e c l", 2.0),
+        >>>     (10L, "spark compile", 1.0),
+        >>>     (11L, "hadoop software", 2.0)
+        >>> ], ["id", "text", "label"])
+        >>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
+        >>> hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
+        >>> lr = LogisticRegression(sqlCtx)
+        >>> pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+        >>> model = pipeline.fit(training)
+        >>> test = sqlCtx.createDataFrame([
+        >>>     (12L, "spark i j k"),
+        >>>     (13L, "l m n"),
+        >>>     (14L, "mapreduce spark"),
+        >>>     (15L, "apache hadoop")], ["id", "text"])
+        >>> prediction = model.transform(test)
+        >>> prediction.show()
+
+ *class*`systemml.mllearn.estimators.SVM`(*sqlCtx*, *fit\_intercept=True*, *max\_iter=100*, *tol=1e-06*, *C=1.0*, *is\_multi\_class=False*, *transferUsingDF=False*)(#systemml.mllearn.estimators.SVM "Permalink to this definition")
+:   Bases: `systemml.mllearn.estimators.BaseSystemMLClassifier`{.xref
+    .py .py-class .docutils .literal}
+
+    Performs both binary-class and multiclass SVM (Support Vector
+    Machines).
+
+        >>> from sklearn import datasets, neighbors
+        >>> from systemml.mllearn import SVM
+        >>> from pyspark.sql import SQLContext
+        >>> sqlCtx = SQLContext(sc)
+        >>> digits = datasets.load_digits()
+        >>> X_digits = digits.data
+        >>> y_digits = digits.target 
+        >>> n_samples = len(X_digits)
+        >>> X_train = X_digits[:.9 * n_samples]
+        >>> y_train = y_digits[:.9 * n_samples]
+        >>> X_test = X_digits[.9 * n_samples:]
+        >>> y_test = y_digits[.9 * n_samples:]
+        >>> svm = SVM(sqlCtx, is_multi_class=True)
+        >>> print('LogisticRegression score: %f' % svm.fit(X_train, y_train).score(X_test, y_test))
+
+ *class*`systemml.mllearn.estimators.NaiveBayes`(*sqlCtx*, *laplace=1.0*, *transferUsingDF=False*)(#systemml.mllearn.estimators.NaiveBayes "Permalink to this definition")
+:   Bases: `systemml.mllearn.estimators.BaseSystemMLClassifier`{.xref
+    .py .py-class .docutils .literal}
+
+    Performs Naive Bayes.
+
+        >>> from sklearn.datasets import fetch_20newsgroups
+        >>> from sklearn.feature_extraction.text import TfidfVectorizer
+        >>> from systemml.mllearn import NaiveBayes
+        >>> from sklearn import metrics
+        >>> from pyspark.sql import SQLContext
+        >>> sqlCtx = SQLContext(sc)
+        >>> categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
+        >>> newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
+        >>> newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
+        >>> vectorizer = TfidfVectorizer()
+        >>> # Both vectors and vectors_test are SciPy CSR matrix
+        >>> vectors = vectorizer.fit_transform(newsgroups_train.data)
+        >>> vectors_test = vectorizer.transform(newsgroups_test.data)
+        >>> nb = NaiveBayes(sqlCtx)
+        >>> nb.fit(vectors, newsgroups_train.target)
+        >>> pred = nb.predict(vectors_test)
+        >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
+
+
+## Utility classes (used internally)
+
+### systemml.classloader 
+
+ `systemml.classloader.createJavaObject`(*sc*, *obj\_type*)[](#systemml.classloader.createJavaObject "Permalink to this definition")
+:   Performs appropriate check if SystemML.jar is available and returns
+    the handle to MLContext object on JVM
+
+    sc: SparkContext
+    :   SparkContext
+
+    obj\_type: Type of object to create ('mlcontext' or 'dummy')
+
+### systemml.converters
+
+ `systemml.converters.getNumCols`(*numPyArr*)[](#systemml.converters.getNumCols "Permalink to this definition")
+:   
+
+ `systemml.converters.convertToMatrixBlock`(*sc*, *src*)[](#systemml.converters.convertToMatrixBlock "Permalink to this definition")
+:   
+
+ `systemml.converters.convertToNumPyArr`(*sc*, *mb*)[](#systemml.converters.convertToNumPyArr "Permalink to this definition")
+:   
+
+ `systemml.converters.convertToPandasDF`(*X*)[](#systemml.converters.convertToPandasDF "Permalink to this definition")
+:   
+
+ `systemml.converters.convertToLabeledDF`(*sqlCtx*, *X*, *y=None*)[](#systemml.converters.convertToLabeledDF "Permalink to this definition")
+:  
+
+### Other classes from systemml.defmatrix
+
+ *class*`systemml.defmatrix.DMLOp`(*inputs*, *dml=None*)[](#systemml.defmatrix.DMLOp "Permalink to this definition")
+:   Bases: `object`{.xref .py .py-class .docutils .literal}
+
+    Represents an intermediate node of Abstract syntax tree created to
+    generate the PyDML script
+
+
+## Troubleshooting Python APIs
+
+#### Unable to load SystemML.jar into current pyspark session.
+
+While using SystemML's Python package through pyspark or notebook (SparkContext is not previously created in the session), the
+below method is not required. However, if the user wishes to use SystemML through spark-submit and has not previously invoked 
+
+ `systemml.defmatrix.setSparkContext`(*sc*)
+:   Before using the matrix, the user needs to invoke this function if SparkContext is not previously created in the session.
+
+    sc: SparkContext
+    :   SparkContext
+
+Example:
+
+```python
+import systemml as sml
+import numpy as np
+sml.setSparkContext(sc)
+m1 = sml.matrix(np.ones((3,3)) + 2)
+m2 = sml.matrix(np.ones((3,3)) + 3)
+m2 = m1 * (m2 + m1)
+m4 = 1.0 - m2
+m4.sum(axis=1).toNumPy()
+```
+
+If SystemML was not installed via pip, you may have to download SystemML.jar and provide it to pyspark via `--driver-class-path` and `--jars`. 
+
+#### matrix API is running slow when set_lazy(False) or when eval() is called often.
+
+This is a known issue. The matrix API is slow in this scenario due to slow Py4J conversion from Java MatrixObject or Java RDD to Python NumPy or DataFrame.
+To resolve this for now, we recommend writing the matrix to FileSystemML and using `load` function.
+
+#### maximum recursion depth exceeded
+
+SystemML matrix is backed by lazy evaluation and uses a recursive Depth First Search (DFS).
+Python can throw `RuntimeError: maximum recursion depth exceeded` when the recursion of DFS exceeds beyond the limit 
+set by Python. There are two ways to address it:
+
+1. Increase the limit in Python:
+ 
+	```python
+	import sys
+	some_large_number = 2000
+	sys.setrecursionlimit(some_large_number)
+	```
+
+2. Evaluate the intermeditate matrix to cut-off large recursion.
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 006a1f3..d78967e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -433,6 +433,15 @@
 										<include>original-*.jar</include>
 									</includes>
 								</fileset>
+								<fileset>
+									<directory>src/main/python/dist</directory>
+								</fileset>
+								<fileset>
+									<directory>src/main/python/systemml_incubating.egg-info</directory>
+								</fileset>
+								<fileset>
+									<directory>src/main/python/systemml/systemml-java</directory>
+								</fileset>
 							</filesets>
 						</configuration>
 					</execution>

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index 497136e..ab85f1e 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -47,6 +47,7 @@ import org.apache.spark.sql.SQLContext;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
+
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 
@@ -201,11 +202,18 @@ public class RDDConverterUtilsExt
 		return df.select(columns.get(0), scala.collection.JavaConversions.asScalaBuffer(columnToSelect).toList());
 	}
 	
+	public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen) throws DMLRuntimeException {
+		return convertPy4JArrayToMB(data, (int)rlen, (int)clen, false);
+	}
 	
 	public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen) throws DMLRuntimeException {
 		return convertPy4JArrayToMB(data, rlen, clen, false);
 	}
 	
+	public static MatrixBlock convertSciPyCOOToMB(byte [] data, byte [] row, byte [] col, long rlen, long clen, long nnz) throws DMLRuntimeException {
+		return convertSciPyCOOToMB(data, row, col, (int)rlen, (int)clen, (int)nnz);
+	}
+	
 	public static MatrixBlock convertSciPyCOOToMB(byte [] data, byte [] row, byte [] col, int rlen, int clen, int nnz) throws DMLRuntimeException {
 		MatrixBlock mb = new MatrixBlock(rlen, clen, true);
 		mb.allocateSparseRowsBlock(false);
@@ -224,6 +232,10 @@ public class RDDConverterUtilsExt
 		return mb;
 	}
 	
+	public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen, boolean isSparse) throws DMLRuntimeException {
+		return convertPy4JArrayToMB(data, (int) rlen, (int) clen, isSparse);
+	}
+	
 	public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen, boolean isSparse) throws DMLRuntimeException {
 		MatrixBlock mb = new MatrixBlock(rlen, clen, isSparse, -1);
 		if(isSparse) {
@@ -245,19 +257,20 @@ public class RDDConverterUtilsExt
 	public static byte [] convertMBtoPy4JDenseArr(MatrixBlock mb) throws DMLRuntimeException {
 		byte [] ret = null;
 		if(mb.isInSparseFormat()) {
+			mb.sparseToDense();
+//			throw new DMLRuntimeException("Sparse to dense conversion is not yet implemented");
+		}
+		
+		double [] denseBlock = mb.getDenseBlock();
+		if(denseBlock == null) {
 			throw new DMLRuntimeException("Sparse to dense conversion is not yet implemented");
 		}
-		else {
-			double [] denseBlock = mb.getDenseBlock();
-			if(denseBlock == null) {
-				throw new DMLRuntimeException("Sparse to dense conversion is not yet implemented");
-			}
-			int times = Double.SIZE / Byte.SIZE;
-			ret = new byte[denseBlock.length * times];
-			for(int i=0;i < denseBlock.length;i++){
-		        ByteBuffer.wrap(ret, i*times, times).order(ByteOrder.nativeOrder()).putDouble(denseBlock[i]);
-			}
+		int times = Double.SIZE / Byte.SIZE;
+		ret = new byte[denseBlock.length * times];
+		for(int i=0;i < denseBlock.length;i++){
+	        ByteBuffer.wrap(ret, i*times, times).order(ByteOrder.nativeOrder()).putDouble(denseBlock[i]);
 		}
+		
 		return ret;
 	}
 	

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 01ff6d3..db53ef6 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -1088,7 +1088,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		denseBlock = null;
 	}
 
-	void sparseToDense() 
+	public void sparseToDense() 
 		throws DMLRuntimeException 
 	{	
 		//set target representation

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/LICENSE
----------------------------------------------------------------------
diff --git a/src/main/python/LICENSE b/src/main/python/LICENSE
index 8f71f43..9ae3ea8 100644
--- a/src/main/python/LICENSE
+++ b/src/main/python/LICENSE
@@ -200,3 +200,49 @@
    See the License for the specific language governing permissions and
    limitations under the License.
 
+===============================================================================
+
+The following compile-scope dependencies come under the Apache Software License 2.0.
+
+Apache Wink :: JSON4J (http://www.apache.org/wink/wink-json4j/) org.apache.wink:wink-json4j:1.4
+
+===============================================================================
+
+The following compile-scope ANTLR dependencies are distributed under the BSD license.
+
+ANTLR 4 Runtime (http://www.antlr.org/antlr4-runtime) org.antlr:antlr4-runtime:4.5.3
+
+Copyright (c) 2012 Terence Parr and Sam Harwell
+All rights reserved.
+
+BSD license:
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ - Redistributions of source code must retain the above copyright notice, this
+   list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright notice,
+   this list of conditions and the following disclaimer in the documentation
+   and/or other materials provided with the distribution.
+
+ - Neither the name of the author nor the names of
+   contributors may be used to endorse or promote products derived from this
+   software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
+WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY
+DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
+USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
+OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/MANIFEST.in
----------------------------------------------------------------------
diff --git a/src/main/python/MANIFEST.in b/src/main/python/MANIFEST.in
index 2dea648..6e78ded 100644
--- a/src/main/python/MANIFEST.in
+++ b/src/main/python/MANIFEST.in
@@ -20,9 +20,9 @@
 #-------------------------------------------------------------
 include LICENSE
 include NOTICE
-include systemml/systemml-java/SystemML.jar
 include DISCLAIMER
 include systemml/systemml-java/scripts/sparkDML.sh
+recursive-include systemml/systemml-java *
 recursive-include systemml/systemml-java/scripts/algorithms *
 recursive-include systemml/systemml-java/scripts/datagen *
 recursive-include systemml/systemml-java/scripts/utils *
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/pre_setup.py
----------------------------------------------------------------------
diff --git a/src/main/python/pre_setup.py b/src/main/python/pre_setup.py
index e02e5a4..202f7d1 100644
--- a/src/main/python/pre_setup.py
+++ b/src/main/python/pre_setup.py
@@ -20,6 +20,7 @@
 #-------------------------------------------------------------
 
 import os, shutil
+import fnmatch
 python_dir = 'systemml'
 java_dir='systemml-java'
 java_dir_full_path = os.path.join(python_dir, java_dir)
@@ -27,5 +28,7 @@ if os.path.exists(java_dir_full_path):
     shutil.rmtree(java_dir_full_path, True)
 os.mkdir(java_dir_full_path)
 root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
-shutil.copyfile(os.path.join(root_dir, 'target', 'SystemML.jar'), os.path.join(java_dir_full_path, 'SystemML.jar'))
+for file in os.listdir(os.path.join(root_dir, 'target')):
+    if fnmatch.fnmatch(file, 'systemml-*-incubating-SNAPSHOT.jar'):
+        shutil.copyfile(os.path.join(root_dir, 'target', file), os.path.join(java_dir_full_path, file))
 shutil.copytree(os.path.join(root_dir, 'scripts'), os.path.join(java_dir_full_path, 'scripts'))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/setup.py
----------------------------------------------------------------------
diff --git a/src/main/python/setup.py b/src/main/python/setup.py
index 48d5e19..1b3b620 100644
--- a/src/main/python/setup.py
+++ b/src/main/python/setup.py
@@ -23,7 +23,7 @@ import os
 from setuptools import find_packages, setup
 import time
 
-VERSION = '0.11.0.dev1'
+VERSION = '0.12.0.dev1'
 RELEASED_DATE = str(time.strftime("%m/%d/%Y"))
 numpy_version = '1.8.2'
 scipy_version = '0.15.1'

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/23ccab85/src/main/python/systemml/classloader.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/classloader.py b/src/main/python/systemml/classloader.py
index 75d0b0f..1085be4 100644
--- a/src/main/python/systemml/classloader.py
+++ b/src/main/python/systemml/classloader.py
@@ -51,8 +51,12 @@ def createJavaObject(sc, obj_type):
     try:
         return _createJavaObject(sc, obj_type)
     except (py4j.protocol.Py4JError, TypeError):
-        import imp
-        jar_file_name = os.path.join(imp.find_module("systemml")[1], "systemml-java", 'SystemML.jar')
+        import imp, fnmatch
+        jar_file_name = '_ignore.jar'
+        java_dir = os.path.join(imp.find_module("systemml")[1], "systemml-java")
+        for file in os.listdir(java_dir):
+            if fnmatch.fnmatch(file, 'systemml-*-incubating-SNAPSHOT.jar'):
+                jar_file_name = os.path.join(java_dir, file)
         err_msg = 'Unable to load SystemML.jar into current pyspark session.'
         hint = 'Provide the following argument to pyspark: --driver-class-path '
         if os.path.isfile(jar_file_name):