You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/10/24 06:09:35 UTC
[09/11] ignite git commit: IGNITE-9282: [ML] Add Naive Bayes
classifier
IGNITE-9282: [ML] Add Naive Bayes classifier
this closes #4869
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/e29a8cb9
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/e29a8cb9
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/e29a8cb9
Branch: refs/heads/ignite-9720
Commit: e29a8cb9380fb2c1f6815d40670315919af58d3b
Parents: 86f5437
Author: dehasi <rg...@gmail.com>
Authored: Tue Oct 23 19:11:23 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Tue Oct 23 19:11:23 2018 +0300
----------------------------------------------------------------------
.../GaussianNaiveBayesTrainerExample.java | 113 +++++++++++
.../examples/ml/naivebayes/package-info.java | 22 +++
.../ignite/examples/util/IrisDataset.java | 129 +++++++++++++
.../gaussian/GaussianNaiveBayesModel.java | 111 +++++++++++
.../gaussian/GaussianNaiveBayesSumsHolder.java | 55 ++++++
.../gaussian/GaussianNaiveBayesTrainer.java | 186 +++++++++++++++++++
.../ml/naivebayes/gaussian/package-info.java | 22 +++
.../ignite/ml/naivebayes/package-info.java | 22 +++
.../gaussian/GaussianNaiveBayesModelTest.java | 50 +++++
.../gaussian/GaussianNaiveBayesTest.java | 86 +++++++++
.../gaussian/GaussianNaiveBayesTrainerTest.java | 182 ++++++++++++++++++
11 files changed, 978 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
new file mode 100644
index 0000000..cd8383e
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.naivebayes;
+
+import java.util.Arrays;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.examples.ml.util.TestCache;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+
+import static org.apache.ignite.examples.util.IrisDataset.irisDatasetFirstAndSecondClasses;
+
+/**
+ * Run naive Bayes classification model based on <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier"> naive
+ * Bayes classifier</a> algorithm ({@link GaussianNaiveBayesTrainer}) over distributed cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * After that it trains the naive Bayes classification model based on the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class GaussianNaiveBayesTrainerExample {
+ /** Run example. */
+ public static void main(String[] args) throws InterruptedException {
+ System.out.println();
+ System.out.println(">>> Naive Bayes classification model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(irisDatasetFirstAndSecondClasses);
+
+ System.out.println(">>> Create new naive Bayes classification trainer object.");
+ GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+
+ System.out.println(">>> Perform the training to get the model.");
+ GaussianNaiveBayesModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+
+ System.out.println(">>> Naive Bayes model: " + mdl);
+
+ int amountOfErrors = 0;
+ int totalAmount = 0;
+
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0}, {0, 0}};
+
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ Vector inputs = VectorUtils.of(Arrays.copyOfRange(val, 1, val.length));
+ double groundTruth = val[0];
+
+ double prediction = mdl.apply(inputs);
+
+ totalAmount++;
+ if (groundTruth != prediction)
+ amountOfErrors++;
+
+ int idx1 = (int)prediction;
+ int idx2 = (int)groundTruth;
+
+ confusionMtx[idx1][idx2]++;
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
+
+ System.out.println(">>> ---------------------------------");
+
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ }
+
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+ System.out.println(">>> ---------------------------------");
+
+ System.out.println(">>> Naive bayes model over partitioned dataset usage example completed.");
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java
new file mode 100644
index 0000000..7f0420c
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML naive Bayes classificator examples.
+ */
+package org.apache.ignite.examples.ml.naivebayes;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java b/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java
new file mode 100644
index 0000000..53080e8
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ignite.examples.util;
+
+/** Contains data from the <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>. */
+public final class IrisDataset {
+
+ /** The 1st and 2nd classes from the Iris dataset. */
+ public static final double[][] irisDatasetFirstAndSecondClasses = {
+ {0, 5.1, 3.5, 1.4, 0.2},
+ {0, 4.9, 3, 1.4, 0.2},
+ {0, 4.7, 3.2, 1.3, 0.2},
+ {0, 4.6, 3.1, 1.5, 0.2},
+ {0, 5, 3.6, 1.4, 0.2},
+ {0, 5.4, 3.9, 1.7, 0.4},
+ {0, 4.6, 3.4, 1.4, 0.3},
+ {0, 5, 3.4, 1.5, 0.2},
+ {0, 4.4, 2.9, 1.4, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 5.4, 3.7, 1.5, 0.2},
+ {0, 4.8, 3.4, 1.6, 0.2},
+ {0, 4.8, 3, 1.4, 0.1},
+ {0, 4.3, 3, 1.1, 0.1},
+ {0, 5.8, 4, 1.2, 0.2},
+ {0, 5.7, 4.4, 1.5, 0.4},
+ {0, 5.4, 3.9, 1.3, 0.4},
+ {0, 5.1, 3.5, 1.4, 0.3},
+ {0, 5.7, 3.8, 1.7, 0.3},
+ {0, 5.1, 3.8, 1.5, 0.3},
+ {0, 5.4, 3.4, 1.7, 0.2},
+ {0, 5.1, 3.7, 1.5, 0.4},
+ {0, 4.6, 3.6, 1, 0.2},
+ {0, 5.1, 3.3, 1.7, 0.5},
+ {0, 4.8, 3.4, 1.9, 0.2},
+ {0, 5, 3, 1.6, 0.2},
+ {0, 5, 3.4, 1.6, 0.4},
+ {0, 5.2, 3.5, 1.5, 0.2},
+ {0, 5.2, 3.4, 1.4, 0.2},
+ {0, 4.7, 3.2, 1.6, 0.2},
+ {0, 4.8, 3.1, 1.6, 0.2},
+ {0, 5.4, 3.4, 1.5, 0.4},
+ {0, 5.2, 4.1, 1.5, 0.1},
+ {0, 5.5, 4.2, 1.4, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 5, 3.2, 1.2, 0.2},
+ {0, 5.5, 3.5, 1.3, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 4.4, 3, 1.3, 0.2},
+ {0, 5.1, 3.4, 1.5, 0.2},
+ {0, 5, 3.5, 1.3, 0.3},
+ {0, 4.5, 2.3, 1.3, 0.3},
+ {0, 4.4, 3.2, 1.3, 0.2},
+ {0, 5, 3.5, 1.6, 0.6},
+ {0, 5.1, 3.8, 1.9, 0.4},
+ {0, 4.8, 3, 1.4, 0.3},
+ {0, 5.1, 3.8, 1.6, 0.2},
+ {0, 4.6, 3.2, 1.4, 0.2},
+ {0, 5.3, 3.7, 1.5, 0.2},
+ {0, 5, 3.3, 1.4, 0.2},
+ {1, 7, 3.2, 4.7, 1.4},
+ {1, 6.4, 3.2, 4.5, 1.5},
+ {1, 6.9, 3.1, 4.9, 1.5},
+ {1, 5.5, 2.3, 4, 1.3},
+ {1, 6.5, 2.8, 4.6, 1.5},
+ {1, 5.7, 2.8, 4.5, 1.3},
+ {1, 6.3, 3.3, 4.7, 1.6},
+ {1, 4.9, 2.4, 3.3, 1},
+ {1, 6.6, 2.9, 4.6, 1.3},
+ {1, 5.2, 2.7, 3.9, 1.4},
+ {1, 5, 2, 3.5, 1},
+ {1, 5.9, 3, 4.2, 1.5},
+ {1, 6, 2.2, 4, 1},
+ {1, 6.1, 2.9, 4.7, 1.4},
+ {1, 5.6, 2.9, 3.6, 1.3},
+ {1, 6.7, 3.1, 4.4, 1.4},
+ {1, 5.6, 3, 4.5, 1.5},
+ {1, 5.8, 2.7, 4.1, 1},
+ {1, 6.2, 2.2, 4.5, 1.5},
+ {1, 5.6, 2.5, 3.9, 1.1},
+ {1, 5.9, 3.2, 4.8, 1.8},
+ {1, 6.1, 2.8, 4, 1.3},
+ {1, 6.3, 2.5, 4.9, 1.5},
+ {1, 6.1, 2.8, 4.7, 1.2},
+ {1, 6.4, 2.9, 4.3, 1.3},
+ {1, 6.6, 3, 4.4, 1.4},
+ {1, 6.8, 2.8, 4.8, 1.4},
+ {1, 6.7, 3, 5, 1.7},
+ {1, 6, 2.9, 4.5, 1.5},
+ {1, 5.7, 2.6, 3.5, 1},
+ {1, 5.5, 2.4, 3.8, 1.1},
+ {1, 5.5, 2.4, 3.7, 1},
+ {1, 5.8, 2.7, 3.9, 1.2},
+ {1, 6, 2.7, 5.1, 1.6},
+ {1, 5.4, 3, 4.5, 1.5},
+ {1, 6, 3.4, 4.5, 1.6},
+ {1, 6.7, 3.1, 4.7, 1.5},
+ {1, 6.3, 2.3, 4.4, 1.3},
+ {1, 5.6, 3, 4.1, 1.3},
+ {1, 5.5, 2.5, 4, 1.3},
+ {1, 5.5, 2.6, 4.4, 1.2},
+ {1, 6.1, 3, 4.6, 1.4},
+ {1, 5.8, 2.6, 4, 1.2},
+ {1, 5, 2.3, 3.3, 1},
+ {1, 5.6, 2.7, 4.2, 1.3},
+ {1, 5.7, 3, 4.2, 1.2},
+ {1, 5.7, 2.9, 4.2, 1.3},
+ {1, 6.2, 2.9, 4.3, 1.3},
+ {1, 5.1, 2.5, 3, 1.1},
+ {1, 5.7, 2.8, 4.1, 1.3},
+ };
+
+ /** */
+ private IrisDataset() {
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
new file mode 100644
index 0000000..985d9fe
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Simple naive Bayes model which predicts result value {@code y} belongs to a class {@code C_k, k in [0..K]} as {@code
+ * p(C_k,y) = p(C_k)*p(y_1,C_k) *...*p(y_n,C_k) / p(y)}. Return the number of the most possible class.
+ */
+public class GaussianNaiveBayesModel implements Model<Vector, Double>, Exportable<GaussianNaiveBayesModel>, Serializable {
+ /** */
+ private static final long serialVersionUID = -127386523291350345L;
+ /** Means of features for all classes. kth row contains means for labels[k] class. */
+ private final double[][] means;
+ /** Variances of features for all classes. kth row contains variances for labels[k] class */
+ private final double[][] variances;
+ /** Prior probabilities of each class */
+ private final double[] classProbabilities;
+ /** Labels. */
+ private final double[] labels;
+ /** Feature sum, squared sum and cound per label. */
+ private final GaussianNaiveBayesSumsHolder sumsHolder;
+
+ /**
+ * @param means Means of features for all classes.
+ * @param variances Variances of features for all classes.
+ * @param classProbabilities Probabilities for all classes.
+ * @param labels Labels.
+ * @param sumsHolder Feature sum, squared sum and count sum per label. This data is used for future model updating.
+ */
+ public GaussianNaiveBayesModel(double[][] means, double[][] variances,
+ double[] classProbabilities, double[] labels, GaussianNaiveBayesSumsHolder sumsHolder) {
+ this.means = means;
+ this.variances = variances;
+ this.classProbabilities = classProbabilities;
+ this.labels = labels;
+ this.sumsHolder = sumsHolder;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <P> void saveModel(Exporter<GaussianNaiveBayesModel, P> exporter, P path) {
+ exporter.save(this, path);
+ }
+
+ /** Returns a number of class to which the input belongs. */
+ @Override public Double apply(Vector vector) {
+ int k = classProbabilities.length;
+
+ double maxProbapility = .0;
+ int max = 0;
+
+ for (int i = 0; i < k; i++) {
+ double p = classProbabilities[i];
+ for (int j = 0; j < vector.size(); j++) {
+ double x = vector.get(j);
+ double g = gauss(x, means[i][j], variances[i][j]);
+ p *= g;
+ }
+ if (p > maxProbapility) {
+ max = i;
+ maxProbapility = p;
+ }
+ }
+ return labels[max];
+ }
+
+ /** */
+ public double[][] getMeans() {
+ return means;
+ }
+
+ /** */
+ public double[][] getVariances() {
+ return variances;
+ }
+
+ /** */
+ public double[] getClassProbabilities() {
+ return classProbabilities;
+ }
+
+ /** */
+ public GaussianNaiveBayesSumsHolder getSumsHolder() {
+ return sumsHolder;
+ }
+
+ /** Gauss distribution */
+ private double gauss(double x, double mean, double variance) {
+ return Math.exp(-1. * Math.pow(x - mean, 2) / (2. * variance)) / Math.sqrt(2. * Math.PI * variance);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
new file mode 100644
index 0000000..735bbd1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.math.util.MapUtil;
+
+/** Service class is used to calculate means and vaiances */
+class GaussianNaiveBayesSumsHolder implements Serializable, AutoCloseable {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 1L;
+ /** Sum of all values for all features for each label */
+ Map<Double, double[]> featureSumsPerLbl = new HashMap<>();
+ /** Sum of all squared values for all features for each label */
+ Map<Double, double[]> featureSquaredSumsPerLbl = new HashMap<>();
+ /** Rows count for each label */
+ Map<Double, Integer> featureCountersPerLbl = new HashMap<>();
+
+ /** Merge to current */
+ GaussianNaiveBayesSumsHolder merge(GaussianNaiveBayesSumsHolder other) {
+ featureSumsPerLbl = MapUtil.mergeMaps(featureSumsPerLbl, other.featureSumsPerLbl, this::sum, HashMap::new);
+ featureSquaredSumsPerLbl = MapUtil.mergeMaps(featureSquaredSumsPerLbl, other.featureSquaredSumsPerLbl, this::sum, HashMap::new);
+ featureCountersPerLbl = MapUtil.mergeMaps(featureCountersPerLbl, other.featureCountersPerLbl, (i1, i2) -> i1 + i2, HashMap::new);
+ return this;
+ }
+
+ /** In-place operation. Sums {@code arr2} to {@code arr1} element to element. */
+ private double[] sum(double[] arr1, double[] arr2) {
+ for (int i = 0; i < arr1.length; i++) {
+ arr1[i] += arr2[i];
+ }
+ return arr1;
+ }
+
+ /** */
+ @Override public void close() {
+ // Do nothing, GC will clean up.
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
new file mode 100644
index 0000000..1c1df83
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+/**
+ * Trainer for the naive Bayes classification model. The trainer calculates prior probabilities from the input dataset.
+ * Prior probabilities can be also set by {@code setPriorProbabilities} or {@code withEquiprobableClasses}. If {@code
+ * equiprobableClasses} is set, the probalilities of all classes will be {@code 1/k}, where {@code k} is classes count.
+ */
+public class GaussianNaiveBayesTrainer extends SingleLabelDatasetTrainer<GaussianNaiveBayesModel> {
+
+ /* Preset prior probabilities. */
+ private double[] priorProbabilities;
+ /* Sets equivalent probability for all classes. */
+ private boolean equiprobableClasses;
+
+ /**
+ * Trains model based on the specified data.
+ *
+ * @param datasetBuilder Dataset builder.
+ * @param featureExtractor Feature extractor.
+ * @param lbExtractor Label extractor.
+ * @return Model.
+ */
+ @Override public <K, V> GaussianNaiveBayesModel fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected boolean checkState(GaussianNaiveBayesModel mdl) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel mdl,
+ DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, Double> lbExtractor) {
+ assert datasetBuilder != null;
+
+ try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build(
+ (upstream, upstreamSize) -> new EmptyContext(),
+ (upstream, upstreamSize, ctx) -> {
+
+ GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder();
+ while (upstream.hasNext()) {
+ UpstreamEntry<K, V> entity = upstream.next();
+
+ Vector features = featureExtractor.apply(entity.getKey(), entity.getValue());
+ Double label = lbExtractor.apply(entity.getKey(), entity.getValue());
+
+ double[] toMeans;
+ double[] sqSum;
+
+ if (!res.featureSumsPerLbl.containsKey(label)) {
+ toMeans = new double[features.size()];
+ Arrays.fill(toMeans, 0.);
+ res.featureSumsPerLbl.put(label, toMeans);
+ }
+ if (!res.featureSquaredSumsPerLbl.containsKey(label)) {
+ sqSum = new double[features.size()];
+ res.featureSquaredSumsPerLbl.put(label, sqSum);
+ }
+ if (!res.featureCountersPerLbl.containsKey(label)) {
+ res.featureCountersPerLbl.put(label, 0);
+ }
+ res.featureCountersPerLbl.put(label, res.featureCountersPerLbl.get(label) + 1);
+
+ toMeans = res.featureSumsPerLbl.get(label);
+ sqSum = res.featureSquaredSumsPerLbl.get(label);
+ for (int j = 0; j < features.size(); j++) {
+ double x = features.get(j);
+ toMeans[j] += x;
+ sqSum[j] += x * x;
+ }
+ }
+ return res;
+ }
+ )) {
+ GaussianNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> {
+ if (a == null)
+ return b == null ? new GaussianNaiveBayesSumsHolder() : b;
+ if (b == null)
+ return a;
+ return a.merge(b);
+ });
+ if (mdl != null && mdl.getSumsHolder() != null) {
+ sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
+ }
+
+ List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
+ sortedLabels.sort(Double::compareTo);
+ assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
+
+ int labelCount = sortedLabels.size();
+ int featureCount = sumsHolder.featureSumsPerLbl.get(sortedLabels.get(0)).length;
+
+ double[][] means = new double[labelCount][featureCount];
+ double[][] variances = new double[labelCount][featureCount];
+ double[] classProbabilities = new double[labelCount];
+ double[] labels = new double[labelCount];
+
+ long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
+
+ int lbl = 0;
+ for (Double label : sortedLabels) {
+ int count = sumsHolder.featureCountersPerLbl.get(label);
+ double[] sum = sumsHolder.featureSumsPerLbl.get(label);
+ double[] sqSum = sumsHolder.featureSquaredSumsPerLbl.get(label);
+
+ for (int i = 0; i < featureCount; i++) {
+ means[lbl][i] = sum[i] / count;
+ variances[lbl][i] = (sqSum[i] - sum[i] * sum[i] / count) / count;
+ }
+
+ if (equiprobableClasses) {
+ classProbabilities[lbl] = 1. / labelCount;
+ }
+ else if (priorProbabilities != null) {
+ assert classProbabilities.length == priorProbabilities.length;
+ classProbabilities[lbl] = priorProbabilities[lbl];
+ }
+ else {
+ classProbabilities[lbl] = (double)count / datasetSize;
+ }
+
+ labels[lbl] = label;
+ ++lbl;
+ }
+
+ return new GaussianNaiveBayesModel(means, variances, classProbabilities, labels, sumsHolder);
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ /** Sets equal probability for all classes. */
+ public GaussianNaiveBayesTrainer withEquiprobableClasses() {
+ resetSettings();
+ equiprobableClasses = true;
+ return this;
+ }
+
+ /** Sets prior probabilities. */
+ public GaussianNaiveBayesTrainer setPriorProbabilities(double[] priorProbabilities) {
+ resetSettings();
+ this.priorProbabilities = priorProbabilities.clone();
+ return this;
+ }
+
+ /** Sets default settings. */
+ public GaussianNaiveBayesTrainer resetSettings() {
+ equiprobableClasses = false;
+ priorProbabilities = null;
+ return this;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java
new file mode 100644
index 0000000..4e572cf
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains Gaussian naive Bayes classifier.
+ */
+package org.apache.ignite.ml.naivebayes.gaussian;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java
new file mode 100644
index 0000000..fae5387
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains various naive Bayes classifiers.
+ */
+package org.apache.ignite.ml.naivebayes;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
new file mode 100644
index 0000000..c79c0d7
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GaussianNaiveBayesModel}.
+ */
+public class GaussianNaiveBayesModelTest {
+
+ /** */
+ @Test
+ public void testPredictWithTwoClasses() {
+ double first = 1;
+ double second = 2;
+ double[][] means = new double[][] {
+ {5.855, 176.25, 11.25},
+ {5.4175, 132.5, 7.5},
+ };
+ double[][] variances = new double[][] {
+ {3.5033E-2, 1.2292E2, 9.1667E-1},
+ {9.7225E-2, 5.5833E2, 1.6667},
+ };
+ double[] probabilities = new double[] {.5, .5};
+ GaussianNaiveBayesModel mdl = new GaussianNaiveBayesModel(means, variances, probabilities, new double[] {first, second}, null);
+ Vector observation = VectorUtils.of(6, 130, 8);
+
+ Assert.assertEquals(second, mdl.apply(observation), 0.0001);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java
new file mode 100644
index 0000000..504b464
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Complex tests for naive Bayes algorithm with different datasets.
+ */
+public class GaussianNaiveBayesTest {
+ /** Precision in test checks. */
+ private static final double PRECISION = 1e-2;
+
+ /**
+ * An example data set from wikipedia article about Naive Bayes https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Sex_classification
+ */
+ @Test
+ public void wikipediaSexClassificationDataset() {
+ Map<Integer, double[]> data = new HashMap<>();
+ double male = 0.;
+ double female = 1.;
+ data.put(0, new double[] {male, 6, 180, 12});
+ data.put(2, new double[] {male, 5.92, 190, 11});
+ data.put(3, new double[] {male, 5.58, 170, 12});
+ data.put(4, new double[] {male, 5.92, 165, 10});
+ data.put(5, new double[] {female, 5, 100, 6});
+ data.put(6, new double[] {female, 5.5, 150, 8});
+ data.put(7, new double[] {female, 5.42, 130, 7});
+ data.put(8, new double[] {female, 5.75, 150, 9});
+ GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(data, 2),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+ Vector observation = VectorUtils.of(6, 130, 8);
+
+ Assert.assertEquals(female, model.apply(observation), PRECISION);
+ }
+
+ /** Dataset from Gaussian NB example in the scikit-learn documentation */
+ @Test
+ public void scikitLearnExample() {
+ Map<Integer, double[]> data = new HashMap<>();
+ double one = 1.;
+ double two = 2.;
+ data.put(0, new double[] {one, -1, 1});
+ data.put(2, new double[] {one, -2, -1});
+ data.put(3, new double[] {one, -3, -2});
+ data.put(4, new double[] {two, 1, 1});
+ data.put(5, new double[] {two, 2, 1});
+ data.put(6, new double[] {two, 3, 2});
+ GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(data, 2),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+ Vector observation = VectorUtils.of(-0.8, -1);
+
+ Assert.assertEquals(one, model.apply(observation), PRECISION);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
new file mode 100644
index 0000000..f70f7c2
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GaussianNaiveBayesTrainer}.
+ */
+public class GaussianNaiveBayesTrainerTest extends TrainerTest {
+ /** Precision in test checks. */
+ private static final double PRECISION = 1e-2;
+ /** */
+ private static final double LABEL_1 = 1.;
+ /** */
+ private static final double LABEL_2 = 2.;
+
+ /** Data. */
+ private static final Map<Integer, double[]> data = new HashMap<>();
+ /** */
+ private static final Map<Integer, double[]> singleLabeldata1 = new HashMap<>();
+ /** */
+ private static final Map<Integer, double[]> singleLabeldata2 = new HashMap<>();
+
+ static {
+ data.put(0, new double[] {1.0, -1.0, LABEL_1});
+ data.put(1, new double[] {-1.0, 2.0, LABEL_1});
+ data.put(2, new double[] {6.0, 1.0, LABEL_1});
+ data.put(3, new double[] {-3.0, 2.0, LABEL_2});
+ data.put(4, new double[] {-5.0, -2.0, LABEL_2});
+
+ singleLabeldata1.put(0, new double[] {1.0, -1.0, LABEL_1});
+ singleLabeldata1.put(1, new double[] {-1.0, 2.0, LABEL_1});
+ singleLabeldata1.put(2, new double[] {6.0, 1.0, LABEL_1});
+
+ singleLabeldata2.put(0, new double[] {-3.0, 2.0, LABEL_2});
+ singleLabeldata2.put(1, new double[] {-5.0, -2.0, LABEL_2});
+ }
+
+ private GaussianNaiveBayesTrainer trainer;
+
+ /** Initialization {@code GaussianNaiveBayesTrainer}.*/
+ @Before
+ public void createTrainer() {
+ trainer = new GaussianNaiveBayesTrainer();
+ }
+
+ /** */
+ @Test
+ public void testWithLinearlySeparableData() {
+ Map<Integer, double[]> cacheMock = new HashMap<>();
+ for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+ cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+ GaussianNaiveBayesModel mdl = trainer.fit(
+ cacheMock,
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+
+ TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
+ TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
+ }
+
+ /** */
+ @Test
+ public void testReturnsCorrectLabelProbalities() {
+
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(data, parts),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[2]
+ );
+
+ Assert.assertEquals(3. / data.size(), model.getClassProbabilities()[0], PRECISION);
+ Assert.assertEquals(2. / data.size(), model.getClassProbabilities()[1], PRECISION);
+ }
+
+ /** */
+ @Test
+ public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() {
+ GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer()
+ .withEquiprobableClasses();
+
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(data, parts),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[2]
+ );
+
+ Assert.assertEquals(.5, model.getClassProbabilities()[0], PRECISION);
+ Assert.assertEquals(.5, model.getClassProbabilities()[1], PRECISION);
+ }
+
+ /** */
+ @Test
+ public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() {
+ double[] priorProbabilities = new double[] {.35, .65};
+ GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer()
+ .setPriorProbabilities(priorProbabilities);
+
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(data, parts),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[2]
+ );
+
+ Assert.assertEquals(priorProbabilities[0], model.getClassProbabilities()[0], PRECISION);
+ Assert.assertEquals(priorProbabilities[1], model.getClassProbabilities()[1], PRECISION);
+ }
+
+ /** */
+ @Test
+ public void testReturnsCorrectMeans() {
+
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(singleLabeldata1, parts),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[2]
+ );
+
+ Assert.assertArrayEquals(new double[] {2.0, 2. / 3.}, model.getMeans()[0], PRECISION);
+ }
+
+ /** */
+ @Test
+ public void testReturnsCorrectVariances() {
+
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(singleLabeldata1, parts),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[2]
+ );
+
+ double[] expectedVars = {8.666666666666666, 1.5555555555555556};
+ Assert.assertArrayEquals(expectedVars, model.getVariances()[0], PRECISION);
+ }
+
+ /** */
+ @Test
+ public void testUpdatigModel() {
+ GaussianNaiveBayesModel model = trainer.fit(
+ new LocalDatasetBuilder<>(singleLabeldata1, parts),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[2]
+ );
+
+ GaussianNaiveBayesModel updatedModel = trainer.updateModel(model,
+ new LocalDatasetBuilder<>(singleLabeldata2, parts),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[2]
+ );
+
+ Assert.assertEquals(3. / data.size(), updatedModel.getClassProbabilities()[0], PRECISION);
+ Assert.assertEquals(2. / data.size(), updatedModel.getClassProbabilities()[1], PRECISION);
+ }
+}