You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/06/21 03:24:53 UTC

[flink-ml] branch master updated: [FLINK-27096] Optimize KMeans performance

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 10c1ef4  [FLINK-27096] Optimize KMeans performance
10c1ef4 is described below

commit 10c1ef4e91700045b7515e5d80a7b5a143cc1821
Author: yunfengzhou-hub <yu...@outlook.com>
AuthorDate: Tue Jun 21 11:24:48 2022 +0800

    [FLINK-27096] Optimize KMeans performance
    
    This closes #110.
---
 .../flink/ml/common/distance/DistanceMeasure.java  |   7 +-
 .../common/distance/EuclideanDistanceMeasure.java  |  39 +++--
 .../main/java/org/apache/flink/ml/linalg/BLAS.java |  13 +-
 .../org/apache/flink/ml/linalg/VectorWithNorm.java |  40 +++++
 .../ml/linalg/typeinfo/DenseVectorSerializer.java  |   6 +-
 .../linalg/typeinfo/VectorWithNormSerializer.java  | 122 +++++++++++++++
 .../ml/linalg/typeinfo/VectorWithNormTypeInfo.java |  83 +++++++++++
 .../typeinfo/VectorWithNormTypeInfoFactory.java    |  37 +++++
 .../java/org/apache/flink/ml/linalg/BLASTest.java  |   6 +-
 .../apache/flink/ml/linalg/VectorWithNormTest.java |  39 +++++
 .../apache/flink/ml/clustering/kmeans/KMeans.java  | 166 ++++++++-------------
 .../flink/ml/clustering/kmeans/KMeansModel.java    |  11 +-
 .../flink/ml/clustering/kmeans/OnlineKMeans.java   |   7 +-
 .../ml/clustering/kmeans/OnlineKMeansModel.java    |  11 +-
 14 files changed, 460 insertions(+), 127 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
index 0fe01fe..ee7b25d 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.ml.common.distance;
 
-import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
 
 import java.io.Serializable;
 
@@ -40,5 +40,8 @@ public interface DistanceMeasure extends Serializable {
      *
      * <p>Required: The two vectors should have the same dimension.
      */
-    double distance(Vector v1, Vector v2);
+    double distance(VectorWithNorm v1, VectorWithNorm v2);
+
+    /** Finds the index of the closest center to the given point. */
+    int findClosest(VectorWithNorm[] centroids, VectorWithNorm point);
 }
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
index 5864c17..a3798a0 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
@@ -18,8 +18,8 @@
 
 package org.apache.flink.ml.common.distance;
 
