You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2021/05/01 19:19:42 UTC

[systemds] branch master updated: [SYSTEMDS-2959] Add a predict function for Naive Bayes

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 624a612  [SYSTEMDS-2959] Add a predict function for Naive Bayes
624a612 is described below

commit 624a61229f9f744958faa310a9e80a86518b3c44
Author: sfathollahzadeh <s....@gmail.com>
AuthorDate: Sat May 1 21:12:17 2021 +0200

    [SYSTEMDS-2959] Add a predict function for Naive Bayes
    
    Closes #1217
---
 docs/site/builtins-reference.md                    | 102 ++++++++++++++-------
 scripts/builtin/{naivebayes.dml => naiveBayes.dml} |   2 +-
 .../builtin/naiveBayesPredict.dml                  |  16 +++-
 .../java/org/apache/sysds/common/Builtins.java     |   3 +-
 .../builtin/BuiltinNaiveBayesPredictTest.java      | 102 +++++++++++++++++++++
 src/test/scripts/functions/builtin/NaiveBayes.dml  |   2 +-
 .../scripts/functions/builtin/NaiveBayesPredict.R  |  66 +++++++++++++
 .../{NaiveBayes.dml => NaiveBayesPredict.dml}      |  24 ++++-
 src/test/scripts/installDependencies.R             |   1 +
 9 files changed, 270 insertions(+), 48 deletions(-)

