You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/18 10:20:13 UTC
[ignite] branch master updated: IGNITE-10546: [ML] GMM with adding
and removal of components
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 6db24b0 IGNITE-10546: [ML] GMM with adding and removal of components
6db24b0 is described below
commit 6db24b0148c4b394b074973ee9425a2af9aa6213
Author: Alexey Platonov <ap...@gmail.com>
AuthorDate: Mon Feb 18 13:19:28 2019 +0300
IGNITE-10546: [ML] GMM with adding and removal of components
This closes #6113
---
.../ml/clustering/GmmClusterizationExample.java | 16 +-
.../ignite/ml/clustering/gmm/GmmPartitionData.java | 45 ++++-
.../ignite/ml/clustering/gmm/GmmTrainer.java | 223 ++++++++++++++++-----
.../gmm/MeanWithClusterProbAggregator.java | 13 +-
.../gmm/NewComponentStatisticsAggregator.java | 169 ++++++++++++++++
.../DefaultLearningEnvironmentBuilder.java | 2 +-
.../ignite/ml/math/stat/DistributionMixture.java | 5 +-
.../org/apache/ignite/ml/math/util/MatrixUtil.java | 15 +-
.../logistic/LogisticRegressionSGDTrainer.java | 3 +-
.../ignite/ml/clustering/ClusteringTestSuite.java | 4 +-
.../gmm/MeanWithClusterProbAggregatorTest.java | 2 +-
.../gmm/NewComponentStatisticsAggregatorTest.java | 147 ++++++++++++++
12 files changed, 566 insertions(+), 78 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
index c15b839..d9f03c1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
@@ -17,6 +17,7 @@
package org.apache.ignite.examples.ml.clustering;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -34,8 +35,6 @@ import org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProduce
import org.apache.ignite.ml.util.generators.primitives.scalar.RandomProducer;
import org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily;
-import java.util.UUID;
-
/**
* Example of using GMM clusterization algorithm. Gaussian Mixture Algorithm (GMM, see {@link GmmModel}, {@link
* GmmTrainer}) can be used for input dataset data distribution representation as mixture of multivariance gaussians.
@@ -57,8 +56,8 @@ public class GmmClusterizationExample {
System.out.println(">>> Ignite grid started.");
long seed = 0;
- IgniteCache<UUID, LabeledVector<Double>> dataCache = ignite.getOrCreateCache(
- new CacheConfiguration<UUID, LabeledVector<Double>>("GMM_EXAMPLE_CACHE")
+ IgniteCache<Integer, LabeledVector<Double>> dataCache = ignite.getOrCreateCache(
+ new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE")
.setAffinity(new RendezvousAffinityFunction(false, 10))
);
@@ -78,11 +77,14 @@ public class GmmClusterizationExample {
).move(VectorUtils.of(0., -10.))
).build(seed++).asDataStream();
- dataStream.fillCacheWithVecUUIDAsKey(50000, dataCache);
- GmmTrainer trainer = new GmmTrainer(3);
+ AtomicInteger keyGen = new AtomicInteger();
+ dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
+ GmmTrainer trainer = new GmmTrainer(1);
GmmModel mdl = trainer
- .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed++))
+ .withMaxCountIterations(10)
+ .withMaxCountOfClusters(4)
+ .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed))
.fit(ignite, dataCache, (k, v) -> v.features(), (k, v) -> v.label());
System.out.println(">>> GMM means and covariances");
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
index 942c511..1b8e50c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
@@ -22,6 +22,7 @@ import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
@@ -95,13 +96,6 @@ class GmmPartitionData implements AutoCloseable {
return pcxi.length;
}
- /**
- * @return count of GMM components.
- */
- public int countOfComponents() {
- return size() != 0 ? pcxi[0].length : 0;
- }
-
/** {@inheritDoc} */
@Override public void close() throws Exception {
//NOP
@@ -162,6 +156,7 @@ class GmmPartitionData implements AutoCloseable {
Vector x = data.getX(i);
for (int c = 0; c < initMeans.length; c++) {
+ data.setPcxi(c, i, 0.0);
double distance = initMeans[c].getDistanceSquared(x);
if (distance < minSquaredDist) {
closestClusterId = c;
@@ -174,16 +169,40 @@ class GmmPartitionData implements AutoCloseable {
}
/**
+ * Updates P(c|xi) values in partitions and compute dataset likelihood.
+ *
+ * @param dataset Dataset.
+ * @param clusterProbs Component probabilities.
+ * @param components Components.
+ * @return dataset likelihood.
+ */
+ static double updatePcxiAndComputeLikelihood(Dataset<EmptyContext, GmmPartitionData> dataset, Vector clusterProbs,
+ List<MultivariateGaussianDistribution> components) {
+
+ return dataset.compute(
+ data -> updatePcxi(data, clusterProbs, components),
+ (left, right) -> asPrimitive(left) + asPrimitive(right)
+ );
+ }
+
+ /**
* Updates P(c|xi) values in partitions given components probabilities and components of GMM.
*
* @param clusterProbs Component probabilities.
* @param components Components.
*/
- static void updatePcxi(GmmPartitionData data, Vector clusterProbs,
+ static double updatePcxi(GmmPartitionData data, Vector clusterProbs,
List<MultivariateGaussianDistribution> components) {
+ GmmModel model = new GmmModel(clusterProbs, components);
+ double maxProb = Double.NEGATIVE_INFINITY;
+
for (int i = 0; i < data.size(); i++) {
Vector x = data.getX(i);
+ double xProb = model.prob(x);
+ if(xProb > maxProb)
+ maxProb = xProb;
+
double normalizer = 0.0;
for (int c = 0; c < clusterProbs.size(); c++)
normalizer += components.get(c).prob(x) * clusterProbs.get(c);
@@ -191,5 +210,15 @@ class GmmPartitionData implements AutoCloseable {
for (int c = 0; c < clusterProbs.size(); c++)
data.pcxi[i][c] = (components.get(c).prob(x) * clusterProbs.get(c)) / normalizer;
}
+
+ return maxProb;
+ }
+
+ /**
+ * @param val Value.
+ * @return 0 if Value == null and simplified value in terms of type otherwise.
+ */
+ private static double asPrimitive(Double val) {
+ return val == null ? 0.0 : val;
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
index 09e1325..09e93f6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
@@ -17,13 +17,20 @@
package org.apache.ignite.ml.clustering.gmm;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
-import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -36,14 +43,6 @@ import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.jetbrains.annotations.NotNull;
-import java.util.ArrayList;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-import java.util.stream.DoubleStream;
-import java.util.stream.Stream;
-
/**
* Traner for GMM model.
*/
@@ -64,6 +63,20 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
private int maxCountOfInitTries = 3;
/**
+ * Maximum count of clusters that can be achieved.
+ */
+ private int maxCountOfClusters = 2;
+
+ /** Maximum divergence between maximum of likelihood of vector in dataset and other for anomalies identification. */
+ private double maxLikelihoodDivergence = 5;
+
+ /** Minimum required anomalies in terms of maxLikelihoodDivergence for creating new cluster. */
+ private double minElementsForNewCluster = 300;
+
+ /** Min cluster probability. */
+ private double minClusterProbability = 0.05;
+
+ /**
* Creates an instance of GmmTrainer.
*/
public GmmTrainer() {
@@ -101,11 +114,13 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* @param numberOfComponents Number of components.
* @return trainer.
*/
- public GmmTrainer withCountOfComponents(int numberOfComponents) {
+ public GmmTrainer withInitialCountOfComponents(int numberOfComponents) {
A.ensure(numberOfComponents > 0, "Number of components in GMM cannot equal 0");
this.countOfComponents = numberOfComponents;
initialMeans = null;
+ if (countOfComponents > maxCountOfClusters)
+ maxCountOfClusters = countOfComponents;
return this;
}
@@ -120,6 +135,8 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
this.initialMeans = means.toArray(new Vector[means.size()]);
this.countOfComponents = means.size();
+ if (countOfComponents > maxCountOfClusters)
+ maxCountOfClusters = countOfComponents;
return this;
}
@@ -164,13 +181,128 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
/**
+ * Sets maximum number of clusters in GMM.
+ *
+ * @param maxCountOfClusters Max count of clusters.
+ * @return trainer.
+ */
+ public GmmTrainer withMaxCountOfClusters(int maxCountOfClusters) {
+ A.ensure(maxCountOfClusters >= countOfComponents, "Max count of components should be greater than " +
+ "initial count of components or equal to it");
+
+ this.maxCountOfClusters = maxCountOfClusters;
+ return this;
+ }
+
+ /**
+ * Sets maximum divergence between maximum of likelihood of vector in dataset and other for anomalies
+ * identification.
+ *
+ * @param maxLikelihoodDivergence Max likelihood divergence.
+ * @return trainer.
+ */
+ public GmmTrainer withMaxLikelihoodDivergence(double maxLikelihoodDivergence) {
+ A.ensure(maxLikelihoodDivergence > 0, "Max likelihood divergence should be > 0");
+
+ this.maxLikelihoodDivergence = maxLikelihoodDivergence;
+ return this;
+ }
+
+ /**
+ * Sets minimum required anomalies in terms of maxLikelihoodDivergence for creating new cluster.
+ *
+ * @param minElementsForNewCluster Min elements for new cluster.
+ * @return trainer.
+ */
+ public GmmTrainer withMinElementsForNewCluster(int minElementsForNewCluster) {
+ A.ensure(minElementsForNewCluster > 0, "Min elements for new cluster should be > 0");
+
+ this.minElementsForNewCluster = minElementsForNewCluster;
+ return this;
+ }
+
+ /**
+ * Sets minimum requred probability for cluster. If cluster has probability value less than this value then this
+ * cluster will be eliminated.
+ *
+ * @param minClusterProbability Min cluster probability.
+ * @return trainer.
+ */
+ public GmmTrainer withMinClusterProbability(double minClusterProbability) {
+ this.minClusterProbability = minClusterProbability;
+ return this;
+ }
+
+ /**
* Trains model based on the specified data.
*
* @param dataset Dataset.
* @return GMM model.
*/
private Optional<GmmModel> fit(Dataset<EmptyContext, GmmPartitionData> dataset) {
- return init(dataset).map(model -> updateModel(dataset, model));
+ return init(dataset).map(model -> {
+ GmmModel currentModel = model;
+
+ do {
+ UpdateResult updateResult = updateModel(dataset, currentModel);
+ currentModel = updateResult.model;
+
+ double minCompProb = currentModel.componentsProbs().minElement().get();
+ if (countOfComponents >= maxCountOfClusters || minCompProb < minClusterProbability)
+ break;
+
+ double maxXProb = updateResult.maxProbInDataset;
+ NewComponentStatisticsAggregator newMeanAdder = NewComponentStatisticsAggregator.computeNewMean(dataset,
+ maxXProb, maxLikelihoodDivergence, currentModel);
+
+ Vector newMean = newMeanAdder.mean();
+ if (newMeanAdder.rowCountForNewCluster() < minElementsForNewCluster)
+ break;
+
+ countOfComponents += 1;
+ Vector[] newMeans = new Vector[countOfComponents];
+ for (int i = 0; i < currentModel.countOfComponents(); i++)
+ newMeans[i] = currentModel.distributions().get(i).mean();
+ newMeans[countOfComponents - 1] = newMean;
+
+ initialMeans = newMeans;
+
+ Optional<GmmModel> newModelOpt = init(dataset);
+ if (newModelOpt.isPresent())
+ currentModel = newModelOpt.get();
+ else
+ break;
+ }
+ while (true);
+
+ return filterModel(currentModel);
+ });
+ }
+
+ /**
+ * Remove clusters with probability value < minClusterProbability
+ *
+ * @param model Model.
+ * @return filtered model.
+ */
+ private GmmModel filterModel(GmmModel model) {
+ List<Double> componentProbs = new ArrayList<>();
+ List<MultivariateGaussianDistribution> distributions = new ArrayList<>();
+
+ Vector originalComponentProbs = model.componentsProbs();
+ List<MultivariateGaussianDistribution> originalDistr = model.distributions();
+ for (int i = 0; i < model.countOfComponents(); i++) {
+ double prob = originalComponentProbs.get(i);
+ if (prob > minClusterProbability) {
+ componentProbs.add(prob);
+ distributions.add(originalDistr.get(i));
+ }
+ }
+
+ return new GmmModel(
+ VectorUtils.of(componentProbs.toArray(new Double[0])),
+ distributions
+ );
}
/**
@@ -180,11 +312,12 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* @param model Model.
* @return updated model.
*/
- @NotNull private GmmModel updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
+ @NotNull private UpdateResult updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
boolean isConverged = false;
int countOfIterations = 0;
+ double maxProbInDataset = Double.NEGATIVE_INFINITY;
while (!isConverged) {
- MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset);
+ MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset, countOfComponents);
Vector clusterProbs = stats.clusterProbabilities();
Vector[] newMeans = stats.means().toArray(new Vector[countOfComponents]);
@@ -199,9 +332,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
countOfIterations += 1;
isConverged = isConverged(model, newModel) || countOfIterations > maxCountOfIterations;
model = newModel;
-
- if (!isConverged)
- dataset.compute(data -> GmmPartitionData.updatePcxi(data, clusterProbs, components));
+ maxProbInDataset = GmmPartitionData.updatePcxiAndComputeLikelihood(dataset, clusterProbs, components);
}
catch (SingularMatrixException | IllegalArgumentException e) {
String msg = "Cannot construct non-singular covariance matrix by data. " +
@@ -211,7 +342,27 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
}
- return model;
+ return new UpdateResult(model, maxProbInDataset);
+ }
+
+ /**
+ * Result of current model update by EM-algorithm.
+ */
+ private static class UpdateResult {
+ /** Model. */
+ private final GmmModel model;
+
+ /** Max likelihood in dataset. */
+ private final double maxProbInDataset;
+
+ /**
+ * @param model Model.
+ * @param maxProbInDataset Max likelihood in dataset.
+ */
+ public UpdateResult(GmmModel model, double maxProbInDataset) {
+ this.model = model;
+ this.maxProbInDataset = maxProbInDataset;
+ }
}
/**
@@ -226,27 +377,21 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
while (true) {
try {
if (initialMeans == null) {
- List<List<Vector>> randomMeansSets = Stream.of(dataset.compute(
+ List<Vector> randomMeansSets = Stream.of(dataset.compute(
selectNRandomXsMapper(countOfComponents),
GmmTrainer::selectNRandomXsReducer
- )).map(this::asList).collect(Collectors.toList());
+ )).flatMap(Stream::of).collect(Collectors.toList());
+
+ Collections.sort(randomMeansSets, Comparator.comparingDouble(Vector::getLengthSquared));
+ Collections.shuffle(randomMeansSets, environment.randomNumbersGenerator());
A.ensure(
- randomMeansSets.stream().mapToInt(List::size).sum() >= countOfComponents,
+ randomMeansSets.size() >= countOfComponents,
"There is not enough data in dataset for select N random means"
);
- initialMeans = new Vector[countOfComponents];
- int j = 0;
- for (int i = 0; i < countOfComponents; ) {
- List<Vector> randomMeansPart = randomMeansSets.get(j);
- if (!randomMeansPart.isEmpty()) {
- initialMeans[i] = randomMeansPart.remove(0);
- i++;
- }
-
- j = (j + 1) % randomMeansSets.size();
- }
+ initialMeans = randomMeansSets.subList(0, countOfComponents)
+ .toArray(new Vector[countOfComponents]);
}
dataset.compute(data -> GmmPartitionData.estimateLikelihoodClusters(data, initialMeans));
@@ -282,17 +427,6 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
/**
- * @param vectors Array of vectors.
- * @return list of vectors.
- */
- private LinkedList<Vector> asList(Vector... vectors) {
- LinkedList<Vector> res = new LinkedList<>();
- for (Vector v : vectors)
- res.addFirst(v);
- return res;
- }
-
- /**
* Create new model components with provided means and covariances.
*
* @param means Means.
@@ -340,10 +474,9 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
@Override protected <K, V> GmmModel updateModel(GmmModel mdl, DatasetBuilder<K, V> datasetBuilder,
FeatureLabelExtractor<K, V, Double> extractor) {
- try (Dataset<EmptyContext, GmmPartitionData> dataset = datasetBuilder.build(
- LearningEnvironmentBuilder.defaultBuilder(),
+ try (Dataset<EmptyContext, GmmPartitionData> dataset = datasetBuilder.build(envBuilder,
new EmptyContextBuilder<>(),
- new GmmPartitionData.Builder<>(extractor, countOfComponents)
+ new GmmPartitionData.Builder<>(extractor, maxCountOfClusters)
)) {
if (mdl != null) {
if (initialMeans != null)
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
index 58044a7..99e60ba 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
@@ -47,6 +47,7 @@ class MeanWithClusterProbAggregator implements Serializable {
* Create an instance of MeanWithClusterProbAggregator.
*/
MeanWithClusterProbAggregator() {
+ // NO-OP.
}
/**
@@ -80,10 +81,11 @@ class MeanWithClusterProbAggregator implements Serializable {
* Aggregates statistics for means and cluster probabilities computing given dataset.
*
* @param dataset Dataset.
+ * @param countOfComponents
*/
- public static AggregatedStats aggreateStats(Dataset<EmptyContext, GmmPartitionData> dataset) {
+ public static AggregatedStats aggreateStats(Dataset<EmptyContext, GmmPartitionData> dataset, int countOfComponents) {
return new AggregatedStats(dataset.compute(
- MeanWithClusterProbAggregator::map,
+ data -> map(data, countOfComponents),
MeanWithClusterProbAggregator::reduce
));
}
@@ -123,15 +125,16 @@ class MeanWithClusterProbAggregator implements Serializable {
* Map stage for statistics aggregation.
*
* @param data Partition data.
+ * @param countOfComponents Count of components.
* @return Aggregated statistics.
*/
- static List<MeanWithClusterProbAggregator> map(GmmPartitionData data) {
+ static List<MeanWithClusterProbAggregator> map(GmmPartitionData data, int countOfComponents) {
List<MeanWithClusterProbAggregator> aggregators = new ArrayList<>();
- for (int i = 0; i < data.countOfComponents(); i++)
+ for (int i = 0; i < countOfComponents; i++)
aggregators.add(new MeanWithClusterProbAggregator());
for (int i = 0; i < data.size(); i++) {
- for (int c = 0; c < data.countOfComponents(); c++)
+ for (int c = 0; c < countOfComponents; c++)
aggregators.get(c).add(data.getX(i), data.pcxi(c, i));
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java
new file mode 100644
index 0000000..4fa5406
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.io.Serializable;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Class for aggregate statistics for finding new mean for GMM.
+ */
+public class NewComponentStatisticsAggregator implements Serializable {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 6748270328889375005L;
+
+ /** Total row count in dataset. */
+ private long totalRowCount;
+
+ /** Row count for new cluster. */
+ private long rowCountForNewCluster;
+
+ /** Sum of anomalies vectors. */
+ private Vector sumOfAnomalies;
+
+ /**
+ * Creates an instance of NewComponentStatisticsAggregator.
+ *
+ * @param totalRowCount Total row count in dataset.
+ * @param rowCountForNewCluster Row count for new cluster.
+ * @param sumOfAnomalies Sum of anomalies.
+ */
+ public NewComponentStatisticsAggregator(long totalRowCount, long rowCountForNewCluster, Vector sumOfAnomalies) {
+ this.totalRowCount = totalRowCount;
+ this.rowCountForNewCluster = rowCountForNewCluster;
+ this.sumOfAnomalies = sumOfAnomalies;
+ }
+
+ /**
+ * Creates an instance of NewComponentStatisticsAggregator.
+ */
+ public NewComponentStatisticsAggregator() {
+ }
+
+ /**
+ * @return Mean of anomalies.
+ */
+ public Vector mean() {
+ return sumOfAnomalies.divide(rowCountForNewCluster);
+ }
+
+ /**
+ * @return Row count for new cluster.
+ */
+ public long rowCountForNewCluster() {
+ return rowCountForNewCluster;
+ }
+
+ /**
+ * @return Total count of rows in partition/dataset.
+ */
+ public long totalRowCount() {
+ return totalRowCount;
+ }
+
+ /**
+ * Compute statistics for new mean for GMM.
+ *
+ * @param dataset Dataset.
+ * @param maxXsProb Max likelihood between all xs.
+ * @param maxProbDivergence Max probability divergence between maximum value and others.
+ * @param currentModel Current model.
+ * @return aggregated statistics for new mean.
+ */
+ static NewComponentStatisticsAggregator computeNewMean(Dataset<EmptyContext, GmmPartitionData> dataset,
+ double maxXsProb, double maxProbDivergence, GmmModel currentModel) {
+
+ return dataset.compute(
+ data -> computeNewMeanMap(data, maxXsProb, maxProbDivergence, currentModel),
+ NewComponentStatisticsAggregator::computeNewMeanReduce
+ );
+ }
+
+ /**
+ * Map stage for new mean computing.
+ *
+ * @param data Data.
+ * @param maxXsProb Max xs prob.
+ * @param maxProbDivergence Max prob divergence.
+ * @param currentModel Current model.
+ * @return aggregator for partition.
+ */
+ static NewComponentStatisticsAggregator computeNewMeanMap(GmmPartitionData data, double maxXsProb,
+ double maxProbDivergence, GmmModel currentModel) {
+
+ NewComponentStatisticsAggregator adder = new NewComponentStatisticsAggregator();
+ for (int i = 0; i < data.size(); i++) {
+ Vector x = data.getX(i);
+ adder.add(x, currentModel.prob(x) < (maxXsProb / maxProbDivergence));
+ }
+ return adder;
+ }
+
+ /**
+ * Adds vector to statistics.
+ *
+ * @param x Vector from dataset.
+ * @param isAnomaly True if vector is anomaly.
+ */
+ void add(Vector x, boolean isAnomaly) {
+ if (isAnomaly) {
+ if (sumOfAnomalies == null)
+ sumOfAnomalies = x.copy();
+ else
+ sumOfAnomalies = sumOfAnomalies.plus(x);
+
+ rowCountForNewCluster += 1;
+ }
+
+ totalRowCount += 1;
+ }
+
+ /**
+ * Reduce stage for new mean computing.
+ *
+ * @param left Left argument of reduce.
+ * @param right Right argument of reduce.
+ * @return sum of aggregators.
+ */
+ static NewComponentStatisticsAggregator computeNewMeanReduce(NewComponentStatisticsAggregator left,
+ NewComponentStatisticsAggregator right) {
+ A.ensure(left != null || right != null, "left != null || right != null");
+
+ if (left == null)
+ return right;
+ else if (right == null)
+ return left;
+ else
+ return left.plus(right);
+ }
+
+ /**
+ * @param other Other aggregator.
+ * @return sum of aggregators.
+ */
+ NewComponentStatisticsAggregator plus(NewComponentStatisticsAggregator other) {
+ return new NewComponentStatisticsAggregator(
+ totalRowCount + other.totalRowCount,
+ rowCountForNewCluster + other.rowCountForNewCluster,
+ sumOfAnomalies.plus(other.sumOfAnomalies)
+ );
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java
index 4aef8f2..cefd056 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java
@@ -53,7 +53,7 @@ public class DefaultLearningEnvironmentBuilder implements LearningEnvironmentBui
parallelismStgy = constant(NoParallelismStrategy.INSTANCE);
loggingFactory = constant(NoOpLogger.factory());
seed = constant(new Random().nextLong());
- rngSupplier = constant(new Random());
+ rngSupplier = p -> new Random();
}
/** {@inheritDoc} */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
index 29bb22f..1b969d3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
@@ -31,9 +31,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
* @param <C> distributions mixture component class.
*/
public abstract class DistributionMixture<C extends Distribution> implements Distribution {
- /** */
- private static final double EPS = 1e-5;
-
/** Component probabilities. */
private final Vector componentProbs;
@@ -51,7 +48,7 @@ public abstract class DistributionMixture<C extends Distribution> implements Dis
*/
public DistributionMixture(Vector componentProbs, List<C> distributions) {
A.ensure(DoubleStream.of(componentProbs.asArray()).allMatch(v -> v > 0), "All distribution components should be greater than zero");
- A.ensure(Math.abs(componentProbs.sum() - 1.) < EPS, "Components distribution should be nomalized");
+ componentProbs = componentProbs.divide(componentProbs.sum());
A.ensure(!distributions.isEmpty(), "Distribution mixture should have at least one component");
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
index e4554e9..7dc284a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java
@@ -61,6 +61,13 @@ public class MatrixUtil {
return res;
}
+ public static Matrix identity(int n) {
+ DenseMatrix res = new DenseMatrix(n, n);
+ for (int i = 0; i < n; i++)
+ res.set(i, i, 1.0);
+ return res;
+ }
+
/**
* Create the like matrix with specified size with read-only matrices support.
*
@@ -197,10 +204,10 @@ public class MatrixUtil {
}
/**
- * Zip two vectors with given tri-function taking as third argument position in vector
- * (i.e. apply binary function to both vector elementwise and construct vector from results).
- * Example zipWith({200, 400, 600}, {100, 300, 500}, plusAndMultiplyByIndex) = {(200 + 100) * 0, (400 + 300) * 1, (600 + 500) * 3}.
- * Length of result is length of shortest of vectors.
+ * Zip two vectors with given tri-function taking as third argument position in vector (i.e. apply binary function
+ * to both vector elementwise and construct vector from results). Example zipWith({200, 400, 600}, {100, 300, 500},
+ * plusAndMultiplyByIndex) = {(200 + 100) * 0, (400 + 300) * 1, (600 + 500) * 3}. Length of result is length of
+ * shortest of vectors.
*
* @param v1 First vector.
* @param v2 Second vector.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
index 345a885..1070efc 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
@@ -17,6 +17,7 @@
package org.apache.ignite.ml.regressions.logistic;
+import java.util.Arrays;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -38,8 +39,6 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;
-import java.util.Arrays;
-
/**
* Trainer of the logistic regression model based on stochastic gradient descent algorithm.
*/
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java
index cae8bef..d22198c 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java
@@ -23,6 +23,7 @@ import org.apache.ignite.ml.clustering.gmm.GmmPartitionDataTest;
import org.apache.ignite.ml.clustering.gmm.GmmTrainerIntegrationTest;
import org.apache.ignite.ml.clustering.gmm.GmmTrainerTest;
import org.apache.ignite.ml.clustering.gmm.MeanWithClusterProbAggregatorTest;
+import org.apache.ignite.ml.clustering.gmm.NewComponentStatisticsAggregatorTest;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -41,7 +42,8 @@ import org.junit.runners.Suite;
GmmPartitionDataTest.class,
MeanWithClusterProbAggregatorTest.class,
GmmTrainerTest.class,
- GmmTrainerIntegrationTest.class
+ GmmTrainerIntegrationTest.class,
+ NewComponentStatisticsAggregatorTest.class
})
public class ClusteringTestSuite {
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java
index e6307e1..7a3043d 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java
@@ -117,7 +117,7 @@ public class MeanWithClusterProbAggregatorTest {
}
);
- List<MeanWithClusterProbAggregator> res = MeanWithClusterProbAggregator.map(data);
+ List<MeanWithClusterProbAggregator> res = MeanWithClusterProbAggregator.map(data, 2);
assertEquals(2, res.size());
MeanWithClusterProbAggregator agg1 = res.get(0);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregatorTest.java
new file mode 100644
index 0000000..83a4f17
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregatorTest.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.apache.ignite.ml.clustering.gmm.NewComponentStatisticsAggregator.computeNewMeanMap;
+import static org.apache.ignite.ml.clustering.gmm.NewComponentStatisticsAggregator.computeNewMeanReduce;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link NewComponentStatisticsAggregator} class.
+ */
+public class NewComponentStatisticsAggregatorTest {
+ /** */
+ GmmPartitionData data1 = new GmmPartitionData(
+ Arrays.asList(
+ vec(1, 0),
+ vec(0, 1),
+ vec(3, 7)
+ ),
+ new double[3][]
+ );
+
+ /** */
+ GmmPartitionData data2 = new GmmPartitionData(
+ Arrays.asList(
+ vec(3, 1),
+ vec(1, 4),
+ vec(1, 3)
+ ),
+ new double[3][]
+ );
+
+ /** */
+ GmmModel model;
+
+ /** */
+ @Before
+ public void before() {
+ model = mock(GmmModel.class);
+ when(model.prob(data1.getX(0))).thenReturn(0.1);
+ when(model.prob(data1.getX(1))).thenReturn(0.4);
+ when(model.prob(data1.getX(2))).thenReturn(0.9);
+
+ when(model.prob(data2.getX(0))).thenReturn(0.2);
+ when(model.prob(data2.getX(1))).thenReturn(0.6);
+ when(model.prob(data2.getX(2))).thenReturn(0.1);
+ }
+
+ /** */
+ @Test
+ public void testAdd() {
+ NewComponentStatisticsAggregator agg = new NewComponentStatisticsAggregator();
+ int rowCount = 10;
+ for (int i = 0; i < rowCount; i++)
+ agg.add(VectorUtils.of(0, 1, 2), i % 2 == 0);
+
+ assertEquals(rowCount / 2, agg.rowCountForNewCluster());
+ assertEquals(rowCount, agg.totalRowCount());
+ assertArrayEquals(new double[] {0, 1, 2}, agg.mean().asArray(), 1e-4);
+ }
+
+ /** */
+ @Test
+ public void testPlus() {
+ NewComponentStatisticsAggregator agg1 = new NewComponentStatisticsAggregator();
+ NewComponentStatisticsAggregator agg2 = new NewComponentStatisticsAggregator();
+ int rowCount = 10;
+ for (int i = 0; i < rowCount; i++)
+ agg1.add(VectorUtils.of(0, 1, 2), i % 2 == 0);
+
+ for (int i = 0; i < rowCount; i++)
+ agg2.add(VectorUtils.of(2, 1, 0), i % 2 == 1);
+
+ NewComponentStatisticsAggregator sum = agg1.plus(agg2);
+ assertEquals(rowCount, sum.rowCountForNewCluster());
+ assertEquals(rowCount * 2, sum.totalRowCount());
+ assertArrayEquals(new double[] {1, 1, 1}, sum.mean().asArray(), 1e-4);
+ }
+
+ /** */
+ @Test
+ public void testMap() {
+ NewComponentStatisticsAggregator agg = computeNewMeanMap(data1, 1.0, 2, model);
+
+ assertEquals(2, agg.rowCountForNewCluster());
+ assertEquals(data1.size(), agg.totalRowCount());
+ assertArrayEquals(new double[] {0.5, 0.5}, agg.mean().asArray(), 1e-4);
+ }
+
+ /** */
+ @Test
+ public void testReduce() {
+ double maxXsProb = 1.0;
+ int maxProbDivergence = 2;
+ NewComponentStatisticsAggregator agg1 = computeNewMeanMap(data1, maxXsProb, maxProbDivergence, model);
+ NewComponentStatisticsAggregator agg2 = computeNewMeanMap(data2, maxXsProb, maxProbDivergence, model);
+
+ NewComponentStatisticsAggregator res = computeNewMeanReduce(agg1, null);
+ assertEquals(agg1.rowCountForNewCluster(), res.rowCountForNewCluster());
+ assertEquals(agg1.totalRowCount(), res.totalRowCount());
+ assertArrayEquals(agg1.mean().asArray(), res.mean().asArray(), 1e-4);
+
+ res = computeNewMeanReduce(null, agg1);
+ assertEquals(agg1.rowCountForNewCluster(), res.rowCountForNewCluster());
+ assertEquals(agg1.totalRowCount(), res.totalRowCount());
+ assertArrayEquals(agg1.mean().asArray(), res.mean().asArray(), 1e-4);
+
+ res = computeNewMeanReduce(agg2, agg1);
+ assertEquals(4, res.rowCountForNewCluster());
+ assertEquals(6, res.totalRowCount());
+ assertArrayEquals(new double[] {1.25, 1.25}, res.mean().asArray(), 1e-4);
+
+ res = computeNewMeanReduce(agg1, agg2);
+ assertEquals(4, res.rowCountForNewCluster());
+ assertEquals(6, res.totalRowCount());
+ assertArrayEquals(new double[] {1.25, 1.25}, res.mean().asArray(), 1e-4);
+ }
+
+ /** */
+ private LabeledVector<Double> vec(double... values) {
+ return new LabeledVector<>(VectorUtils.of(values), 1.0);
+ }
+}