You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/12/16 02:01:44 UTC

[GitHub] [flink-ml] Fanoid commented on a diff in pull request #191: [FLINK-30401] Add Estimator and Transformer for MinHashLSH

Fanoid commented on code in PR #191:
URL: https://github.com/apache/flink-ml/pull/191#discussion_r1050303584


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java:
##########
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/**
+ * Base class for LSH model.
+ *
+ * @param <T> class type of the LSHModel implementation itself.
+ */
+abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, LSHParams<T> {
+    private static final String MODEL_DATA_BC_KEY = "modelData";
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    /** Stores the corresponding model data class of T. */
+    private final Class<?> modelDataClass;
+
+    protected Table modelDataTable;
+
+    public LSHModel(Class<?> modelDataClass) {
+        this.modelDataClass = modelDataClass;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public T setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return (T) this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<?> modelData = tEnv.toDataStream(modelDataTable, modelDataClass);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?> outputType = TypeInformation.of(DenseVector[].class);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.map(
+                                    new PredictOutputMapFunction(getInputCol()), outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /**
+     * Given a dataset and an item, approximately find at most k items which have the closest
+     * distance to the item. If the `outputCol` is missing in the given dataset, this method
+     * transforms the dataset with the model at first.
+     *
+     * @param dataset The dataset in which to to search for nearest neighbors.
+     * @param key The item to search for.
+     * @param k The maximum number of nearest neighbors.
+     * @param distCol The output column storing the distance between each neighbor and the key.
+     * @return A dataset containing at most k items closest to the key with a column named `distCol`
+     *     appended.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment();
+        Table transformedTable =
+                (dataset.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataset
+                        : transform(dataset)[0];
+
+        DataStream<MinHashLSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, MinHashLSHModelData.class);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), distCol));
+
+        // Fetch items in the same bucket with key's, and calculate their distances to key.
+        DataStream<Row> filteredData =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(transformedTable)),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.flatMap(
+                                    new FilterBySameBucketsFlatMapFunction(
+                                            getInputCol(), getOutputCol(), key),
+                                    outputTypeInfo);
+                        });
+        DataStream<Row> partitionedTopKData =
+                DataStreamUtils.mapPartition(
+                        filteredData, new TopKMapPartitionFunction(distCol, k));
+        DataStream<Row> topKData =
+                DataStreamUtils.mapPartition(
+                        partitionedTopKData.rebalance(), new TopKMapPartitionFunction(distCol, k));
+        topKData.getTransformation().setOutputType(outputTypeInfo);
+        topKData.getTransformation().setParallelism(1);
+        return tEnv.fromDataStream(topKData);
+    }
+
+    /**
+     * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of
+     * `distCol`.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k) {
+        return approxNearestNeighbors(dataset, key, k, "distCol");
+    }
+
+    /**
+     * Join two datasets to approximately find all pairs of rows whose distance are smaller than or
+     * equal to the threshold. If the `outputCol` is missing in either dataset, this method
+     * transforms the dataset at first.
+     *
+     * @param datasetA One dataset.
+     * @param datasetB The other dataset.
+     * @param threshold The distance threshold.
+     * @param idCol A column in the two datasets to identify each row.
+     * @param distCol The output column storing the distance between each pair of rows.
+     * @return A joined dataset containing pairs of rows. The original rows are in columns
+     *     "datasetA" and "datasetB", and a column "distCol" is added to show the distance between
+     *     each pair.
+     */
+    public Table approxSimilarityJoin(
+            Table datasetA, Table datasetB, double threshold, String idCol, String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) datasetA).getTableEnvironment();
+
+        DataStream<Row> explodedA = preprocessData(datasetA, idCol);
+        DataStream<Row> explodedB = preprocessData(datasetB, idCol);
+
+        DataStream<MinHashLSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, MinHashLSHModelData.class);
+        DataStream<Row> sameBucketPairs =
+                explodedA
+                        .join(explodedB)
+                        .where(new IndexHashValueKeySelector())
+                        .equalTo(new IndexHashValueKeySelector())
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (r0, r1) ->
+                                        Row.of(
+                                                r0.getField(0),
+                                                r1.getField(0),
+                                                r0.getField(1),
+                                                r1.getField(1)));
+        DataStream<Row> distinctSameBucketPairs =
+                DataStreamUtils.reduce(
+                        sameBucketPairs.keyBy(
+                                new KeySelector<Row, Tuple2<Integer, Integer>>() {
+                                    @Override
+                                    public Tuple2<Integer, Integer> getKey(Row r) {
+                                        return Tuple2.of(r.getFieldAs(0), r.getFieldAs(1));
+                                    }
+                                }),
+                        (r0, r1) -> r0);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(datasetA.getResolvedSchema());
+        TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol);
+        DataStream<Row> pairsWithDists =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(distinctSameBucketPairs),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            DataStream<Row> data = (DataStream<Row>) inputList.get(0);
+                            return data.flatMap(
+                                    new FilterByDistanceFlatMapFunction(threshold),
+                                    new RowTypeInfo(
+                                            new TypeInformation[] {
+                                                idColType, idColType, Types.DOUBLE
+                                            },
+                                            new String[] {"datasetA.id", "datasetB.id", distCol}));
+                        });
+        return tEnv.fromDataStream(pairsWithDists);
+    }
+
+    /**
+     * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of
+     * `distCol`.
+     */
+    public Table approxSimilarityJoin(
+            Table datasetA, Table datasetB, double threshold, String idCol) {
+        return approxSimilarityJoin(datasetA, datasetB, threshold, idCol, "distCol");
+    }
+
+    private DataStream<Row> preprocessData(Table dataTable, String idCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment();
+
+        dataTable =
+                (dataTable.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataTable
+                        : transform(dataTable)[0];
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(dataTable.getResolvedSchema());
+        TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol);
+        final String indexCol = "index";
+        final String hashValueCol = "hashValue";
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            idColType,
+                            TypeInformation.of(Vector.class),
+                            Types.INT,
+                            TypeInformation.of(DenseVector.class)
+                        },
+                        new String[] {idCol, getInputCol(), indexCol, hashValueCol});
+
+        return tEnv.toDataStream(dataTable)
+                .flatMap(
+                        new ExplodeHashValuesFlatMapFunction(idCol, getInputCol(), getOutputCol()),
+                        outputTypeInfo);
+    }
+
+    private static class PredictOutputMapFunction extends RichMapFunction<Row, Row> {
+        private final String inputCol;
+
+        private MinHashLSHModelData modelData;
+
+        public PredictOutputMapFunction(String inputCol) {
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row value) throws Exception {
+            if (null == modelData) {
+                modelData =
+                        (MinHashLSHModelData)

Review Comment:
   Ah, it's my fault.  Yes, there will be `BucketedRandomProjectionLSH` to extend `LSH` with different model data.
   So, I should use the base class `LSHScheme` instead of `MinHashLSHModelData`.
   
   I'll check this class and correct all such problems.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org