-import org.apache.flink.ml.linalg.Vector;
-import org.apache.flink.util.Preconditions;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.VectorWithNorm;
 
 /** Interface for measuring the Euclidean distance between two vectors. */
 public class EuclideanDistanceMeasure implements DistanceMeasure {
@@ -33,16 +33,35 @@ public class EuclideanDistanceMeasure implements DistanceMeasure {
         return instance;
     }
 
-    // TODO: Improve distance calculation with BLAS.
     @Override
-    public double distance(Vector v1, Vector v2) {
-        Preconditions.checkArgument(v1.size() == v2.size());
-        double squaredDistance = 0.0;
+    public double distance(VectorWithNorm v1, VectorWithNorm v2) {
+        return Math.sqrt(distanceSquare(v1, v2));
+    }
+
+    private double distanceSquare(VectorWithNorm v1, VectorWithNorm v2) {
+        return v1.l2Norm * v1.l2Norm + v2.l2Norm * v2.l2Norm - 2.0 * BLAS.dot(v1.vector, v2.vector);
+    }
 
-        for (int i = 0; i < v1.size(); i++) {
-            double diff = v1.get(i) - v2.get(i);
-            squaredDistance += diff * diff;
+    @Override
+    public int findClosest(VectorWithNorm[] centroids, VectorWithNorm point) {
+        double bestL2DistanceSquare = Double.POSITIVE_INFINITY;
+        int bestIndex = 0;
+        for (int i = 0; i < centroids.length; i++) {
+            VectorWithNorm centroid = centroids[i];
+
+            double lowerBoundSqrt = point.l2Norm - centroid.l2Norm;
+            double lowerBound = lowerBoundSqrt * lowerBoundSqrt;
+            if (lowerBound >= bestL2DistanceSquare) {
+                continue;
+            }
+
+            double l2DistanceSquare = distanceSquare(point, centroid);
+            if (l2DistanceSquare < bestL2DistanceSquare) {
+                bestL2DistanceSquare = l2DistanceSquare;
+                bestIndex = i;
+            }
         }
-        return Math.sqrt(squaredDistance);
+
+        return bestIndex;
     }
 }
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
index c00f642..3b46466 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
@@ -114,10 +114,21 @@ public class BLAS {
     }
 
     /** \sqrt(\sum_i x_i * x_i) . */
-    public static double norm2(DenseVector x) {
+    public static double norm2(Vector x) {
+        if (x instanceof DenseVector) {
+            return norm2((DenseVector) x);
+        }
+        return norm2((SparseVector) x);
+    }
+
+    private static double norm2(DenseVector x) {
         return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
     }
 
+    private static double norm2(SparseVector x) {
+        return JAVA_BLAS.dnrm2(x.values.length, x.values, 1);
+    }
+
     /** x = x * a . */
     public static void scal(double a, DenseVector x) {
         JAVA_BLAS.dscal(x.size(), a, x.values, 1);
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java
new file mode 100644
index 0000000..bb78ef2
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorWithNormTypeInfoFactory;
+
+/** A vector with its norm. */
+@TypeInfo(VectorWithNormTypeInfoFactory.class)
+public class VectorWithNorm {
+    public final Vector vector;
+
+    public final double l2Norm;
+
+    public VectorWithNorm(Vector vector) {
+        this(vector, BLAS.norm2(vector));
+    }
+
+    public VectorWithNorm(Vector vector, double l2Norm) {
+        this.vector = vector;
+        this.l2Norm = l2Norm;
+    }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
index 94b8d91..5b6f984 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
@@ -84,7 +84,7 @@ public final class DenseVectorSerializer extends TypeSerializer<DenseVector> {
         target.writeInt(len);
 
         for (int i = 0; i < len; i++) {
-            Bits.putDouble(buf, i << 3, vector.values[i]);
+            Bits.putDouble(buf, (i & 127) << 3, vector.values[i]);
             if ((i & 127) == 127) {
                 target.write(buf);
             }
@@ -104,12 +104,12 @@ public final class DenseVectorSerializer extends TypeSerializer<DenseVector> {
     private void readDoubleArray(double[] dst, DataInputView source, int len) throws IOException {
         int index = 0;
         for (int i = 0; i < (len >> 7); i++) {
-            source.read(buf, 0, 1024);
+            source.readFully(buf, 0, 1024);
             for (int j = 0; j < 128; j++) {
                 dst[index++] = Bits.getDouble(buf, j << 3);
             }
         }
-        source.read(buf, 0, (len << 3) & 1023);
+        source.readFully(buf, 0, (len << 3) & 1023);
         for (int j = 0; j < (len & 127); j++) {
             dst[index++] = Bits.getDouble(buf, j << 3);
         }
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java
new file mode 100644
index 0000000..92d1de1
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+
+import java.io.IOException;
+
+/** Specialized serializer for {@link VectorWithNorm}. */
+public class VectorWithNormSerializer extends TypeSerializer<VectorWithNorm> {
+    private final VectorSerializer vectorSerializer = new VectorSerializer();
+
+    private static final long serialVersionUID = 1L;
+
+    private static final double[] EMPTY = new double[0];
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public TypeSerializer<VectorWithNorm> duplicate() {
+        return new VectorWithNormSerializer();
+    }
+
+    @Override
+    public VectorWithNorm createInstance() {
+        return new VectorWithNorm(new DenseVector(EMPTY));
+    }
+
+    @Override
+    public VectorWithNorm copy(VectorWithNorm from) {
+        Vector vector = vectorSerializer.copy(from.vector);
+        return new VectorWithNorm(vector, from.l2Norm);
+    }
+
+    @Override
+    public VectorWithNorm copy(VectorWithNorm from, VectorWithNorm reuse) {
+        Vector vector = vectorSerializer.copy(from.vector, reuse.vector);
+        return new VectorWithNorm(vector, from.l2Norm);
+    }
+
+    @Override
+    public int getLength() {
+        return -1;
+    }
+
+    @Override
+    public void serialize(VectorWithNorm from, DataOutputView dataOutputView) throws IOException {
+        vectorSerializer.serialize(from.vector, dataOutputView);
+        dataOutputView.writeDouble(from.l2Norm);
+    }
+
+    @Override
+    public VectorWithNorm deserialize(DataInputView dataInputView) throws IOException {
+        Vector vector = vectorSerializer.deserialize(dataInputView);
+        double l2NormSquare = dataInputView.readDouble();
+        return new VectorWithNorm(vector, l2NormSquare);
+    }
+
+    @Override
+    public VectorWithNorm deserialize(VectorWithNorm reuse, DataInputView dataInputView)
+            throws IOException {
+        Vector vector = vectorSerializer.deserialize(reuse.vector, dataInputView);
+        double l2NormSquare = dataInputView.readDouble();
+        return new VectorWithNorm(vector, l2NormSquare);
+    }
+
+    @Override
+    public void copy(DataInputView dataInputView, DataOutputView dataOutputView)
+            throws IOException {
+        vectorSerializer.copy(dataInputView, dataOutputView);
+        dataOutputView.write(dataInputView, 8);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        return o instanceof VectorWithNormSerializer;
+    }
+
+    @Override
+    public int hashCode() {
+        return VectorWithNormSerializer.class.hashCode();
+    }
+
+    @Override
+    public TypeSerializerSnapshot<VectorWithNorm> snapshotConfiguration() {
+        return new VectorWithNormSerializerSnapshot();
+    }
+
+    private static class VectorWithNormSerializerSnapshot
+            extends SimpleTypeSerializerSnapshot<VectorWithNorm> {
+        public VectorWithNormSerializerSnapshot() {
+            super(VectorWithNormSerializer::new);
+        }
+    }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfo.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfo.java
new file mode 100644
index 0000000..e56a354
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfo.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+
+/** A {@link TypeInformation} for the {@link VectorWithNorm} type. */
+public class VectorWithNormTypeInfo extends TypeInformation<VectorWithNorm> {
+    @Override
+    public boolean isBasicType() {
+        return false;
+    }
+
+    @Override
+    public boolean isTupleType() {
+        return false;
+    }
+
+    @Override
+    public int getArity() {
+        return 2;
+    }
+
+    @Override
+    public int getTotalFields() {
+        return 2;
+    }
+
+    @Override
+    public Class<VectorWithNorm> getTypeClass() {
+        return VectorWithNorm.class;
+    }
+
+    @Override
+    public boolean isKeyType() {
+        return false;
+    }
+
+    @Override
+    public TypeSerializer<VectorWithNorm> createSerializer(ExecutionConfig executionConfig) {
+        return new VectorWithNormSerializer();
+    }
+
+    @Override
+    public String toString() {
+        return "VectorWithNormType";
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        return o instanceof VectorWithNormTypeInfo;
+    }
+
+    @Override
+    public int hashCode() {
+        return getClass().hashCode();
+    }
+
+    @Override
+    public boolean canEqual(Object o) {
+        return o instanceof VectorWithNormTypeInfo;
+    }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfoFactory.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfoFactory.java
new file mode 100644
index 0000000..13f46bf
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfoFactory.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeinfo.TypeInfoFactory;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+
+import java.lang.reflect.Type;
+import java.util.Map;
+
+/** Used by {@link TypeExtractor} to create a {@link TypeInformation} for {@link VectorWithNorm}. */
+public class VectorWithNormTypeInfoFactory extends TypeInfoFactory<VectorWithNorm> {
+    @Override
+    public TypeInformation<VectorWithNorm> createTypeInfo(
+            Type type, Map<String, TypeInformation<?>> map) {
+        return new VectorWithNormTypeInfo();
+    }
+}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
index 21d68a9..469bbe7 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
@@ -86,8 +86,10 @@ public class BLASTest {
 
     @Test
     public void testNorm2() {
-        double expectedResult = Math.sqrt(55);
-        assertEquals(expectedResult, BLAS.norm2(inputDenseVec), TOLERANCE);
+        assertEquals(Math.sqrt(55), BLAS.norm2(inputDenseVec), TOLERANCE);
+
+        SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5});
+        assertEquals(Math.sqrt(35), BLAS.norm2(sparseVector), TOLERANCE);
     }
 
     @Test
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java
new file mode 100644
index 0000000..25b45b0
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests the behavior of {@link VectorWithNorm}. */
+public class VectorWithNormTest {
+    @Test
+    public void testL2Norm() {
+        DenseVector denseVector = Vectors.dense(1, 2, 3);
+        VectorWithNorm denseVectorWithNorm = new VectorWithNorm(denseVector);
+        assertEquals(denseVector, denseVectorWithNorm.vector);
+        assertEquals(Math.sqrt(14), denseVectorWithNorm.l2Norm, 1e-7);
+
+        SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 2, 3});
+        VectorWithNorm sparseVectorWithNorm = new VectorWithNorm(sparseVector);
+        assertEquals(sparseVector, sparseVectorWithNorm.vector);
+        assertEquals(Math.sqrt(14), sparseVectorWithNorm.l2Norm, 1e-7);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 47a185c..6f2120e 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -23,10 +23,9 @@ import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.iteration.DataStreamList;
@@ -40,24 +39,23 @@ import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
 import org.apache.flink.iteration.operator.OperatorStateUtils;
 import org.apache.flink.ml.api.Estimator;
 import org.apache.flink.ml.common.datastream.DataStreamUtils;
-import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
 import org.apache.flink.ml.common.distance.DistanceMeasure;
 import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound;
 import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
+import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
-import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.ml.linalg.VectorWithNorm;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorWithNormSerializer;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
@@ -65,10 +63,9 @@ import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
-import org.apache.commons.collections.IteratorUtils;
-
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -153,39 +150,24 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
             DataStream<Integer> terminationCriteria =
                     centroids.flatMap(new TerminateOnMaxIter(maxIterationNum));
 
-            DataStream<Tuple2<Integer, DenseVector>> centroidIdAndPoints =
+            DataStream<Tuple2<Integer[], DenseVector[]>> centroidIdAndPoints =
                     points.connect(centroids.broadcast())
                             .transform(
-                                    "SelectNearestCentroid",
+                                    "CentroidsUpdateAccumulator",
                                     new TupleTypeInfo<>(
-                                            BasicTypeInfo.INT_TYPE_INFO,
-                                            DenseVectorTypeInfo.INSTANCE),
-                                    new SelectNearestCentroidOperator(distanceMeasure));
+                                            BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO,
+                                            ObjectArrayTypeInfo.getInfoFor(
+                                                    DenseVectorTypeInfo.INSTANCE)),
+                                    new CentroidsUpdateAccumulator(distanceMeasure));
 
             DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints.getTransformation(), 100);
 
-            PerRoundSubBody perRoundSubBody =
-                    new PerRoundSubBody() {
-                        @Override
-                        public DataStreamList process(DataStreamList inputs) {
-                            DataStream<Tuple2<Integer, DenseVector>> centroidIdAndPoints =
-                                    inputs.get(0);
-                            DataStream<KMeansModelData> modelDataStream =
-                                    centroidIdAndPoints
-                                            .map(new CountAppender())
-                                            .keyBy(t -> t.f0)
-                                            .window(EndOfStreamWindows.get())
-                                            .reduce(new CentroidAccumulator())
-                                            .map(new CentroidAverager())
-                                            .windowAll(EndOfStreamWindows.get())
-                                            .apply(new ModelDataGenerator());
-                            return DataStreamList.of(modelDataStream);
-                        }
-                    };
+            int parallelism = centroidIdAndPoints.getParallelism();
             DataStream<KMeansModelData> newModelData =
-                    IterationBody.forEachRound(
-                                    DataStreamList.of(centroidIdAndPoints), perRoundSubBody)
-                            .get(0);
+                    centroidIdAndPoints
+                            .countWindowAll(parallelism)
+                            .reduce(new CentroidsUpdateReducer())
+                            .map(new ModelDataGenerator());
 
             DataStream<DenseVector[]> newCentroids =
                     newModelData.map(x -> x.centroids).setParallelism(1);
@@ -200,70 +182,48 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
         }
     }
 
-    private static class ModelDataGenerator
-            implements AllWindowFunction<Tuple2<DenseVector, Double>, KMeansModelData, TimeWindow> {
+    private static class CentroidsUpdateReducer
+            implements ReduceFunction<Tuple2<Integer[], DenseVector[]>> {
         @Override
-        public void apply(
-                TimeWindow timeWindow,
-                Iterable<Tuple2<DenseVector, Double>> iterable,
-                Collector<KMeansModelData> collector) {
-            List<Tuple2<DenseVector, Double>> list = IteratorUtils.toList(iterable.iterator());
-            DenseVector[] centroids = new DenseVector[list.size()];
-            DenseVector weights = new DenseVector(list.size());
-            for (int i = 0; i < list.size(); i++) {
-                centroids[i] = list.get(i).f0;
-                weights.values[i] = list.get(i).f1;
+        public Tuple2<Integer[], DenseVector[]> reduce(
+                Tuple2<Integer[], DenseVector[]> tuple2, Tuple2<Integer[], DenseVector[]> t1)
+                throws Exception {
+            for (int i = 0; i < tuple2.f0.length; i++) {
+                tuple2.f0[i] += t1.f0[i];
+                BLAS.axpy(1.0, t1.f1[i], tuple2.f1[i]);
             }
-            collector.collect(new KMeansModelData(centroids, weights));
-        }
-    }
 
-    private static class CentroidAverager
-            implements MapFunction<
-                    Tuple3<Integer, DenseVector, Long>, Tuple2<DenseVector, Double>> {
-        @Override
-        public Tuple2<DenseVector, Double> map(Tuple3<Integer, DenseVector, Long> value) {
-            for (int i = 0; i < value.f1.size(); i++) {
-                value.f1.values[i] /= value.f2;
-            }
-            return Tuple2.of(value.f1, value.f2.doubleValue());
+            return tuple2;
         }
     }
 
-    private static class CentroidAccumulator
-            implements ReduceFunction<Tuple3<Integer, DenseVector, Long>> {
+    private static class ModelDataGenerator
+            implements MapFunction<Tuple2<Integer[], DenseVector[]>, KMeansModelData> {
         @Override
-        public Tuple3<Integer, DenseVector, Long> reduce(
-                Tuple3<Integer, DenseVector, Long> v1, Tuple3<Integer, DenseVector, Long> v2) {
-            for (int i = 0; i < v1.f1.size(); i++) {
-                v1.f1.values[i] += v2.f1.values[i];
+        public KMeansModelData map(Tuple2<Integer[], DenseVector[]> tuple2) throws Exception {
+            double[] weights = new double[tuple2.f0.length];
+            for (int i = 0; i < tuple2.f0.length; i++) {
+                BLAS.scal(1.0 / tuple2.f0[i], tuple2.f1[i]);
+                weights[i] = tuple2.f0[i];
             }
-            return new Tuple3<>(v1.f0, v1.f1, v1.f2 + v2.f2);
-        }
-    }
 
-    private static class CountAppender
-            implements MapFunction<
-                    Tuple2<Integer, DenseVector>, Tuple3<Integer, DenseVector, Long>> {
-        @Override
-        public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value) {
-            return Tuple3.of(value.f0, value.f1, 1L);
+            return new KMeansModelData(tuple2.f1, new DenseVector(weights));
         }
     }
 
-    private static class SelectNearestCentroidOperator
-            extends AbstractStreamOperator<Tuple2<Integer, DenseVector>>
+    private static class CentroidsUpdateAccumulator
+            extends AbstractStreamOperator<Tuple2<Integer[], DenseVector[]>>
             implements TwoInputStreamOperator<
-                            DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
-                    IterationListener<Tuple2<Integer, DenseVector>> {
+                            DenseVector, DenseVector[], Tuple2<Integer[], DenseVector[]>>,
+                    IterationListener<Tuple2<Integer[], DenseVector[]>> {
 
         private final DistanceMeasure distanceMeasure;
 
         private ListState<DenseVector[]> centroids;
 
-        private ListStateWithCache<DenseVector> points;
+        private ListStateWithCache<VectorWithNorm> points;
 
-        public SelectNearestCentroidOperator(DistanceMeasure distanceMeasure) {
+        public CentroidsUpdateAccumulator(DistanceMeasure distanceMeasure) {
             super();
             this.distanceMeasure = distanceMeasure;
         }
@@ -281,7 +241,7 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
 
             points =
                     new ListStateWithCache<>(
-                            new DenseVectorSerializer(),
+                            new VectorWithNormSerializer(),
                             getContainingTask(),
                             getRuntimeContext(),
                             context,
@@ -296,7 +256,7 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
 
         @Override
         public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception {
-            points.add(streamRecord.getValue());
+            points.add(new VectorWithNorm(streamRecord.getValue()));
         }
 
         @Override
@@ -307,44 +267,46 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
 
         @Override
         public void onEpochWatermarkIncremented(
-                int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out)
+                int epochWatermark,
+                Context context,
+                Collector<Tuple2<Integer[], DenseVector[]>> out)
                 throws Exception {
             DenseVector[] centroidValues =
                     Objects.requireNonNull(
                             OperatorStateUtils.getUniqueElement(centroids, "centroids")
                                     .orElse(null));
-            for (DenseVector point : points.get()) {
-                int closestCentroidId =
-                        findClosestCentroidId(centroidValues, point, distanceMeasure);
-                output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
+
+            VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[centroidValues.length];
+            for (int i = 0; i < centroidsWithNorm.length; i++) {
+                centroidsWithNorm[i] = new VectorWithNorm(centroidValues[i]);
+            }
+
+            DenseVector[] newCentroids = new DenseVector[centroidValues.length];
+            Integer[] counts = new Integer[centroidValues.length];
+            Arrays.fill(counts, 0);
+            for (int i = 0; i < centroidValues.length; i++) {
+                newCentroids[i] = new DenseVector(centroidValues[i].size());
+            }
+
+            for (VectorWithNorm point : points.get()) {
+                int closestCentroidId = distanceMeasure.findClosest(centroidsWithNorm, point);
+                BLAS.axpy(1.0, point.vector, newCentroids[closestCentroidId]);
+                counts[closestCentroidId]++;
             }
 
+            output.collect(new StreamRecord<>(Tuple2.of(counts, newCentroids)));
+
             centroids.clear();
         }
 
         @Override
         public void onIterationTerminated(
-                Context context, Collector<Tuple2<Integer, DenseVector>> collector) {
+                Context context, Collector<Tuple2<Integer[], DenseVector[]>> collector) {
             centroids.clear();
             points.clear();
         }
     }
 
-    protected static int findClosestCentroidId(
-            DenseVector[] centroids, DenseVector point, DistanceMeasure distanceMeasure) {
-        double minDistance = Double.MAX_VALUE;
-        int closestCentroidId = -1;
-        for (int i = 0; i < centroids.length; i++) {
-            DenseVector centroid = centroids[i];
-            double distance = distanceMeasure.distance(centroid, point);
-            if (distance < minDistance) {
-                minDistance = distance;
-                closestCentroidId = i;
-            }
-        }
-        return closestCentroidId;
-    }
-
     public static DataStream<DenseVector[]> selectRandomCentroids(
             DataStream<DenseVector> data, int k, long seed) {
         DataStream<DenseVector[]> resultStream =
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
index 9cfea62..5aa57ae 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.common.distance.DistanceMeasure;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
@@ -111,7 +112,7 @@ public class KMeansModel implements Model<KMeansModel>, KMeansModelParams<KMeans
 
         private final int k;
 
-        private DenseVector[] centroids;
+        private VectorWithNorm[] centroids;
 
         public PredictLabelFunction(
                 String broadcastModelKey,
@@ -131,10 +132,14 @@ public class KMeansModel implements Model<KMeansModel>, KMeansModelParams<KMeans
                         (KMeansModelData)
                                 getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
                 Preconditions.checkArgument(modelData.centroids.length <= k);
-                centroids = modelData.centroids;
+                centroids = new VectorWithNorm[modelData.centroids.length];
+                for (int i = 0; i < modelData.centroids.length; i++) {
+                    centroids[i] = new VectorWithNorm(modelData.centroids[i]);
+                }
             }
             DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense();
-            int closestCentroidId = KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+            int closestCentroidId =
+                    distanceMeasure.findClosest(centroids, new VectorWithNorm(point));
             return Row.join(dataPoint, Row.of(closestCentroidId));
         }
     }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
index 4112527..2c8dbd1 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
@@ -35,6 +35,7 @@ import org.apache.flink.ml.common.distance.DistanceMeasure;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
@@ -276,6 +277,10 @@ public class OnlineKMeans
             KMeansModelData modelData =
                     OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
             DenseVector[] centroids = modelData.centroids;
+            VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[modelData.centroids.length];
+            for (int i = 0; i < centroidsWithNorm.length; i++) {
+                centroidsWithNorm[i] = new VectorWithNorm(modelData.centroids[i]);
+            }
             DenseVector weights = modelData.weights;
             modelDataState.clear();
 
@@ -296,7 +301,7 @@ public class OnlineKMeans
             }
             for (DenseVector point : points) {
                 int closestCentroidId =
-                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                        distanceMeasure.findClosest(centroidsWithNorm, new VectorWithNorm(point));
                 counts[closestCentroidId]++;
                 BLAS.axpy(1.0, point, sums[closestCentroidId]);
             }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
index 6a7c852..43742f1 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.common.distance.DistanceMeasure;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
@@ -114,7 +115,7 @@ public class OnlineKMeansModel
 
         private final int k;
 
-        private DenseVector[] centroids;
+        private VectorWithNorm[] centroids;
 
         // TODO: replace this with a complete solution of reading first model data from unbounded
         // model data stream before processing the first predict data.
@@ -173,7 +174,8 @@ public class OnlineKMeansModel
                 return;
             }
             DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense();
-            int closestCentroidId = KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+            int closestCentroidId =
+                    distanceMeasure.findClosest(centroids, new VectorWithNorm(point));
             output.collect(new StreamRecord<>(Row.join(dataPoint, Row.of(closestCentroidId))));
         }
 
@@ -181,7 +183,10 @@ public class OnlineKMeansModel
         public void processElement2(StreamRecord<KMeansModelData> streamRecord) throws Exception {
             KMeansModelData modelData = streamRecord.getValue();
             Preconditions.checkArgument(modelData.centroids.length <= k);
-            centroids = modelData.centroids;
+            centroids = new VectorWithNorm[modelData.centroids.length];
+            for (int i = 0; i < centroids.length; i++) {
+                centroids[i] = new VectorWithNorm(modelData.centroids[i]);
+            }
             modelDataVersion++;
             for (Row dataPoint : bufferedPointsState.get()) {
                 processElement1(new StreamRecord<>(dataPoint));