You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/06/21 03:24:53 UTC
[flink-ml] branch master updated: [FLINK-27096] Optimize KMeans performance
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 10c1ef4 [FLINK-27096] Optimize KMeans performance
10c1ef4 is described below
commit 10c1ef4e91700045b7515e5d80a7b5a143cc1821
Author: yunfengzhou-hub <yu...@outlook.com>
AuthorDate: Tue Jun 21 11:24:48 2022 +0800
[FLINK-27096] Optimize KMeans performance
This closes #110.
---
.../flink/ml/common/distance/DistanceMeasure.java | 7 +-
.../common/distance/EuclideanDistanceMeasure.java | 39 +++--
.../main/java/org/apache/flink/ml/linalg/BLAS.java | 13 +-
.../org/apache/flink/ml/linalg/VectorWithNorm.java | 40 +++++
.../ml/linalg/typeinfo/DenseVectorSerializer.java | 6 +-
.../linalg/typeinfo/VectorWithNormSerializer.java | 122 +++++++++++++++
.../ml/linalg/typeinfo/VectorWithNormTypeInfo.java | 83 +++++++++++
.../typeinfo/VectorWithNormTypeInfoFactory.java | 37 +++++
.../java/org/apache/flink/ml/linalg/BLASTest.java | 6 +-
.../apache/flink/ml/linalg/VectorWithNormTest.java | 39 +++++
.../apache/flink/ml/clustering/kmeans/KMeans.java | 166 ++++++++-------------
.../flink/ml/clustering/kmeans/KMeansModel.java | 11 +-
.../flink/ml/clustering/kmeans/OnlineKMeans.java | 7 +-
.../ml/clustering/kmeans/OnlineKMeansModel.java | 11 +-
14 files changed, 460 insertions(+), 127 deletions(-)
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
index 0fe01fe..ee7b25d 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java
@@ -18,7 +18,7 @@
package org.apache.flink.ml.common.distance;
-import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
import java.io.Serializable;
@@ -40,5 +40,8 @@ public interface DistanceMeasure extends Serializable {
*
* <p>Required: The two vectors should have the same dimension.
*/
- double distance(Vector v1, Vector v2);
+ double distance(VectorWithNorm v1, VectorWithNorm v2);
+
+ /** Finds the index of the closest center to the given point. */
+ int findClosest(VectorWithNorm[] centroids, VectorWithNorm point);
}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
index 5864c17..a3798a0 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
@@ -18,8 +18,8 @@
package org.apache.flink.ml.common.distance;
-import org.apache.flink.ml.linalg.Vector;
-import org.apache.flink.util.Preconditions;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.VectorWithNorm;
/** Interface for measuring the Euclidean distance between two vectors. */
public class EuclideanDistanceMeasure implements DistanceMeasure {
@@ -33,16 +33,35 @@ public class EuclideanDistanceMeasure implements DistanceMeasure {
return instance;
}
- // TODO: Improve distance calculation with BLAS.
@Override
- public double distance(Vector v1, Vector v2) {
- Preconditions.checkArgument(v1.size() == v2.size());
- double squaredDistance = 0.0;
+ public double distance(VectorWithNorm v1, VectorWithNorm v2) {
+ return Math.sqrt(distanceSquare(v1, v2));
+ }
+
+ private double distanceSquare(VectorWithNorm v1, VectorWithNorm v2) {
+ return v1.l2Norm * v1.l2Norm + v2.l2Norm * v2.l2Norm - 2.0 * BLAS.dot(v1.vector, v2.vector);
+ }
- for (int i = 0; i < v1.size(); i++) {
- double diff = v1.get(i) - v2.get(i);
- squaredDistance += diff * diff;
+ @Override
+ public int findClosest(VectorWithNorm[] centroids, VectorWithNorm point) {
+ double bestL2DistanceSquare = Double.POSITIVE_INFINITY;
+ int bestIndex = 0;
+ for (int i = 0; i < centroids.length; i++) {
+ VectorWithNorm centroid = centroids[i];
+
+ double lowerBoundSqrt = point.l2Norm - centroid.l2Norm;
+ double lowerBound = lowerBoundSqrt * lowerBoundSqrt;
+ if (lowerBound >= bestL2DistanceSquare) {
+ continue;
+ }
+
+ double l2DistanceSquare = distanceSquare(point, centroid);
+ if (l2DistanceSquare < bestL2DistanceSquare) {
+ bestL2DistanceSquare = l2DistanceSquare;
+ bestIndex = i;
+ }
}
- return Math.sqrt(squaredDistance);
+
+ return bestIndex;
}
}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
index c00f642..3b46466 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
@@ -114,10 +114,21 @@ public class BLAS {
}
/** \sqrt(\sum_i x_i * x_i) . */
- public static double norm2(DenseVector x) {
+ public static double norm2(Vector x) {
+ if (x instanceof DenseVector) {
+ return norm2((DenseVector) x);
+ }
+ return norm2((SparseVector) x);
+ }
+
+ private static double norm2(DenseVector x) {
return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
}
+ private static double norm2(SparseVector x) {
+ return JAVA_BLAS.dnrm2(x.values.length, x.values, 1);
+ }
+
/** x = x * a . */
public static void scal(double a, DenseVector x) {
JAVA_BLAS.dscal(x.size(), a, x.values, 1);
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java
new file mode 100644
index 0000000..bb78ef2
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorWithNormTypeInfoFactory;
+
+/** A vector with its norm. */
+@TypeInfo(VectorWithNormTypeInfoFactory.class)
+public class VectorWithNorm {
+ public final Vector vector;
+
+ public final double l2Norm;
+
+ public VectorWithNorm(Vector vector) {
+ this(vector, BLAS.norm2(vector));
+ }
+
+ public VectorWithNorm(Vector vector, double l2Norm) {
+ this.vector = vector;
+ this.l2Norm = l2Norm;
+ }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
index 94b8d91..5b6f984 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
@@ -84,7 +84,7 @@ public final class DenseVectorSerializer extends TypeSerializer<DenseVector> {
target.writeInt(len);
for (int i = 0; i < len; i++) {
- Bits.putDouble(buf, i << 3, vector.values[i]);
+ Bits.putDouble(buf, (i & 127) << 3, vector.values[i]);
if ((i & 127) == 127) {
target.write(buf);
}
@@ -104,12 +104,12 @@ public final class DenseVectorSerializer extends TypeSerializer<DenseVector> {
private void readDoubleArray(double[] dst, DataInputView source, int len) throws IOException {
int index = 0;
for (int i = 0; i < (len >> 7); i++) {
- source.read(buf, 0, 1024);
+ source.readFully(buf, 0, 1024);
for (int j = 0; j < 128; j++) {
dst[index++] = Bits.getDouble(buf, j << 3);
}
}
- source.read(buf, 0, (len << 3) & 1023);
+ source.readFully(buf, 0, (len << 3) & 1023);
for (int j = 0; j < (len & 127); j++) {
dst[index++] = Bits.getDouble(buf, j << 3);
}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java
new file mode 100644
index 0000000..92d1de1
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+
+import java.io.IOException;
+
+/** Specialized serializer for {@link VectorWithNorm}. */
+public class VectorWithNormSerializer extends TypeSerializer<VectorWithNorm> {
+ private final VectorSerializer vectorSerializer = new VectorSerializer();
+
+ private static final long serialVersionUID = 1L;
+
+ private static final double[] EMPTY = new double[0];
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<VectorWithNorm> duplicate() {
+ return new VectorWithNormSerializer();
+ }
+
+ @Override
+ public VectorWithNorm createInstance() {
+ return new VectorWithNorm(new DenseVector(EMPTY));
+ }
+
+ @Override
+ public VectorWithNorm copy(VectorWithNorm from) {
+ Vector vector = vectorSerializer.copy(from.vector);
+ return new VectorWithNorm(vector, from.l2Norm);
+ }
+
+ @Override
+ public VectorWithNorm copy(VectorWithNorm from, VectorWithNorm reuse) {
+ Vector vector = vectorSerializer.copy(from.vector, reuse.vector);
+ return new VectorWithNorm(vector, from.l2Norm);
+ }
+
+ @Override
+ public int getLength() {
+ return -1;
+ }
+
+ @Override
+ public void serialize(VectorWithNorm from, DataOutputView dataOutputView) throws IOException {
+ vectorSerializer.serialize(from.vector, dataOutputView);
+ dataOutputView.writeDouble(from.l2Norm);
+ }
+
+ @Override
+ public VectorWithNorm deserialize(DataInputView dataInputView) throws IOException {
+ Vector vector = vectorSerializer.deserialize(dataInputView);
+ double l2NormSquare = dataInputView.readDouble();
+ return new VectorWithNorm(vector, l2NormSquare);
+ }
+
+ @Override
+ public VectorWithNorm deserialize(VectorWithNorm reuse, DataInputView dataInputView)
+ throws IOException {
+ Vector vector = vectorSerializer.deserialize(reuse.vector, dataInputView);
+ double l2NormSquare = dataInputView.readDouble();
+ return new VectorWithNorm(vector, l2NormSquare);
+ }
+
+ @Override
+ public void copy(DataInputView dataInputView, DataOutputView dataOutputView)
+ throws IOException {
+ vectorSerializer.copy(dataInputView, dataOutputView);
+ dataOutputView.write(dataInputView, 8);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return o instanceof VectorWithNormSerializer;
+ }
+
+ @Override
+ public int hashCode() {
+ return VectorWithNormSerializer.class.hashCode();
+ }
+
+ @Override
+ public TypeSerializerSnapshot<VectorWithNorm> snapshotConfiguration() {
+ return new VectorWithNormSerializerSnapshot();
+ }
+
+ private static class VectorWithNormSerializerSnapshot
+ extends SimpleTypeSerializerSnapshot<VectorWithNorm> {
+ public VectorWithNormSerializerSnapshot() {
+ super(VectorWithNormSerializer::new);
+ }
+ }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfo.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfo.java
new file mode 100644
index 0000000..e56a354
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfo.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+
+/** A {@link TypeInformation} for the {@link VectorWithNorm} type. */
+public class VectorWithNormTypeInfo extends TypeInformation<VectorWithNorm> {
+ @Override
+ public boolean isBasicType() {
+ return false;
+ }
+
+ @Override
+ public boolean isTupleType() {
+ return false;
+ }
+
+ @Override
+ public int getArity() {
+ return 2;
+ }
+
+ @Override
+ public int getTotalFields() {
+ return 2;
+ }
+
+ @Override
+ public Class<VectorWithNorm> getTypeClass() {
+ return VectorWithNorm.class;
+ }
+
+ @Override
+ public boolean isKeyType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<VectorWithNorm> createSerializer(ExecutionConfig executionConfig) {
+ return new VectorWithNormSerializer();
+ }
+
+ @Override
+ public String toString() {
+ return "VectorWithNormType";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return o instanceof VectorWithNormTypeInfo;
+ }
+
+ @Override
+ public int hashCode() {
+ return getClass().hashCode();
+ }
+
+ @Override
+ public boolean canEqual(Object o) {
+ return o instanceof VectorWithNormTypeInfo;
+ }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfoFactory.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfoFactory.java
new file mode 100644
index 0000000..13f46bf
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormTypeInfoFactory.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeinfo.TypeInfoFactory;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+
+import java.lang.reflect.Type;
+import java.util.Map;
+
+/** Used by {@link TypeExtractor} to create a {@link TypeInformation} for {@link VectorWithNorm}. */
+public class VectorWithNormTypeInfoFactory extends TypeInfoFactory<VectorWithNorm> {
+ @Override
+ public TypeInformation<VectorWithNorm> createTypeInfo(
+ Type type, Map<String, TypeInformation<?>> map) {
+ return new VectorWithNormTypeInfo();
+ }
+}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
index 21d68a9..469bbe7 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
@@ -86,8 +86,10 @@ public class BLASTest {
@Test
public void testNorm2() {
- double expectedResult = Math.sqrt(55);
- assertEquals(expectedResult, BLAS.norm2(inputDenseVec), TOLERANCE);
+ assertEquals(Math.sqrt(55), BLAS.norm2(inputDenseVec), TOLERANCE);
+
+ SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5});
+ assertEquals(Math.sqrt(35), BLAS.norm2(sparseVector), TOLERANCE);
}
@Test
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java
new file mode 100644
index 0000000..25b45b0
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests the behavior of {@link VectorWithNorm}. */
+public class VectorWithNormTest {
+ @Test
+ public void testL2Norm() {
+ DenseVector denseVector = Vectors.dense(1, 2, 3);
+ VectorWithNorm denseVectorWithNorm = new VectorWithNorm(denseVector);
+ assertEquals(denseVector, denseVectorWithNorm.vector);
+ assertEquals(Math.sqrt(14), denseVectorWithNorm.l2Norm, 1e-7);
+
+ SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 2, 3});
+ VectorWithNorm sparseVectorWithNorm = new VectorWithNorm(sparseVector);
+ assertEquals(sparseVector, sparseVectorWithNorm.vector);
+ assertEquals(Math.sqrt(14), sparseVectorWithNorm.l2Norm, 1e-7);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 47a185c..6f2120e 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -23,10 +23,9 @@ import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.iteration.DataStreamList;
@@ -40,24 +39,23 @@ import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
-import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
+import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
-import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorWithNormSerializer;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
@@ -65,10 +63,9 @@ import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
-import org.apache.commons.collections.IteratorUtils;
-
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -153,39 +150,24 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
DataStream<Integer> terminationCriteria =
centroids.flatMap(new TerminateOnMaxIter(maxIterationNum));
- DataStream<Tuple2<Integer, DenseVector>> centroidIdAndPoints =
+ DataStream<Tuple2<Integer[], DenseVector[]>> centroidIdAndPoints =
points.connect(centroids.broadcast())
.transform(
- "SelectNearestCentroid",
+ "CentroidsUpdateAccumulator",
new TupleTypeInfo<>(
- BasicTypeInfo.INT_TYPE_INFO,
- DenseVectorTypeInfo.INSTANCE),
- new SelectNearestCentroidOperator(distanceMeasure));
+ BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO,
+ ObjectArrayTypeInfo.getInfoFor(
+ DenseVectorTypeInfo.INSTANCE)),
+ new CentroidsUpdateAccumulator(distanceMeasure));
DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints.getTransformation(), 100);
- PerRoundSubBody perRoundSubBody =
- new PerRoundSubBody() {
- @Override
- public DataStreamList process(DataStreamList inputs) {
- DataStream<Tuple2<Integer, DenseVector>> centroidIdAndPoints =
- inputs.get(0);
- DataStream<KMeansModelData> modelDataStream =
- centroidIdAndPoints
- .map(new CountAppender())
- .keyBy(t -> t.f0)
- .window(EndOfStreamWindows.get())
- .reduce(new CentroidAccumulator())
- .map(new CentroidAverager())
- .windowAll(EndOfStreamWindows.get())
- .apply(new ModelDataGenerator());
- return DataStreamList.of(modelDataStream);
- }
- };
+ int parallelism = centroidIdAndPoints.getParallelism();
DataStream<KMeansModelData> newModelData =
- IterationBody.forEachRound(
- DataStreamList.of(centroidIdAndPoints), perRoundSubBody)
- .get(0);
+ centroidIdAndPoints
+ .countWindowAll(parallelism)
+ .reduce(new CentroidsUpdateReducer())
+ .map(new ModelDataGenerator());
DataStream<DenseVector[]> newCentroids =
newModelData.map(x -> x.centroids).setParallelism(1);
@@ -200,70 +182,48 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
}
}
- private static class ModelDataGenerator
- implements AllWindowFunction<Tuple2<DenseVector, Double>, KMeansModelData, TimeWindow> {
+ private static class CentroidsUpdateReducer
+ implements ReduceFunction<Tuple2<Integer[], DenseVector[]>> {
@Override
- public void apply(
- TimeWindow timeWindow,
- Iterable<Tuple2<DenseVector, Double>> iterable,
- Collector<KMeansModelData> collector) {
- List<Tuple2<DenseVector, Double>> list = IteratorUtils.toList(iterable.iterator());
- DenseVector[] centroids = new DenseVector[list.size()];
- DenseVector weights = new DenseVector(list.size());
- for (int i = 0; i < list.size(); i++) {
- centroids[i] = list.get(i).f0;
- weights.values[i] = list.get(i).f1;
+ public Tuple2<Integer[], DenseVector[]> reduce(
+ Tuple2<Integer[], DenseVector[]> tuple2, Tuple2<Integer[], DenseVector[]> t1)
+ throws Exception {
+ for (int i = 0; i < tuple2.f0.length; i++) {
+ tuple2.f0[i] += t1.f0[i];
+ BLAS.axpy(1.0, t1.f1[i], tuple2.f1[i]);
}
- collector.collect(new KMeansModelData(centroids, weights));
- }
- }
- private static class CentroidAverager
- implements MapFunction<
- Tuple3<Integer, DenseVector, Long>, Tuple2<DenseVector, Double>> {
- @Override
- public Tuple2<DenseVector, Double> map(Tuple3<Integer, DenseVector, Long> value) {
- for (int i = 0; i < value.f1.size(); i++) {
- value.f1.values[i] /= value.f2;
- }
- return Tuple2.of(value.f1, value.f2.doubleValue());
+ return tuple2;
}
}
- private static class CentroidAccumulator
- implements ReduceFunction<Tuple3<Integer, DenseVector, Long>> {
+ private static class ModelDataGenerator
+ implements MapFunction<Tuple2<Integer[], DenseVector[]>, KMeansModelData> {
@Override
- public Tuple3<Integer, DenseVector, Long> reduce(
- Tuple3<Integer, DenseVector, Long> v1, Tuple3<Integer, DenseVector, Long> v2) {
- for (int i = 0; i < v1.f1.size(); i++) {
- v1.f1.values[i] += v2.f1.values[i];
+ public KMeansModelData map(Tuple2<Integer[], DenseVector[]> tuple2) throws Exception {
+ double[] weights = new double[tuple2.f0.length];
+ for (int i = 0; i < tuple2.f0.length; i++) {
+ BLAS.scal(1.0 / tuple2.f0[i], tuple2.f1[i]);
+ weights[i] = tuple2.f0[i];
}
- return new Tuple3<>(v1.f0, v1.f1, v1.f2 + v2.f2);
- }
- }
- private static class CountAppender
- implements MapFunction<
- Tuple2<Integer, DenseVector>, Tuple3<Integer, DenseVector, Long>> {
- @Override
- public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value) {
- return Tuple3.of(value.f0, value.f1, 1L);
+ return new KMeansModelData(tuple2.f1, new DenseVector(weights));
}
}
- private static class SelectNearestCentroidOperator
- extends AbstractStreamOperator<Tuple2<Integer, DenseVector>>
+ private static class CentroidsUpdateAccumulator
+ extends AbstractStreamOperator<Tuple2<Integer[], DenseVector[]>>
implements TwoInputStreamOperator<
- DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
- IterationListener<Tuple2<Integer, DenseVector>> {
+ DenseVector, DenseVector[], Tuple2<Integer[], DenseVector[]>>,
+ IterationListener<Tuple2<Integer[], DenseVector[]>> {
private final DistanceMeasure distanceMeasure;
private ListState<DenseVector[]> centroids;
- private ListStateWithCache<DenseVector> points;
+ private ListStateWithCache<VectorWithNorm> points;
- public SelectNearestCentroidOperator(DistanceMeasure distanceMeasure) {
+ public CentroidsUpdateAccumulator(DistanceMeasure distanceMeasure) {
super();
this.distanceMeasure = distanceMeasure;
}
@@ -281,7 +241,7 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
points =
new ListStateWithCache<>(
- new DenseVectorSerializer(),
+ new VectorWithNormSerializer(),
getContainingTask(),
getRuntimeContext(),
context,
@@ -296,7 +256,7 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
@Override
public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception {
- points.add(streamRecord.getValue());
+ points.add(new VectorWithNorm(streamRecord.getValue()));
}
@Override
@@ -307,44 +267,46 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
@Override
public void onEpochWatermarkIncremented(
- int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out)
+ int epochWatermark,
+ Context context,
+ Collector<Tuple2<Integer[], DenseVector[]>> out)
throws Exception {
DenseVector[] centroidValues =
Objects.requireNonNull(
OperatorStateUtils.getUniqueElement(centroids, "centroids")
.orElse(null));
- for (DenseVector point : points.get()) {
- int closestCentroidId =
- findClosestCentroidId(centroidValues, point, distanceMeasure);
- output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
+
+ VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[centroidValues.length];
+ for (int i = 0; i < centroidsWithNorm.length; i++) {
+ centroidsWithNorm[i] = new VectorWithNorm(centroidValues[i]);
+ }
+
+ DenseVector[] newCentroids = new DenseVector[centroidValues.length];
+ Integer[] counts = new Integer[centroidValues.length];
+ Arrays.fill(counts, 0);
+ for (int i = 0; i < centroidValues.length; i++) {
+ newCentroids[i] = new DenseVector(centroidValues[i].size());
+ }
+
+ for (VectorWithNorm point : points.get()) {
+ int closestCentroidId = distanceMeasure.findClosest(centroidsWithNorm, point);
+ BLAS.axpy(1.0, point.vector, newCentroids[closestCentroidId]);
+ counts[closestCentroidId]++;
}
+ output.collect(new StreamRecord<>(Tuple2.of(counts, newCentroids)));
+
centroids.clear();
}
@Override
public void onIterationTerminated(
- Context context, Collector<Tuple2<Integer, DenseVector>> collector) {
+ Context context, Collector<Tuple2<Integer[], DenseVector[]>> collector) {
centroids.clear();
points.clear();
}
}
- protected static int findClosestCentroidId(
- DenseVector[] centroids, DenseVector point, DistanceMeasure distanceMeasure) {
- double minDistance = Double.MAX_VALUE;
- int closestCentroidId = -1;
- for (int i = 0; i < centroids.length; i++) {
- DenseVector centroid = centroids[i];
- double distance = distanceMeasure.distance(centroid, point);
- if (distance < minDistance) {
- minDistance = distance;
- closestCentroidId = i;
- }
- }
- return closestCentroidId;
- }
-
public static DataStream<DenseVector[]> selectRandomCentroids(
DataStream<DenseVector> data, int k, long seed) {
DataStream<DenseVector[]> resultStream =
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
index 9cfea62..5aa57ae 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
@@ -111,7 +112,7 @@ public class KMeansModel implements Model<KMeansModel>, KMeansModelParams<KMeans
private final int k;
- private DenseVector[] centroids;
+ private VectorWithNorm[] centroids;
public PredictLabelFunction(
String broadcastModelKey,
@@ -131,10 +132,14 @@ public class KMeansModel implements Model<KMeansModel>, KMeansModelParams<KMeans
(KMeansModelData)
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
Preconditions.checkArgument(modelData.centroids.length <= k);
- centroids = modelData.centroids;
+ centroids = new VectorWithNorm[modelData.centroids.length];
+ for (int i = 0; i < modelData.centroids.length; i++) {
+ centroids[i] = new VectorWithNorm(modelData.centroids[i]);
+ }
}
DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense();
- int closestCentroidId = KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+ int closestCentroidId =
+ distanceMeasure.findClosest(centroids, new VectorWithNorm(point));
return Row.join(dataPoint, Row.of(closestCentroidId));
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
index 4112527..2c8dbd1 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
@@ -35,6 +35,7 @@ import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
@@ -276,6 +277,10 @@ public class OnlineKMeans
KMeansModelData modelData =
OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
DenseVector[] centroids = modelData.centroids;
+ VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[modelData.centroids.length];
+ for (int i = 0; i < centroidsWithNorm.length; i++) {
+ centroidsWithNorm[i] = new VectorWithNorm(modelData.centroids[i]);
+ }
DenseVector weights = modelData.weights;
modelDataState.clear();
@@ -296,7 +301,7 @@ public class OnlineKMeans
}
for (DenseVector point : points) {
int closestCentroidId =
- KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+ distanceMeasure.findClosest(centroidsWithNorm, new VectorWithNorm(point));
counts[closestCentroidId]++;
BLAS.axpy(1.0, point, sums[closestCentroidId]);
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
index 6a7c852..43742f1 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
@@ -114,7 +115,7 @@ public class OnlineKMeansModel
private final int k;
- private DenseVector[] centroids;
+ private VectorWithNorm[] centroids;
// TODO: replace this with a complete solution of reading first model data from unbounded
// model data stream before processing the first predict data.
@@ -173,7 +174,8 @@ public class OnlineKMeansModel
return;
}
DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense();
- int closestCentroidId = KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+ int closestCentroidId =
+ distanceMeasure.findClosest(centroids, new VectorWithNorm(point));
output.collect(new StreamRecord<>(Row.join(dataPoint, Row.of(closestCentroidId))));
}
@@ -181,7 +183,10 @@ public class OnlineKMeansModel
public void processElement2(StreamRecord<KMeansModelData> streamRecord) throws Exception {
KMeansModelData modelData = streamRecord.getValue();
Preconditions.checkArgument(modelData.centroids.length <= k);
- centroids = modelData.centroids;
+ centroids = new VectorWithNorm[modelData.centroids.length];
+ for (int i = 0; i < centroids.length; i++) {
+ centroids[i] = new VectorWithNorm(modelData.centroids[i]);
+ }
modelDataVersion++;
for (Row dataPoint : bufferedPointsState.get()) {
processElement1(new StreamRecord<>(dataPoint));