diff --git a/docs/site/builtins-reference.md b/docs/site/builtins-reference.md
index c3d21ca..c46de37 100644
--- a/docs/site/builtins-reference.md
+++ b/docs/site/builtins-reference.md
@@ -61,13 +61,14 @@ limitations under the License.
     * [`gnmf`-Function](#gnmf-function)
     * [`mdedup`-Function](#mdedup-function)
     * [`msvm`-Function](#msvm-function)
-    * [`naivebayes`-Function](#naivebayes-function)
+    * [`naiveBayes`-Function](#naiveBayes-function)
+    * [`naiveBayesPredict`-Function](#naiveBayesPredict-function)
     * [`outlier`-Function](#outlier-function)
     * [`toOneHot`-Function](#toOneHOt-function)
     * [`winsorize`-Function](#winsorize-function)
     * [`gmm`-Function](#gmm-function)
     * [`correctTypos`-Function](#correcttypos-function)
-    
+
 # Introduction
 
 The DML (Declarative Machine Learning) language has built-in functions which enable access to both low- and high-level functions
@@ -106,7 +107,7 @@ Note that this function is highly **unstable** and will be overworked and might
 
 ##### `data`-Argument
 
-The `data`-argument can be a `Matrix` of any datatype from which the elements will be taken and placed in the tensor 
+The `data`-argument can be a `Matrix` of any datatype from which the elements will be taken and placed in the tensor
 until filled. If given as a `Tensor` the same procedure takes place. We iterate through `Matrix` and `Tensor` by starting
 with each dimension index at `0` and then incrementing the lowest one, until we made a complete pass over the dimension,
 and then increasing the dimension index above. This will be done until the `Tensor` is completely filled.
@@ -167,14 +168,14 @@ confusionMatrix(P, Y)
 | Y    | Matrix[Double] | ---     | vector of Golden standard One Hot Encoded |
 
 ### Returns
- 
+
 | Name         | Type           | Description |
 | :----------- | :------------- | :---------- |
 | ConfusionSum | Matrix[Double] | The Confusion Matrix Sums of classifications |
 | ConfusionAvg | Matrix[Double] | The Confusion Matrix averages of each true class |
 
 ### Example
- 
+
 ```r
 numClasses = 1
 z = rand(rows = 5, cols = 1, min = 1, max = 9)
@@ -385,7 +386,7 @@ beta = glm(X=X,Y=y)
 ## `gridSearch`-Function
 
 The `gridSearch`-function is used to find the optimal hyper-parameters of a model which results in the most _accurate_
-predictions. This function takes `train` and `eval` functions by name. 
+predictions. This function takes `train` and `eval` functions by name.
 
 ### Usage
 ```r
@@ -423,13 +424,13 @@ paramRanges = list(10^seq(0,-4), 10^seq(-5,-9), 10^seq(1,3))
 
 ## `hyperband`-Function
 
-The `hyperband`-function is used for hyper parameter optimization and is based on multi-armed bandits and early elimination. 
+The `hyperband`-function is used for hyper parameter optimization and is based on multi-armed bandits and early elimination.
 Through multiple parallel brackets and consecutive trials it will return the hyper parameter combination which performed best
 on a validation dataset. A set of hyper parameter combinations is drawn from uniform distributions with given ranges; Those
 make up the candidates for `hyperband`.
-Notes: 
+Notes:
 * `hyperband` is hard-coded for `lmCG`, and uses `lmPredict` for validation
-* `hyperband` is hard-coded to use the number of iterations as a resource 
+* `hyperband` is hard-coded to use the number of iterations as a resource
 * `hyperband` can only optimize continuous hyperparameters
 
 ### Usage
@@ -542,7 +543,7 @@ B = img_crop(img_in = A, w = 20, h = 10, x_offset = 0, y_offset = 0)
 ## `img_mirror`-Function
 
 The `img_mirror`-function is an image data augumentation function.
-It flips an image on the `X` (horizontal) or `Y` (vertical) axis. 
+It flips an image on the `X` (horizontal) or `Y` (vertical) axis.
 
 ### Usage
 
@@ -679,7 +680,7 @@ kmeans(X = X, k = 20, runs = 10, max_iter = 5000, eps = 0.000001, is_verbose = F
 ## `lm`-Function
 
 The `lm`-function solves linear regression using either the **direct solve method** or the **conjugate gradient algorithm**
-depending on the input size of the matrices (See [`lmDS`-function](#lmds-function) and 
+depending on the input size of the matrices (See [`lmDS`-function](#lmds-function) and
 [`lmCG`-function](#lmcg-function) respectively).
 
 ### Usage
@@ -711,10 +712,10 @@ is called internally and parameters `tol` and `maxi` are ignored.
 ##### `icpt`-Argument
 
 The *icpt-argument* can be set to 3 modes:
- 
-  * 0 = no intercept, no shifting, no rescaling
-  * 1 = add intercept, but neither shift nor rescale X
-  * 2 = add intercept, shift & rescale X columns to mean = 0, variance = 1
+
+* 0 = no intercept, no shifting, no rescaling
+* 1 = add intercept, but neither shift nor rescale X
+* 2 = add intercept, shift & rescale X columns to mean = 0, variance = 1
 
 ### Example
 
@@ -999,7 +1000,7 @@ Y= scale(X,center,scale)
 
 Implements training phase of Sherlock: A Deep Learning Approach to Semantic Data Type Detection
 
-[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection." 
+[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection."
 Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining., 2019]
 ### Usage
 
@@ -1055,7 +1056,7 @@ write(label_encoding, "weights/label_encoding")
 
 Implements prediction and evaluation phase of Sherlock: A Deep Learning Approach to Semantic Data Type Detection
 
-[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection." 
+[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection."
 Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining., 2019]
 ### Usage
 
@@ -1085,7 +1086,7 @@ sherlockPredict(X, cW1, cb1, cW2, cb2, cW3, cb3, wW1, wb1, wW2, wb2, wW3, wb3,
 | Type           | Description |
 | :------------- | :---------- |
 | Matrix[Double] | Class probabilities of shape (N, K). |
-### Example 
+### Example
 
 ```r
 # preprocessed validation data taken from sherlock corpus
@@ -1116,7 +1117,7 @@ fW3,  fb3)
 
 ## `sigmoid`-Function
 
-The Sigmoid function is a type of activation function, and also defined as a squashing function which limit the output 
+The Sigmoid function is a type of activation function, and also defined as a squashing function which limit the output
 to a range between 0 and 1, which will make these functions useful in the prediction of probabilities.
 
 ### Usage
@@ -1149,9 +1150,9 @@ Y = sigmoid(X)
 The `smote`-function (Synthetic Minority Oversampling Technique) implements a classical techniques for handling class imbalance.
 The  built-in takes the samples from minority class and over-sample them by generating the synthesized samples.
 The built-in accepts two parameters s and k. The parameter s define the number of synthesized samples to be generated
- i.e., over-sample the minority class by s time, where s is the multiple of 100. For given 40 samples of minority class and
- s = 200 the smote will generate the 80 synthesized samples to over-sample the class by 200 percent. The parameter k is used to generate the 
- k nearest neighbours for each minority class sample and then the neighbours are chosen randomly in synthesis process.
+i.e., over-sample the minority class by s time, where s is the multiple of 100. For given 40 samples of minority class and
+s = 200 the smote will generate the 80 synthesized samples to over-sample the class by 200 percent. The parameter k is used to generate the
+k nearest neighbours for each minority class sample and then the neighbours are chosen randomly in synthesis process.
 
 ### Usage
 
@@ -1172,7 +1173,7 @@ smote(X, s, k, verbose);
 
 | Type           | Description |
 | :------------- | :---------- |
-| Matrix[Double] | Matrix of (N/100) * X synthetic minority class samples 
+| Matrix[Double] | Matrix of (N/100) * X synthetic minority class samples
 
 
 ### Example
@@ -1184,7 +1185,7 @@ B = smote(X = X, s=200, k=3, verbose=TRUE);
 ## `steplm`-Function
 
 The `steplm`-function (stepwise linear regression) implements a classical forward feature selection method.
-This method iteratively runs what-if scenarios and greedily selects the next best feature until the Akaike 
+This method iteratively runs what-if scenarios and greedily selects the next best feature until the Akaike
 information criterion (AIC) does not improve anymore. Each configuration trains a regression model via `lm`,
 which in turn calls either the closed form `lmDS` or iterative `lmGC`.
 
@@ -1216,13 +1217,13 @@ steplm(X, y, icpt);
 ##### `icpt`-Argument
 
 The *icpt-arg* can be set to 2 modes:
- 
-  * 0 = no intercept, no shifting, no rescaling
-  * 1 = add intercept, but neither shift nor rescale X
+
+* 0 = no intercept, no shifting, no rescaling
+* 1 = add intercept, but neither shift nor rescale X
 
 ##### `selected`-Output
 
-If the best AIC is achieved without any features the matrix of *selected* features contains 0. Moreover, in this case no further statistics will be produced 
+If the best AIC is achieved without any features the matrix of *selected* features contains 0. Moreover, in this case no further statistics will be produced
 
 ### Example
 
@@ -1271,7 +1272,7 @@ ress = slicefinder(X = X,W = w, Y = y,  k = 5, paq = 1, S = 2);
 ## `normalize`-Function
 
 The `normalize`-function normalises the values of a matrix by changing the dataset to use a common scale.
-This is done while preserving differences in the ranges of values. 
+This is done while preserving differences in the ranges of values.
 The output is a matrix of values in range [0,1].
 
 ### Usage
@@ -1341,14 +1342,14 @@ H = rand(rows = 2, cols = ncol(X), min = -0.05, max = 0.05);
 gnmf(X = X, rnk = 2, eps = 10^-8, maxi = 10)
 ```
 
-## `naivebayes`-Function
+## `naiveBayes`-Function
 
-The `naivebayes`-function computes the class conditional probabilities and class priors.
+The `naiveBayes`-function computes the class conditional probabilities and class priors.
 
 ### Usage
 
 ```r
-naivebayes(D, C, laplace, verbose)
+naiveBayes(D, C, laplace, verbose)
 ```
 
 ### Arguments
@@ -1372,7 +1373,38 @@ naivebayes(D, C, laplace, verbose)
 ```r
 D=rand(rows=10,cols=1,min=10)
 C=rand(rows=10,cols=1,min=10)
-[prior, classConditionals] = naivebayes(D, C, laplace = 1, verbose = TRUE)
+[prior, classConditionals] = naiveBayes(D, C, laplace = 1, verbose = TRUE)
+```
+
+## `naiveBaysePredict`-Function
+
+The `naiveBaysePredict`-function predicts the scoring with a naive Bayes model.
+
+### Usage
+
+```r
+naiveBaysePredict(X=X, P=P, C=C)
+```
+
+### Arguments
+
+| Name    | Type           | Default  | Description |
+| :------ | :------------- | -------- | :---------- |
+| X       | Matrix[Double] | required | Matrix of test data with N rows. |
+| P       | Matrix[Double] | required | Class priors, One dimensional column matrix with N rows. |
+| C       | Matrix[Double] | required | Class conditional probabilities, matrix with N rows. |
+
+### Returns
+
+| Type           | Description |
+| :------------- | :---------- |
+| Matrix[Double] | A matrix containing the top-K item-ids with highest predicted ratings. |
+| Matrix[Double] | A matrix containing predicted ratings. |
+
+### Example
+
+```r
+[YRaw, Y] = naiveBaysePredict(X=data, P=model_prior, C=model_conditionals)
 ```
 
 ## `outlier`-Function
@@ -1439,7 +1471,7 @@ y = toOneHot(X,numClasses)
 
 ## `mdedup`-Function
 
-The `mdedup`-function implements builtin for deduplication using matching dependencies 
+The `mdedup`-function implements builtin for deduplication using matching dependencies
 (e.g. Street 0.95, City 0.90 -> ZIP 1.0) by Jaccard distance.
 
 ### Usage
@@ -1553,7 +1585,7 @@ Y = winsorize(X=X)
 
 ## `gmm`-Function
 
-The `gmm`-function implements builtin Gaussian Mixture Model with four different types of 
+The `gmm`-function implements builtin Gaussian Mixture Model with four different types of
 covariance matrices i.e., VVV, EEE, VVI, VII and two initialization methods namely "kmeans" and "random".
 
 ### Usage
diff --git a/scripts/builtin/naivebayes.dml b/scripts/builtin/naiveBayes.dml
similarity index 96%
rename from scripts/builtin/naivebayes.dml
rename to scripts/builtin/naiveBayes.dml
index 1be291a..313ec09 100644
--- a/scripts/builtin/naivebayes.dml
+++ b/scripts/builtin/naiveBayes.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-m_naivebayes = function(Matrix[Double] D, Matrix[Double] C, Double laplace = 1, Boolean verbose = TRUE)
+m_naiveBayes = function(Matrix[Double] D, Matrix[Double] C, Double laplace = 1, Boolean verbose = TRUE)
   return (Matrix[Double] prior, Matrix[Double] classConditionals) 
 {
   laplaceCorrection = laplace;
diff --git a/src/test/scripts/functions/builtin/NaiveBayes.dml b/scripts/builtin/naiveBayesPredict.dml
similarity index 74%
copy from src/test/scripts/functions/builtin/NaiveBayes.dml
copy to scripts/builtin/naiveBayesPredict.dml
index 1af2672..8ae938e 100644
--- a/src/test/scripts/functions/builtin/NaiveBayes.dml
+++ b/scripts/builtin/naiveBayesPredict.dml
@@ -19,8 +19,14 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-y = read($2);
-[prior, conditionals] = naivebayes(D=X, C=y, laplace=$4);
-write(prior, $5);
-write(conditionals, $6);
+m_naiveBayesPredict = function(Matrix[Double] X, Matrix[Double] P, Matrix[Double] C)
+ return (Matrix[Double] YRaw, Matrix[Double] Y)
+{
+  numRows = nrow(X)
+  model = cbind(C, P)
+
+  ones = matrix(1, rows=numRows, cols=1);
+  X_w_ones = cbind(X, ones);
+  YRaw = X_w_ones %*% t(log(model));
+  Y = rowIndexMax(YRaw);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 897aadc..27a0d73 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -181,7 +181,8 @@ public enum Builtins {
 	NCOL("ncol", false),
 	NORMALIZE("normalize", true),
 	NROW("nrow", false),
-	NAIVEBAYES("naivebayes", true, false),
+	NAIVEBAYES("naiveBayes", true, false),
+	NAIVEBAYESPREDICT("naiveBayesPredict", true, false),
 	OUTER("outer", false),
 	OUTLIER("outlier", true, false), //TODO parameterize opposite
 	OUTLIER_SD("outlierBySd", true),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinNaiveBayesPredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinNaiveBayesPredictTest.java
new file mode 100644
index 0000000..9318f45
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinNaiveBayesPredictTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+public class BuiltinNaiveBayesPredictTest extends AutomatedTestBase {
+	private final static String TEST_NAME = "NaiveBayesPredict";
+	private final static String TEST_DIR = "functions/builtin/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinNaiveBayesPredictTest.class.getSimpleName() + "/";
+	private final static int numClasses = 10;
+
+	public double eps = 1e-7;
+
+	@Override public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"YRaw", "Y"}));
+	}
+
+	@Test public void testSmallDense() {
+		testNaiveBayesPredict(100, 50, 0.7);
+	}
+
+	@Test public void testLargeDense() {
+		testNaiveBayesPredict(10000, 750, 0.7);
+	}
+
+	@Test public void testSmallSparse() {
+		testNaiveBayesPredict(100, 50, 0.01);
+	}
+
+	@Test public void testLargeSparse() {
+		testNaiveBayesPredict(10000, 750, 0.01);
+	}
+
+	public void testNaiveBayesPredict(int rows, int cols, double sparsity) {
+		loadTestConfiguration(getTestConfiguration(TEST_NAME));
+		String HOME = SCRIPT_DIR + TEST_DIR;
+		fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+		int classes = numClasses;
+		double laplace = 1;
+
+		List<String> proArgs = new ArrayList<>();
+		proArgs.add("-args");
+		proArgs.add(input("D"));
+		proArgs.add(input("C"));
+		proArgs.add(String.valueOf(classes));
+		proArgs.add(String.valueOf(laplace));
+		proArgs.add(output("YRaw"));
+		proArgs.add(output("Y"));
+		programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+		rCmd = getRCmd(inputDir(), Integer.toString(classes), Double.toString(laplace), expectedDir());
+
+		double[][] D = getRandomMatrix(rows, cols, 0, 1, sparsity, -1);
+		double[][] C = getRandomMatrix(rows, 1, 0, 1, 1, -1);
+		for(int i = 0; i < rows; i++) {
+			C[i][0] = (int) (C[i][0] * classes) + 1;
+			C[i][0] = (C[i][0] > classes) ? classes : C[i][0];
+		}
+
+		writeInputMatrixWithMTD("D", D, true);
+		writeInputMatrixWithMTD("C", C, true);
+
+		runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+		runRScript(true);
+
+		HashMap<CellIndex, Double> YRawR = readRMatrixFromExpectedDir("YRaw");
+		HashMap<CellIndex, Double> YR = readRMatrixFromExpectedDir("Y");
+		HashMap<CellIndex, Double> YRawSYSTEMDS = readDMLMatrixFromOutputDir("YRaw");
+		HashMap<CellIndex, Double> YSYSTEMDS = readDMLMatrixFromOutputDir("Y");
+		TestUtils.compareMatrices(YRawR, YRawSYSTEMDS, eps, "YRawR", "YRawSYSTEMDS");
+		TestUtils.compareMatrices(YR, YSYSTEMDS, eps, "YR", "YSYSTEMDS");
+	}
+}
diff --git a/src/test/scripts/functions/builtin/NaiveBayes.dml b/src/test/scripts/functions/builtin/NaiveBayes.dml
index 1af2672..65303d3 100644
--- a/src/test/scripts/functions/builtin/NaiveBayes.dml
+++ b/src/test/scripts/functions/builtin/NaiveBayes.dml
@@ -21,6 +21,6 @@
 
 X = read($1);
 y = read($2);
-[prior, conditionals] = naivebayes(D=X, C=y, laplace=$4);
+[prior, conditionals] = naiveBayes(D=X, C=y, laplace=$4);
 write(prior, $5);
 write(conditionals, $6);
diff --git a/src/test/scripts/functions/builtin/NaiveBayesPredict.R b/src/test/scripts/functions/builtin/NaiveBayesPredict.R
new file mode 100644
index 0000000..18fcb9e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/NaiveBayesPredict.R
@@ -0,0 +1,66 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("naivebayes")
+
+D = as.matrix(readMM(paste(args[1], "D.mtx", sep="")))
+C = as.matrix(readMM(paste(args[1], "C.mtx", sep="")))
+laplace <- as.numeric(args[3])
+
+# divide D into "train" and "test" data
+numRows = nrow(D)
+trainSize = numRows * 0.8
+
+trainData = D[1:trainSize, ]
+testData = D[(trainSize+1):numRows, ]
+y <- factor(C[1:trainSize])
+
+# The Naive Bayes Predict need to unique column name
+features <- paste0("V", seq_len(ncol(trainData)))
+colnames(trainData) <- features
+colnames(testData) <- features
+
+# Create model base on train data
+model <- multinomial_naive_bayes(x = trainData, y = y, laplace = laplace)
+
+# The SystemDS DML scripts based on YRaw data
+# and the "naivebayes" predict function in R
+# return probabilities matrix
+# Example: YRaw <- predict(model, newdata = testData, type = "prob")
+
+# We need to return "Raw" values
+lev <- model$levels
+prior <- model$prior
+params <- t(model$params)
+YRaw <- tcrossprod(testData, log(params))
+
+for (ith_class in seq_along(lev)) {
+  YRaw[ ,ith_class] <- YRaw[ ,ith_class] + log(prior[ith_class])
+}
+
+Y <- max.col(YRaw, ties.method="last")
+
+# write out the predict
+writeMM(as(YRaw, "CsparseMatrix"), paste(args[4], "YRaw", sep=""))
+writeMM(as(Y, "CsparseMatrix"), paste(args[4], "Y", sep=""))
diff --git a/src/test/scripts/functions/builtin/NaiveBayes.dml b/src/test/scripts/functions/builtin/NaiveBayesPredict.dml
similarity index 64%
copy from src/test/scripts/functions/builtin/NaiveBayes.dml
copy to src/test/scripts/functions/builtin/NaiveBayesPredict.dml
index 1af2672..4943827 100644
--- a/src/test/scripts/functions/builtin/NaiveBayes.dml
+++ b/src/test/scripts/functions/builtin/NaiveBayesPredict.dml
@@ -19,8 +19,22 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-y = read($2);
-[prior, conditionals] = naivebayes(D=X, C=y, laplace=$4);
-write(prior, $5);
-write(conditionals, $6);
+D = read($1);
+C = read($2);
+
+# divide data into "train" and "test" subsets
+numRows = nrow(D);
+trainSize = numRows * 0.8;
+trainData = D[1:trainSize,];
+testData = D[(trainSize+1):numRows,];
+C = C[1:trainSize,];
+
+# calc "prior" and "conditionals" with naiveBayes build-in function
+[prior, conditionals] = naiveBayes(D=trainData, C=C, laplace=$4, verbose=FALSE);
+
+# compute predict
+[YRaw,Y] = naiveBayesPredict(X=testData, P=prior, C=conditionals);
+
+# write the results
+write(YRaw, $5);
+write(Y, $6);
diff --git a/src/test/scripts/installDependencies.R b/src/test/scripts/installDependencies.R
index 8482f1e..46d6f65 100644
--- a/src/test/scripts/installDependencies.R
+++ b/src/test/scripts/installDependencies.R
@@ -61,6 +61,7 @@ custom_install("imputeTS");
 custom_install("FNN");
 custom_install("class");
 custom_install("unbalanced");
+custom_install("naivebayes");
 
 print("Installation Done")