You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/18 10:20:13 UTC

[ignite] branch master updated: IGNITE-10546: [ML] GMM with adding and removal of components

This is an automated email from the ASF dual-hosted git repository.

chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new 6db24b0  IGNITE-10546: [ML] GMM with adding and removal of components
6db24b0 is described below

commit 6db24b0148c4b394b074973ee9425a2af9aa6213
Author: Alexey Platonov <ap...@gmail.com>
AuthorDate: Mon Feb 18 13:19:28 2019 +0300

    IGNITE-10546: [ML] GMM with adding and removal of components
    
    This closes #6113
---
 .../ml/clustering/GmmClusterizationExample.java    |  16 +-
 .../ignite/ml/clustering/gmm/GmmPartitionData.java |  45 ++++-
 .../ignite/ml/clustering/gmm/GmmTrainer.java       | 223 ++++++++++++++++-----
 .../gmm/MeanWithClusterProbAggregator.java         |  13 +-
 .../gmm/NewComponentStatisticsAggregator.java      | 169 ++++++++++++++++
 .../DefaultLearningEnvironmentBuilder.java         |   2 +-
 .../ignite/ml/math/stat/DistributionMixture.java   |   5 +-
 .../org/apache/ignite/ml/math/util/MatrixUtil.java |  15 +-
 .../logistic/LogisticRegressionSGDTrainer.java     |   3 +-
 .../ignite/ml/clustering/ClusteringTestSuite.java  |   4 +-
 .../gmm/MeanWithClusterProbAggregatorTest.java     |   2 +-
 .../gmm/NewComponentStatisticsAggregatorTest.java  | 147 ++++++++++++++
 12 files changed, 566 insertions(+), 78 deletions(-)

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
index c15b839..d9f03c1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.examples.ml.clustering;
 
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
@@ -34,8 +35,6 @@ import org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProduce
 import org.apache.ignite.ml.util.generators.primitives.scalar.RandomProducer;
 import org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily;
 
