You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/12/20 03:52:54 UTC

[flink-ml] branch master updated: [FLINK-24557] Add Estimator and Transformer for K-nearest neighbor

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 26d5196  [FLINK-24557] Add Estimator and Transformer for K-nearest neighbor
26d5196 is described below

commit 26d519613aef33f8e6ee94acc396b138e7907ea5
Author: weibo <we...@alibaba-inc.com>
AuthorDate: Tue Dec 14 10:40:18 2021 +0800

    [FLINK-24557] Add Estimator and Transformer for K-nearest neighbor
    
    This closes #24.
---
 .../main/java/org/apache/flink/ml/linalg/BLAS.java |  37 ++++
 .../org/apache/flink/ml/linalg/DenseMatrix.java    |  85 ++++++++
 .../java/org/apache/flink/ml/linalg/Matrix.java    |  34 +++
 .../ml/linalg/typeinfo/DenseMatrixSerializer.java  | 143 +++++++++++++
 .../ml/linalg/typeinfo/DenseMatrixTypeInfo.java    |  91 ++++++++
 .../typeinfo/DenseMatrixTypeInfoFactory.java       |  40 ++++
 .../java/org/apache/flink/ml/linalg/BLASTest.java  |  11 +-
 .../apache/flink/ml/classification/knn/Knn.java    | 157 ++++++++++++++
 .../flink/ml/classification/knn/KnnModel.java      | 197 +++++++++++++++++
 .../flink/ml/classification/knn/KnnModelData.java  | 128 +++++++++++
 .../ml/classification/knn/KnnModelParams.java      |  43 ++++
 .../flink/ml/classification/knn/KnnParams.java     |  28 +++
 .../logisticregression/LogisticRegression.java     |   2 +-
 .../flink/ml/classification/knn/KnnTest.java       | 237 +++++++++++++++++++++
 14 files changed, 1231 insertions(+), 2 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
