You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/01 05:02:49 UTC

[flink-ml] branch master updated: [FLINK-29824] Stop clustering when size of the active clusters is one

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 135763c  [FLINK-29824] Stop clustering when size of the active clusters is one
135763c is described below

commit 135763c195bb4724952222467b63aea5b65c525b
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Tue Nov 1 13:02:45 2022 +0800

    [FLINK-29824] Stop clustering when size of the active clusters is one
    
    This closes #169.
---
 .../AgglomerativeClustering.java                        |  6 ++++--
 .../ml/clustering/AgglomerativeClusteringTest.java      | 17 +++++++++++++++++
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java
index cc0750a..60f5efc 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java
@@ -258,12 +258,12 @@ public class AgglomerativeClustering
         private void doClustering(
                 List<Cluster> activeClusters,
                 ProcessAllWindowFunction<Row, Row, ?>.Context context) {
-            int clusterOffset1 = -1, clusterOffset2 = -1;
             boolean clusteringRunning =
                     (numCluster != null && activeClusters.size() > numCluster)
                             || (distanceThreshold != null);
 
             while (clusteringRunning || (computeFullTree && activeClusters.size() > 1)) {
+                int clusterOffset1 = -1, clusterOffset2 = -1;
                 // Computes the distance between two clusters.
                 double minDistance = Double.MAX_VALUE;
                 for (int i = 0; i < activeClusters.size(); i++) {
@@ -309,7 +309,9 @@ public class AgglomerativeClustering
 
                 clusteringRunning =
                         (numCluster != null && activeClusters.size() > numCluster)
-                                || (distanceThreshold != null && distanceThreshold > minDistance);
+                                || (distanceThreshold != null
+                                        && distanceThreshold > minDistance
+                                        && activeClusters.size() > 1);
             }
         }
 
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
index fdc7eb2..99f5ca3 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
@@ -239,6 +239,23 @@ public class AgglomerativeClusteringTest extends AbstractTestBase {
                 agglomerativeClustering.getPredictionCol());
     }
 
+    @Test
+    public void testLargeDistanceThreshold() throws Exception {
+        AgglomerativeClustering agglomerativeClustering =
+                new AgglomerativeClustering()
+                        .setNumClusters(null)
+                        .setDistanceThreshold(Double.MAX_VALUE);
+        Table output = agglomerativeClustering.transform(inputDataTable)[0];
+        HashSet<Integer> clusterIds = new HashSet<>();
+        tEnv.toDataStream(output)
+                .executeAndCollect()
+                .forEachRemaining(
+                        x ->
+                                clusterIds.add(
+                                        x.getFieldAs(agglomerativeClustering.getPredictionCol())));
+        assertEquals(1, clusterIds.size());
+    }
+
     @Test
     public void testTransformWithCountTumblingWindows() throws Exception {
         env.setParallelism(1);