-import java.util.UUID;
-
 /**
  * Example of using GMM clusterization algorithm. Gaussian Mixture Algorithm (GMM, see {@link GmmModel}, {@link
  * GmmTrainer}) can be used for input dataset data distribution representation as mixture of multivariance gaussians.
@@ -57,8 +56,8 @@ public class GmmClusterizationExample {
             System.out.println(">>> Ignite grid started.");
 
             long seed = 0;
-            IgniteCache<UUID, LabeledVector<Double>> dataCache = ignite.getOrCreateCache(
-                new CacheConfiguration<UUID, LabeledVector<Double>>("GMM_EXAMPLE_CACHE")
+            IgniteCache<Integer, LabeledVector<Double>> dataCache = ignite.getOrCreateCache(
+                new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE")
                     .setAffinity(new RendezvousAffinityFunction(false, 10))
             );
 
@@ -78,11 +77,14 @@ public class GmmClusterizationExample {
                 ).move(VectorUtils.of(0., -10.))
             ).build(seed++).asDataStream();
 
-            dataStream.fillCacheWithVecUUIDAsKey(50000, dataCache);
-            GmmTrainer trainer = new GmmTrainer(3);
+            AtomicInteger keyGen = new AtomicInteger();
+            dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
+            GmmTrainer trainer = new GmmTrainer(1);
 
             GmmModel mdl = trainer
-                .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed++))
+                .withMaxCountIterations(10)
+                .withMaxCountOfClusters(4)
+                .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed))
                 .fit(ignite, dataCache, (k, v) -> v.features(), (k, v) -> v.label());
 
             System.out.println(">>> GMM means and covariances");
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
index 942c511..1b8e50c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
@@ -22,6 +22,7 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
@@ -95,13 +96,6 @@ class GmmPartitionData implements AutoCloseable {
         return pcxi.length;
     }
 
-    /**
-     * @return count of GMM components.
-     */
-    public int countOfComponents() {
-        return size() != 0 ? pcxi[0].length : 0;
-    }
-
     /** {@inheritDoc} */
     @Override public void close() throws Exception {
         //NOP
@@ -162,6 +156,7 @@ class GmmPartitionData implements AutoCloseable {
 
             Vector x = data.getX(i);
             for (int c = 0; c < initMeans.length; c++) {
+                data.setPcxi(c, i, 0.0);
                 double distance = initMeans[c].getDistanceSquared(x);
                 if (distance < minSquaredDist) {
                     closestClusterId = c;
@@ -174,16 +169,40 @@ class GmmPartitionData implements AutoCloseable {
     }
 
     /**
+     * Updates P(c|xi) values in partitions and compute dataset likelihood.
+     *
+     * @param dataset Dataset.
+     * @param clusterProbs Component probabilities.
+     * @param components Components.
+     * @return dataset likelihood.
+     */
+    static double updatePcxiAndComputeLikelihood(Dataset<EmptyContext, GmmPartitionData> dataset, Vector clusterProbs,
+        List<MultivariateGaussianDistribution> components) {
+
+        return dataset.compute(
+            data -> updatePcxi(data, clusterProbs, components),
+            (left, right) -> asPrimitive(left) + asPrimitive(right)
+        );
+    }
+
+    /**
      * Updates P(c|xi) values in partitions given components probabilities and components of GMM.
      *
      * @param clusterProbs Component probabilities.
      * @param components Components.
      */
-    static void updatePcxi(GmmPartitionData data, Vector clusterProbs,
+    static double updatePcxi(GmmPartitionData data, Vector clusterProbs,
         List<MultivariateGaussianDistribution> components) {
 
+        GmmModel model = new GmmModel(clusterProbs, components);
+        double maxProb = Double.NEGATIVE_INFINITY;
+
         for (int i = 0; i < data.size(); i++) {
             Vector x = data.getX(i);
+            double xProb = model.prob(x);
+            if(xProb > maxProb)
+                maxProb = xProb;
+
             double normalizer = 0.0;
             for (int c = 0; c < clusterProbs.size(); c++)
                 normalizer += components.get(c).prob(x) * clusterProbs.get(c);
@@ -191,5 +210,15 @@ class GmmPartitionData implements AutoCloseable {
             for (int c = 0; c < clusterProbs.size(); c++)
                 data.pcxi[i][c] = (components.get(c).prob(x) * clusterProbs.get(c)) / normalizer;
         }
+
+        return maxProb;
+    }
+
+    /**
+     * @param val Value.
+     * @return 0 if Value == null and simplified value in terms of type otherwise.
+     */
+    private static double asPrimitive(Double val) {
+        return val == null ? 0.0 : val;
     }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
index 09e1325..09e93f6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
@@ -17,13 +17,20 @@
 
 package org.apache.ignite.ml.clustering.gmm;
 
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
 import org.apache.ignite.internal.util.typedef.internal.A;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.environment.LearningEnvironment;
-import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -36,14 +43,6 @@ import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
 import org.jetbrains.annotations.NotNull;
 
-import java.util.ArrayList;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-import java.util.stream.DoubleStream;
-import java.util.stream.Stream;
-
 /**
  * Traner for GMM model.
  */
@@ -64,6 +63,20 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
     private int maxCountOfInitTries = 3;
 
     /**
+     * Maximum count of clusters that can be achieved.
+     */
+    private int maxCountOfClusters = 2;
+
+    /** Maximum divergence between maximum of likelihood of vector in dataset and other for anomalies identification. */
+    private double maxLikelihoodDivergence = 5;
+
+    /** Minimum required anomalies in terms of maxLikelihoodDivergence for creating new cluster. */
+    private double minElementsForNewCluster = 300;
+
+    /** Min cluster probability. */
+    private double minClusterProbability = 0.05;
+
+    /**
      * Creates an instance of GmmTrainer.
      */
     public GmmTrainer() {
@@ -101,11 +114,13 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
      * @param numberOfComponents Number of components.
      * @return trainer.
      */
-    public GmmTrainer withCountOfComponents(int numberOfComponents) {
+    public GmmTrainer withInitialCountOfComponents(int numberOfComponents) {
         A.ensure(numberOfComponents > 0, "Number of components in GMM cannot equal 0");
 
         this.countOfComponents = numberOfComponents;
         initialMeans = null;
+        if (countOfComponents > maxCountOfClusters)
+            maxCountOfClusters = countOfComponents;
         return this;
     }
 
@@ -120,6 +135,8 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
 
         this.initialMeans = means.toArray(new Vector[means.size()]);
         this.countOfComponents = means.size();
+        if (countOfComponents > maxCountOfClusters)
+            maxCountOfClusters = countOfComponents;
         return this;
     }
 
@@ -164,13 +181,128 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
     }
 
     /**
+     * Sets maximum number of clusters in GMM.
+     *
+     * @param maxCountOfClusters Max count of clusters.
+     * @return trainer.
+     */
+    public GmmTrainer withMaxCountOfClusters(int maxCountOfClusters) {
+        A.ensure(maxCountOfClusters >= countOfComponents, "Max count of components should be greater than " +
+            "initial count of components or equal to it");
+
+        this.maxCountOfClusters = maxCountOfClusters;
+        return this;
+    }
+
+    /**
+     * Sets maximum divergence between maximum of likelihood of vector in dataset and other for anomalies
+     * identification.
+     *
+     * @param maxLikelihoodDivergence Max likelihood divergence.
+     * @return trainer.
+     */
+    public GmmTrainer withMaxLikelihoodDivergence(double maxLikelihoodDivergence) {
+        A.ensure(maxLikelihoodDivergence > 0, "Max likelihood divergence should be > 0");
+
+        this.maxLikelihoodDivergence = maxLikelihoodDivergence;
+        return this;
+    }
+
+    /**
+     * Sets minimum required anomalies in terms of maxLikelihoodDivergence for creating new cluster.
+     *
+     * @param minElementsForNewCluster Min elements for new cluster.
+     * @return trainer.
+     */
+    public GmmTrainer withMinElementsForNewCluster(int minElementsForNewCluster) {
+        A.ensure(minElementsForNewCluster > 0, "Min elements for new cluster should be > 0");
+
+        this.minElementsForNewCluster = minElementsForNewCluster;
+        return this;
+    }
+
+    /**
+     * Sets minimum requred probability for cluster. If cluster has probability value less than this value then this
+     * cluster will be eliminated.
+     *
+     * @param minClusterProbability Min cluster probability.
+     * @return trainer.
+     */
+    public GmmTrainer withMinClusterProbability(double minClusterProbability) {
+        this.minClusterProbability = minClusterProbability;
+        return this;
+    }
+
+    /**
      * Trains model based on the specified data.
      *
      * @param dataset Dataset.
      * @return GMM model.
      */
     private Optional<GmmModel> fit(Dataset<EmptyContext, GmmPartitionData> dataset) {
-        return init(dataset).map(model -> updateModel(dataset, model));
+        return init(dataset).map(model -> {
+            GmmModel currentModel = model;
+
+            do {
+                UpdateResult updateResult = updateModel(dataset, currentModel);
+                currentModel = updateResult.model;
+
+                double minCompProb = currentModel.componentsProbs().minElement().get();
+                if (countOfComponents >= maxCountOfClusters || minCompProb < minClusterProbability)
+                    break;
+
+                double maxXProb = updateResult.maxProbInDataset;
+                NewComponentStatisticsAggregator newMeanAdder = NewComponentStatisticsAggregator.computeNewMean(dataset,
+                    maxXProb, maxLikelihoodDivergence, currentModel);
+
+                Vector newMean = newMeanAdder.mean();
+                if (newMeanAdder.rowCountForNewCluster() < minElementsForNewCluster)
+                    break;
+
+                countOfComponents += 1;
+                Vector[] newMeans = new Vector[countOfComponents];
+                for (int i = 0; i < currentModel.countOfComponents(); i++)
+                    newMeans[i] = currentModel.distributions().get(i).mean();
+                newMeans[countOfComponents - 1] = newMean;
+
+                initialMeans = newMeans;
+
+                Optional<GmmModel> newModelOpt = init(dataset);
+                if (newModelOpt.isPresent())
+                    currentModel = newModelOpt.get();
+                else
+                    break;
+            }
+            while (true);
+
+            return filterModel(currentModel);
+        });
+    }
+
+    /**
+     * Remove clusters with probability value < minClusterProbability
+     *
+     * @param model Model.
+     * @return filtered model.
+     */
+    private GmmModel filterModel(GmmModel model) {
+        List<Double> componentProbs = new ArrayList<>();
+        List<MultivariateGaussianDistribution> distributions = new ArrayList<>();
+
+        Vector originalComponentProbs = model.componentsProbs();
+        List<MultivariateGaussianDistribution> originalDistr = model.distributions();
+        for (int i = 0; i < model.countOfComponents(); i++) {
+            double prob = originalComponentProbs.get(i);
+            if (prob > minClusterProbability) {
+                componentProbs.add(prob);
+                distributions.add(originalDistr.get(i));
+            }
+        }
+
+        return new GmmModel(
+            VectorUtils.of(componentProbs.toArray(new Double[0])),
+            distributions
+        );
     }
 
     /**
@@ -180,11 +312,12 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
      * @param model Model.
      * @return updated model.
      */
-    @NotNull private GmmModel updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
+    @NotNull private UpdateResult updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
         boolean isConverged = false;
         int countOfIterations = 0;
+        double maxProbInDataset = Double.NEGATIVE_INFINITY;
         while (!isConverged) {
-            MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset);
+            MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset, countOfComponents);
             Vector clusterProbs = stats.clusterProbabilities();
             Vector[] newMeans = stats.means().toArray(new Vector[countOfComponents]);
 
@@ -199,9 +332,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
                 countOfIterations += 1;
                 isConverged = isConverged(model, newModel) || countOfIterations > maxCountOfIterations;
                 model = newModel;
-
-                if (!isConverged)
-                    dataset.compute(data -> GmmPartitionData.updatePcxi(data, clusterProbs, components));
+                maxProbInDataset = GmmPartitionData.updatePcxiAndComputeLikelihood(dataset, clusterProbs, components);
             }
             catch (SingularMatrixException | IllegalArgumentException e) {
                 String msg = "Cannot construct non-singular covariance matrix by data. " +
@@ -211,7 +342,27 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
             }
         }
 
-        return model;
+        return new UpdateResult(model, maxProbInDataset);
+    }
+
+    /**
+     * Result of current model update by EM-algorithm.
+     */
+    private static class UpdateResult {
+        /** Model. */
+        private final GmmModel model;
+
+        /** Max likelihood in dataset. */
+        private final double maxProbInDataset;
+
+        /**
+         * @param model Model.
+         * @param maxProbInDataset Max likelihood in dataset.
+         */
+        public UpdateResult(GmmModel model, double maxProbInDataset) {
+            this.model = model;
+            this.maxProbInDataset = maxProbInDataset;
+        }
     }
 
     /**
@@ -226,27 +377,21 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
         while (true) {
             try {
                 if (initialMeans == null) {
-                    List<List<Vector>> randomMeansSets = Stream.of(dataset.compute(
+                    List<Vector> randomMeansSets = Stream.of(dataset.compute(
                         selectNRandomXsMapper(countOfComponents),
                         GmmTrainer::selectNRandomXsReducer
-                    )).map(this::asList).collect(Collectors.toList());
+                    )).flatMap(Stream::of).collect(Collectors.toList());
+
+                    Collections.sort(randomMeansSets, Comparator.comparingDouble(Vector::getLengthSquared));
+                    Collections.shuffle(randomMeansSets, environment.randomNumbersGenerator());
 
                     A.ensure(
-                        randomMeansSets.stream().mapToInt(List::size).sum() >= countOfComponents,
+                        randomMeansSets.size() >= countOfComponents,
                         "There is not enough data in dataset for select N random means"
                     );
 
-                    initialMeans = new Vector[countOfComponents];
-                    int j = 0;
-                    for (int i = 0; i < countOfComponents; ) {
-                        List<Vector> randomMeansPart = randomMeansSets.get(j);
-                        if (!randomMeansPart.isEmpty()) {
-                            initialMeans[i] = randomMeansPart.remove(0);
-                            i++;
-                        }
-
-                        j = (j + 1) % randomMeansSets.size();
-                    }
+                    initialMeans = randomMeansSets.subList(0, countOfComponents)
+                        .toArray(new Vector[countOfComponents]);
                 }
 
                 dataset.compute(data -> GmmPartitionData.estimateLikelihoodClusters(data, initialMeans));
@@ -282,17 +427,6 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
     }
 
     /**
-     * @param vectors Array of vectors.
-     * @return list of vectors.
-     */
-    private LinkedList<Vector> asList(Vector... vectors) {
-        LinkedList<Vector> res = new LinkedList<>();
-        for (Vector v : vectors)
-            res.addFirst(v);
-        return res;
-    }
-
-    /**
      * Create new model components with provided means and covariances.
      *
      * @param means Means.
@@ -340,10 +474,9 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
     @Override protected <K, V> GmmModel updateModel(GmmModel mdl, DatasetBuilder<K, V> datasetBuilder,
         FeatureLabelExtractor<K, V, Double> extractor) {
 
-        try (Dataset<EmptyContext, GmmPartitionData> dataset = datasetBuilder.build(
-            LearningEnvironmentBuilder.defaultBuilder(),
+        try (Dataset<EmptyContext, GmmPartitionData> dataset = datasetBuilder.build(envBuilder,
             new EmptyContextBuilder<>(),
-            new GmmPartitionData.Builder<>(extractor, countOfComponents)
+            new GmmPartitionData.Builder<>(extractor, maxCountOfClusters)
         )) {
             if (mdl != null) {
                 if (initialMeans != null)
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
index 58044a7..99e60ba 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
@@ -47,6 +47,7 @@ class MeanWithClusterProbAggregator implements Serializable {
      * Create an instance of MeanWithClusterProbAggregator.
      */
     MeanWithClusterProbAggregator() {
+        // NO-OP.
     }
 
     /**
@@ -80,10 +81,11 @@ class MeanWithClusterProbAggregator implements Serializable {
      * Aggregates statistics for means and cluster probabilities computing given dataset.
      *
      * @param dataset Dataset.
+     * @param countOfComponents
      */
-    public static AggregatedStats aggreateStats(Dataset<EmptyContext, GmmPartitionData> dataset) {
+    public static AggregatedStats aggreateStats(Dataset<EmptyContext, GmmPartitionData> dataset, int countOfComponents) {
         return new AggregatedStats(dataset.compute(
-            MeanWithClusterProbAggregator::map,
+            data -> map(data, countOfComponents),
             MeanWithClusterProbAggregator::reduce
         ));
     }
@@ -123,15 +125,16 @@ class MeanWithClusterProbAggregator implements Serializable {
      * Map stage for statistics aggregation.
      *
      * @param data Partition data.
+     * @param countOfComponents Count of components.
      * @return Aggregated statistics.
      */
-    static List<MeanWithClusterProbAggregator> map(GmmPartitionData data) {
+    static List<MeanWithClusterProbAggregator> map(GmmPartitionData data, int countOfComponents) {
         List<MeanWithClusterProbAggregator> aggregators = new ArrayList<>();
-        for (int i = 0; i < data.countOfComponents(); i++)
+        for (int i = 0; i < countOfComponents; i++)
             aggregators.add(new MeanWithClusterProbAggregator());
 
         for (int i = 0; i < data.size(); i++) {
-            for (int c = 0; c < data.countOfComponents(); c++)
+            for (int c = 0; c < countOfComponents; c++)
                 aggregators.get(c).add(data.getX(i), data.pcxi(c, i));
         }
 
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java
new file mode 100644
index 0000000..4fa5406
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.io.Serializable;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Class for aggregate statistics for finding new mean for GMM.
+ */
+public class NewComponentStatisticsAggregator implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 6748270328889375005L;
+
+    /** Total row count in dataset. */
+    private long totalRowCount;
+
+    /** Row count for new cluster. */
+    private long rowCountForNewCluster;
+
+    /** Sum of anomalies vectors. */
+    private Vector sumOfAnomalies;
+
+    /**
+     * Creates an instance of NewComponentStatisticsAggregator.
+     *
+     * @param totalRowCount Total row count in dataset.
+     * @param rowCountForNewCluster Row count for new cluster.
+     * @param sumOfAnomalies Sum of anomalies.
+     */
+    public NewComponentStatisticsAggregator(long totalRowCount, long rowCountForNewCluster, Vector sumOfAnomalies) {
+        this.totalRowCount = totalRowCount;
+        this.rowCountForNewCluster = rowCountForNewCluster;
+        this.sumOfAnomalies = sumOfAnomalies;
+    }
+
+    /**
+     * Creates an instance of NewComponentStatisticsAggregator.
+     */
+    public NewComponentStatisticsAggregator() {
+    }
+
+    /**
+     * @return Mean of anomalies.
+     */
+    public Vector mean() {
+        return sumOfAnomalies.divide(rowCountForNewCluster);
+    }
+
+    /**
+     * @return Row count for new cluster.
+     */
+    public long rowCountForNewCluster() {
+        return rowCountForNewCluster;
+    }
+
+    /**
+     * @return Total count of rows in partition/dataset.
+     */
+    public long totalRowCount() {
+        return totalRowCount;
+    }
+
+    /**
+     * Compute statistics for new mean for GMM.
+     *
+     * @param dataset Dataset.
+     * @param maxXsProb Max likelihood between all xs.
+     * @param maxProbDivergence Max probability divergence between maximum value and others.
+     * @param currentModel Current model.
+     * @return aggregated statistics for new mean.
+     */
+    static NewComponentStatisticsAggregator computeNewMean(Dataset<EmptyContext, GmmPartitionData> dataset,
+        double maxXsProb, double maxProbDivergence, GmmModel currentModel) {
+
+        return dataset.compute(
+            data -> computeNewMeanMap(data, maxXsProb, maxProbDivergence, currentModel),
+            NewComponentStatisticsAggregator::computeNewMeanReduce
+        );
+    }
+
+    /**
+     * Map stage for new mean computing.
+     *
+     * @param data Data.
+     * @param maxXsProb Max xs prob.
+     * @param maxProbDivergence Max prob divergence.
+     * @param currentModel Current model.
+     * @return aggregator for partition.
+     */
+    static NewComponentStatisticsAggregator computeNewMeanMap(GmmPartitionData data, double maxXsProb,
+        double maxProbDivergence, GmmModel currentModel) {
+
+        NewComponentStatisticsAggregator adder = new NewComponentStatisticsAggregator();
+        for (int i = 0; i < data.size(); i++) {
+            Vector x = data.getX(i);
+            adder.add(x, currentModel.prob(x) < (maxXsProb / maxProbDivergence));
+        }
+        return adder;
+    }
+
+    /**
+     * Adds vector to statistics.
+     *
+     * @param x Vector from dataset.
+     * @param isAnomaly True if vector is anomaly.
+     */
+    void add(Vector x, boolean isAnomaly) {
+        if (isAnomaly) {
+            if (sumOfAnomalies == null)
+                sumOfAnomalies = x.copy();
+            else
+                sumOfAnomalies = sumOfAnomalies.plus(x);
+
+            rowCountForNewCluster += 1;
+        }
+
+        totalRowCount += 1;
+    }
+
+    /**
+     * Reduce stage for new mean computing.
+     *
+     * @param left Left argument of reduce.
+     * @param right Right argument of reduce.
+     * @return sum of aggregators.
+     */
+    static NewComponentStatisticsAggregator computeNewMeanReduce(NewComponentStatisticsAggregator left,
+        NewComponentStatisticsAggregator right) {
+        A.ensure(left != null || right != null, "left != null || right != null");
+
+        if (left == null)
+            return right;
+        else if (right == null)
+            return left;
+        else
+            return left.plus(right);
+    }
+
+    /**
+     * @param other Other aggregator.
+     * @return sum of aggregators.
+     */
+    NewComponentStatisticsAggregator plus(NewComponentStatisticsAggregator other) {
+        return new NewComponentStatisticsAggregator(
+            totalRowCount + other.totalRowCount,
+            rowCountForNewCluster + other.rowCountForNewCluster,
+            sumOfAnomalies.plus(other.sumOfAnomalies)
+        );
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java
index 4aef8f2..cefd056 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java
@@ -53,7 +53,7 @@ public class DefaultLearningEnvironmentBuilder implements LearningEnvironmentBui
         parallelismStgy = constant(NoParallelismStrategy.INSTANCE);
         loggingFactory = constant(NoOpLogger.factory());
         seed = constant(new Random().nextLong());
-        rngSupplier = constant(new Random());
+        rngSupplier = p -> new Random();
     }
 
     /** {@inheritDoc} */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
index 29bb22f..1b969d3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
@@ -31,9 +31,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
  * @param <C> distributions mixture component class.
  */
 public abstract class DistributionMixture<C extends Distribution> implements Distribution {
-    /** */
-    private static final double EPS = 1e-5;
-
     /** Component probabilities. */
     private final Vector componentProbs;
 
@@ -51,7 +48,7 @@ public abstract class DistributionMixture<C extends Distribution> implements Dis
      */
     public DistributionMixture(Vector componentProbs, List<C> distributions) {
         A.ensure(DoubleStream.of(componentProbs.asArray()).allMatch(v -> v > 0), "All distribution components should be greater than zero");
-        A.ensure(Math.abs(componentProbs.sum() - 1.) < EPS, "Components distribution should be nomalized");
+        componentProbs = componentProbs.divide(componentProbs.sum());
 
         A.ensure(!distributions.isEmpty(), "Distribution mixture should have at least one component");
 
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
index e4554e9..7dc284a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
@@ -61,6 +61,13 @@ public class MatrixUtil {
         return res;
     }
 
+    public static Matrix identity(int n) {
+        DenseMatrix res = new DenseMatrix(n, n);
+        for (int i = 0; i < n; i++)
+            res.set(i, i, 1.0);
+        return res;
+    }
+
     /**
      * Create the like matrix with specified size with read-only matrices support.
      *
@@ -197,10 +204,10 @@ public class MatrixUtil {
     }
 
     /**
-     * Zip two vectors with given tri-function taking as third argument position in vector
-     * (i.e. apply binary function to both vector elementwise and construct vector from results).
-     * Example zipWith({200, 400, 600}, {100, 300, 500}, plusAndMultiplyByIndex) = {(200 + 100) * 0, (400 + 300) * 1, (600 + 500) * 3}.
-     * Length of result is length of shortest of vectors.
+     * Zip two vectors with given tri-function taking as third argument position in vector (i.e. apply binary function
+     * to both vector elementwise and construct vector from results). Example zipWith({200, 400, 600}, {100, 300, 500},
+     * plusAndMultiplyByIndex) = {(200 + 100) * 0, (400 + 300) * 1, (600 + 500) * 3}. Length of result is length of
+     * shortest of vectors.
      *
      * @param v1 First vector.
      * @param v2 Second vector.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
index 345a885..1070efc 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.regressions.logistic;
 
+import java.util.Arrays;
 import org.apache.ignite.ml.composition.CompositionUtils;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -38,8 +39,6 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.jetbrains.annotations.NotNull;
 
-import java.util.Arrays;
-
 /**
  * Trainer of the logistic regression model based on stochastic gradient descent algorithm.
  */
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java
index cae8bef..d22198c 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java
@@ -23,6 +23,7 @@ import org.apache.ignite.ml.clustering.gmm.GmmPartitionDataTest;
 import org.apache.ignite.ml.clustering.gmm.GmmTrainerIntegrationTest;
 import org.apache.ignite.ml.clustering.gmm.GmmTrainerTest;
 import org.apache.ignite.ml.clustering.gmm.MeanWithClusterProbAggregatorTest;
+import org.apache.ignite.ml.clustering.gmm.NewComponentStatisticsAggregatorTest;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
@@ -41,7 +42,8 @@ import org.junit.runners.Suite;
     GmmPartitionDataTest.class,
     MeanWithClusterProbAggregatorTest.class,
     GmmTrainerTest.class,
-    GmmTrainerIntegrationTest.class
+    GmmTrainerIntegrationTest.class,
+    NewComponentStatisticsAggregatorTest.class
 })
 public class ClusteringTestSuite {
 }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java
index e6307e1..7a3043d 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java
@@ -117,7 +117,7 @@ public class MeanWithClusterProbAggregatorTest {
             }
         );
 
-        List<MeanWithClusterProbAggregator> res = MeanWithClusterProbAggregator.map(data);
+        List<MeanWithClusterProbAggregator> res = MeanWithClusterProbAggregator.map(data, 2);
         assertEquals(2, res.size());
 
         MeanWithClusterProbAggregator agg1 = res.get(0);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregatorTest.java
new file mode 100644
index 0000000..83a4f17
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregatorTest.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.apache.ignite.ml.clustering.gmm.NewComponentStatisticsAggregator.computeNewMeanMap;
+import static org.apache.ignite.ml.clustering.gmm.NewComponentStatisticsAggregator.computeNewMeanReduce;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link NewComponentStatisticsAggregator} class.
+ */
+public class NewComponentStatisticsAggregatorTest {
+    /** */
+    GmmPartitionData data1 = new GmmPartitionData(
+        Arrays.asList(
+            vec(1, 0),
+            vec(0, 1),
+            vec(3, 7)
+        ),
+        new double[3][]
+    );
+
+    /** */
+    GmmPartitionData data2 = new GmmPartitionData(
+        Arrays.asList(
+            vec(3, 1),
+            vec(1, 4),
+            vec(1, 3)
+        ),
+        new double[3][]
+    );
+
+    /** */
+    GmmModel model;
+
+    /** */
+    @Before
+    public void before() {
+        model = mock(GmmModel.class);
+        when(model.prob(data1.getX(0))).thenReturn(0.1);
+        when(model.prob(data1.getX(1))).thenReturn(0.4);
+        when(model.prob(data1.getX(2))).thenReturn(0.9);
+
+        when(model.prob(data2.getX(0))).thenReturn(0.2);
+        when(model.prob(data2.getX(1))).thenReturn(0.6);
+        when(model.prob(data2.getX(2))).thenReturn(0.1);
+    }
+
+    /** */
+    @Test
+    public void testAdd() {
+        NewComponentStatisticsAggregator agg = new NewComponentStatisticsAggregator();
+        int rowCount = 10;
+        for (int i = 0; i < rowCount; i++)
+            agg.add(VectorUtils.of(0, 1, 2), i % 2 == 0);
+
+        assertEquals(rowCount / 2, agg.rowCountForNewCluster());
+        assertEquals(rowCount, agg.totalRowCount());
+        assertArrayEquals(new double[] {0, 1, 2}, agg.mean().asArray(), 1e-4);
+    }
+
+    /** */
+    @Test
+    public void testPlus() {
+        NewComponentStatisticsAggregator agg1 = new NewComponentStatisticsAggregator();
+        NewComponentStatisticsAggregator agg2 = new NewComponentStatisticsAggregator();
+        int rowCount = 10;
+        for (int i = 0; i < rowCount; i++)
+            agg1.add(VectorUtils.of(0, 1, 2), i % 2 == 0);
+
+        for (int i = 0; i < rowCount; i++)
+            agg2.add(VectorUtils.of(2, 1, 0), i % 2 == 1);
+
+        NewComponentStatisticsAggregator sum = agg1.plus(agg2);
+        assertEquals(rowCount, sum.rowCountForNewCluster());
+        assertEquals(rowCount * 2, sum.totalRowCount());
+        assertArrayEquals(new double[] {1, 1, 1}, sum.mean().asArray(), 1e-4);
+    }
+
+    /** */
+    @Test
+    public void testMap() {
+        NewComponentStatisticsAggregator agg = computeNewMeanMap(data1, 1.0, 2, model);
+
+        assertEquals(2, agg.rowCountForNewCluster());
+        assertEquals(data1.size(), agg.totalRowCount());
+        assertArrayEquals(new double[] {0.5, 0.5}, agg.mean().asArray(), 1e-4);
+    }
+
+    /** */
+    @Test
+    public void testReduce() {
+        double maxXsProb = 1.0;
+        int maxProbDivergence = 2;
+        NewComponentStatisticsAggregator agg1 = computeNewMeanMap(data1, maxXsProb, maxProbDivergence, model);
+        NewComponentStatisticsAggregator agg2 = computeNewMeanMap(data2, maxXsProb, maxProbDivergence, model);
+
+        NewComponentStatisticsAggregator res = computeNewMeanReduce(agg1, null);
+        assertEquals(agg1.rowCountForNewCluster(), res.rowCountForNewCluster());
+        assertEquals(agg1.totalRowCount(), res.totalRowCount());
+        assertArrayEquals(agg1.mean().asArray(), res.mean().asArray(), 1e-4);
+
+        res = computeNewMeanReduce(null, agg1);
+        assertEquals(agg1.rowCountForNewCluster(), res.rowCountForNewCluster());
+        assertEquals(agg1.totalRowCount(), res.totalRowCount());
+        assertArrayEquals(agg1.mean().asArray(), res.mean().asArray(), 1e-4);
+
+        res = computeNewMeanReduce(agg2, agg1);
+        assertEquals(4, res.rowCountForNewCluster());
+        assertEquals(6, res.totalRowCount());
+        assertArrayEquals(new double[] {1.25, 1.25}, res.mean().asArray(), 1e-4);
+
+        res = computeNewMeanReduce(agg1, agg2);
+        assertEquals(4, res.rowCountForNewCluster());
+        assertEquals(6, res.totalRowCount());
+        assertArrayEquals(new double[] {1.25, 1.25}, res.mean().asArray(), 1e-4);
+    }
+
+    /** */
+    private LabeledVector<Double> vec(double... values) {
+        return new LabeledVector<>(VectorUtils.of(values), 1.0);
+    }
+}