You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/26 19:26:52 UTC

incubator-systemml git commit: [MINOR] Updated the documentation for mllearn and python dsl

Repository: incubator-systemml
Updated Branches:
  refs/heads/master fd41f0253 -> 3a1431c84


[MINOR] Updated the documentation for mllearn and python dsl

- Added design document for mllearn in BaseSystemMLClassifier class
- Add Python DSL documentation
- Updated Native backend documentation wrt MKL DNN
- Caffe2DML documentation: minor updates

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/3a1431c8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/3a1431c8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/3a1431c8

Branch: refs/heads/master
Commit: 3a1431c8486898c03bf2be5e18f33b4e35edad84
Parents: fd41f02
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri May 26 12:25:36 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri May 26 12:25:36 2017 -0700

----------------------------------------------------------------------
 docs/beginners-guide-caffe2dml.md               |   4 +-
 docs/native-backend.md                          |  79 ++++----
 docs/python-reference.md                        | 192 +++++++++++++------
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |  35 ++++
 4 files changed, 213 insertions(+), 97 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3a1431c8/docs/beginners-guide-caffe2dml.md
----------------------------------------------------------------------
diff --git a/docs/beginners-guide-caffe2dml.md b/docs/beginners-guide-caffe2dml.md
index 55eb154..429bce2 100644
--- a/docs/beginners-guide-caffe2dml.md
+++ b/docs/beginners-guide-caffe2dml.md
@@ -95,7 +95,7 @@ lenet.setStatistics(True).setExplain(True)
 # If you want to force GPU execution. Please make sure the required dependency are available.  
 # lenet.setGPU(True).setForceGPU(True)
 
-# (Optional but recommended) Enable native BLAS. For more detail see http://apache.github.io/incubator-systemml/native-backend
+# (Optional but recommended) Enable native BLAS. 
 lenet.setConfigProperty("native.blas", "auto")
 
 # In case you want to enable experimental feature such as codegen
@@ -106,6 +106,8 @@ lenet.fit(X_train, y_train)
 lenet.predict(X_test)
 ```
 
+For more detail on enabling native BLAS, please see the documentation for the [native backend](http://apache.github.io/incubator-systemml/native-backend).
+
 ## Frequently asked questions
 
 #### How can I speedup the training with Caffe2DML ?

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3a1431c8/docs/native-backend.md
----------------------------------------------------------------------
diff --git a/docs/native-backend.md b/docs/native-backend.md
index d6a6228..33a1a02 100644
--- a/docs/native-backend.md
+++ b/docs/native-backend.md
@@ -74,19 +74,15 @@ sudo make install
 # After installation, you may also want to add `/opt/OpenBLAS/lib` to your LD_LIBRARY_PATH or `java.library.path`.
 ```
 
-You can check if the OpenBLAS on you system is compiled with OpenMP or not using following commands:
+We also depend on GNU OpenMP (gomp) which will be installed by GCC.
+To find the location of `gomp` on your system, please use the command `ldconfig -p | grep libgomp`.
+If gomp is available as `/lib64/libgomp.so.1` instead of `/lib64/libgomp.so`,
+please add a softlink to it:
 
 ```bash
-$ ldconfig -p | grep libopenblas.so
-libopenblas.so (libc6,x86-64) => /opt/OpenBLAS/lib/libopenblas.so
-$ ldd /opt/OpenBLAS/lib/libopenblas.so | grep libgomp
-libgomp.so.1 => /lib64/libgomp.so.1
+sudo ln -s /lib64/libgomp.so.1 /lib64/libgomp.so
 ```
 
-If you don't see any output after the second command, then OpenBLAS installed on your system is using its internal threading.
-In this case, we highly recommend that you reinstall OpenBLAS using the above commands.
-
-
 ## Step 2: Install other dependencies
 
 ```bash
@@ -95,61 +91,76 @@ sudo yum install gcc-c++
 # Ubuntu
 sudo apt-get install g++ 
 ```
-
-We also depend on GNU OpenMP (gomp) which will be installed by GCC.
-To find the location of `gomp` on your system, please use the command `ldconfig -p | grep libgomp`.
-If gomp is available as `/lib64/libgomp.so.1` instead of `/lib64/libgomp.so`,
-please add a softlink to it:
-
-```bash
-sudo ln -s /lib64/libgomp.so.1 /lib64/libgomp.so
-```
 	
 ## Step 3: Provide the location of the native libraries
 
 1. Pass the location of the native libraries using command-line options:
 
