You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/09/02 12:33:14 UTC
[flink-ml] branch master updated: [FLINK-28906] Add AlgoOperator for AgglomerativeClustering
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 061be95 [FLINK-28906] Add AlgoOperator for AgglomerativeClustering
061be95 is described below
commit 061be9566aa78581c6356a92debbe72e783dc215
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Fri Sep 2 20:33:09 2022 +0800
[FLINK-28906] Add AlgoOperator for AgglomerativeClustering
This closes #148.
---
.../clustering/agglomerativeclustering.md | 181 +++++++++
docs/content/docs/operators/clustering/kmeans.md | 36 +-
...anceMeasure.java => CosineDistanceMeasure.java} | 37 +-
.../flink/ml/common/distance/DistanceMeasure.java | 31 +-
.../common/distance/EuclideanDistanceMeasure.java | 2 +-
...eMeasure.java => ManhattanDistanceMeasure.java} | 38 +-
.../ml/common/distance/DistanceMeasureTest.java | 76 ++++
.../clustering/AgglomerativeClusteringExample.java | 70 ++++
.../AgglomerativeClustering.java | 421 +++++++++++++++++++++
.../AgglomerativeClusteringParams.java | 108 ++++++
.../flink/ml/common/param/HasDistanceMeasure.java | 7 +-
.../ml/clustering/AgglomerativeClusteringTest.java | 302 +++++++++++++++
.../clustering/agglomerativeclustering_example.py | 84 ++++
.../pyflink/ml/core/tests/test_param.py | 3 +-
.../ml/lib/clustering/agglomerativeclustering.py | 136 +++++++
.../pyflink/ml/lib/clustering/common.py | 28 +-
.../tests/test_agglomerativeclustering.py | 213 +++++++++++
flink-ml-python/pyflink/ml/lib/param.py | 4 +-
.../ml/lib/tests/test_ml_lib_completeness.py | 7 +-
19 files changed, 1710 insertions(+), 74 deletions(-)
diff --git a/docs/content/docs/operators/clustering/agglomerativeclustering.md b/docs/content/docs/operators/clustering/agglomerativeclustering.md
new file mode 100644
index 0000000..1803c87
--- /dev/null
+++ b/docs/content/docs/operators/clustering/agglomerativeclustering.md
@@ -0,0 +1,181 @@
+---
+title: "AgglomerativeClustering"
+type: docs
+aliases:
+- /operators/clustering/agglomerativeclustering.html
+---
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## AgglomerativeClustering
+
+AgglomerativeClustering performs a hierarchical clustering
+using a bottom-up approach. Each observation starts in its
+own cluster and the clusters are merged together one by one.
+
+The output contains two tables. The first one assigns one
+cluster Id for each data point. The second one contains the
+information of merging two clusters at each step. The data
+format of the merging information is
+(clusterId1, clusterId2, distance, sizeOfMergedCluster).
+
+### Input Columns
+
+| Param name | Type | Default | Description |
+|:------------|:-------|:-------------|:----------------|
+| featuresCol | Vector | `"features"` | Feature vector. |
+
+### Output Columns
+
+| Param name | Type | Default | Description |
+|:--------------|:--------|:---------------|:--------------------------|
+| predictionCol | Integer | `"prediction"` | Predicted cluster center. |
+
+### Parameters
+
+| Key | Default | Type | Required | Description |
+|:------------------|:---------------|:--------|:---------|:-----------------------------------------------------------|
+| numClusters | `2` | Integer | no | The max number of clusters to create. |
+| distanceThreshold | `null` | Double | no | Threshold to decide whether two clusters should be merged. |
+| linkage | `"ward"` | String | no | Criterion for computing distance between two clusters. |
+| computeFullTree | `false` | Boolean | no | Whether computes the full tree after convergence. |
+| distanceMeasure | `"euclidean"` | String | no | Distance measure. |
+| featuresCol | `"features"` | String | no | Features column name. |
+| predictionCol | `"prediction"` | String | no | Prediction column name. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+```java
+import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
+import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates an AgglomerativeClustering instance and uses it for clustering. */
+public class AgglomerativeClusteringExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<DenseVector> inputStream =
+ env.fromElements(
+ Vectors.dense(1, 1),
+ Vectors.dense(1, 4),
+ Vectors.dense(1, 0),
+ Vectors.dense(4, 1.5),
+ Vectors.dense(4, 4),
+ Vectors.dense(4, 0));
+ Table inputTable = tEnv.fromDataStream(inputStream).as("features");
+
+ // Creates an AgglomerativeClustering object and initializes its parameters.
+ AgglomerativeClustering agglomerativeClustering =
+ new AgglomerativeClustering()
+ .setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD)
+ .setDistanceMeasure(EuclideanDistanceMeasure.NAME)
+ .setPredictionCol("prediction");
+
+ // Uses the AgglomerativeClustering object for clustering.
+ Table[] outputs = agglomerativeClustering.transform(inputTable);
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputs[0].execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector features =
+ (DenseVector) row.getField(agglomerativeClustering.getFeaturesCol());
+ int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol());
+ System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId);
+ }
+ }
+}
+
+```
+{{< /tab>}}
+
+{{< tab "Python">}}
+```python
+# Simple program that creates a Bucketizer instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.clustering.agglomerativeclustering import AgglomerativeClustering
+from pyflink.table import StreamTableEnvironment
+from matplotlib import pyplot as plt
+from scipy.cluster.hierarchy import dendrogram
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input data.
+input_data = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense([1, 1]),),
+ (Vectors.dense([1, 4]),),
+ (Vectors.dense([1, 0]),),
+ (Vectors.dense([4, 1.5]),),
+ (Vectors.dense([4, 4]),),
+ (Vectors.dense([4, 0]),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['features'],
+ [DenseVectorTypeInfo()])))
+
+# Creates an AgglomerativeClustering object and initializes its parameters.
+agglomerative_clustering = AgglomerativeClustering()
+ .set_linkage('ward')
+ .set_distance_measure('euclidean')
+ .set_prediction_col('prediction')
+
+# Uses the AgglomerativeClustering for clustering.
+outputs = agglomerative_clustering.transform(input_data)
+
+# Extracts and display the clustering results.
+field_names = outputs[0].get_schema().get_field_names()
+for result in t_env.to_data_stream(outputs[0]).execute_and_collect():
+ features = result[field_names.index(agglomerative_clustering.features_col)]
+ cluster_id = result[field_names.index(agglomerative_clustering.prediction_col)]
+ print('Features: ' + str(features) + '\tCluster ID: ' + str(cluster_id))
+
+# Visualizes the merge info.
+merge_info = [result for result in
+ t_env.to_data_stream(outputs[1]).execute_and_collect()]
+plt.title("Agglomerative Clustering Dendrogram")
+dendrogram(merge_info)
+plt.xlabel("Index of data point.")
+plt.ylabel("Distances between merged clusters.")
+plt.show()
+```
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/docs/content/docs/operators/clustering/kmeans.md b/docs/content/docs/operators/clustering/kmeans.md
index 0883729..f5c4bd2 100644
--- a/docs/content/docs/operators/clustering/kmeans.md
+++ b/docs/content/docs/operators/clustering/kmeans.md
@@ -31,30 +31,30 @@ into a predefined number of clusters.
### Input Columns
| Param name | Type | Default | Description |
-| :---------- | :----- | :----------- | :------------- |
+|:------------|:-------|:-------------|:---------------|
| featuresCol | Vector | `"features"` | Feature vector |
### Output Columns
| Param name | Type | Default | Description |
-| :------------ | :------ | :------------- | :----------------------- |
+|:--------------|:--------|:---------------|:-------------------------|
| predictionCol | Integer | `"prediction"` | Predicted cluster center |
### Parameters
Below are the parameters required by `KMeansModel`.
-| Key | Default | Type | Required | Description |
-| --------------- | ------------------------------- | ------- | -------- | ------------------------------------------------------------ |
-| distanceMeasure | `EuclideanDistanceMeasure.NAME` | String | no | Distance measure. Supported values: `EuclideanDistanceMeasure.NAME` |
-| featuresCol | `"features"` | String | no | Features column name. |
-| predictionCol | `"prediction"` | String | no | Prediction column name. |
-| k | `2` | Integer | no | The max number of clusters to create. |
+| Key | Default | Type | Required | Description |
+|-----------------|----------------|---------|----------|---------------------------------------------------------------------------|
+| distanceMeasure | `euclidean` | String | no | Distance measure. Supported values: `'euclidean', 'manhattan', 'cosine'`. |
+| featuresCol | `"features"` | String | no | Features column name. |
+| predictionCol | `"prediction"` | String | no | Prediction column name. |
+| k | `2` | Integer | no | The max number of clusters to create. |
`KMeans` needs parameters above and also below.
| Key | Default | Type | Required | Description |
-| -------- | ---------- | ------- | -------- | ---------------------------------------------------------- |
+|----------|------------|---------|----------|------------------------------------------------------------|
| initMode | `"random"` | String | no | The initialization algorithm. Supported options: 'random'. |
| seed | `null` | Long | no | The random seed. |
| maxIter | `20` | Integer | no | Maximum number of iterations. |
@@ -187,30 +187,30 @@ correspond to more forgetting.
### Input Columns
| Param name | Type | Default | Description |
-| :---------- | :----- | :----------- | :------------- |
+|:------------|:-------|:-------------|:---------------|
| featuresCol | Vector | `"features"` | Feature vector |
### Output Columns
| Param name | Type | Default | Description |
-| :------------ | :------ | :------------- | :----------------------- |
+|:--------------|:--------|:---------------|:-------------------------|
| predictionCol | Integer | `"prediction"` | Predicted cluster center |
### Parameters
Below are the parameters required by `OnlineKMeansModel`.
-| Key | Default | Type | Required | Description |
-| --------------- | ------------------------------- | ------- | -------- | ------------------------------------------------------------ |
-| distanceMeasure | `EuclideanDistanceMeasure.NAME` | String | no | Distance measure. Supported values: `EuclideanDistanceMeasure.NAME` |
-| featuresCol | `"features"` | String | no | Features column name. |
-| predictionCol | `"prediction"` | String | no | Prediction column name. |
-| k | `2` | Integer | no | The max number of clusters to create. |
+| Key | Default | Type | Required | Description |
+|-----------------|----------------|---------|----------|---------------------------------------------------------------------------|
+| distanceMeasure | `euclidean` | String | no | Distance measure. Supported values: `'euclidean', 'manhattan', 'cosine'`. |
+| featuresCol | `"features"` | String | no | Features column name. |
+| predictionCol | `"prediction"` | String | no | Prediction column name. |
+| k | `2` | Integer | no | The max number of clusters to create. |
`OnlineKMeans` needs parameters above and also below.
| Key | Default | Type | Required | Description |
-| --------------- | ---------------- | ------- | -------- | ----------------------------------------------------- |
+|-----------------|------------------|---------|----------|-------------------------------------------------------|
| batchStrategy | `COUNT_STRATEGY` | String | no | Strategy to create mini batch from online train data. |
| globalBatchSize | `32` | Integer | no | Global batch size of training algorithms. |
| decayFactor | `0.` | Double | no | The forgetfulness of the previous centroids. |
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/CosineDistanceMeasure.java
similarity index 50%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
copy to flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/CosineDistanceMeasure.java
index ee7b25d..251b933 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/CosineDistanceMeasure.java
@@ -18,30 +18,27 @@
package org.apache.flink.ml.common.distance;
+import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.util.Preconditions;
-import java.io.Serializable;
+/** Cosine distance between two vectors. */
+public class CosineDistanceMeasure implements DistanceMeasure {
-/** Interface for measuring distance between two vectors. */
-public interface DistanceMeasure extends Serializable {
+ private static final CosineDistanceMeasure instance = new CosineDistanceMeasure();
+ public static final String NAME = "cosine";
- static DistanceMeasure getInstance(String distanceMeasure) {
- if (distanceMeasure.equals(EuclideanDistanceMeasure.NAME)) {
- return EuclideanDistanceMeasure.getInstance();
- }
- throw new IllegalArgumentException(
- "distanceMeasure "
- + distanceMeasure
- + " is not recognized. Supported options: 'euclidean'.");
- }
+ private CosineDistanceMeasure() {}
- /**
- * Measures the distance between two vectors.
- *
- * <p>Required: The two vectors should have the same dimension.
- */
- double distance(VectorWithNorm v1, VectorWithNorm v2);
+ public static CosineDistanceMeasure getInstance() {
+ return instance;
+ }
- /** Finds the index of the closest center to the given point. */
- int findClosest(VectorWithNorm[] centroids, VectorWithNorm point);
+ @Override
+ public double distance(VectorWithNorm v1, VectorWithNorm v2) {
+ Preconditions.checkArgument(
+ v1.l2Norm > 0 && v2.l2Norm > 0,
+ "Consine distance is not defined for zero-length vectors.");
+ return 1 - BLAS.dot(v1.vector, v2.vector) / v1.l2Norm / v2.l2Norm;
+ }
}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
index ee7b25d..ada5c0e 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
@@ -26,13 +26,19 @@ import java.io.Serializable;
public interface DistanceMeasure extends Serializable {
static DistanceMeasure getInstance(String distanceMeasure) {
- if (distanceMeasure.equals(EuclideanDistanceMeasure.NAME)) {
- return EuclideanDistanceMeasure.getInstance();
+ switch (distanceMeasure) {
+ case EuclideanDistanceMeasure.NAME:
+ return EuclideanDistanceMeasure.getInstance();
+ case ManhattanDistanceMeasure.NAME:
+ return ManhattanDistanceMeasure.getInstance();
+ case CosineDistanceMeasure.NAME:
+ return CosineDistanceMeasure.getInstance();
+ default:
+ throw new IllegalArgumentException(
+ "distanceMeasure "
+ + distanceMeasure
+ + " is not recognized. Supported options: 'euclidean, manhattan, cosine'.");
}
- throw new IllegalArgumentException(
- "distanceMeasure "
- + distanceMeasure
- + " is not recognized. Supported options: 'euclidean'.");
}
/**
@@ -43,5 +49,16 @@ public interface DistanceMeasure extends Serializable {
double distance(VectorWithNorm v1, VectorWithNorm v2);
/** Finds the index of the closest center to the given point. */
- int findClosest(VectorWithNorm[] centroids, VectorWithNorm point);
+ default int findClosest(VectorWithNorm[] centroids, VectorWithNorm point) {
+ int targetCentroidId = -1;
+ double minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < centroids.length; i++) {
+ double distance = distance(centroids[i], point);
+ if (distance < minDistance) {
+ minDistance = distance;
+ targetCentroidId = i;
+ }
+ }
+ return targetCentroidId;
+ }
}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
index a3798a0..8ef341a 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
@@ -21,7 +21,7 @@ package org.apache.flink.ml.common.distance;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.VectorWithNorm;
-/** Interface for measuring the Euclidean distance between two vectors. */
+/** Euclidean distance (also known as L2 distance) between two vectors. */
public class EuclideanDistanceMeasure implements DistanceMeasure {
private static final EuclideanDistanceMeasure instance = new EuclideanDistanceMeasure();
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/ManhattanDistanceMeasure.java
similarity index 51%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
copy to flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/ManhattanDistanceMeasure.java
index ee7b25d..a21364a 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/ManhattanDistanceMeasure.java
@@ -19,29 +19,27 @@
package org.apache.flink.ml.common.distance;
import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.util.Preconditions;
-import java.io.Serializable;
+/** Manhattan distance (also known as L1 distance) between two vectors. */
+public class ManhattanDistanceMeasure implements DistanceMeasure {
-/** Interface for measuring distance between two vectors. */
-public interface DistanceMeasure extends Serializable {
+ private static final ManhattanDistanceMeasure instance = new ManhattanDistanceMeasure();
+ public static final String NAME = "manhattan";
- static DistanceMeasure getInstance(String distanceMeasure) {
- if (distanceMeasure.equals(EuclideanDistanceMeasure.NAME)) {
- return EuclideanDistanceMeasure.getInstance();
- }
- throw new IllegalArgumentException(
- "distanceMeasure "
- + distanceMeasure
- + " is not recognized. Supported options: 'euclidean'.");
- }
+ private ManhattanDistanceMeasure() {}
- /**
- * Measures the distance between two vectors.
- *
- * <p>Required: The two vectors should have the same dimension.
- */
- double distance(VectorWithNorm v1, VectorWithNorm v2);
+ public static ManhattanDistanceMeasure getInstance() {
+ return instance;
+ }
- /** Finds the index of the closest center to the given point. */
- int findClosest(VectorWithNorm[] centroids, VectorWithNorm point);
+ @Override
+ public double distance(VectorWithNorm v1, VectorWithNorm v2) {
+ Preconditions.checkArgument(v1.vector.size() == v2.vector.size());
+ double sum = 0;
+ for (int i = 0; i < v1.vector.size(); i++) {
+ sum += Math.abs(v1.vector.get(i) - v2.vector.get(i));
+ }
+ return sum;
+ }
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/distance/DistanceMeasureTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/distance/DistanceMeasureTest.java
new file mode 100644
index 0000000..7370f2b
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/distance/DistanceMeasureTest.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.distance;
+
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+
+import org.junit.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Tests {@link CosineDistanceMeasure}, {@link EuclideanDistanceMeasure} and {@link
+ * ManhattanDistanceMeasure}.
+ */
+public class DistanceMeasureTest {
+ private static final VectorWithNorm VECTOR_WITH_NORM_A =
+ new VectorWithNorm(Vectors.sparse(3, new int[] {1, 2}, new double[] {1, 2}));
+ private static final VectorWithNorm VECTOR_WITH_NORM_B =
+ new VectorWithNorm(Vectors.dense(1, 2, 3));
+ private static final VectorWithNorm[] CENTROIDS =
+ new VectorWithNorm[] {
+ new VectorWithNorm(Vectors.dense(0, 1, 2)),
+ new VectorWithNorm(Vectors.dense(1, 2, 3)),
+ new VectorWithNorm(Vectors.dense(2, 3, 4))
+ };
+
+ private static final double TOLERANCE = 1e-7;
+
+ @Test
+ public void testEuclidean() {
+ DistanceMeasure distanceMeasure = EuclideanDistanceMeasure.getInstance();
+ assertEquals(
+ Math.sqrt(3),
+ distanceMeasure.distance(VECTOR_WITH_NORM_A, VECTOR_WITH_NORM_B),
+ TOLERANCE);
+ assertEquals(0, distanceMeasure.findClosest(CENTROIDS, VECTOR_WITH_NORM_A));
+ assertEquals(1, distanceMeasure.findClosest(CENTROIDS, VECTOR_WITH_NORM_B));
+ }
+
+ @Test
+ public void testManhattan() {
+ DistanceMeasure distanceMeasure = ManhattanDistanceMeasure.getInstance();
+ assertEquals(
+ 3, distanceMeasure.distance(VECTOR_WITH_NORM_A, VECTOR_WITH_NORM_B), TOLERANCE);
+ assertEquals(0, distanceMeasure.findClosest(CENTROIDS, VECTOR_WITH_NORM_A));
+ assertEquals(1, distanceMeasure.findClosest(CENTROIDS, VECTOR_WITH_NORM_B));
+ }
+
+ @Test
+ public void testCosine() {
+ DistanceMeasure distanceMeasure = CosineDistanceMeasure.getInstance();
+ assertEquals(
+ 0.04381711,
+ distanceMeasure.distance(VECTOR_WITH_NORM_A, VECTOR_WITH_NORM_B),
+ TOLERANCE);
+ assertEquals(0, distanceMeasure.findClosest(CENTROIDS, VECTOR_WITH_NORM_A));
+ assertEquals(1, distanceMeasure.findClosest(CENTROIDS, VECTOR_WITH_NORM_B));
+ }
+}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java
new file mode 100644
index 0000000..c48448f
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.clustering;
+
+import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
+import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates an AgglomerativeClustering instance and uses it for clustering. */
+public class AgglomerativeClusteringExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<DenseVector> inputStream =
+ env.fromElements(
+ Vectors.dense(1, 1),
+ Vectors.dense(1, 4),
+ Vectors.dense(1, 0),
+ Vectors.dense(4, 1.5),
+ Vectors.dense(4, 4),
+ Vectors.dense(4, 0));
+ Table inputTable = tEnv.fromDataStream(inputStream).as("features");
+
+ // Creates an AgglomerativeClustering object and initializes its parameters.
+ AgglomerativeClustering agglomerativeClustering =
+ new AgglomerativeClustering()
+ .setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD)
+ .setDistanceMeasure(EuclideanDistanceMeasure.NAME)
+ .setPredictionCol("prediction")
+ .setComputeFullTree(true);
+
+ // Uses the AgglomerativeClustering object for clustering.
+ Table[] outputs = agglomerativeClustering.transform(inputTable);
+
+ // Extracts and displays the clustering results.
+ for (CloseableIterator<Row> it = outputs[0].execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector features =
+ (DenseVector) row.getField(agglomerativeClustering.getFeaturesCol());
+ int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol());
+ System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java
new file mode 100644
index 0000000..0d8bcf1
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java
@@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom-up approach. Each
+ * observation starts in its own cluster and the clusters are merged together one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measures by setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>The output contains two tables. The first one assigns one cluster Id for each data point. The
+ * second one contains the information of merging two clusters at each step. The data format of the
+ * merging information is (clusterId1, clusterId2, distance, sizeOfMergedCluster).
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+ implements AlgoOperator<AgglomerativeClustering>,
+ AgglomerativeClusteringParams<AgglomerativeClustering> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public AgglomerativeClustering() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Integer numCluster = getNumClusters();
+ Double distanceThreshold = getDistanceThreshold();
+ Preconditions.checkArgument(
+ (numCluster == null && distanceThreshold != null)
+ || (numCluster != null && distanceThreshold == null),
+ "One of param numCluster and distanceThreshold should be null.");
+
+ if (getLinkage().equals(LINKAGE_WARD)) {
+ String distanceMeasure = getDistanceMeasure();
+ Preconditions.checkArgument(
+ distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+ distanceMeasure
+ + " was provided as distance measure while linkage was ward. Ward only works with euclidean.");
+ }
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>> mergeInfoOutputTag =
+ new OutputTag<Tuple4<Integer, Integer, Double, Integer>>("MERGE_INFO") {};
+
+ SingleOutputStreamOperator<Row> output =
+ dataStream.transform(
+ "doLocalAgglomerativeClustering",
+ outputTypeInfo,
+ new LocalAgglomerativeClusteringOperator(
+ getFeaturesCol(),
+ getLinkage(),
+ getDistanceMeasure(),
+ getNumClusters(),
+ getDistanceThreshold(),
+ getComputeFullTree(),
+ mergeInfoOutputTag));
+ output.getTransformation().setParallelism(1);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+ output.getSideOutput(mergeInfoOutputTag);
+ mergeInfo.getTransformation().setParallelism(1);
+ Table mergeInfoTable =
+ tEnv.fromDataStream(mergeInfo)
+ .as("clusterId1", "clusterId2", "distance", "sizeOfMergedCluster");
+
+ return new Table[] {outputTable, mergeInfoTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static AgglomerativeClustering load(StreamTableEnvironment tEnv, String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ private static class LocalAgglomerativeClusteringOperator extends AbstractStreamOperator<Row>
+ implements OneInputStreamOperator<Row, Row>, BoundedOneInput {
+ private final String featuresCol;
+ private final String linkage;
+ private final DistanceMeasure distanceMeasure;
+ private final Integer numCluster;
+ private final Double distanceThreshold;
+ private final boolean computeFullTree;
+ private final OutputTag<Tuple4<Integer, Integer, Double, Integer>> mergeInfoOutputTag;
+
+ /** State for the input data. */
+ private ListState<Row> inputListState;
+ /** Cluster id of each data point in inputList. */
+ private int[] clusterIds;
+ /** Precomputes the norm of each vector for performance. */
+ private VectorWithNorm[] vectorWithNorms;
+ /** Next cluster Id to be assigned. */
+ private int nextClusterId = 0;
+
+ public LocalAgglomerativeClusteringOperator(
+ String featuresCol,
+ String linkage,
+ String distanceMeasureName,
+ Integer numCluster,
+ Double distanceThreshold,
+ boolean computeFullTree,
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>> mergeInfoOutputTag) {
+ this.featuresCol = featuresCol;
+ this.linkage = linkage;
+ this.numCluster = numCluster;
+ this.distanceThreshold = distanceThreshold;
+ this.computeFullTree = computeFullTree;
+ this.mergeInfoOutputTag = mergeInfoOutputTag;
+
+ distanceMeasure = DistanceMeasure.getInstance(distanceMeasureName);
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ inputListState =
+ context.getOperatorStateStore()
+ .getListState(new ListStateDescriptor<>("inputListState", Row.class));
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> input) throws Exception {
+ inputListState.add(input.getValue());
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void endInput() throws Exception {
+ List<Row> inputList = IteratorUtils.toList(inputListState.get().iterator());
+ int numDataPoints = inputList.size();
+
+ // Assigns initial cluster Ids.
+ clusterIds = new int[numDataPoints];
+ for (int i = 0; i < numDataPoints; i++) {
+ clusterIds[i] = getNextClusterId();
+ }
+
+ List<Cluster> activeClusters = new ArrayList<>();
+ for (int i = 0; i < numDataPoints; i++) {
+ List<Integer> dataPointIds = new ArrayList<>();
+ dataPointIds.add(i);
+ activeClusters.add(new Cluster(i, dataPointIds));
+ }
+
+ // Precomputes vector norms for faster computation.
+ vectorWithNorms = new VectorWithNorm[inputList.size()];
+ for (int i = 0; i < numDataPoints; i++) {
+ vectorWithNorms[i] =
+ new VectorWithNorm((Vector) inputList.get(i).getField(featuresCol));
+ }
+
+ // Clustering process.
+ doClustering(activeClusters);
+
+ // Remaps the cluster Ids and output results.
+ HashMap<Integer, Integer> remappedClusterIds = new HashMap<>();
+ int cnt = 0;
+ for (int i = 0; i < clusterIds.length; i++) {
+ int clusterId = clusterIds[i];
+ if (remappedClusterIds.containsKey(clusterId)) {
+ clusterIds[i] = remappedClusterIds.get(clusterId);
+ } else {
+ clusterIds[i] = cnt;
+ remappedClusterIds.put(clusterId, cnt++);
+ }
+ }
+
+ for (int i = 0; i < numDataPoints; i++) {
+ output.collect(
+ new StreamRecord<>(Row.join(inputList.get(i), Row.of(clusterIds[i]))));
+ }
+ }
+
+ private int getNextClusterId() {
+ return nextClusterId++;
+ }
+
+ private void doClustering(List<Cluster> activeClusters) {
+ int clusterOffset1 = -1, clusterOffset2 = -1;
+ boolean clusteringRunning =
+ (numCluster != null && activeClusters.size() > numCluster)
+ || (distanceThreshold != null);
+
+ while (clusteringRunning || (computeFullTree && activeClusters.size() > 1)) {
+ // Computes the distance between two clusters.
+ double minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < activeClusters.size(); i++) {
+ for (int j = i + 1; j < activeClusters.size(); j++) {
+ double distance =
+ computeDistanceBetweenClusters(
+ activeClusters.get(i), activeClusters.get(j));
+ if (distance < minDistance) {
+ minDistance = distance;
+ clusterOffset1 = i;
+ clusterOffset2 = j;
+ }
+ }
+ }
+
+ // Outputs the merge info.
+ Cluster cluster1 = activeClusters.get(clusterOffset1);
+ Cluster cluster2 = activeClusters.get(clusterOffset2);
+ int clusterId1 = cluster1.clusterId;
+ int clusterId2 = cluster2.clusterId;
+ output.collect(
+ mergeInfoOutputTag,
+ new StreamRecord<>(
+ Tuple4.of(
+ Math.min(clusterId1, clusterId2),
+ Math.max(clusterId1, clusterId2),
+ minDistance,
+ cluster1.dataPointIds.size()
+ + cluster2.dataPointIds.size())));
+
+ // Merges these two clusters.
+ Cluster mergedCluster =
+ new Cluster(
+ getNextClusterId(), cluster1.dataPointIds, cluster2.dataPointIds);
+ activeClusters.set(clusterOffset1, mergedCluster);
+ activeClusters.remove(clusterOffset2);
+
+ // Updates cluster Ids for each data point if clustering is still running.
+ if (clusteringRunning) {
+ int mergedClusterId = mergedCluster.clusterId;
+ for (int dataPointId : mergedCluster.dataPointIds) {
+ clusterIds[dataPointId] = mergedClusterId;
+ }
+ }
+
+ clusteringRunning =
+ (numCluster != null && activeClusters.size() > numCluster)
+ || (distanceThreshold != null && distanceThreshold > minDistance);
+ }
+ }
+
+ private double computeDistanceBetweenClusters(Cluster cluster1, Cluster cluster2) {
+ double distance;
+ int size1 = cluster1.dataPointIds.size();
+ int size2 = cluster2.dataPointIds.size();
+
+ switch (linkage) {
+ case LINKAGE_AVERAGE:
+ distance = 0;
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ VectorWithNorm vectorWithNorm1 =
+ vectorWithNorms[cluster1.dataPointIds.get(i)];
+ VectorWithNorm vectorWithNorm2 =
+ vectorWithNorms[cluster2.dataPointIds.get(j)];
+ distance += distanceMeasure.distance(vectorWithNorm1, vectorWithNorm2);
+ }
+ }
+ distance /= size1 * size2;
+ break;
+ case LINKAGE_COMPLETE:
+ distance = Double.MIN_VALUE;
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ VectorWithNorm vectorWithNorm1 =
+ vectorWithNorms[cluster1.dataPointIds.get(i)];
+ VectorWithNorm vectorWithNorm2 =
+ vectorWithNorms[cluster2.dataPointIds.get(j)];
+ distance =
+ Math.max(
+ distance,
+ distanceMeasure.distance(
+ vectorWithNorm1, vectorWithNorm2));
+ }
+ }
+ break;
+ case LINKAGE_SINGLE:
+ distance = Double.MAX_VALUE;
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ VectorWithNorm vectorWithNorm1 =
+ vectorWithNorms[cluster1.dataPointIds.get(i)];
+ VectorWithNorm vectorWithNorm2 =
+ vectorWithNorms[cluster2.dataPointIds.get(j)];
+ distance =
+ Math.min(
+ distance,
+ distanceMeasure.distance(
+ vectorWithNorm1, vectorWithNorm2));
+ }
+ }
+ break;
+ case LINKAGE_WARD:
+ int vecSize = vectorWithNorms[0].vector.size();
+ DenseVector mean1 = Vectors.dense(new double[vecSize]);
+ DenseVector mean2 = Vectors.dense(new double[vecSize]);
+
+ for (int i = 0; i < size1; i++) {
+ BLAS.axpy(1.0, vectorWithNorms[cluster1.dataPointIds.get(i)].vector, mean1);
+ }
+ for (int i = 0; i < size2; i++) {
+ BLAS.axpy(1.0, vectorWithNorms[cluster2.dataPointIds.get(i)].vector, mean2);
+ }
+
+ DenseVector meanMerged = mean1.clone();
+ BLAS.axpy(1.0, mean2, meanMerged);
+ BLAS.scal(1.0 / size1, mean1);
+ BLAS.scal(1.0 / size2, mean2);
+ BLAS.scal(1.0 / (size1 + size2), meanMerged);
+ double essInc =
+ size1 * BLAS.dot(mean1, mean1)
+ + size2 * BLAS.dot(mean2, mean2)
+ - (size1 + size2) * BLAS.dot(meanMerged, meanMerged);
+
+ distance = Math.sqrt(2 * essInc);
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported " + LINKAGE + " type: " + linkage + ".");
+ }
+ return distance;
+ }
+
+ /** A cluster with cluster Id specified and data points that belong to this cluster. */
+ private static class Cluster {
+ private final int clusterId;
+ private final List<Integer> dataPointIds;
+
+ public Cluster(int clusterId, List<Integer> dataPointIds) {
+ this.clusterId = clusterId;
+ this.dataPointIds = dataPointIds;
+ }
+
+ public Cluster(
+ int clusterId, List<Integer> dataPointIds, List<Integer> otherDataPointIds) {
+ this.clusterId = clusterId;
+ this.dataPointIds = dataPointIds;
+ this.dataPointIds.addAll(otherDataPointIds);
+ }
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClusteringParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClusteringParams.java
new file mode 100644
index 0000000..f3369b3
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClusteringParams.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.ml.common.param.HasDistanceMeasure;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link AgglomerativeClustering}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface AgglomerativeClusteringParams<T>
+ extends HasDistanceMeasure<T>, HasFeaturesCol<T>, HasPredictionCol<T> {
+ Param<Integer> NUM_CLUSTERS =
+ new IntParam("numClusters", "The max number of clusters to create.", 2);
+
+ Param<Double> DISTANCE_THRESHOLD =
+ new DoubleParam(
+ "distanceThreshold",
+ "Threshold to decide whether two clusters should be merged.",
+ null);
+
+ String LINKAGE_WARD = "ward";
+ String LINKAGE_COMPLETE = "complete";
+ String LINKAGE_SINGLE = "single";
+ String LINKAGE_AVERAGE = "average";
+ /**
+ * Supported options to compute the distance between two clusters. The algorithm will merge the
+ * pairs of cluster that minimize this criterion.
+ *
+ * <ul>
+ * <li>ward: the variance between the two clusters.
+ * <li>complete: the maximum distance between all observations of the two clusters.
+ * <li>single: the minimum distance between all observations of the two clusters.
+ * <li>average: the average distance between all observations of the two clusters.
+ * </ul>
+ */
+ Param<String> LINKAGE =
+ new StringParam(
+ "linkage",
+ "Criterion for computing distance between two clusters.",
+ LINKAGE_WARD,
+ ParamValidators.inArray(
+ LINKAGE_WARD, LINKAGE_COMPLETE, LINKAGE_AVERAGE, LINKAGE_SINGLE));
+
+ Param<Boolean> COMPUTE_FULL_TREE =
+ new BooleanParam(
+ "computeFullTree",
+ "Whether computes the full tree after convergence.",
+ false,
+ ParamValidators.notNull());
+
+ default Integer getNumClusters() {
+ return get(NUM_CLUSTERS);
+ }
+
+ default T setNumClusters(Integer value) {
+ return set(NUM_CLUSTERS, value);
+ }
+
+ default String getLinkage() {
+ return get(LINKAGE);
+ }
+
+ default T setLinkage(String value) {
+ return set(LINKAGE, value);
+ }
+
+ default Double getDistanceThreshold() {
+ return get(DISTANCE_THRESHOLD);
+ }
+
+ default T setDistanceThreshold(Double value) {
+ return set(DISTANCE_THRESHOLD, value);
+ }
+
+ default Boolean getComputeFullTree() {
+ return get(COMPUTE_FULL_TREE);
+ }
+
+ default T setComputeFullTree(Boolean value) {
+ return set(COMPUTE_FULL_TREE, value);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
index f9f2190..2c4c86f 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
@@ -18,7 +18,9 @@
package org.apache.flink.ml.common.param;
+import org.apache.flink.ml.common.distance.CosineDistanceMeasure;
import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.common.distance.ManhattanDistanceMeasure;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.ml.param.StringParam;
@@ -31,7 +33,10 @@ public interface HasDistanceMeasure<T> extends WithParams<T> {
"distanceMeasure",
"Distance measure.",
EuclideanDistanceMeasure.NAME,
- ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
+ ParamValidators.inArray(
+ EuclideanDistanceMeasure.NAME,
+ ManhattanDistanceMeasure.NAME,
+ CosineDistanceMeasure.NAME));
default String getDistanceMeasure() {
return get(DISTANCE_MEASURE);
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
new file mode 100644
index 0000000..064aa63
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java
@@ -0,0 +1,302 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
+import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
+import org.apache.flink.ml.common.distance.CosineDistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.common.distance.ManhattanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link AgglomerativeClustering}. */
+public class AgglomerativeClusteringTest extends AbstractTestBase {
+
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamTableEnvironment tEnv;
+ private StreamExecutionEnvironment env;
+ private Table inputDataTable;
+
+ private static final double[] EUCLIDEAN_AVERAGE_MERGE_DISTANCES =
+ new double[] {1, 1.5, 3, 3.1394402, 3.9559706};
+
+ private static final double[] COSINE_AVERAGE_MERGE_DISTANCES =
+ new double[] {0, 1.1102230E-16, 0.0636708, 0.1425070, 0.3664484};
+
+ private static final double[] MANHATTAN_AVERAGE_MERGE_DISTANCES =
+ new double[] {1, 1.5, 3, 3.75, 4.875};
+ private static final double[] EUCLIDEAN_SINGLE_MERGE_DISTANCES =
+ new double[] {1, 1.5, 2.5, 3, 3};
+
+ private static final double[] EUCLIDEAN_WARD_MERGE_DISTANCES =
+ new double[] {1, 1.5, 3, 4.2573465, 5.5113519};
+
+ private static final double[] EUCLIDEAN_COMPLETE_MERGE_DISTANCES =
+ new double[] {1, 1.5, 3, 3.3541019, 5};
+
+ private static final List<Set<DenseVector>> EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT =
+ Arrays.asList(
+ new HashSet<>(
+ Arrays.asList(
+ Vectors.dense(1, 1),
+ Vectors.dense(1, 0),
+ Vectors.dense(4, 1.5),
+ Vectors.dense(4, 0))),
+ new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4))));
+
+ private static final List<Set<DenseVector>> EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT =
+ Arrays.asList(
+ new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))),
+ new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4))),
+ new HashSet<>(Arrays.asList(Vectors.dense(4, 1.5), Vectors.dense(4, 0))));
+
+ private static final double TOLERANCE = 1e-7;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(3);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ List<DenseVector> inputData =
+ Arrays.asList(
+ Vectors.dense(1, 1),
+ Vectors.dense(1, 4),
+ Vectors.dense(1, 0),
+ Vectors.dense(4, 1.5),
+ Vectors.dense(4, 4),
+ Vectors.dense(4, 0));
+ inputDataTable =
+ tEnv.fromDataStream(env.fromCollection(inputData).map(x -> x)).as("features");
+ }
+
+ @Test
+ public void testParam() {
+ AgglomerativeClustering agglomerativeClustering = new AgglomerativeClustering();
+ assertEquals("features", agglomerativeClustering.getFeaturesCol());
+ assertEquals(2, agglomerativeClustering.getNumClusters().intValue());
+ assertNull(agglomerativeClustering.getDistanceThreshold());
+ assertEquals(AgglomerativeClustering.LINKAGE_WARD, agglomerativeClustering.getLinkage());
+ assertEquals(EuclideanDistanceMeasure.NAME, agglomerativeClustering.getDistanceMeasure());
+ assertFalse(agglomerativeClustering.getComputeFullTree());
+ assertEquals("prediction", agglomerativeClustering.getPredictionCol());
+
+ agglomerativeClustering
+ .setFeaturesCol("test_features")
+ .setNumClusters(null)
+ .setDistanceThreshold(0.01)
+ .setLinkage(AgglomerativeClusteringParams.LINKAGE_AVERAGE)
+ .setDistanceMeasure(CosineDistanceMeasure.NAME)
+ .setComputeFullTree(true)
+ .setPredictionCol("cluster_id");
+
+ assertEquals("test_features", agglomerativeClustering.getFeaturesCol());
+ assertNull(agglomerativeClustering.getNumClusters());
+ assertEquals(0.01, agglomerativeClustering.getDistanceThreshold(), TOLERANCE);
+ assertEquals(AgglomerativeClustering.LINKAGE_AVERAGE, agglomerativeClustering.getLinkage());
+ assertEquals(CosineDistanceMeasure.NAME, agglomerativeClustering.getDistanceMeasure());
+ assertTrue(agglomerativeClustering.getComputeFullTree());
+ assertEquals("cluster_id", agglomerativeClustering.getPredictionCol());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Table tempTable =
+ tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+ .as("test_features", "dummy_input");
+ AgglomerativeClustering agglomerativeClustering =
+ new AgglomerativeClustering()
+ .setFeaturesCol("test_features")
+ .setPredictionCol("test_prediction");
+ Table[] outputs = agglomerativeClustering.transform(tempTable);
+ assertEquals(2, outputs.length);
+ assertEquals(
+ Arrays.asList("test_features", "dummy_input", "test_prediction"),
+ outputs[0].getResolvedSchema().getColumnNames());
+ assertEquals(
+ Arrays.asList("clusterId1", "clusterId2", "distance", "sizeOfMergedCluster"),
+ outputs[1].getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testTransform() throws Exception {
+ Table[] outputs;
+ AgglomerativeClustering agglomerativeClustering =
+ new AgglomerativeClustering()
+ .setLinkage(AgglomerativeClusteringParams.LINKAGE_AVERAGE)
+ .setDistanceMeasure(EuclideanDistanceMeasure.NAME)
+ .setPredictionCol("pred");
+
+ // Tests euclidean distance with linkage as average, numClusters = 2.
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyClusteringResult(
+ EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT,
+ outputs[0],
+ agglomerativeClustering.getFeaturesCol(),
+ agglomerativeClustering.getPredictionCol());
+
+ // Tests euclidean distance with linkage as average, numClusters = 2, compute_full_tree =
+ // true.
+ outputs = agglomerativeClustering.setComputeFullTree(true).transform(inputDataTable);
+ verifyClusteringResult(
+ EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT,
+ outputs[0],
+ agglomerativeClustering.getFeaturesCol(),
+ agglomerativeClustering.getPredictionCol());
+
+ // Tests euclidean distance with linkage as average, distance_threshold = 2.
+ outputs =
+ agglomerativeClustering
+ .setNumClusters(null)
+ .setDistanceThreshold(2.0)
+ .transform(inputDataTable);
+ verifyClusteringResult(
+ EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT,
+ outputs[0],
+ agglomerativeClustering.getFeaturesCol(),
+ agglomerativeClustering.getPredictionCol());
+ }
+
+ @Test
+ public void testMergeInfo() throws Exception {
+ Table[] outputs;
+ AgglomerativeClustering agglomerativeClustering =
+ new AgglomerativeClustering()
+ .setLinkage(AgglomerativeClusteringParams.LINKAGE_AVERAGE)
+ .setDistanceMeasure(EuclideanDistanceMeasure.NAME)
+ .setPredictionCol("pred")
+ .setComputeFullTree(true);
+
+ // Tests euclidean distance with linkage as average.
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyMergeInfo(EUCLIDEAN_AVERAGE_MERGE_DISTANCES, outputs[1]);
+
+ // Tests cosine distance with linkage as average.
+ agglomerativeClustering.setDistanceMeasure(CosineDistanceMeasure.NAME);
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyMergeInfo(COSINE_AVERAGE_MERGE_DISTANCES, outputs[1]);
+
+ // Tests manhattan distance with linkage as average.
+ agglomerativeClustering.setDistanceMeasure(ManhattanDistanceMeasure.NAME);
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyMergeInfo(MANHATTAN_AVERAGE_MERGE_DISTANCES, outputs[1]);
+
+ // Tests euclidean distance with linkage as complete.
+ agglomerativeClustering
+ .setDistanceMeasure(EuclideanDistanceMeasure.NAME)
+ .setLinkage(AgglomerativeClusteringParams.LINKAGE_COMPLETE);
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyMergeInfo(EUCLIDEAN_COMPLETE_MERGE_DISTANCES, outputs[1]);
+
+ // Tests euclidean distance with linkage as single.
+ agglomerativeClustering.setLinkage(AgglomerativeClusteringParams.LINKAGE_SINGLE);
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyMergeInfo(EUCLIDEAN_SINGLE_MERGE_DISTANCES, outputs[1]);
+
+ // Tests euclidean distance with linkage as ward.
+ agglomerativeClustering.setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD);
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyMergeInfo(EUCLIDEAN_WARD_MERGE_DISTANCES, outputs[1]);
+
+ // Tests merge info not fully computed.
+ agglomerativeClustering.setComputeFullTree(false);
+ outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyMergeInfo(
+ Arrays.copyOfRange(
+ EUCLIDEAN_WARD_MERGE_DISTANCES,
+ 0,
+ EUCLIDEAN_WARD_MERGE_DISTANCES.length - 1),
+ outputs[1]);
+ }
+
+ @Test
+ public void testSaveLoadTransform() throws Exception {
+ AgglomerativeClustering agglomerativeClustering =
+ new AgglomerativeClustering()
+ .setLinkage(AgglomerativeClusteringParams.LINKAGE_AVERAGE)
+ .setDistanceMeasure(EuclideanDistanceMeasure.NAME)
+ .setPredictionCol("pred");
+
+ agglomerativeClustering =
+ TestUtils.saveAndReload(
+ tEnv, agglomerativeClustering, tempFolder.newFolder().getAbsolutePath());
+
+ Table[] outputs = agglomerativeClustering.transform(inputDataTable);
+ verifyClusteringResult(
+ EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT,
+ outputs[0],
+ agglomerativeClustering.getFeaturesCol(),
+ agglomerativeClustering.getPredictionCol());
+ }
+
+ @SuppressWarnings("unchecked")
+ private void verifyMergeInfo(double[] expectedDistances, Table mergeInfoTable)
+ throws Exception {
+ List<Row> mergeInfo =
+ IteratorUtils.toList(tEnv.toDataStream(mergeInfoTable).executeAndCollect());
+ assertEquals(expectedDistances.length, mergeInfo.size());
+ for (int i = 0; i < mergeInfo.size(); i++) {
+ double actualDistance = ((Number) mergeInfo.get(i).getFieldAs(2)).doubleValue();
+ assertEquals(expectedDistances[i], actualDistance, TOLERANCE);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ public void verifyClusteringResult(
+ List<Set<DenseVector>> expected,
+ Table outputTable,
+ String featureCol,
+ String predictionCol)
+ throws Exception {
+ List<Row> output = IteratorUtils.toList(tEnv.toDataStream(outputTable).executeAndCollect());
+ List<Set<DenseVector>> actualGroups =
+ KMeansTest.groupFeaturesByPrediction(output, featureCol, predictionCol);
+ assertTrue(CollectionUtils.isEqualCollection(expected, actualGroups));
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/clustering/agglomerativeclustering_example.py b/flink-ml-python/pyflink/examples/ml/clustering/agglomerativeclustering_example.py
new file mode 100644
index 0000000..8d90f28
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/clustering/agglomerativeclustering_example.py
@@ -0,0 +1,84 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a Bucketizer instance and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.clustering.agglomerativeclustering import AgglomerativeClustering
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input data.
+input_data = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense([1, 1]),),
+ (Vectors.dense([1, 4]),),
+ (Vectors.dense([1, 0]),),
+ (Vectors.dense([4, 1.5]),),
+ (Vectors.dense([4, 4]),),
+ (Vectors.dense([4, 0]),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['features'],
+ [DenseVectorTypeInfo()])))
+
+# Creates an AgglomerativeClustering object and initializes its parameters.
+agglomerative_clustering = AgglomerativeClustering() \
+ .set_linkage('ward') \
+ .set_distance_measure('euclidean') \
+ .set_prediction_col('prediction') \
+ .set_compute_full_tree(True)
+
+# Uses the AgglomerativeClustering for clustering.
+outputs = agglomerative_clustering.transform(input_data)
+
+# Extracts and display the results.
+field_names = outputs[0].get_schema().get_field_names()
+for result in t_env.to_data_stream(outputs[0]).execute_and_collect():
+ features = result[field_names.index(agglomerative_clustering.features_col)]
+ cluster_id = result[field_names.index(agglomerative_clustering.prediction_col)]
+ print('Features: ' + str(features) + '\tCluster ID: ' + str(cluster_id))
+
+"""
+# The following code snippet could be used to visualize the merge info.
+
+from matplotlib import pyplot as plt
+from scipy.cluster.hierarchy import dendrogram
+
+merge_info = [result for result in
+ t_env.to_data_stream(outputs[1]).execute_and_collect()]
+plt.title("Agglomerative Clustering Dendrogram")
+dendrogram(merge_info)
+plt.xlabel("Index of data point.")
+plt.ylabel("Distances between merged clusters.")
+plt.show()
+"""
diff --git a/flink-ml-python/pyflink/ml/core/tests/test_param.py b/flink-ml-python/pyflink/ml/core/tests/test_param.py
index 17e507a..d27f479 100644
--- a/flink-ml-python/pyflink/ml/core/tests/test_param.py
+++ b/flink-ml-python/pyflink/ml/core/tests/test_param.py
@@ -41,7 +41,8 @@ class ParamTests(unittest.TestCase):
distance_measure = param.DISTANCE_MEASURE
self.assertEqual(distance_measure.name, "distance_measure")
self.assertEqual(distance_measure.description,
- "Distance measure. Supported options: 'euclidean' and 'cosine'.")
+ "Distance measure. Supported options: "
+ "'euclidean', 'manhattan' and 'cosine'.")
self.assertEqual(distance_measure.default_value, "euclidean")
param.set_distance_measure("cosine")
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/agglomerativeclustering.py b/flink-ml-python/pyflink/ml/lib/clustering/agglomerativeclustering.py
new file mode 100644
index 0000000..8c32e76
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/clustering/agglomerativeclustering.py
@@ -0,0 +1,136 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import Param, StringParam, IntParam, FloatParam, \
+ BooleanParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.clustering.common import JavaClusteringAlgoOperator
+from pyflink.ml.lib.param import HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
+
+
+class _AgglomerativeClusteringParams(
+ JavaWithParams,
+ HasDistanceMeasure,
+ HasFeaturesCol,
+ HasPredictionCol
+):
+ """
+ Params for :class:`AgglomerativeClustering`.
+ """
+ NUM_CLUSTERS: Param[int] = IntParam("num_clusters",
+ "The max number of clusters to create.",
+ 2)
+
+ DISTANCE_THRESHOLD: Param[float] = \
+ FloatParam("distance_threshold",
+ "Threshold to decide whether two clusters should be merged.",
+ None)
+
+ """
+ Supported options to compute the distance between two clusters. The
+ algorithm will merge the pairs of cluster that minimize this criterion.
+ <ul>
+ <li>ward: the variance between the two clusters.
+ <li>complete: the maximum distance between all observations of the two clusters.
+ <li>single: the minimum distance between all observations of the two clusters.
+ <li>average: the average distance between all observations of the two clusters.
+ </ul>
+ """
+ LINKAGE: Param[str] = StringParam(
+ "linkage",
+ "Criterion for computing distance between two clusters.",
+ "ward",
+ ParamValidators.in_array(
+ ["ward", "complete", "single", "average"]))
+
+ COMPUTE_FULL_TREE: Param[bool] = BooleanParam(
+ "compute_full_tree",
+ "Whether computes the full tree after convergence.",
+ False,
+ ParamValidators.not_null())
+
+ def __init__(self, java_params):
+ super(_AgglomerativeClusteringParams, self).__init__(java_params)
+
+ def set_num_clusters(self, value: int):
+ return typing.cast(_AgglomerativeClusteringParams, self.set(self.NUM_CLUSTERS, value))
+
+ def get_num_clusters(self) -> int:
+ return self.get(self.NUM_CLUSTERS)
+
+ def set_distance_threshold(self, value: float):
+ return typing.cast(_AgglomerativeClusteringParams, self.set(self.DISTANCE_THRESHOLD, value))
+
+ def get_distance_threshold(self) -> float:
+ return self.get(self.DISTANCE_THRESHOLD)
+
+ def set_linkage(self, value: str):
+ return typing.cast(_AgglomerativeClusteringParams, self.set(self.LINKAGE, value))
+
+ def get_linkage(self) -> str:
+ return self.get(self.LINKAGE)
+
+ def set_compute_full_tree(self, value: bool):
+ return typing.cast(_AgglomerativeClusteringParams, self.set(self.COMPUTE_FULL_TREE, value))
+
+ def get_compute_full_tree(self) -> bool:
+ return self.get(self.COMPUTE_FULL_TREE)
+
+ @property
+ def num_clusters(self):
+ return self.get_num_clusters()
+
+ @property
+ def distance_threshold(self):
+ return self.get_distance_threshold()
+
+ @property
+ def linkage(self):
+ return self.get_linkage()
+
+ @property
+ def compute_full_tree(self):
+ return self.get_compute_full_tree()
+
+
+class AgglomerativeClustering(JavaClusteringAlgoOperator, _AgglomerativeClusteringParams):
+ """
+ An AlgoOperator that performs a hierarchical clustering using a bottom-up approach. Each
+ observation starts in its own cluster and the clusters are merged together one by one.
+ Users can choose different strategies to merge two clusters by setting
+ {@link AgglomerativeClusteringParams#LINKAGE} and different distance measures by setting
+ {@link AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+
+ <p>The output contains two tables. The first one assigns one cluster Id for each data point.
+ The second one contains the information of merging two clusters at each step. The data format
+ of the merging information is (clusterId1, clusterId2, distance, sizeOfMergedCluster).
+
+ <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ """
+
+ def __init__(self, java_algo_operator=None):
+ super(AgglomerativeClustering, self).__init__(java_algo_operator)
+
+ @classmethod
+ def _java_algo_operator_package_name(cls) -> str:
+ return "agglomerativeclustering"
+
+ @classmethod
+ def _java_algo_operator_class_name(cls) -> str:
+ return "AgglomerativeClustering"
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/common.py b/flink-ml-python/pyflink/ml/lib/clustering/common.py
index 8f3153d..f64cb1f 100644
--- a/flink-ml-python/pyflink/ml/lib/clustering/common.py
+++ b/flink-ml-python/pyflink/ml/lib/clustering/common.py
@@ -17,7 +17,7 @@
################################################################################
from abc import ABC, abstractmethod
-from pyflink.ml.core.wrapper import JavaModel, JavaEstimator
+from pyflink.ml.core.wrapper import JavaModel, JavaEstimator, JavaAlgoOperator
JAVA_CLUSTERING_PACKAGE_NAME = "org.apache.flink.ml.clustering"
@@ -72,3 +72,29 @@ class JavaClusteringEstimator(JavaEstimator, ABC):
@abstractmethod
def _java_estimator_class_name(cls) -> str:
pass
+
+
+class JavaClusteringAlgoOperator(JavaAlgoOperator, ABC):
+ """
+ Wrapper class for a Java clustering AlgoOperator.
+ """
+
+ def __init__(self, java_algo_operator):
+ super(JavaClusteringAlgoOperator, self).__init__(java_algo_operator)
+
+ @classmethod
+ def _java_stage_path(cls) -> str:
+ return ".".join(
+ [JAVA_CLUSTERING_PACKAGE_NAME,
+ cls._java_algo_operator_package_name(),
+ cls._java_algo_operator_class_name()])
+
+ @classmethod
+ @abstractmethod
+ def _java_algo_operator_package_name(cls) -> str:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def _java_algo_operator_class_name(cls) -> str:
+ pass
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/tests/test_agglomerativeclustering.py b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_agglomerativeclustering.py
new file mode 100644
index 0000000..48606be
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_agglomerativeclustering.py
@@ -0,0 +1,213 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.clustering.agglomerativeclustering import AgglomerativeClustering
+from pyflink.ml.lib.clustering.tests.test_kmeans import group_features_by_prediction
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class AgglomerativeClusteringTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(AgglomerativeClusteringTest, self).setUp()
+ self.input_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense([1, 1]),),
+ (Vectors.dense([1, 4]),),
+ (Vectors.dense([1, 0]),),
+ (Vectors.dense([4, 1.5]),),
+ (Vectors.dense([4, 4]),),
+ (Vectors.dense([4, 0]),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['features'],
+ [DenseVectorTypeInfo()])))
+
+ self.euclidean_average_merge_distances = [1.0, 1.5, 3.0, 3.1394402, 3.9559706]
+ self.cosine_average_merge_distances = [0, 1.1102230E-16, 0.0636708, 0.1425070, 0.3664484]
+ self.manhattan_average_merge_distances = [1, 1.5, 3, 3.75, 4.875]
+ self.eucliean_single_merge_distances = [1, 1.5, 2.5, 3, 3]
+ self.eucliean_ward_merge_distances = [1, 1.5, 3, 4.2573465, 5.5113519]
+ self.eucliean_complete_merge_distances = [1, 1.5, 3, 3.3541019, 5]
+
+ self.eucliean_ward_num_clusters_as_two_result = [
+ {Vectors.dense(1, 1), Vectors.dense(1, 0), Vectors.dense(4, 1.5), Vectors.dense(4, 0)},
+ {Vectors.dense(1, 4), Vectors.dense(4, 4)}
+ ]
+
+ self.eucliean_ward_threshold_as_two_result = [
+ {Vectors.dense(1, 1), Vectors.dense(1, 0)},
+ {Vectors.dense(1, 4), Vectors.dense(4, 4)},
+ {Vectors.dense(4, 1.5), Vectors.dense(4, 0)}
+ ]
+
+ self.tolerance = 1e-7
+
+ def test_param(self):
+ agglomerative_clustering = AgglomerativeClustering()
+ self.assertEqual('features', agglomerative_clustering.features_col)
+ self.assertEqual(2, agglomerative_clustering.num_clusters)
+ self.assertIsNone(agglomerative_clustering.distance_threshold)
+ self.assertEqual('ward', agglomerative_clustering.linkage)
+ self.assertEqual('euclidean', agglomerative_clustering.distance_measure)
+ self.assertFalse(agglomerative_clustering.compute_full_tree)
+ self.assertEqual('prediction', agglomerative_clustering.prediction_col)
+
+ agglomerative_clustering \
+ .set_features_col("test_features") \
+ .set_num_clusters(None) \
+ .set_distance_threshold(0.01) \
+ .set_linkage('average') \
+ .set_distance_measure('cosine') \
+ .set_compute_full_tree(True) \
+ .set_prediction_col('cluster_id')
+
+ self.assertEqual('test_features', agglomerative_clustering.features_col)
+ self.assertIsNone(agglomerative_clustering.num_clusters)
+ self.assertEqual(0.01, agglomerative_clustering.distance_threshold)
+ self.assertEqual('average', agglomerative_clustering.linkage)
+ self.assertEqual('cosine', agglomerative_clustering.distance_measure)
+ self.assertTrue(agglomerative_clustering.compute_full_tree)
+ self.assertEqual('cluster_id', agglomerative_clustering.prediction_col)
+
+ def test_output_schema(self):
+ input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ ('', ''),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['test_input', 'dummy_input'],
+ [Types.STRING(), Types.STRING()])))
+
+ agglomerative_clustering = AgglomerativeClustering() \
+ .set_features_col("test_input") \
+ .set_prediction_col("test_prediction")
+
+ outputs = agglomerative_clustering \
+ .transform(input_data_table)
+
+ self.assertEqual(2, len(outputs))
+
+ self.assertEqual(
+ ['test_input', 'dummy_input', 'test_prediction'],
+ outputs[0].get_schema().get_field_names())
+
+ self.assertEqual(
+ ['clusterId1', 'clusterId2', 'distance', 'sizeOfMergedCluster'],
+ outputs[1].get_schema().get_field_names())
+
+ def verify_clustering_result(self, expected, output_table, features_col, prediction_col):
+ predicted_results = [result for result in
+ self.t_env.to_data_stream(output_table).execute_and_collect()]
+ field_names = output_table.get_schema().get_field_names()
+ actual_groups = group_features_by_prediction(
+ predicted_results,
+ field_names.index(features_col),
+ field_names.index(prediction_col))
+
+ self.assertTrue(expected == actual_groups)
+
+ def verify_merge_info(self, expected, output_table):
+ merge_infos = [result for result in
+ self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+ self.assertEquals(len(expected), len(merge_infos))
+ for i in range(len(expected)):
+ self.assertAlmostEqual(expected[i], merge_infos[i][2], delta=self.tolerance)
+
+ def test_transform(self):
+ agglomerative_clustering = AgglomerativeClustering() \
+ .set_linkage('average') \
+ .set_distance_measure('euclidean') \
+ .set_prediction_col('pred')
+
+ # Tests euclidean distance with linkage as average, num_clusters = 2.
+ outputs = agglomerative_clustering.transform(self.input_table)
+ self.verify_clustering_result(self.eucliean_ward_num_clusters_as_two_result,
+ outputs[0], "features", "pred")
+
+ # Tests euclidean distance with linkage as average, num_clusters = 2,
+ # compute_full_tree = true.
+ outputs = agglomerative_clustering.set_compute_full_tree(True).transform(self.input_table)
+ self.verify_clustering_result(self.eucliean_ward_num_clusters_as_two_result,
+ outputs[0], "features", "pred")
+
+ # Tests euclidean distance with linkage as average, distance_threshold = 2.
+ outputs = agglomerative_clustering \
+ .set_num_clusters(None) \
+ .set_distance_threshold(2.0) \
+ .transform(self.input_table)
+ self.verify_clustering_result(self.eucliean_ward_threshold_as_two_result,
+ outputs[0], "features", "pred")
+
+ def test_merge_info(self):
+ agglomerative_clustering = AgglomerativeClustering() \
+ .set_linkage('average') \
+ .set_distance_measure('euclidean') \
+ .set_prediction_col('pred') \
+ .set_compute_full_tree(True)
+
+ # Tests euclidean distance with linkage as average.
+ outputs = agglomerative_clustering.transform(self.input_table)
+ self.verify_merge_info(self.euclidean_average_merge_distances, outputs[1])
+
+ # Tests cosine distance with linkage as average.
+ outputs = agglomerative_clustering \
+ .set_distance_measure('cosine') \
+ .transform(self.input_table)
+ self.verify_merge_info(self.cosine_average_merge_distances, outputs[1])
+
+ # Tests manhattan distance with linkage as average.
+ outputs = agglomerative_clustering \
+ .set_distance_measure('manhattan') \
+ .transform(self.input_table)
+ self.verify_merge_info(self.manhattan_average_merge_distances, outputs[1])
+
+ # Tests euclidean distance with linkage as complete.
+ outputs = agglomerative_clustering \
+ .set_distance_measure('euclidean') \
+ .set_linkage('complete') \
+ .transform(self.input_table)
+ self.verify_merge_info(self.eucliean_complete_merge_distances, outputs[1])
+
+ # Tests euclidean distance with linkage as single.
+ outputs = agglomerative_clustering.set_linkage('single').transform(self.input_table)
+ self.verify_merge_info(self.eucliean_single_merge_distances, outputs[1])
+
+ # Tests euclidean distance with linkage as ward.
+ outputs = agglomerative_clustering.set_linkage('ward').transform(self.input_table)
+ self.verify_merge_info(self.eucliean_ward_merge_distances, outputs[1])
+
+ # Tests merge info not fully computed.
+ outputs = agglomerative_clustering.set_compute_full_tree(False).transform(self.input_table)
+ self.verify_merge_info(self.eucliean_ward_merge_distances[0:4], outputs[1])
+
+ def test_save_load_transform(self):
+ agglomerative_clustering = AgglomerativeClustering() \
+ .set_linkage('average') \
+ .set_distance_measure('euclidean') \
+ .set_prediction_col('pred')
+
+ path = os.path.join(self.temp_dir, 'test_save_load_and_transform_agglomerativeclustering')
+ agglomerative_clustering.save(path)
+ loaded_agglomerative_clustering = AgglomerativeClustering.load(self.t_env, path)
+ outputs = loaded_agglomerative_clustering.transform(self.input_table)
+ self.verify_clustering_result(self.eucliean_ward_num_clusters_as_two_result,
+ outputs[0], "features", "pred")
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
index 39c6277..2875730 100644
--- a/flink-ml-python/pyflink/ml/lib/param.py
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -28,9 +28,9 @@ class HasDistanceMeasure(WithParams, ABC):
"""
DISTANCE_MEASURE: Param[str] = StringParam(
"distance_measure",
- "Distance measure. Supported options: 'euclidean' and 'cosine'.",
+ "Distance measure. Supported options: 'euclidean', 'manhattan' and 'cosine'.",
"euclidean",
- ParamValidators.in_array(['euclidean', 'cosine']))
+ ParamValidators.in_array(['euclidean', 'manhattan', 'cosine']))
def set_distance_measure(self, distance_measure: str):
return self.set(self.DISTANCE_MEASURE, distance_measure)
diff --git a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
index 7bd894d..5036109 100644
--- a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
+++ b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
@@ -68,9 +68,10 @@ class MLLibTest(PyFlinkMLTestCase):
if hasattr(obj, '_java_stage_path') and name not in (
'JavaClassificationEstimator', 'JavaClassificationModel',
'JavaClusteringEstimator', 'JavaClusteringModel',
- 'JavaEvaluationAlgoOperator', 'JavaFeatureTransformer',
- 'JavaFeatureEstimator', 'JavaFeatureModel',
- 'JavaRegressionEstimator', 'JavaRegressionModel')]
+ 'JavaClusteringAlgoOperator', 'JavaEvaluationAlgoOperator',
+ 'JavaFeatureTransformer', 'JavaFeatureEstimator',
+ 'JavaFeatureModel', 'JavaRegressionEstimator',
+ 'JavaRegressionModel')]
@abstractmethod
def module_name(self):