You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/12/15 14:09:26 UTC

[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #191: [FLINK-30401] Add Estimator and Transformer for MinHashLSH

zhipeng93 commented on code in PR #191:
URL: https://github.com/apache/flink-ml/pull/191#discussion_r1049588862


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Base class for estimators which implement LSH (Locality-sensitive hashing) algorithms.
+ *
+ * <p>See: <a
+ * href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">Locality-sensitive_hashing</a>.
+ *
+ * @param <E> class type of the Estimator implementation itself.
+ * @param <M> class type of the Model this Estimator produces.
+ */
+abstract class LSH<E extends Estimator<E, M>, M extends LSHModel<M>>
+        implements Estimator<E, M>, LSHParams<E> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public LSH() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public M fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Integer> inputDim =
+                getVectorColSize(tEnv.toDataStream(inputs[0]), getInputCol());
+        return createModel(inputDim, tEnv);
+    }
+
+    protected abstract M createModel(DataStream<Integer> inputDim, StreamTableEnvironment tEnv);
+
+    private DataStream<Integer> getVectorColSize(DataStream<Row> input, String vectorCol) {

Review Comment:
   How about renaming it as `getVectorSize()`? This method could be static.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Base class for estimators which implement LSH (Locality-sensitive hashing) algorithms.

Review Comment:
   Let's add more explanation of LSH here.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -133,6 +135,15 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
         }
     }
 
+    public static <T, K> DataStream<T> reduce(KeyedStream<T, K> input, ReduceFunction<T> func) {

Review Comment:
   Let's use `KeyedStream.reduce(...)` instead. I am not sure whether `BatchGroupedOperator` works as expected in streaming mode.
   
   By the way, we need to add java docs for public methods.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java:
##########
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.lsh.MinHashLSH;
+import org.apache.flink.ml.feature.lsh.MinHashLSHModel;
+import org.apache.flink.ml.feature.lsh.MinHashLSHModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.table.api.Expressions.$;
+
+/** Tests {@link MinHashLSH} and {@link MinHashLSHModel}. */
+public class MinHashLSHTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    /**
+     * Default case for most tests.
+     *
+     * @return a tuple including the estimator, input data table, and output rows.
+     */
+    private Tuple3<MinHashLSH, Table, List<Row>> getDefaultCase() {
+        MinHashLSH lsh =
+                new MinHashLSH()
+                        .setInputCol("vec")
+                        .setOutputCol("hashes")
+                        .setSeed(2022L)
+                        .setNumHashTables(5)
+                        .setNumHashFunctionsPerTable(3);
+
+        List<Row> inputRows =
+                Arrays.asList(
+                        Row.of(
+                                0,
+                                Vectors.sparse(6, new int[] {0, 1, 2}, new double[] {1., 1., 1.})),
+                        Row.of(
+                                1,
+                                Vectors.sparse(6, new int[] {2, 3, 4}, new double[] {1., 1., 1.})),
+                        Row.of(
+                                2,
+                                Vectors.sparse(6, new int[] {0, 2, 4}, new double[] {1., 1., 1.})));
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.INT())
+                        .column("f1", DataTypes.of(SparseVector.class))
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(inputRows);
+        Table inputTable = tEnv.fromDataStream(dataStream, schema).as("id", "vec");
+
+        List<Row> outputRows =
+                convertToOutputFormat(
+                        Arrays.asList(
+                                new double[][] {
+                                    {1.73046954E8, 1.57275425E8, 6.90717571E8},
+                                    {5.02301169E8, 7.967141E8, 4.06089319E8},
+                                    {2.83652171E8, 1.97714719E8, 6.04731316E8},
+                                    {5.2181506E8, 6.36933726E8, 6.13894128E8},
+                                    {3.04301769E8, 1.113672955E9, 6.1388711E8}
+                                },
+                                new double[][] {
+                                    {1.73046954E8, 1.57275425E8, 6.7798584E7},
+                                    {6.38582806E8, 1.78703694E8, 4.06089319E8},
+                                    {6.232638E8, 9.28867E7, 9.92010642E8},
+                                    {2.461064E8, 1.12787481E8, 1.92180297E8},
+                                    {2.38162496E8, 1.552933319E9, 2.77995137E8}
+                                },
+                                new double[][] {
+                                    {1.73046954E8, 1.57275425E8, 6.90717571E8},
+                                    {1.453197722E9, 7.967141E8, 4.06089319E8},
+                                    {6.232638E8, 1.97714719E8, 6.04731316E8},
+                                    {2.461064E8, 1.12787481E8, 1.92180297E8},
+                                    {1.224130231E9, 1.113672955E9, 2.77995137E8}
+                                }));
+
+        return Tuple3.of(lsh, inputTable, outputRows);
+    }
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.getConfig().enableObjectReuse();
+        env.setParallelism(1);