-- [Spark](http://spark.apache.org/docs/latest/configuration.html): `--conf spark.executorEnv.LD_LIBRARY_PATH=/path/to/blas-n-other-dependencies`
-- Java: `-Djava.library.path=/path/to/blas-n-other-dependencies`
+	- [Spark](http://spark.apache.org/docs/latest/configuration.html): `--conf spark.executorEnv.LD_LIBRARY_PATH=/path/to/blas-n-other-dependencies`
+	- Java: `-Djava.library.path=/path/to/blas-n-other-dependencies`
 
 2. Alternatively, you can add the location of the native libraries (i.e. BLAS and other dependencies) 
 to the environment variable `LD_LIBRARY_PATH` (on Linux). 
 If you want to use SystemML with Spark, please add the following line to `spark-env.sh` 
 (or to the bash profile).
 
-	```bash
 	export LD_LIBRARY_PATH=/path/to/blas-n-other-dependencies
-	```
-
+ 
 
 ## Common issues on Linux
 
-1. Unable to load `gomp`
+- Unable to load `gomp`.
 
 First make sure if gomp is available on your system.
 
-	```bash
 	ldconfig -p | grep  libgomp
-	```
 
 If the above command returns no results, then you may have to install `gcc`.
 On the other hand, if the above command only returns libgomp with major suffix (such as `so.1`),
 then please execute the below command:
 
-	```bash
 	sudo ln -s /lib64/libgomp.so.1 /usr/lib64/libgomp.so
-	```
 
-2. Unable to load `mkl_rt`
+- Unable to load `mkl_rt`.
 
 By default, Intel MKL libraries will be installed in the location `/opt/intel/mkl/lib/intel64/`.
 Make sure that this path is accessible to Java as per instructions provided in the above section.
 
-3. Unable to load `openblas`
+- Unable to load `openblas`.
 
 By default, OpenBLAS libraries will be installed in the location `/opt/OpenBLAS/lib/`.
 Make sure that this path is accessible to Java as per instructions provided in the above section.
 
+- Using OpenBLAS without OpenMP can lead to performance degradation when using SystemML.
+ 
+You can check if the OpenBLAS on you system is compiled with OpenMP or not using following commands:
+If you don't see any output after the second command, then OpenBLAS installed on your system is using its internal threading.
+In this case, we highly recommend that you reinstall OpenBLAS using the above commands.
+
+	$ ldconfig -p | grep libopenblas.so
+	libopenblas.so (libc6,x86-64) => /opt/OpenBLAS/lib/libopenblas.so
+	$ ldd /opt/OpenBLAS/lib/libopenblas.so | grep libgomp
+	libgomp.so.1 => /lib64/libgomp.so.1
+
+- Using MKL can lead to slow performance for convolution instruction.
+
+We noticed that double-precision MKL DNN primitives for convolution instruction
+is considerably slower than than  the corresponding single-precision MKL DNN primitives
+as of MKL 2017 Update 1. We anticipate that this performance bug will be fixed in the future MKL versions.
+Until then or until SystemML supports single-precision matrices, we recommend that you use OpenBLAS when using script with `conv2d`.
+
+Here are the runtime performance in seconds of `conv2d` on 64 images of size 256 X 256 with sparsity 0.9
+and 32 filter of size 5x5 with stride = [1,1] and pad=[1,1].
+  
+
+|                               | MKL    | OpenBLAS |
+|-------------------------------|--------|----------|
+| Single-precision, channels=3  | 5.144  | 7.918    |
+| Double-precision, channels=3  | 12.599 | 8.688    |
+| Single-precision, channels=32 | 10.765 | 21.963   |
+| Double-precision, channels=32 | 71.118 | 34.881   |
+
+
 # Developer Guide
 
 This section describes how to compile shared libraries in the folder `src/main/cpp/lib`. 
@@ -176,16 +187,13 @@ For this project, I typically make a directory in the `cpp` folder (this folder)
 
 3. Install cmake
 
-	```bash
 	# Centos/RedHat
 	sudo yum install cmake3
 	# Ubuntu
 	sudo apt-get install cmake
-	```
 
 4. Compile the libs using the below script. 
 
-	```bash
 	mkdir INTEL && cd INTEL
 	cmake -DUSE_INTEL_MKL=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_CXX_FLAGS="-DUSE_GNU_THREADING -m64" ..
 	make install
@@ -196,7 +204,7 @@ For this project, I typically make a directory in the `cpp` folder (this folder)
 	cd ..
 	# The below script helps maintain this document as well as avoid accidental inclusion of non-standard dependencies.
 	./check-dependency-linux-x86_64.sh
-	```
+	
 
 
 The generated library files are placed in src/main/cpp/lib. This location can be changed from the CMakeLists.txt file.
@@ -210,4 +218,5 @@ The current set of dependencies other than MKL and OpenBLAS, are as follows:
 - GCC OpenMP v3.0 shared support library: `libgomp.so.1`
 - Additional OpenBLAS dependencies: Fortran runtime (`libgfortran.so.3`) and GCC `__float128` shared support library (`libquadmath.so.0`)
 
-If CMake cannot detect your OpenBLAS installation, set the `OpenBLAS_HOME` environment variable to the OpenBLAS Home.
\ No newline at end of file
+If CMake cannot detect your OpenBLAS installation, set the `OpenBLAS_HOME` environment variable to the OpenBLAS Home.
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3a1431c8/docs/python-reference.md
----------------------------------------------------------------------
diff --git a/docs/python-reference.md b/docs/python-reference.md
index 0d90ec3..2ebfc38 100644
--- a/docs/python-reference.md
+++ b/docs/python-reference.md
@@ -40,16 +40,16 @@ To understand more about DML and PyDML, we recommend that you read [Beginner's G
 For convenience of Python users, SystemML exposes several language-level APIs that allow Python users to use SystemML
 and its algorithms without the need to know DML or PyDML. We explain these APIs in the below sections.
 
-## matrix API
+## matrix class
 
-The matrix class allows users to perform linear algebra operations in SystemML using a NumPy-like interface.
-This class supports several arithmetic operators (such as +, -, *, /, ^, etc).
-
-matrix class is a python wrapper that implements basic matrix
-operators, matrix functions as well as converters to common Python
+The matrix class is an **experimental** feature that is often referred to as Python DSL.
+It allows the user to perform linear algebra operations in SystemML using a NumPy-like interface.
+It implements basic matrix operators, matrix functions as well as converters to common Python
 types (for example: Numpy arrays, PySpark DataFrame and Pandas
 DataFrame).
 
+### Operators
+ 
 The operators supported are:
 
 1.  Arithmetic operators: +, -, *, /, //, %, \** as well as dot
@@ -57,51 +57,24 @@ The operators supported are:
 2.  Indexing in the matrix
 3.  Relational/Boolean operators: \<, \<=, \>, \>=, ==, !=, &, \|
 
-In addition, following functions are supported for matrix:
-
-1.  transpose
-2.  Aggregation functions: sum, mean, var, sd, max, min, argmin,
-    argmax, cumsum
-3.  Global statistical built-In functions: exp, log, abs, sqrt,
-    round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve
-
-For all the above functions, we always return a two dimensional matrix, especially for aggregation functions with axis. 
-For example: Assuming m1 is a matrix of (3, n), NumPy returns a 1d vector of dimension (3,) for operation m1.sum(axis=1)
-whereas SystemML returns a 2d matrix of dimension (3, 1).
-
-Note: an evaluated matrix contains a data field computed by eval
-method as DataFrame or NumPy array.
-
-It is important to note that matrix class also supports most of NumPy's universal functions (i.e. ufuncs).
-The current version of NumPy explicitly disables overriding ufunc, but this should be enabled in next release. 
-Until then to test above code, please use:
-
-```bash
-git clone https://github.com/niketanpansare/numpy.git
-cd numpy
-python setup.py install
-```
+This class also supports several input/output formats such as NumPy arrays, Pandas DataFrame, SciPy sparse matrix and PySpark DataFrame.
 
-This will enable NumPy's functions to invoke matrix class:
+Here is a small example that demonstrates the usage:
 
 ```python
-import systemml as sml
-import numpy as np
-m1 = sml.matrix(np.ones((3,3)) + 2)
-m2 = sml.matrix(np.ones((3,3)) + 3)
-np.add(m1, m2)
-``` 
-
-The matrix class doesnot support following ufuncs:
-
-- Complex number related ufunc (for example: `conj`)
-- Hyperbolic/inverse-hyperbolic functions (for example: sinh, arcsinh, cosh, ...)
-- Bitwise operators
-- Xor operator
-- Infinite/Nan-checking (for example: isreal, iscomplex, isfinite, isinf, isnan)
-- Other ufuncs: copysign, nextafter, modf, frexp, trunc.
+>>> import systemml as sml
+>>> import numpy as np
+>>> m1 = sml.matrix(np.ones((3,3)) + 2)
+>>> m2 = sml.matrix(np.ones((3,3)) + 3)
+>>> m2 = m1 * (m2 + m1)
+>>> m4 = 1.0 - m2
+>>> m4.sum(axis=1).toNumPy()
+array([[-60.],
+       [-60.],
+       [-60.]])
+```
 
-This class also supports several input/output formats such as NumPy arrays, Pandas DataFrame, SciPy sparse matrix and PySpark DataFrame.
+### Lazy evaluation
 
 By default, the operations are evaluated lazily to avoid conversion overhead and also to maximize optimization scope.
 To disable lazy evaluation, please us `set_lazy` method:
@@ -130,28 +103,123 @@ save(mVar4, " ")
 # This matrix (mVar8) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.
 ``` 
 
-### Usage:
+Since matrix is backed by lazy evaluation and uses a recursive Depth First Search (DFS),
+you may run into `RuntimeError: maximum recursion depth exceeded`. 
+Please see below [troubleshooting steps](http://apache.github.io/incubator-systemml/python-reference#maximum-recursion-depth-exceeded)
+
+
+### Built-in functions
+
+In addition to the above mentioned operators, following functions are supported. 
+
+- transpose: Transposes the input matrix. 
+
+- Aggregation functions: prod, sum, mean, var, sd, max, min, argmin, argmax, cumsum
+
+|                                                      | Description                                                                                                                     | Parameters                                                                                                                                                                                                                  |
+|------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| prod(self)                                           | Return the product of all cells in matrix                                                                                       | self: input matrix object                                                                                                                                                                                                   |
+| sum(self, axis=None)                                 | Compute the sum along the specified axis                                                                                        | axis : int, optional                                                                                                                                                                                                        |
+| mean(self, axis=None)                                | Compute the arithmetic mean along the specified axis                                                                            | axis : int, optional                                                                                                                                                                                                        |
+| var(self, axis=None)                                 | Compute the variance along the specified axis. We assume that delta degree of freedom is 1 (unlike NumPy which assumes ddof=0). | axis : int, optional                                                                                                                                                                                                        |
+| moment(self, moment=1, axis=None)                    | Calculates the nth moment about the mean                                                                                        | moment : int (can be 1, 2, 3 or 4), axis : int, optional                                                                                                                                                                    |
+| sd(self, axis=None)                                  | Compute the standard deviation along the specified axis                                                                         | axis : int, optional                                                                                                                                                                                                        |
+| max(self, other=None, axis=None)                     | Compute the maximum value along the specified axis                                                                              | other: matrix or numpy array (& other supported types) or scalar, axis : int, optional                                                                                                                                      |
+| min(self, other=None, axis=None)                     | Compute the minimum value along the specified axis                                                                              | other: matrix or numpy array (& other supported types) or scalar, axis : int, optional                                                                                                                                      |
+| argmin(self, axis=None)                              | Returns the indices of the minimum values along an axis.                                                                        | axis : int, optional,(only axis=1, i.e. rowIndexMax is supported in this version)                                                                                                                                           |
+| argmax(self, axis=None)                              | Returns the indices of the maximum values along an axis.                                                                        | axis : int, optional (only axis=1, i.e. rowIndexMax is supported in this version)                                                                                                                                           |
+| cumsum(self, axis=None)                              | Returns the indices of the maximum values along an axis.                                                                        | axis : int, optional (only axis=0, i.e. cumsum along the rows is supported in this version)                                                                                                                                 |
+
+- Global statistical built-In functions: exp, log, abs, sqrt, round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve
+
+|                                                      | Description                                                                                                                     | Parameters                                                                                                                                                                                              |
+|------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| solve(A, b)                                          | Computes the least squares solution for system of linear equations A %*% x = b                                                  | A, b: input matrices                                                                                                                                                                                    |
+
+
+- Built-in sampling functions: normal, uniform, poisson
+
+|                                                      | Description                                                                                                                     | Parameters                                                                                                                                                                                                                  |
+|------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| normal(loc=0.0, scale=1.0, size=(1,1), sparsity=1.0) | Draw random samples from a normal (Gaussian) distribution.                                                                      | loc: Mean ("centre") of the distribution, scale: Standard deviation (spread or "width") of the distribution, size: Output shape (only tuple of length 2, i.e. (m, n), supported), sparsity: Sparsity (between 0.0 and 1.0). |
+| uniform(low=0.0, high=1.0, size=(1,1), sparsity=1.0) | Draw samples from a uniform distribution.                                                                                       | low: Lower boundary of the output interval, high: Upper boundary of the output interval, size: Output shape (only tuple of length 2, i.e. (m, n), supported), sparsity: Sparsity (between 0.0 and 1.0).                     |
+| poisson(lam=1.0, size=(1,1), sparsity=1.0)           | Draw samples from a Poisson distribution.                                                                                       | lam: Expectation of interval, should be > 0, size: Output shape (only tuple of length 2, i.e. (m, n), supported), sparsity: Sparsity (between 0.0 and 1.0).                                                                 |
+
+- Other builtin functions: hstack, vstack, trace
+
+|                                                      | Description                                                                                                                     | Parameters                                                                                                                                                                                                                  |
+|------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| hstack(self, other)                                  | Stack matrices horizontally (column wise). Invokes cbind internally.                                                            | self: lhs matrix object, other: rhs matrix object                                                                                                                                                                           |
+| vstack(self, other)                                  | Stack matrices vertically (row wise). Invokes rbind internally.                                                                 | self: lhs matrix object, other: rhs matrix object                                                                                                                                                                           |
+| trace(self)                                          | Return the sum of the cells of the main diagonal square matrix                                                                  | self: input matrix                                                                                                                                                                                                          |
+
+Here is an example that uses the above functions and trains a simple linear regression model:
+
+```python
+>>> import numpy as np
+>>> from sklearn import datasets
+>>> import systemml as sml
+>>> # Load the diabetes dataset
+>>> diabetes = datasets.load_diabetes()
+>>> # Use only one feature
+>>> diabetes_X = diabetes.data[:, np.newaxis, 2]
+>>> # Split the data into training/testing sets
+>>> X_train = diabetes_X[:-20]
+>>> X_test = diabetes_X[-20:]
+>>> # Split the targets into training/testing sets
+>>> y_train = diabetes.target[:-20]
+>>> y_test = diabetes.target[-20:]
+>>> # Train Linear Regression model
+>>> X = sml.matrix(X_train)
+>>> y = sml.matrix(np.matrix(y_train).T)
+>>> A = X.transpose().dot(X)
+>>> b = X.transpose().dot(y)
+>>> beta = sml.solve(A, b).toNumPy()
+>>> y_predicted = X_test.dot(beta)
+>>> print('Residual sum of squares: %.2f' % np.mean((y_predicted - y_test) ** 2))
+Residual sum of squares: 25282.12
+```
+
+For all the above functions, we always return a two dimensional matrix, especially for aggregation functions with axis. 
+For example: Assuming m1 is a matrix of (3, n), NumPy returns a 1d vector of dimension (3,) for operation m1.sum(axis=1)
+whereas SystemML returns a 2d matrix of dimension (3, 1).
+
+Note: an evaluated matrix contains a data field computed by eval
+method as DataFrame or NumPy array.
+
+### Support for NumPy's universal functions
+
+The matrix class also supports most of NumPy's universal functions (i.e. ufuncs).
+The current version of NumPy explicitly disables overriding ufunc, but this should be enabled in next release. 
+Until then to test above code, please use:
+
+```bash
+git clone https://github.com/niketanpansare/numpy.git
+cd numpy
+python setup.py install
+```
+
+This will enable NumPy's functions to invoke matrix class:
 
 ```python
 import systemml as sml
 import numpy as np
 m1 = sml.matrix(np.ones((3,3)) + 2)
 m2 = sml.matrix(np.ones((3,3)) + 3)
-m2 = m1 * (m2 + m1)
-m4 = 1.0 - m2
-m4.sum(axis=1).toNumPy()
-```
+np.add(m1, m2)
+``` 
 
-Output:
+The matrix class doesnot support following ufuncs:
 
-```bash
-array([[-60.],
-       [-60.],
-       [-60.]])
-```
+- Complex number related ufunc (for example: `conj`)
+- Hyperbolic/inverse-hyperbolic functions (for example: sinh, arcsinh, cosh, ...)
+- Bitwise operators
+- Xor operator
+- Infinite/Nan-checking (for example: isreal, iscomplex, isfinite, isinf, isnan)
+- Other ufuncs: copysign, nextafter, modf, frexp, trunc.
 
 
-### Design Decisions:
+### Design Decisions of matrix class (Developer documentation)
 
 1.  Until eval() method is invoked, we create an AST (not exposed to
 the user) that consist of unevaluated operations and data
@@ -242,6 +310,10 @@ beta = ml.execute(script).get('B_out').toNumPy()
 
 ## mllearn API
 
+mllearn API is designed to be compatible with scikit-learn and MLLib.
+The classes that are part of mllearn API are LogisticRegression, LinearRegression, SVM, NaiveBayes 
+and [Caffe2DML](http://apache.github.io/incubator-systemml/beginners-guide-caffe2dml).
+
 The below code describes how to use mllearn API for training:
 
 <div class="codetabs">
@@ -412,8 +484,6 @@ Output:
 ```
 
 
-
-
 ## Troubleshooting Python APIs
 
 #### Unable to load SystemML.jar into current pyspark session.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3a1431c8/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index e601a7d..2ea305b 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -37,6 +37,41 @@ import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel
 import java.util.HashMap
 import scala.collection.JavaConversions._
 
+
+/****************************************************
+DESIGN DOCUMENT for MLLEARN API:
+The mllearn API supports LogisticRegression, LinearRegression, SVM, NaiveBayes 
+and Caffe2DML. Every algorithm in this API has a python wrapper (implemented in the mllearn python package)
+and a Scala class where the actual logic is implementation. 
+Both wrapper and scala class follow the below hierarchy to reuse code and simplify the implementation.
+
+
+                  BaseSystemMLEstimator
+                          |
+      --------------------------------------------
+      |                                          |
+BaseSystemMLClassifier                  BaseSystemMLRegressor
+      ^                                          ^
+      |                                          |
+SVM, Caffe2DML, ...                          LinearRegression
+
+
+To conform with MLLib API, for every algorithm, we support two classes for every algorithm:
+1. Estimator for training: For example: SVM extends Estimator[SVMModel].
+2. Model for prediction: For example: SVMModel extends Model[SVMModel]
+
+Both BaseSystemMLRegressor and BaseSystemMLClassifier implements following methods for training:
+1. For compatibility with scikit-learn: baseFit(X_mb: MatrixBlock, y_mb: MatrixBlock, sc: SparkContext): MLResults
+2. For compatibility with MLLib: baseFit(df: ScriptsUtils.SparkDataType, sc: SparkContext): MLResults
+
+In the above methods, we execute the DML script for the given algorithm using MLContext.
+The missing piece of the puzzle is how does BaseSystemMLRegressor and BaseSystemMLClassifier interfaces
+get the DML script. To enable this, each wrapper class has to implement following methods:
+1. getTrainingScript(isSingleNode:Boolean):(Script object of mlcontext, variable name of X in the script:String, variable name of y in the script:String)
+2. getPredictionScript(isSingleNode:Boolean): (Script object of mlcontext, variable name of X in the script:String)
+
+****************************************************/
+
 trait HasLaplace extends Params {
   final val laplace: Param[Double] = new Param[Double](this, "laplace", "Laplace smoothing specified by the user to avoid creation of 0 probabilities.")
   setDefault(laplace, 1.0)