index 8afd301..6d9b6eb 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
@@ -52,4 +52,41 @@ public class BLAS {
     public static void scal(double a, DenseVector x) {
         JAVA_BLAS.dscal(x.size(), a, x.values, 1);
     }
+
+    /**
+     * y = alpha * matrix * x + beta * y or y = alpha * (matrix^T) * x + beta * y.
+     *
+     * @param alpha The alpha value.
+     * @param matrix Dense matrix with size m x n.
+     * @param transMatrix Whether transposes matrix before multiply.
+     * @param x Dense vector with size n.
+     * @param beta The beta value.
+     * @param y Dense vector with size m.
+     */
+    public static void gemv(
+            double alpha,
+            DenseMatrix matrix,
+            boolean transMatrix,
+            DenseVector x,
+            double beta,
+            DenseVector y) {
+        Preconditions.checkArgument(
+                transMatrix
+                        ? (matrix.numRows() == x.size() && matrix.numCols() == y.size())
+                        : (matrix.numRows() == y.size() && matrix.numCols() == x.size()),
+                "Matrix and vector size mismatched.");
+        final String trans = transMatrix ? "T" : "N";
+        JAVA_BLAS.dgemv(
+                trans,
+                matrix.numRows(),
+                matrix.numCols(),
+                alpha,
+                matrix.values,
+                matrix.numRows(),
+                x.values,
+                1,
+                beta,
+                y.values,
+                1);
+    }
 }
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
new file mode 100644
index 0000000..80c85b0
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfoFactory;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Column-major dense matrix. The entry values are stored in a single array of doubles with columns
+ * listed in sequence.
+ */
+@TypeInfo(DenseMatrixTypeInfoFactory.class)
+public class DenseMatrix implements Matrix {
+
+    /** Row dimension. */
+    private final int numRows;
+
+    /** Column dimension. */
+    private final int numCols;
+
+    /**
+     * Array for internal storage of elements.
+     *
+     * <p>The matrix data is stored in column major format internally.
+     */
+    public final double[] values;
+
+    /**
+     * Constructs an m-by-n matrix of zeros.
+     *
+     * @param numRows Number of rows.
+     * @param numCols Number of columns.
+     */
+    public DenseMatrix(int numRows, int numCols) {
+        this(numRows, numCols, new double[numRows * numCols]);
+    }
+
+    /**
+     * Constructs a matrix from a 1-D array. The data in the array should be organized in column
+     * major.
+     *
+     * @param numRows Number of rows.
+     * @param numCols Number of cols.
+     * @param values One-dimensional array of doubles.
+     */
+    public DenseMatrix(int numRows, int numCols, double[] values) {
+        Preconditions.checkArgument(values.length == numRows * numCols);
+        this.numRows = numRows;
+        this.numCols = numCols;
+        this.values = values;
+    }
+
+    @Override
+    public int numRows() {
+        return numRows;
+    }
+
+    @Override
+    public int numCols() {
+        return numCols;
+    }
+
+    @Override
+    public double get(int i, int j) {
+        Preconditions.checkArgument(i >= 0 && i < numRows && j >= 0 && j < numCols);
+        return values[numRows * j + i];
+    }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java
new file mode 100644
index 0000000..d4e0897
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/** A matrix of double values. */
+public interface Matrix extends Serializable {
+
+    /** Gets number of rows. */
+    int numRows();
+
+    /** Gets number of columns. */
+    int numCols();
+
+    /** Gets value of the (i,j) element. */
+    double get(int i, int j);
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java
new file mode 100644
index 0000000..b25748d
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@link DenseMatrix}. */
+public final class DenseMatrixSerializer extends TypeSerializerSingleton<DenseMatrix> {
+
+    private static final long serialVersionUID = 1L;
+
+    private static final double[] EMPTY = new double[0];
+
+    public static final DenseMatrixSerializer INSTANCE = new DenseMatrixSerializer();
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public DenseMatrix createInstance() {
+        return new DenseMatrix(0, 0, EMPTY);
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from) {
+        return new DenseMatrix(
+                from.numRows(), from.numCols(), Arrays.copyOf(from.values, from.values.length));
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from, DenseMatrix reuse) {
+        if (from.values.length == reuse.values.length) {
+            System.arraycopy(from.values, 0, reuse.values, 0, from.values.length);
+            if (from.numCols() == reuse.numCols()) {
+                return reuse;
+            } else {
+                return new DenseMatrix(from.numRows(), from.numCols(), reuse.values);
+            }
+        }
+        return copy(from);
+    }
+
+    @Override
+    public int getLength() {
+        return -1;
+    }
+
+    @Override
+    public void serialize(DenseMatrix matrix, DataOutputView target) throws IOException {
+        if (matrix == null) {
+            throw new IllegalArgumentException("The matrix must not be null.");
+        }
+        final int len = matrix.values.length;
+        target.writeInt(matrix.numRows());
+        target.writeInt(matrix.numCols());
+        for (int i = 0; i < len; i++) {
+            target.writeDouble(matrix.values[i]);
+        }
+    }
+
+    @Override
+    public DenseMatrix deserialize(DataInputView source) throws IOException {
+        int m = source.readInt();
+        int n = source.readInt();
+        double[] values = new double[m * n];
+        deserializeDoubleArray(values, source, m * n);
+        return new DenseMatrix(m, n, values);
+    }
+
+    private static void deserializeDoubleArray(double[] dst, DataInputView source, int len)
+            throws IOException {
+        for (int i = 0; i < len; i++) {
+            dst[i] = source.readDouble();
+        }
+    }
+
+    @Override
+    public DenseMatrix deserialize(DenseMatrix reuse, DataInputView source) throws IOException {
+        int m = source.readInt();
+        int n = source.readInt();
+        double[] values = reuse.values;
+        if (values.length != m * n) {
+            double[] tmpValues = new double[m * n];
+            deserializeDoubleArray(tmpValues, source, m * n);
+            return new DenseMatrix(m, n, tmpValues);
+        }
+        deserializeDoubleArray(values, source, m * n);
+        return new DenseMatrix(m, n, values);
+    }
+
+    @Override
+    public void copy(DataInputView source, DataOutputView target) throws IOException {
+        int m = source.readInt();
+        target.writeInt(m);
+        int n = source.readInt();
+        target.writeInt(n);
+
+        target.write(source, m * n * Double.BYTES);
+    }
+
+    // ------------------------------------------------------------------------
+
+    @Override
+    public TypeSerializerSnapshot<DenseMatrix> snapshotConfiguration() {
+        return new DenseMatrixSerializerSnapshot();
+    }
+
+    /** Serializer configuration snapshot for compatibility and format evolution. */
+    @SuppressWarnings("WeakerAccess")
+    public static final class DenseMatrixSerializerSnapshot
+            extends SimpleTypeSerializerSnapshot<DenseMatrix> {
+        public DenseMatrixSerializerSnapshot() {
+            super(() -> INSTANCE);
+        }
+    }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfo.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfo.java
new file mode 100644
index 0000000..04fba3f
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfo.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+/** A {@link TypeInformation} for the {@link DenseMatrix} type. */
+public class DenseMatrixTypeInfo extends TypeInformation<DenseMatrix> {
+    private static final long serialVersionUID = 1L;
+
+    public static final DenseMatrixTypeInfo INSTANCE = new DenseMatrixTypeInfo();
+
+    public DenseMatrixTypeInfo() {}
+
+    @Override
+    public int getArity() {
+        return 3;
+    }
+
+    @Override
+    public int getTotalFields() {
+        return 3;
+    }
+
+    @Override
+    public Class<DenseMatrix> getTypeClass() {
+        return DenseMatrix.class;
+    }
+
+    @Override
+    public boolean isBasicType() {
+        return false;
+    }
+
+    @Override
+    public boolean isTupleType() {
+        return false;
+    }
+
+    @Override
+    public boolean isKeyType() {
+        return false;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public TypeSerializer<DenseMatrix> createSerializer(ExecutionConfig executionConfig) {
+        return new DenseMatrixSerializer();
+    }
+
+    // --------------------------------------------------------------------------------------------
+
+    @Override
+    public int hashCode() {
+        return getClass().hashCode();
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        return obj instanceof DenseMatrixTypeInfo;
+    }
+
+    @Override
+    public boolean canEqual(Object obj) {
+        return obj instanceof DenseMatrixTypeInfo;
+    }
+
+    @Override
+    public String toString() {
+        return "DenseMatrixType";
+    }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfoFactory.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfoFactory.java
new file mode 100644
index 0000000..f0845cb
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfoFactory.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeinfo.TypeInfoFactory;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.lang.reflect.Type;
+import java.util.Map;
+
+/**
+ * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link
+ * DenseMatrix}.
+ */
+public class DenseMatrixTypeInfoFactory extends TypeInfoFactory<DenseMatrix> {
+
+    @Override
+    public TypeInformation<DenseMatrix> createTypeInfo(
+            Type t, Map<String, TypeInformation<?>> genericParameters) {
+        return new DenseMatrixTypeInfo();
+    }
+}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
index d799047..1bb103d 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
@@ -27,8 +27,9 @@ import static org.junit.Assert.assertEquals;
 public class BLASTest {
 
     private static final double TOLERANCE = 1e-7;
-
     private static final DenseVector inputDenseVec = Vectors.dense(1, -2, 3, 4, -5);
+    private static final DenseMatrix inputDenseMat =
+            new DenseMatrix(2, 5, new double[] {1, -2, 3, 4, -5, 1, -2, 3, 4, -5});
 
     @Test
     public void testAsum() {
@@ -61,4 +62,12 @@ public class BLASTest {
         double[] expectedResult = new double[] {2, -4, 6, 8, -10};
         assertArrayEquals(expectedResult, inputDenseVec.values, TOLERANCE);
     }
+
+    @Test
+    public void testGemv() {
+        DenseVector anotherDenseVec = Vectors.dense(1.0, 2.0);
+        BLAS.gemv(-2.0, inputDenseMat, false, inputDenseVec, 0.0, anotherDenseVec);
+        double[] expectedResult = new double[] {96.0, -60.0};
+        assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE);
+    }
 }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
new file mode 100644
index 0000000..54fcf80
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        /* Tuple3 : <feature, label, norm square> */
+        DataStream<Tuple3<DenseVector, Double, Double>> inputDataWithNorm =
+                computeNormSquare(tEnv.toDataStream(inputs[0]));
+        DataStream<KnnModelData> modelData = genModelData(inputDataWithNorm);
+        KnnModel model = new KnnModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Knn load(StreamExecutionEnvironment env, String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /**
+     * Generates knn model data. For Euclidean distance, distance = sqrt((a - b)^2) = (sqrt(a^2 +
+     * b^2 - 2ab)) So it can pre-calculate the L2 norm square of the feature vector, and when
+     * calculating the distance with another feature vector, only dot product is calculated. On the
+     * other hand, we assemble the feature vectors into a matrix, then it can use blas to accelerate
+     * the speed of calculating distances.
+     *
+     * @param inputDataWithNormSqare Input data with norm square.
+     * @return Knn model.
+     */
+    private static DataStream<KnnModelData> genModelData(
+            DataStream<Tuple3<DenseVector, Double, Double>> inputDataWithNormSqare) {
+        DataStream<KnnModelData> modelData =
+                DataStreamUtils.mapPartition(
+                        inputDataWithNormSqare,
+                        new RichMapPartitionFunction<
+                                Tuple3<DenseVector, Double, Double>, KnnModelData>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple3<DenseVector, Double, Double>> dataPoints,
+                                    Collector<KnnModelData> out) {
+                                List<Tuple3<DenseVector, Double, Double>> bufferedDataPoints =
+                                        new ArrayList<>();
+                                for (Tuple3<DenseVector, Double, Double> dataPoint : dataPoints) {
+                                    bufferedDataPoints.add(dataPoint);
+                                }
+                                int featureDim = bufferedDataPoints.get(0).f0.size();
+                                DenseMatrix packedFeatures =
+                                        new DenseMatrix(featureDim, bufferedDataPoints.size());
+                                DenseVector normSquares =
+                                        new DenseVector(bufferedDataPoints.size());
+                                DenseVector labels = new DenseVector(bufferedDataPoints.size());
+                                int offset = 0;
+                                for (Tuple3<DenseVector, Double, Double> dataPoint :
+                                        bufferedDataPoints) {
+                                    System.arraycopy(
+                                            dataPoint.f0.values,
+                                            0,
+                                            packedFeatures.values,
+                                            offset * featureDim,
+                                            featureDim);
+                                    labels.values[offset] = dataPoint.f1;
+                                    normSquares.values[offset++] = dataPoint.f2;
+                                }
+                                out.collect(new KnnModelData(packedFeatures, normSquares, labels));
+                            }
+                        });
+        modelData.getTransformation().setParallelism(1);
+        return modelData;
+    }
+
+    /**
+     * Computes feature norm square.
+     *
+     * @param inputData Input data.
+     * @return Input data with norm square.
+     */
+    private DataStream<Tuple3<DenseVector, Double, Double>> computeNormSquare(
+            DataStream<Row> inputData) {
+        return inputData.map(
+                new MapFunction<Row, Tuple3<DenseVector, Double, Double>>() {
+                    @Override
+                    public Tuple3<DenseVector, Double, Double> map(Row value) {
+                        Double label = (Double) value.getField(getLabelCol());
+                        DenseVector feature = (DenseVector) value.getField(getFeaturesCol());
+                        return Tuple3.of(feature, label, Math.pow(BLAS.norm2(feature), 2));
+                    }
+                });
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
new file mode 100644
index 0000000..97aa965
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder());
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<KnnModelData> modelData =
+                ReadWriteUtils.loadModelData(env, path, new KnnModelData.ModelDataDecoder());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+        private DenseVector distanceVector;
+
+        public PredictLabelFunction(String broadcastKey, int k, String featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                distanceVector = new DenseVector(knnModelData.labels.size());
+            }
+            DenseVector feature = (DenseVector) row.getField(featureCol);
+            double prediction = predictLabel(feature);
+            return Row.join(row, Row.of(prediction));
+        }
+
+        private double predictLabel(DenseVector feature) {
+            double normSquare = Math.pow(BLAS.norm2(feature), 2);
+            BLAS.gemv(-2.0, knnModelData.packedFeatures, true, feature, 0.0, distanceVector);
+            for (int i = 0; i < distanceVector.size(); i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(
+                                Math.abs(
+                                        distanceVector.values[i]
+                                                + normSquare
+                                                + knnModelData.featureNormSquares.values[i]));
+            }
+            PriorityQueue<Tuple2<Double, Double>> nearestKNeighbors =
+                    new PriorityQueue<>(
+                            Comparator.comparingDouble(distanceAndLabel -> -distanceAndLabel.f0));
+            double[] labelValues = knnModelData.labels.values;
+            for (int i = 0; i < labelValues.length; ++i) {
+                if (nearestKNeighbors.size() < k) {
+                    nearestKNeighbors.add(Tuple2.of(distanceVector.get(i), labelValues[i]));
+                } else {
+                    Tuple2<Double, Double> currentFarthestNeighbor = nearestKNeighbors.peek();
+                    if (currentFarthestNeighbor.f0 > distanceVector.get(i)) {
+                        nearestKNeighbors.poll();
+                        nearestKNeighbors.add(Tuple2.of(distanceVector.get(i), labelValues[i]));
+                    }
+                }
+            }
+            Map<Double, Double> labelWeights = new HashMap<>(nearestKNeighbors.size());
+            while (!nearestKNeighbors.isEmpty()) {
+                Tuple2<Double, Double> distanceAndLabel = nearestKNeighbors.poll();
+                labelWeights.merge(distanceAndLabel.f1, 1.0, Double::sum);
+            }
+            double maxWeight = 0.0;
+            double predictedLabel = -1.0;
+            for (Map.Entry<Double, Double> entry : labelWeights.entrySet()) {
+                if (entry.getValue() > maxWeight) {
+                    maxWeight = entry.getValue();
+                    predictedLabel = entry.getKey();
+                }
+            }
+            return predictedLabel;
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
new file mode 100644
index 0000000..89051e6
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link KnnModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class KnnModelData {
+
+    public DenseMatrix packedFeatures;
+    public DenseVector featureNormSquares;
+    public DenseVector labels;
+
+    public KnnModelData() {}
+
+    public KnnModelData(
+            DenseMatrix packedFeatures, DenseVector featureNormSquares, DenseVector labels) {
+        this.packedFeatures = packedFeatures;
+        this.featureNormSquares = featureNormSquares;
+        this.labels = labels;
+    }
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelDataTable The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<KnnModelData> getModelDataStream(Table modelDataTable) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+        return tEnv.toDataStream(modelDataTable)
+                .map(
+                        x ->
+                                new KnnModelData(
+                                        (DenseMatrix) x.getField(0),
+                                        (DenseVector) x.getField(1),
+                                        (DenseVector) x.getField(2)));
+    }
+
+    /** Encoder for {@link KnnModelData}. */
+    public static class ModelDataEncoder implements Encoder<KnnModelData> {
+        @Override
+        public void encode(KnnModelData knnModelData, OutputStream outputStream)
+                throws IOException {
+            DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+            DenseMatrixSerializer.INSTANCE.serialize(knnModelData.packedFeatures, dataOutputView);
+            DenseVectorSerializer.INSTANCE.serialize(
+                    knnModelData.featureNormSquares, dataOutputView);
+            DenseVectorSerializer.INSTANCE.serialize(knnModelData.labels, dataOutputView);
+        }
+    }
+
+    /** Decoder for {@link KnnModelData}. */
+    public static class ModelDataDecoder extends SimpleStreamFormat<KnnModelData> {
+        @Override
+        public Reader<KnnModelData> createReader(Configuration config, FSDataInputStream stream) {
+            return new Reader<KnnModelData>() {
+
+                private final DataInputView source = new DataInputViewStreamWrapper(stream);
+
+                @Override
+                public KnnModelData read() throws IOException {
+                    try {
+                        DenseMatrix matrix = DenseMatrixSerializer.INSTANCE.deserialize(source);
+                        DenseVector normSquares =
+                                DenseVectorSerializer.INSTANCE.deserialize(source);
+                        DenseVector labels = DenseVectorSerializer.INSTANCE.deserialize(source);
+                        return new KnnModelData(matrix, normSquares, labels);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    stream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<KnnModelData> getProducedType() {
+            return TypeInformation.of(KnnModelData.class);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelParams.java
new file mode 100644
index 0000000..260a92e
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelParams.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link KnnModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface KnnModelParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> {
+    Param<Integer> K =
+            new IntParam("k", "The number of nearest neighbors.", 5, ParamValidators.gt(0));
+
+    default Integer getK() {
+        return get(K);
+    }
+
+    default T setK(Integer value) {
+        return set(K, value);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
new file mode 100644
index 0000000..0996178
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+
+/**
+ * Params for {@link Knn}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface KnnParams<T> extends HasLabelCol<T>, KnnModelParams<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index a17269b..9df0cdf 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -311,7 +311,7 @@ public class LogisticRegression
 
         private ListState<LabeledPointWithWeight> trainDataState;
 
-        private Random random = new Random(2021);
+        private final Random random = new Random(2021);
 
         private List<LabeledPointWithWeight> miniBatchData;
 
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
new file mode 100644
index 0000000..2811a7c
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private Table predictData;
+    private static final List<Row> trainRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(Vectors.dense(2.0, 3.0), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(200.1, 300.1), 2.0),
+                            Row.of(Vectors.dense(200.2, 300.2), 2.0),
+                            Row.of(Vectors.dense(200.3, 300.3), 2.0),
+                            Row.of(Vectors.dense(200.4, 300.4), 2.0),
+                            Row.of(Vectors.dense(200.4, 300.4), 2.0),
+                            Row.of(Vectors.dense(200.6, 300.6), 2.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.3, 3.2), 1.0),
+                            Row.of(Vectors.dense(2.3, 3.2), 1.0),
+                            Row.of(Vectors.dense(2.8, 3.2), 3.0),
+                            Row.of(Vectors.dense(300., 3.2), 4.0),
+                            Row.of(Vectors.dense(2.2, 3.2), 1.0),
+                            Row.of(Vectors.dense(2.4, 3.2), 5.0),
+                            Row.of(Vectors.dense(2.5, 3.2), 5.0),
+                            Row.of(Vectors.dense(2.5, 3.2), 5.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0)));
+    private static final List<Row> predictRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(Vectors.dense(4.0, 4.1), 5.0),
+                            Row.of(Vectors.dense(300, 42), 2.0)));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(trainRows);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("features", "label");
+        DataStream<Row> predDataStream = env.fromCollection(predictRows);
+        predictData = tEnv.fromDataStream(predDataStream, schema).as("features", "label");
+    }
+
+    private static void verifyPredictionResult(Table output, String labelCol, String predictionCol)
+            throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() {
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) {
+                                        return Tuple2.of(
+                                                (Double) row.getField(labelCol),
+                                                (Double) row.getField(predictionCol));
+                                    }
+                                });
+        List<Tuple2<Double, Double>> result = IteratorUtils.toList(stream.executeAndCollect());
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(5, (int) knn.getK());
+        assertEquals("prediction", knn.getPredictionCol());
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4, (int) knn.getK());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("test_label")
+                        .setFeaturesCol("test_features")
+                        .setK(4)
+                        .setPredictionCol("test_prediction");
+        KnnModel model = knn.fit(trainData.as("test_features, test_label"));
+        Table output = model.transform(predictData.as("test_features, test_label"))[0];
+        assertEquals(
+                Arrays.asList("test_features", "test_label", "test_prediction"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFewerDistinctPointsThanCluster() throws Exception {
+        Knn knn = new Knn();
+        KnnModel model = knn.fit(predictData);
+        Table output = model.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        Knn knn = new Knn();
+        KnnModel knnModel = knn.fit(trainData);
+        Table output = knnModel.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), knn.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        Knn knn = new Knn();
+        Knn loadedKnn =
+                StageTestUtils.saveAndReload(env, knn, tempFolder.newFolder().getAbsolutePath());
+        KnnModel knnModel = loadedKnn.fit(trainData);
+        knnModel =
+                StageTestUtils.saveAndReload(
+                        env, knnModel, tempFolder.newFolder().getAbsolutePath());
+        assertEquals(
+                Arrays.asList("packedFeatures", "featureNormSquares", "labels"),
+                knnModel.getModelData()[0].getResolvedSchema().getColumnNames());
+        Table output = knnModel.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), knn.getPredictionCol());
+    }
+
+    @Test
+    public void testModelSaveLoadAndPredict() throws Exception {
+        Knn knn = new Knn();
+        KnnModel knnModel = knn.fit(trainData);
+        KnnModel newModel =
+                StageTestUtils.saveAndReload(
+                        env, knnModel, tempFolder.newFolder().getAbsolutePath());
+        Table output = newModel.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), knn.getPredictionCol());
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        Knn knn = new Knn();
+        KnnModel knnModel = knn.fit(trainData);
+        Table modelData = knnModel.getModelData()[0];
+        DataStream<Row> output = tEnv.toDataStream(modelData);
+        assertEquals("packedFeatures", modelData.getResolvedSchema().getColumnNames().get(0));
+        assertEquals("featureNormSquares", modelData.getResolvedSchema().getColumnNames().get(1));
+        assertEquals("labels", modelData.getResolvedSchema().getColumnNames().get(2));
+        List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+        KnnModelData data =
+                new KnnModelData(
+                        (DenseMatrix) modelRows.get(0).getField(0),
+                        (DenseVector) modelRows.get(0).getField(1),
+                        (DenseVector) modelRows.get(0).getField(2));
+        Assert.assertNotNull(data);
+        assertEquals(2, data.packedFeatures.numRows());
+        assertEquals(data.packedFeatures.numCols(), data.labels.size());
+        assertEquals(data.featureNormSquares.size(), data.labels.size());
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        Knn knn = new Knn();
+        KnnModel modelA = knn.fit(trainData);
+        Table modelData = modelA.getModelData()[0];
+        KnnModel modelB = new KnnModel().setModelData(modelData);
+        ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+        Table output = modelB.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), knn.getPredictionCol());
+    }
+}