Review Comment:
   Let's test the paralllel case.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHScheme.java:
##########
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+
+/**
+ * Interface for an LSH scheme. An LSH scheme should implement how to map a feature vector to
+ * multiple hash vectors, and how to calculate corresponding distance between two feature vectors.
+ */
+interface LSHScheme {
+    /**
+     * The hash function to map an input feature vector to multiple hash vectors.
+     *
+     * @param vec input vector.
+     * @return the mapping of LSH functions.
+     */
+    DenseVector[] hashFunction(Vector vec);
+
+    /**
+     * Calculate the distance between two different feature vectors using the corresponding distance

Review Comment:
   The java doc is usually decribed in third person. For example, `Calculates...`.
   
   Same for other java docs.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * Model data of {@link MinHashLSHModel}.
+ *
+ * <p>This class also provides classes to save/load model data.
+ */
+public class MinHashLSHModelData implements LSHScheme {
+
+    // A large prime smaller than sqrt(2^63 − 1)
+    private static final int HASH_PRIME = 2038074743;
+
+    public int numHashTables;
+    public int numHashFunctionsPerTable;
+    public int[] randCoeffA;
+    public int[] randCoeffB;
+
+    public MinHashLSHModelData() {}
+
+    public MinHashLSHModelData(
+            int numHashTables, int numHashFunctionsPerTable, int[] randCoeffA, int[] randCoeffB) {
+        this.numHashTables = numHashTables;
+        this.numHashFunctionsPerTable = numHashFunctionsPerTable;
+        this.randCoeffA = randCoeffA;
+        this.randCoeffB = randCoeffB;
+    }
+
+    public static MinHashLSHModelData generateModelData(
+            int numHashTables, int numHashFunctionsPerTable, int dim, long seed) {
+        Preconditions.checkArgument(
+                dim <= HASH_PRIME,
+                "The input vector dimension %d exceeds the threshold %s.",
+                dim,
+                HASH_PRIME);
+
+        Random random = new Random(seed);
+        int numHashFunctions = numHashTables * numHashFunctionsPerTable;
+        int[] randCoeffA = new int[numHashFunctions];
+        int[] randCoeffB = new int[numHashFunctions];
+        for (int i = 0; i < numHashFunctions; i += 1) {
+            randCoeffA[i] = 1 + random.nextInt(HASH_PRIME - 1);
+            randCoeffB[i] = random.nextInt(HASH_PRIME - 1);
+        }
+        return new MinHashLSHModelData(
+                numHashTables, numHashFunctionsPerTable, randCoeffA, randCoeffB);
+    }
+
+    static class ModelDataDecoder extends SimpleStreamFormat<MinHashLSHModelData> {
+        @Override
+        public Reader<MinHashLSHModelData> createReader(
+                Configuration configuration, FSDataInputStream fsDataInputStream)
+                throws IOException {
+            return new Reader<MinHashLSHModelData>() {
+                @Override
+                public MinHashLSHModelData read() throws IOException {
+                    try {
+                        DataInputViewStreamWrapper source =
+                                new DataInputViewStreamWrapper(fsDataInputStream);
+                        int numHashTables = IntSerializer.INSTANCE.deserialize(source);
+                        int numHashFunctionsPerTable = IntSerializer.INSTANCE.deserialize(source);
+                        int[] randCoeffA = IntPrimitiveArraySerializer.INSTANCE.deserialize(source);
+                        int[] randCoeffB = IntPrimitiveArraySerializer.INSTANCE.deserialize(source);
+                        return new MinHashLSHModelData(
+                                numHashTables, numHashFunctionsPerTable, randCoeffA, randCoeffB);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    fsDataInputStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<MinHashLSHModelData> getProducedType() {
+            return TypeInformation.of(MinHashLSHModelData.class);
+        }
+    }
+
+    /**
+     * indices: indexes of data in vec whose values are not zero.
+     *
+     * <p>hashValue = (((1 + indices) * randCoefficientA + randCoefficientB) % HASH_PRIME).min.
+     */
+    @Override
+    public DenseVector[] hashFunction(Vector vec) {
+        if (vec instanceof DenseVector) {
+            return hashFunction((DenseVector) vec);
+        } else {
+            return hashFunction(vec.toSparse());
+        }
+    }
+
+    @Override
+    public double keyDistance(Vector x, Vector y) {
+        int[] xIndices = x.toSparse().indices;
+        int[] yIndices = y.toSparse().indices;
+        Preconditions.checkArgument(
+                xIndices.length + yIndices.length > 0,
+                "The union of two input sets must have at least 1 elements");
+        int px = 0, py = 0;
+        int intersectionSize = 0;
+        while (px < xIndices.length && py < yIndices.length) {
+            if (xIndices[px] == yIndices[py]) {
+                intersectionSize += 1;
+                px += 1;
+                py += 1;
+            } else if (xIndices[px] < yIndices[py]) {
+                px += 1;
+            } else {
+                py += 1;
+            }
+        }
+        int unionSize = xIndices.length + yIndices.length - intersectionSize;
+        return 1. - 1. * intersectionSize / unionSize;
+    }
+
+    private DenseVector[] hashFunction(SparseVector vec) {

Review Comment:
   The logic for sparse vector and dense vector seems almost the same. Can they be merged?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSH.java:
##########
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+
+import java.io.IOException;
+
+/**
+ * An Estimator which implements the MinHash LSH algorithm, with Jaccard distance as its distance

Review Comment:
   Let's add some details about MinHash here.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * Model data of {@link MinHashLSHModel}.
+ *
+ * <p>This class also provides classes to save/load model data.
+ */
+public class MinHashLSHModelData implements LSHScheme {
+
+    // A large prime smaller than sqrt(2^63 − 1)
+    private static final int HASH_PRIME = 2038074743;
+
+    public int numHashTables;
+    public int numHashFunctionsPerTable;
+    public int[] randCoeffA;
+    public int[] randCoeffB;
+
+    public MinHashLSHModelData() {}
+
+    public MinHashLSHModelData(
+            int numHashTables, int numHashFunctionsPerTable, int[] randCoeffA, int[] randCoeffB) {
+        this.numHashTables = numHashTables;
+        this.numHashFunctionsPerTable = numHashFunctionsPerTable;
+        this.randCoeffA = randCoeffA;
+        this.randCoeffB = randCoeffB;
+    }
+
+    public static MinHashLSHModelData generateModelData(

Review Comment:
   It is a bit confusing that why the model data contains these cofficients. Is the model data deterministic when the dimension of the input is fixed? Then the model data should simply be the `inputDim` and the parameters, right?



##########
docs/content/docs/operators/feature/minhashlsh.md:
##########
@@ -0,0 +1,276 @@
+---
+title: "MinHash LSH"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/minhashlsh.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## MinHash LSH
+
+MinHash LSH is a Locality Sensitive Hashing (LSH) scheme for Jaccard distance metric.
+The input features are sets of natural numbers represented as non-zero indices of vectors,
+either dense vectors or sparse vectors. Typically, sparse vectors are more efficient.
+
+### Input Columns
+
+| Param name | Type   | Default   | Description            |
+|:-----------|:-------|:----------|:-----------------------|
+| inputCol   | Vector | `"input"` | Features to be mapped. |
+
+### Output Columns
+
+| Param name | Type          | Default    | Description  |
+|:-----------|:--------------|:-----------|:-------------|
+| outputCol  | DenseVector[] | `"output"` | Hash values. |
+
+### Parameters
+
+| Key                     | Default    | Type    | Required | Description                                                        |
+|-------------------------|------------|---------|----------|--------------------------------------------------------------------|
+| inputCol                | `"input"`  | String  | no       | Input column name.                                                 |
+| outputCol               | `"output"` | String  | no       | Output column name.                                                |
+| seed                    | `null`     | Long    | no       | The random seed.                                                   |

Review Comment:
   The default value of `seed` seems not be null.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java:
##########
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/**
+ * Base class for LSH model.
+ *
+ * @param <T> class type of the LSHModel implementation itself.
+ */
+abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, LSHParams<T> {
+    private static final String MODEL_DATA_BC_KEY = "modelData";
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    /** Stores the corresponding model data class of T. */
+    private final Class<?> modelDataClass;
+
+    protected Table modelDataTable;
+
+    public LSHModel(Class<?> modelDataClass) {
+        this.modelDataClass = modelDataClass;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public T setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return (T) this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<?> modelData = tEnv.toDataStream(modelDataTable, modelDataClass);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?> outputType = TypeInformation.of(DenseVector[].class);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.map(
+                                    new PredictOutputMapFunction(getInputCol()), outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /**
+     * Given a dataset and an item, approximately find at most k items which have the closest
+     * distance to the item. If the `outputCol` is missing in the given dataset, this method
+     * transforms the dataset with the model at first.
+     *
+     * @param dataset The dataset in which to to search for nearest neighbors.
+     * @param key The item to search for.
+     * @param k The maximum number of nearest neighbors.
+     * @param distCol The output column storing the distance between each neighbor and the key.
+     * @return A dataset containing at most k items closest to the key with a column named `distCol`
+     *     appended.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment();
+        Table transformedTable =
+                (dataset.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataset
+                        : transform(dataset)[0];
+
+        DataStream<MinHashLSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, MinHashLSHModelData.class);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), distCol));
+
+        // Fetch items in the same bucket with key's, and calculate their distances to key.
+        DataStream<Row> filteredData =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(transformedTable)),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.flatMap(
+                                    new FilterBySameBucketsFlatMapFunction(
+                                            getInputCol(), getOutputCol(), key),
+                                    outputTypeInfo);
+                        });
+        DataStream<Row> partitionedTopKData =
+                DataStreamUtils.mapPartition(
+                        filteredData, new TopKMapPartitionFunction(distCol, k));
+        DataStream<Row> topKData =
+                DataStreamUtils.mapPartition(
+                        partitionedTopKData.rebalance(), new TopKMapPartitionFunction(distCol, k));
+        topKData.getTransformation().setOutputType(outputTypeInfo);
+        topKData.getTransformation().setParallelism(1);
+        return tEnv.fromDataStream(topKData);
+    }
+
+    /**
+     * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of
+     * `distCol`.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k) {
+        return approxNearestNeighbors(dataset, key, k, "distCol");
+    }
+
+    /**
+     * Join two datasets to approximately find all pairs of rows whose distance are smaller than or
+     * equal to the threshold. If the `outputCol` is missing in either dataset, this method
+     * transforms the dataset at first.
+     *
+     * @param datasetA One dataset.
+     * @param datasetB The other dataset.
+     * @param threshold The distance threshold.
+     * @param idCol A column in the two datasets to identify each row.
+     * @param distCol The output column storing the distance between each pair of rows.
+     * @return A joined dataset containing pairs of rows. The original rows are in columns
+     *     "datasetA" and "datasetB", and a column "distCol" is added to show the distance between
+     *     each pair.
+     */
+    public Table approxSimilarityJoin(
+            Table datasetA, Table datasetB, double threshold, String idCol, String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) datasetA).getTableEnvironment();
+
+        DataStream<Row> explodedA = preprocessData(datasetA, idCol);
+        DataStream<Row> explodedB = preprocessData(datasetB, idCol);
+
+        DataStream<MinHashLSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, MinHashLSHModelData.class);
+        DataStream<Row> sameBucketPairs =
+                explodedA
+                        .join(explodedB)
+                        .where(new IndexHashValueKeySelector())
+                        .equalTo(new IndexHashValueKeySelector())
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (r0, r1) ->
+                                        Row.of(
+                                                r0.getField(0),
+                                                r1.getField(0),
+                                                r0.getField(1),
+                                                r1.getField(1)));
+        DataStream<Row> distinctSameBucketPairs =
+                DataStreamUtils.reduce(
+                        sameBucketPairs.keyBy(
+                                new KeySelector<Row, Tuple2<Integer, Integer>>() {
+                                    @Override
+                                    public Tuple2<Integer, Integer> getKey(Row r) {
+                                        return Tuple2.of(r.getFieldAs(0), r.getFieldAs(1));
+                                    }
+                                }),
+                        (r0, r1) -> r0);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(datasetA.getResolvedSchema());
+        TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol);
+        DataStream<Row> pairsWithDists =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(distinctSameBucketPairs),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.flatMap(
+                                    new FilterByDistanceFlatMapFunction(threshold),
+                                    new RowTypeInfo(
+                                            new TypeInformation[] {
+                                                idColType, idColType, Types.DOUBLE
+                                            },
+                                            new String[] {"datasetA.id", "datasetB.id", distCol}));
+                        });
+        return tEnv.fromDataStream(pairsWithDists);
+    }
+
+    /**
+     * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of
+     * `distCol`.
+     */
+    public Table approxSimilarityJoin(
+            Table datasetA, Table datasetB, double threshold, String idCol) {
+        return approxSimilarityJoin(datasetA, datasetB, threshold, idCol, "distCol");
+    }
+
+    private DataStream<Row> preprocessData(Table dataTable, String idCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment();
+
+        dataTable =
+                (dataTable.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataTable
+                        : transform(dataTable)[0];
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(dataTable.getResolvedSchema());
+        TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol);
+        final String indexCol = "index";
+        final String hashValueCol = "hashValue";
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            idColType,
+                            TypeInformation.of(Vector.class),
+                            Types.INT,
+                            TypeInformation.of(DenseVector.class)
+                        },
+                        new String[] {idCol, getInputCol(), indexCol, hashValueCol});
+
+        return tEnv.toDataStream(dataTable)
+                .flatMap(
+                        new ExplodeHashValuesFlatMapFunction(idCol, getInputCol(), getOutputCol()),
+                        outputTypeInfo);
+    }
+
+    private static class PredictOutputMapFunction extends RichMapFunction<Row, Row> {
+        private final String inputCol;
+
+        private MinHashLSHModelData modelData;
+
+        public PredictOutputMapFunction(String inputCol) {
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row value) throws Exception {
+            if (null == modelData) {
+                modelData =
+                        (MinHashLSHModelData)

Review Comment:
   What is the difference between `MinHashLSH` and `LSH`? Are we going to implement another class extending `LSH`, which have exactly the same model data with `MinHashModelData`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java:
##########
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/**
+ * Base class for LSH model.
+ *
+ * @param <T> class type of the LSHModel implementation itself.
+ */
+abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, LSHParams<T> {
+    private static final String MODEL_DATA_BC_KEY = "modelData";
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    /** Stores the corresponding model data class of T. */
+    private final Class<?> modelDataClass;
+
+    protected Table modelDataTable;
+
+    public LSHModel(Class<?> modelDataClass) {
+        this.modelDataClass = modelDataClass;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public T setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return (T) this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<?> modelData = tEnv.toDataStream(modelDataTable, modelDataClass);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?> outputType = TypeInformation.of(DenseVector[].class);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.map(
+                                    new PredictOutputMapFunction(getInputCol()), outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /**
+     * Given a dataset and an item, approximately find at most k items which have the closest
+     * distance to the item. If the `outputCol` is missing in the given dataset, this method
+     * transforms the dataset with the model at first.
+     *
+     * @param dataset The dataset in which to to search for nearest neighbors.
+     * @param key The item to search for.
+     * @param k The maximum number of nearest neighbors.
+     * @param distCol The output column storing the distance between each neighbor and the key.
+     * @return A dataset containing at most k items closest to the key with a column named `distCol`
+     *     appended.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment();
+        Table transformedTable =
+                (dataset.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataset
+                        : transform(dataset)[0];
+
+        DataStream<MinHashLSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, MinHashLSHModelData.class);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), distCol));
+
+        // Fetch items in the same bucket with key's, and calculate their distances to key.
+        DataStream<Row> filteredData =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(transformedTable)),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.flatMap(
+                                    new FilterBySameBucketsFlatMapFunction(
+                                            getInputCol(), getOutputCol(), key),
+                                    outputTypeInfo);
+                        });
+        DataStream<Row> partitionedTopKData =
+                DataStreamUtils.mapPartition(
+                        filteredData, new TopKMapPartitionFunction(distCol, k));
+        DataStream<Row> topKData =
+                DataStreamUtils.mapPartition(
+                        partitionedTopKData.rebalance(), new TopKMapPartitionFunction(distCol, k));

Review Comment:
   The `rebalance` here seems uncessary here since the parallelism is set as one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org