You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/07/26 02:02:45 UTC
[flink-ml] branch master updated: [FLINK-28501] Add Transformer and Estimator for VectorIndexer
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new e31101f [FLINK-28501] Add Transformer and Estimator for VectorIndexer
e31101f is described below
commit e31101fdb614ea6366111f761e3176c5c2ab6ef0
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Tue Jul 26 10:02:41 2022 +0800
[FLINK-28501] Add Transformer and Estimator for VectorIndexer
This closes #134.
---
.../org/apache/flink/ml/linalg/DenseVector.java | 5 +
.../org/apache/flink/ml/linalg/SparseVector.java | 24 +-
.../java/org/apache/flink/ml/linalg/Vector.java | 3 +
.../apache/flink/ml/linalg/DenseVectorTest.java | 10 +
.../apache/flink/ml/linalg/SparseVectorTest.java | 13 +
.../ml/examples/feature/VectorIndexerExample.java | 81 ++++++
.../ml/feature/vectorindexer/VectorIndexer.java | 255 ++++++++++++++++++
.../feature/vectorindexer/VectorIndexerModel.java | 200 +++++++++++++++
.../vectorindexer/VectorIndexerModelData.java | 133 ++++++++++
.../vectorindexer/VectorIndexerModelParams.java | 34 +--
.../feature/vectorindexer/VectorIndexerParams.java | 46 ++++
.../apache/flink/ml/feature/VectorIndexerTest.java | 285 +++++++++++++++++++++
.../examples/ml/feature/vectorindexer_example.py | 80 ++++++
flink-ml-python/pyflink/ml/core/linalg.py | 27 +-
.../pyflink/ml/core/tests/test_linalg.py | 15 ++
.../ml/lib/feature/tests/test_vectorindexer.py | 103 ++++++++
.../pyflink/ml/lib/feature/vectorindexer.py | 127 +++++++++
17 files changed, 1415 insertions(+), 26 deletions(-)
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
index b4f0603..ff08503 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
@@ -46,6 +46,11 @@ public class DenseVector implements Vector {
return values[i];
}
+ @Override
+ public void set(int i, double value) {
+ values[i] = value;
+ }
+
@Override
public double[] toArray() {
return values;
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
index f241e30..3c043fc 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
@@ -29,8 +29,8 @@ import java.util.Objects;
@TypeInfo(SparseVectorTypeInfoFactory.class)
public class SparseVector implements Vector {
public final int n;
- public final int[] indices;
- public final double[] values;
+ public int[] indices;
+ public double[] values;
public SparseVector(int n, int[] indices, double[] values) {
this.n = n;
@@ -56,6 +56,26 @@ public class SparseVector implements Vector {
return 0.;
}
+ @Override
+ public void set(int i, double value) {
+ int pos = Arrays.binarySearch(indices, i);
+ if (pos >= 0) {
+ values[pos] = value;
+ } else if (value != 0.0) {
+ Preconditions.checkArgument(i < n, "Index out of bounds: " + i);
+ int[] indices = new int[this.indices.length + 1];
+ double[] values = new double[this.indices.length + 1];
+ System.arraycopy(this.indices, 0, indices, 0, -pos - 1);
+ System.arraycopy(this.values, 0, values, 0, -pos - 1);
+ indices[-pos - 1] = i;
+ values[-pos - 1] = value;
+ System.arraycopy(this.indices, -pos - 1, indices, -pos, this.indices.length + pos + 1);
+ System.arraycopy(this.values, -pos - 1, values, -pos, this.indices.length + pos + 1);
+ this.indices = indices;
+ this.values = values;
+ }
+ }
+
@Override
public double[] toArray() {
double[] result = new double[n];
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
index 21718b3..976e9a7 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
@@ -29,6 +29,9 @@ public interface Vector extends Serializable {
/** Gets the value of the ith element. */
double get(int i);
+ /** Sets the value of the ith element. */
+ void set(int i, double value);
+
/** Converts the instance to a double array. */
double[] toArray();
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java
index 12ebc61..427403d 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.ml.linalg;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
/** Tests the behavior of {@link DenseVector}. */
public class DenseVectorTest {
@@ -37,4 +38,13 @@ public class DenseVectorTest {
assertArrayEquals(denseVec.values, new double[] {1, 2, 3}, TOLERANCE);
assertArrayEquals(clonedDenseVec.values, new double[] {-1, 2, 3}, TOLERANCE);
}
+
+ @Test
+ public void testGetAndSet() {
+ DenseVector denseVec = Vectors.dense(1, 2, 3);
+ assertEquals(1, denseVec.get(0), TOLERANCE);
+
+ denseVec.set(0, 2);
+ assertEquals(2, denseVec.get(0), TOLERANCE);
+ }
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
index 163586d..916963f 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
@@ -147,4 +147,17 @@ public class SparseVectorTest {
assertArrayEquals(clonedSparseVec.indices, new int[] {0, 2});
assertArrayEquals(clonedSparseVec.values, new double[] {-1, 3}, TOLERANCE);
}
+
+ @Test
+ public void testGetAndSet() {
+ SparseVector sparseVec = Vectors.sparse(4, new int[] {2}, new double[] {0.3});
+ assertEquals(0, sparseVec.get(0), TOLERANCE);
+ assertEquals(0.3, sparseVec.get(2), TOLERANCE);
+
+ sparseVec.set(2, 0.5);
+ assertEquals(0.5, sparseVec.get(2), TOLERANCE);
+
+ sparseVec.set(0, 0.1);
+ assertEquals(0.1, sparseVec.get(0), TOLERANCE);
+ }
}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorIndexerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorIndexerExample.java
new file mode 100644
index 0000000..b80506e
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorIndexerExample.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexer;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModel;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Simple program that creates a VectorIndexer instance and uses it for feature engineering. */
+public class VectorIndexerExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ List<Row> trainInput =
+ Arrays.asList(
+ Row.of(Vectors.dense(1, 1)),
+ Row.of(Vectors.dense(2, -1)),
+ Row.of(Vectors.dense(3, 1)),
+ Row.of(Vectors.dense(4, 0)),
+ Row.of(Vectors.dense(5, 0)));
+
+ List<Row> predictInput =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 2)),
+ Row.of(Vectors.dense(0, 0)),
+ Row.of(Vectors.dense(0, -1)));
+
+ Table trainTable = tEnv.fromDataStream(env.fromCollection(trainInput)).as("input");
+ Table predictTable = tEnv.fromDataStream(env.fromCollection(predictInput)).as("input");
+
+ // Creates a VectorIndexer object and initializes its parameters.
+ VectorIndexer vectorIndexer =
+ new VectorIndexer()
+ .setInputCol("input")
+ .setOutputCol("output")
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID)
+ .setMaxCategories(3);
+
+ // Trains the VectorIndexer Model.
+ VectorIndexerModel model = vectorIndexer.fit(trainTable);
+
+ // Uses the VectorIndexer Model for predictions.
+ Table outputTable = model.transform(predictTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ System.out.printf(
+ "Input Value: %s \tOutput Value: %s\n",
+ row.getField(vectorIndexer.getInputCol()),
+ row.getField(vectorIndexer.getOutputCol()));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
new file mode 100644
index 0000000..74d6a6b
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the vector indexing algorithm.
+ *
+ * <p>A vector indexer maps each column of the input vector into a continuous/categorical feature.
+ * Whether one feature is transformed into a continuous or categorical feature depends on the number
+ * of distinct values in this column. If the number of distinct values in one column is greater than
+ * a specified parameter (i.e., maxCategories), the corresponding output column is unchanged.
+ * Otherwise, it is transformed into a categorical value. For categorical outputs, the indices are
+ * in [0, numDistinctValuesInThisColumn].
+ *
+ * <p>The output model is organized in ascending order except that 0.0 is always mapped to 0 (for
+ * sparsity). We list two examples here:
+ *
+ * <ul>
+ * <li>If one column contains {-1.0, 1.0}, then -1.0 should be encoded as 0 and 1.0 will be
+ * encoded as 1.
+ * <li>If one column contains {-1.0, 0.0, 1.0}, then -1.0 should be encoded as 1, 0.0 should be
+ * encoded as 0 and 1.0 should be encoded as 2.
+ * </ul>
+ *
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries in a
+ * special bucket, whose index is the number of distinct values in this column.
+ */
+public class VectorIndexer
+ implements Estimator<VectorIndexer, VectorIndexerModel>,
+ VectorIndexerParams<VectorIndexer> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public VectorIndexer() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public VectorIndexerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ int maxCategories = getMaxCategories();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<HashSet<Double>[]> localDistinctDoubles =
+ tEnv.toDataStream(inputs[0])
+ .transform(
+ "computeDistinctDoublesOperator",
+ TypeInformation.of(new TypeHint<HashSet<Double>[]>() {}),
+ new ComputeDistinctDoublesOperator(getInputCol(), maxCategories));
+
+ DataStream<HashSet<Double>[]> distinctDoubles =
+ DataStreamUtils.reduce(
+ localDistinctDoubles,
+ (ReduceFunction<HashSet<Double>[]>)
+ (value1, value2) -> {
+ for (int i = 0; i < value1.length; i++) {
+ if (value1[i] == null || value2[i] == null) {
+ value1[i] = null;
+ } else {
+ value1[i].addAll(value2[i]);
+ }
+ }
+ return value1;
+ });
+
+ DataStream<VectorIndexerModelData> modelData =
+ distinctDoubles.map(new ModelGenerator(maxCategories));
+ modelData.getTransformation().setParallelism(1);
+
+ VectorIndexerModel model =
+ new VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, paramMap);
+ return model;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static VectorIndexer load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ /**
+ * Computes the distinct doubles by columns. If the number of distinct values in one column is
+ * greater than maxCategories, the corresponding returned HashSet is null.
+ */
+ private static class ComputeDistinctDoublesOperator
+ extends AbstractStreamOperator<HashSet<Double>[]>
+ implements OneInputStreamOperator<Row, HashSet<Double>[]>, BoundedOneInput {
+ /** The name of input column. */
+ private final String inputCol;
+ /** Max number of categories. */
+ private final int maxCategories;
+ /** The distinct doubles of each column. */
+ private HashSet<Double>[] doublesByColumn;
+ /** The state of doublesByColumn. */
+ private ListState<HashSet<Double>[]> doublesByColumnState;
+
+ public ComputeDistinctDoublesOperator(String inputCol, int maxCategories) {
+ this.inputCol = inputCol;
+ this.maxCategories = maxCategories;
+ }
+
+ @Override
+ public void endInput() {
+ if (doublesByColumn != null) {
+ output.collect(new StreamRecord<>(doublesByColumn));
+ }
+ doublesByColumnState.clear();
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> element) {
+ if (doublesByColumn == null) {
+ // First record.
+ Vector vector = (Vector) element.getValue().getField(inputCol);
+ doublesByColumn = new HashSet[vector.size()];
+ for (int i = 0; i < doublesByColumn.length; i++) {
+ doublesByColumn[i] = new HashSet<>();
+ }
+ }
+
+ Vector vector = (Vector) element.getValue().getField(inputCol);
+ Preconditions.checkState(
+ vector.size() == doublesByColumn.length,
+ "The size of the all input vectors should be the same.");
+ double[] values = vector.toDense().values;
+ for (int i = 0; i < values.length; i++) {
+ if (doublesByColumn[i] != null) {
+ doublesByColumn[i].add(values[i]);
+ if (doublesByColumn[i].size() > maxCategories) {
+ doublesByColumn[i] = null;
+ }
+ }
+ }
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ doublesByColumnState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "doublesByColumnState",
+ TypeInformation.of(
+ new TypeHint<HashSet<Double>[]>() {})));
+
+ OperatorStateUtils.getUniqueElement(doublesByColumnState, "doublesByColumnState")
+ .ifPresent(x -> doublesByColumn = x);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ doublesByColumnState.update(Collections.singletonList(doublesByColumn));
+ }
+ }
+
+ /**
+ * Merges all the distinct doubles by columns and generates the {@link VectorIndexerModelData}.
+ */
+ private static class ModelGenerator
+ implements MapFunction<HashSet<Double>[], VectorIndexerModelData> {
+ private final int maxCategories;
+
+ public ModelGenerator(int maxCategories) {
+ this.maxCategories = maxCategories;
+ }
+
+ @Override
+ public VectorIndexerModelData map(HashSet<Double>[] distinctDoubles) {
+ Map<Integer, Map<Double, Integer>> categoryMaps = new HashMap<>();
+ for (int i = 0; i < distinctDoubles.length; i++) {
+ if (distinctDoubles[i] != null && distinctDoubles[i].size() <= maxCategories) {
+ double[] values =
+ distinctDoubles[i].stream().mapToDouble(Double::doubleValue).toArray();
+ Arrays.sort(values);
+ // If 0 exists, we put it as the first element.
+ int index0 = Arrays.binarySearch(values, 0);
+ while (index0 > 0) {
+ values[index0] = values[--index0];
+ }
+ if (index0 == 0) {
+ values[index0] = 0;
+ }
+ Map<Double, Integer> valueAndIndex = new HashMap<>(values.length);
+ for (int valueIdx = 0; valueIdx < values.length; valueIdx++) {
+ valueAndIndex.put(values[valueIdx], valueIdx);
+ }
+ categoryMaps.put(i, valueAndIndex);
+ }
+ }
+
+ return new VectorIndexerModelData(categoryMaps);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java
new file mode 100644
index 0000000..3becd23
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which encodes input vector to an output vector using the model data computed by {@link
+ * VectorIndexer}.
+ *
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries in a
+ * special bucket, whose index is the number of distinct values in this column.
+ */
+public class VectorIndexerModel
+ implements Model<VectorIndexerModel>, VectorIndexerModelParams<VectorIndexerModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public VectorIndexerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked, rawtypes")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ String inputCol = getInputCol();
+ String outputCol = getOutputCol();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCol));
+
+ final String broadcastModelKey = "broadcastModelKey";
+ DataStream<VectorIndexerModelData> modelDataStream =
+ VectorIndexerModelData.getModelDataStream(modelDataTable);
+
+ DataStream<Row> result =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(tEnv.toDataStream(inputs[0])),
+ Collections.singletonMap(broadcastModelKey, modelDataStream),
+ inputList -> {
+ DataStream inputData = inputList.get(0);
+ return inputData.flatMap(
+ new FindIndex(broadcastModelKey, inputCol, getHandleInvalid()),
+ outputTypeInfo);
+ });
+
+ return new Table[] {tEnv.fromDataStream(result)};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ VectorIndexerModelData.getModelDataStream(modelDataTable),
+ path,
+ new VectorIndexerModelData.ModelDataEncoder());
+ }
+
+ public static VectorIndexerModel load(StreamTableEnvironment tEnv, String path)
+ throws IOException {
+ VectorIndexerModel model = ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(
+ tEnv, path, new VectorIndexerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public VectorIndexerModel setModelData(Table... inputs) {
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ /** Finds the index for the input vector using the model data. */
+ private static class FindIndex extends RichFlatMapFunction<Row, Row> {
+ private final String broadcastModelKey;
+ private final String inputCol;
+ private final String handleInValid;
+ private Map<Integer, Map<Double, Integer>> categoryMaps;
+
+ public FindIndex(String broadcastModelKey, String inputCol, String handleInValid) {
+ this.broadcastModelKey = broadcastModelKey;
+ this.inputCol = inputCol;
+ this.handleInValid = handleInValid;
+ }
+
+ @Override
+ public void flatMap(Row input, Collector<Row> out) {
+ if (categoryMaps == null) {
+ VectorIndexerModelData modelData =
+ (VectorIndexerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+ categoryMaps = modelData.categoryMaps;
+ }
+
+ Vector outputVector = ((Vector) input.getField(inputCol)).clone();
+ for (Map.Entry<Integer, Map<Double, Integer>> entry : categoryMaps.entrySet()) {
+ int columnId = entry.getKey();
+ Map<Double, Integer> mapping = entry.getValue();
+ double feature = outputVector.get(columnId);
+ Integer categoricalFeature = getMapping(feature, mapping, handleInValid);
+ if (categoricalFeature == null) {
+ return;
+ } else {
+ outputVector.set(columnId, categoricalFeature);
+ }
+ }
+
+ out.collect(Row.join(input, Row.of(outputVector)));
+ }
+ }
+
+ /**
+ * Maps the input feature to a categorical value using the mappings.
+ *
+ * @param feature The input continuous feature.
+ * @param mapping The mappings from continues features to categorical features.
+ * @param handleInValid The way to handle invalid features.
+ * @return The categorical value. Returns null if invalid values are skipped.
+ */
+ private static Integer getMapping(
+ double feature, Map<Double, Integer> mapping, String handleInValid) {
+ if (mapping.containsKey(feature)) {
+ return mapping.get(feature);
+ } else {
+ switch (handleInValid) {
+ case SKIP_INVALID:
+ return null;
+ case ERROR_INVALID:
+ throw new RuntimeException(
+ "The input contains unseen double: "
+ + feature
+ + ". See "
+ + HANDLE_INVALID
+ + " parameter for more options.");
+ case KEEP_INVALID:
+ return mapping.size();
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported " + HANDLE_INVALID + "type: " + handleInValid);
+ }
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelData.java
new file mode 100644
index 0000000..2d896eb
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelData.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Map;
+
+/**
+ * Model data of {@link VectorIndexerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to DataStream, and classes
+ * to save/load model data.
+ */
+public class VectorIndexerModelData {
+ /**
+ * Index of feature values. Keys are column indices. Values are mapping from original continuous
+ * features values to 0-based categorical indices. If a feature is not in this map, it is
+ * treated as a continuous feature.
+ */
+ public Map<Integer, Map<Double, Integer>> categoryMaps;
+
+ public VectorIndexerModelData(Map<Integer, Map<Double, Integer>> categoryMaps) {
+ this.categoryMaps = categoryMaps;
+ }
+
+ public VectorIndexerModelData() {}
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<VectorIndexerModelData> getModelDataStream(Table modelData) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
+ return tEnv.toDataStream(modelData)
+ .map(
+ x ->
+ new VectorIndexerModelData(
+ (Map<Integer, Map<Double, Integer>>) x.getField(0)));
+ }
+
+ /** Data encoder for {@link VectorIndexerModel}. */
+ public static class ModelDataEncoder implements Encoder<VectorIndexerModelData> {
+
+ @Override
+ public void encode(VectorIndexerModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DataOutputViewStreamWrapper outputViewStreamWrapper =
+ new DataOutputViewStreamWrapper(outputStream);
+
+ MapSerializer<Integer, Map<Double, Integer>> mapSerializer =
+ new MapSerializer<>(
+ IntSerializer.INSTANCE,
+ new MapSerializer<>(DoubleSerializer.INSTANCE, IntSerializer.INSTANCE));
+
+ mapSerializer.serialize(modelData.categoryMaps, outputViewStreamWrapper);
+ }
+ }
+
+ /** Data decoder for {@link VectorIndexerModel}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<VectorIndexerModelData> {
+
+ @Override
+ public Reader<VectorIndexerModelData> createReader(
+ Configuration configuration, FSDataInputStream inputStream) {
+ return new Reader<VectorIndexerModelData>() {
+
+ @Override
+ public VectorIndexerModelData read() throws IOException {
+ try {
+ DataInputViewStreamWrapper inputViewStreamWrapper =
+ new DataInputViewStreamWrapper(inputStream);
+ MapSerializer<Integer, Map<Double, Integer>> mapSerializer =
+ new MapSerializer<>(
+ IntSerializer.INSTANCE,
+ new MapSerializer<>(
+ DoubleSerializer.INSTANCE, IntSerializer.INSTANCE));
+ Map<Integer, Map<Double, Integer>> categoryMaps =
+ mapSerializer.deserialize(inputViewStreamWrapper);
+ return new VectorIndexerModelData(categoryMaps);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ inputStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<VectorIndexerModelData> getProducedType() {
+ return TypeInformation.of(VectorIndexerModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelParams.java
similarity index 58%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelParams.java
index 21718b3..e9e89ee 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelParams.java
@@ -16,28 +16,16 @@
* limitations under the License.
*/
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.feature.vectorindexer;
-import java.io.Serializable;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
-/** A vector of double values. */
-public interface Vector extends Serializable {
-
- /** Gets the size of the vector. */
- int size();
-
- /** Gets the value of the ith element. */
- double get(int i);
-
- /** Converts the instance to a double array. */
- double[] toArray();
-
- /** Converts the instance to a dense vector. */
- DenseVector toDense();
-
- /** Converts the instance to a sparse vector. */
- SparseVector toSparse();
-
- /** Makes a deep copy of the vector. */
- Vector clone();
-}
+/**
+ * Params for {@link VectorIndexerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface VectorIndexerModelParams<T>
+ extends HasInputCol<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerParams.java
new file mode 100644
index 0000000..9af8b8a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerParams.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link VectorIndexer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface VectorIndexerParams<T> extends VectorIndexerModelParams<T> {
+ Param<Integer> MAX_CATEGORIES =
+ new IntParam(
+ "maxCategories",
+ "Threshold for the number of values a categorical feature can take (>= 2). "
+ + "If a feature is found to have > maxCategories values, then it is declared continuous.",
+ 20,
+ ParamValidators.gtEq(2));
+
+ default T setMaxCategories(int value) {
+ return set(MAX_CATEGORIES, value);
+ }
+
+ default int getMaxCategories() {
+ return get(MAX_CATEGORIES);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
new file mode 100644
index 0000000..17ffda5
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexer;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModel;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModelData;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link VectorIndexer} and {@link VectorIndexerModel}. */
+public class VectorIndexerTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainInputTable;
+ private Table testInputTable;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ List<Row> trainInput =
+ Arrays.asList(
+ Row.of(Vectors.dense(1, 1)),
+ Row.of(Vectors.dense(2, -1)),
+ Row.of(Vectors.dense(3, 1)),
+ Row.of(Vectors.dense(4, 0)),
+ Row.of(Vectors.dense(5, 0)));
+ List<Row> testInput =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 2)),
+ Row.of(Vectors.dense(0, 0)),
+ Row.of(Vectors.dense(0, -1)));
+ trainInputTable = tEnv.fromDataStream(env.fromCollection(trainInput)).as("input");
+ testInputTable = tEnv.fromDataStream(env.fromCollection(testInput)).as("input");
+ }
+
+ @Test
+ public void testParam() {
+ VectorIndexer vectorIndexer = new VectorIndexer();
+ assertEquals("input", vectorIndexer.getInputCol());
+ assertEquals("output", vectorIndexer.getOutputCol());
+ assertEquals(20, vectorIndexer.getMaxCategories());
+ assertEquals(HasHandleInvalid.ERROR_INVALID, vectorIndexer.getHandleInvalid());
+
+ vectorIndexer
+ .setInputCol("test_input")
+ .setOutputCol("test_output")
+ .setMaxCategories(3)
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+
+ assertEquals("test_input", vectorIndexer.getInputCol());
+ assertEquals("test_output", vectorIndexer.getOutputCol());
+ assertEquals(3, vectorIndexer.getMaxCategories());
+ assertEquals(HasHandleInvalid.KEEP_INVALID, vectorIndexer.getHandleInvalid());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ VectorIndexer vectorIndexer = new VectorIndexer();
+ Table output = vectorIndexer.fit(trainInputTable).transform(trainInputTable)[0];
+
+ assertEquals(
+ Arrays.asList(vectorIndexer.getInputCol(), vectorIndexer.getOutputCol()),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredictOnSparseInput() throws Exception {
+ List<Row> sparseTrainInput =
+ Arrays.asList(
+ Row.of(Vectors.sparse(2, new int[] {0}, new double[] {1})),
+ Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {2, -1})),
+ Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {3, 1})),
+ Row.of(Vectors.sparse(2, new int[] {0}, new double[] {4})),
+ Row.of(Vectors.sparse(2, new int[] {0}, new double[] {5})));
+
+ List<Row> sparseTestInput =
+ Collections.singletonList(
+ Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {0, 2})));
+ Table sparseTrainTable =
+ tEnv.fromDataStream(env.fromCollection(sparseTrainInput)).as("input");
+ Table sparseTestTable =
+ tEnv.fromDataStream(env.fromCollection(sparseTestInput)).as("input");
+
+ Table output =
+ new VectorIndexer()
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID)
+ .setMaxCategories(3)
+ .fit(sparseTrainTable)
+ .transform(sparseTestTable)[0];
+
+ List<Row> expectedOutput =
+ Collections.singletonList(
+ Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {0, 3})));
+ verifyPredictionResult(expectedOutput, output, "output");
+ }
+
+ @Test
+ public void testFitAndPredictWithLargeMaxCategories() throws Exception {
+ VectorIndexer vectorIndexer =
+ new VectorIndexer()
+ .setMaxCategories(Integer.MAX_VALUE)
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+
+ Table output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+ List<Row> expectedOutput =
+ Arrays.asList(
+ Row.of(Vectors.dense(5, 3)),
+ Row.of(Vectors.dense(5, 0)),
+ Row.of(Vectors.dense(5, 1)));
+ verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+ }
+
+ @Test
+ public void testFitAndPredictWithHandleInvalid() throws Exception {
+ Table output;
+ List<Row> expectedOutput;
+ VectorIndexer vectorIndexer = new VectorIndexer().setMaxCategories(3);
+
+ // Keeps invalid data.
+ expectedOutput =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 3)),
+ Row.of(Vectors.dense(0, 0)),
+ Row.of(Vectors.dense(0, 1)));
+ vectorIndexer.setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+ output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+ verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+
+ // Skips invalid data.
+ vectorIndexer.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+ expectedOutput = Arrays.asList(Row.of(Vectors.dense(0, 0)), Row.of(Vectors.dense(0, 1)));
+ output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+ verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+
+ // Throws an exception on invalid data.
+ vectorIndexer.setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+ try {
+ output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ fail();
+ } catch (Throwable e) {
+ assertEquals(
+ "The input contains unseen double: 2.0. "
+ + "See "
+ + HasHandleInvalid.HANDLE_INVALID
+ + " parameter for more options.",
+ ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ VectorIndexer vectorIndexer =
+ new VectorIndexer().setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+ vectorIndexer =
+ TestUtils.saveAndReload(
+ tEnv, vectorIndexer, tempFolder.newFolder().getAbsolutePath());
+
+ VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
+ model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+
+ assertEquals(
+ Collections.singletonList("categoryMaps"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+ Table output = model.transform(testInputTable)[0];
+ List<Row> expectedOutput =
+ Arrays.asList(
+ Row.of(Vectors.dense(5, 3)),
+ Row.of(Vectors.dense(5, 0)),
+ Row.of(Vectors.dense(5, 1)));
+ verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testGetModelData() throws Exception {
+ VectorIndexer vectorIndexer = new VectorIndexer().setMaxCategories(3);
+ VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
+ Table modelDataTable = model.getModelData()[0];
+
+ assertEquals(
+ Collections.singletonList("categoryMaps"),
+ modelDataTable.getResolvedSchema().getColumnNames());
+
+ List<VectorIndexerModelData> collectedModelData =
+ (List<VectorIndexerModelData>)
+ IteratorUtils.toList(
+ VectorIndexerModelData.getModelDataStream(modelDataTable)
+ .executeAndCollect());
+
+ assertEquals(1, collectedModelData.size());
+ HashMap<Double, Integer> column1ModelData = new HashMap<>();
+ column1ModelData.put(-1.0, 1);
+ column1ModelData.put(0.0, 0);
+ column1ModelData.put(1.0, 2);
+ assertEquals(
+ Collections.singletonMap(1, column1ModelData),
+ collectedModelData.get(0).categoryMaps);
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ VectorIndexer vectorIndexer =
+ new VectorIndexer().setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+ VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
+
+ VectorIndexerModel newModel = new VectorIndexerModel();
+ ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ newModel.setModelData(model.getModelData());
+ Table output = newModel.transform(testInputTable)[0];
+
+ List<Row> expectedOutput =
+ Arrays.asList(
+ Row.of(Vectors.dense(5, 3)),
+ Row.of(Vectors.dense(5, 0)),
+ Row.of(Vectors.dense(5, 1)));
+ verifyPredictionResult(expectedOutput, output, newModel.getOutputCol());
+ }
+
+ @SuppressWarnings("unchecked")
+ private void verifyPredictionResult(List<Row> expectedOutput, Table output, String outputCol)
+ throws Exception {
+ List<Row> collectedResult =
+ IteratorUtils.toList(
+ tEnv.toDataStream(output.select(Expressions.$(outputCol)))
+ .executeAndCollect());
+ compareResultCollections(
+ expectedOutput,
+ collectedResult,
+ Comparator.comparingInt(o -> (o.getField(0)).hashCode()));
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorindexer_example.py b/flink-ml-python/pyflink/examples/ml/feature/vectorindexer_example.py
new file mode 100644
index 0000000..a913444
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/vectorindexer_example.py
@@ -0,0 +1,80 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that trains a StringIndexer model and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.vectorindexer import VectorIndexer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+train_table = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(1, 1),),
+ (Vectors.dense(2, -1),),
+ (Vectors.dense(3, 1),),
+ (Vectors.dense(4, 0),),
+ (Vectors.dense(5, 0),)
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+predict_table = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(0, 2),),
+ (Vectors.dense(0, 0),),
+ (Vectors.dense(0, -1),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+# Creates a VectorIndexer object and initializes its parameters.
+vector_indexer = VectorIndexer() \
+ .set_input_col('input') \
+ .set_output_col('output') \
+ .set_handle_invalid('keep') \
+ .set_max_categories(3)
+
+# Trains the VectorIndexer Model.
+model = vector_indexer.fit(train_table)
+
+# Uses the VectorIndexer Model for predictions.
+output = model.transform(predict_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ print('Input Value: ' + str(result[field_names.index(vector_indexer.get_input_col())])
+ + '\tOutput Value: ' + str(result[field_names.index(vector_indexer.get_output_col())]))
diff --git a/flink-ml-python/pyflink/ml/core/linalg.py b/flink-ml-python/pyflink/ml/core/linalg.py
index 196a8f4..71619a3 100644
--- a/flink-ml-python/pyflink/ml/core/linalg.py
+++ b/flink-ml-python/pyflink/ml/core/linalg.py
@@ -305,6 +305,9 @@ class DenseVector(Vector):
def get(self, i: int):
return self._values[i]
+ def set(self, i: int, value: np.float64):
+ self._values[i] = value
+
def to_array(self) -> np.ndarray:
return self._values
@@ -481,7 +484,29 @@ class SparseVector(Vector):
return self._size
def get(self, i: int):
- return self._values[self._indices.searchsorted(i)]
+ idx = self._indices.searchsorted(i)
+ if idx < len(self._indices) and self._indices[idx] == i:
+ return self._values[idx]
+ else:
+ return 0.0
+
+ def set(self, i: int, value: np.float64):
+ idx = self._indices.searchsorted(i)
+ if idx < len(self._indices) and self._indices[idx] == i:
+ self._values[idx] = value
+ elif value != 0:
+ assert i < self._size
+ cur_len = len(self._indices)
+ indices = np.zeros(cur_len + 1, dtype=np.int32)
+ values = np.zeros(cur_len + 1, dtype=np.float64)
+ indices[0:idx] = self._indices[0:idx]
+ values[0:idx] = self._values[0:idx]
+ indices[idx] = i
+ values[idx] = value
+ indices[idx + 1:] = self._indices[idx:]
+ values[idx + 1:] = self._values[idx]
+ self._indices = indices
+ self._values = values
def to_array(self) -> np.ndarray:
"""
diff --git a/flink-ml-python/pyflink/ml/core/tests/test_linalg.py b/flink-ml-python/pyflink/ml/core/tests/test_linalg.py
index 2095b4a..3ea015c 100644
--- a/flink-ml-python/pyflink/ml/core/tests/test_linalg.py
+++ b/flink-ml-python/pyflink/ml/core/tests/test_linalg.py
@@ -77,3 +77,18 @@ class VectorTests(unittest.TestCase):
self.assertFalse(v2 == v4)
self.assertFalse(v1 == v5)
self.assertFalse(v1 == v6)
+
+ def test_get_set(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ self.assertEqual(0.0, v1.get(0))
+ v1.set(0, 1.0)
+ self.assertEqual(1.0, v1.get(0))
+
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ self.assertEqual(0.0, v2.get(0))
+
+ v2.set(0, 1.0)
+ self.assertEqual(1.0, v2.get(0))
+
+ v2.set(1, 2.0)
+ self.assertEqual(2.0, v2.get(1))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
new file mode 100644
index 0000000..8e63427
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
@@ -0,0 +1,103 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.vectorindexer import VectorIndexer, VectorIndexerModel
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class VectorIndexerTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(VectorIndexerTest, self).setUp()
+ self.train_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense(1, 1),),
+ (Vectors.dense(2, -1),),
+ (Vectors.dense(3, 1),),
+ (Vectors.dense(4, 0),),
+ (Vectors.dense(5, 0),)
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+ self.predict_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense(0, 2),),
+ (Vectors.dense(0, 0),),
+ (Vectors.dense(0, -1),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+ self.expected_output = [
+ Vectors.dense(5, 3),
+ Vectors.dense(5, 0),
+ Vectors.dense(5, 1)]
+
+ def test_param(self):
+ vector_indexer = VectorIndexer()
+
+ self.assertEqual('input', vector_indexer.input_col)
+ self.assertEqual('output', vector_indexer.output_col)
+ self.assertEqual(20, vector_indexer.max_categories)
+ self.assertEqual('error', vector_indexer.handle_invalid)
+
+ vector_indexer.set_input_col('test_input') \
+ .set_output_col("test_output") \
+ .set_max_categories(3) \
+ .set_handle_invalid('skip')
+
+ self.assertEqual('test_input', vector_indexer.input_col)
+ self.assertEqual('test_output', vector_indexer.output_col)
+ self.assertEqual(3, vector_indexer.max_categories)
+ self.assertEqual('skip', vector_indexer.handle_invalid)
+
+ def test_output_schema(self):
+ vector_indexer = VectorIndexer()
+
+ output = vector_indexer.fit(self.train_table).transform(self.predict_table)[0]
+
+ self.assertEqual(
+ ['input', 'output'],
+ output.get_schema().get_field_names())
+
+ def test_save_load_predict(self):
+ vector_indexer = VectorIndexer().set_handle_invalid('keep')
+ estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_vectorindexer')
+ vector_indexer.save(estimator_path)
+ vector_indexer = VectorIndexer.load(self.t_env, estimator_path)
+
+ model = vector_indexer.fit(self.train_table)
+ model_path = os.path.join(self.temp_dir, 'test_save_load_predict_vectorindexer_model')
+ model.save(model_path)
+ self.env.execute('save_model')
+ model = VectorIndexerModel.load(self.t_env, model_path)
+
+ output_table = model.transform(self.predict_table)[0]
+ predicted_results = [result[1] for result in
+ self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+ predicted_results.sort(key=lambda x: x[1])
+ self.expected_output.sort(key=lambda x: x[1])
+ self.assertEqual(self.expected_output, predicted_results)
diff --git a/flink-ml-python/pyflink/ml/lib/feature/vectorindexer.py b/flink-ml-python/pyflink/ml/lib/feature/vectorindexer.py
new file mode 100644
index 0000000..e397ecf
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/vectorindexer.py
@@ -0,0 +1,127 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import IntParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol, HasHandleInvalid
+
+
+class _VectorIndexerModelParams(
+ JavaWithParams,
+ HasInputCol,
+ HasOutputCol,
+ HasHandleInvalid
+):
+ """
+ Params for :class:`VectorIndexerModel`.
+ """
+
+ def __init__(self, java_params):
+ super(_VectorIndexerModelParams, self).__init__(java_params)
+
+
+class _VectorIndexerParams(_VectorIndexerModelParams):
+ """
+ Params for :class:`VectorIndexer`.
+ """
+
+ MAX_CATEGORIES: IntParam = IntParam(
+ "max_categories",
+ "Threshold for the number of values a categorical feature can take (>= 2). "
+ + "If a feature is found to have > maxCategories values, then it is declared continuous.",
+ 20,
+ ParamValidators.gt_eq(2)
+ )
+
+ def __init__(self, java_params):
+ super(_VectorIndexerParams, self).__init__(java_params)
+
+ def set_max_categories(self, value: int):
+ return typing.cast(_VectorIndexerParams, self.set(self.MAX_CATEGORIES, value))
+
+ def get_max_categories(self) -> int:
+ return self.get(self.MAX_CATEGORIES)
+
+ @property
+ def max_categories(self):
+ return self.get_max_categories()
+
+
+class VectorIndexerModel(JavaFeatureModel, _VectorIndexerModelParams):
+ """
+ A Model which encodes input vector to an output vector using the model data computed by
+ :class::VectorIndexer.
+
+ The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries in a
+ special bucket, whose index is the number of distinct values in this column.
+ """
+
+ def __init__(self, java_model=None):
+ super(VectorIndexerModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "vectorindexer"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "VectorIndexerModel"
+
+
+class VectorIndexer(JavaFeatureEstimator, _VectorIndexerParams):
+ """
+ An Estimator which implements the vector indexing algorithm.
+
+ A vector indexer maps each column of the input vector into a continuous/categorical
+ feature. Whether one feature is transformed into a continuous or categorical feature
+ depends on the number of distinct values in this column. If the number of distinct
+ values in one column is greater than a specified parameter (i.e., maxCategories),
+ the corresponding output column is unchanged. Otherwise, it is transformed into
+ a categorical value. For categorical outputs, the indices are
+ in [0, numDistinctValuesInThisColumn].
+
+ The output model is organized in ascending order except that 0.0 is always mapped
+ to 0 (for sparsity). We list two examples here:
+
+ <ul>
+ <li>If one column contains {-1.0, 1.0}, then -1.0 should be encoded as 0
+ and 1.0 will be encoded as 1.
+ <li>If one column contains {-1.0, 0.0, 1.0}, then -1.0 should be encoded as 1,
+ 0.0 should be encoded as 0 and 1.0 should be encoded as 2.
+ </ul>
+
+ The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries
+ in a special bucket, whose index is the number of distinct values in this column.
+ """
+
+ def __init__(self):
+ super(VectorIndexer, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> VectorIndexerModel:
+ return VectorIndexerModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "vectorindexer"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "VectorIndexer"