You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/08/01 03:55:29 UTC
[flink-ml] branch master updated: [FLINK-28601] Add Transformer for FeatureHasher
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 41cccd1 [FLINK-28601] Add Transformer for FeatureHasher
41cccd1 is described below
commit 41cccd144cf8bfdb219b8e6a736697f8a04ad5a8
Author: weibo <wb...@pku.edu.cn>
AuthorDate: Mon Aug 1 11:55:25 2022 +0800
[FLINK-28601] Add Transformer for FeatureHasher
This closes #133.
---
.../ml/examples/feature/FeatureHasherExample.java | 71 +++++++
.../flink/ml/common/param/HasCategoricalCols.java | 42 +++++
.../flink/ml/common/param/HasNumFeatures.java | 43 +++++
.../ml/feature/featurehasher/FeatureHasher.java | 209 +++++++++++++++++++++
.../feature/featurehasher/FeatureHasherParams.java | 32 ++++
.../apache/flink/ml/feature/FeatureHasherTest.java | 152 +++++++++++++++
...sembler_example.py => featurehasher_example.py} | 41 ++--
.../examples/ml/feature/vectorassembler_example.py | 2 +-
.../pyflink/ml/lib/feature/featurehasher.py | 67 +++++++
.../ml/lib/feature/tests/test_feature_hasher.py | 77 ++++++++
flink-ml-python/pyflink/ml/lib/param.py | 42 +++++
11 files changed, 754 insertions(+), 24 deletions(-)
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java
new file mode 100644
index 0000000..c0f81c6
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/** Simple program that creates a FeatureHasher instance and uses it for feature engineering. */
+public class FeatureHasherExample {
+ public static void main(String[] args) {
+
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> dataStream =
+ env.fromCollection(
+ Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false)));
+ Table inputDataTable = tEnv.fromDataStream(dataStream).as("id", "f0", "f1", "f2");
+
+ // Creates a FeatureHasher object and initializes its parameters.
+ FeatureHasher featureHash =
+ new FeatureHasher()
+ .setInputCols("f0", "f1", "f2")
+ .setCategoricalCols("f0", "f2")
+ .setOutputCol("vec")
+ .setNumFeatures(1000);
+
+ // Uses the FeatureHasher object for feature transformations.
+ Table outputTable = featureHash.transform(inputDataTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+
+ Object[] inputValues = new Object[featureHash.getInputCols().length];
+ for (int i = 0; i < inputValues.length; i++) {
+ inputValues[i] = row.getField(featureHash.getInputCols()[i]);
+ }
+ Vector outputValue = (Vector) row.getField(featureHash.getOutputCol());
+
+ System.out.printf(
+ "Input Values: %s \tOutput Value: %s\n",
+ Arrays.toString(inputValues), outputValue);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasCategoricalCols.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasCategoricalCols.java
new file mode 100644
index 0000000..fcc340b
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasCategoricalCols.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared categoricalCols param. */
+public interface HasCategoricalCols<T> extends WithParams<T> {
+ Param<String[]> CATEGORICAL_COLS =
+ new StringArrayParam(
+ "categoricalCols",
+ "Categorical column names.",
+ new String[] {},
+ ParamValidators.notNull());
+
+ default String[] getCategoricalCols() {
+ return get(CATEGORICAL_COLS);
+ }
+
+ default T setCategoricalCols(String... value) {
+ return set(CATEGORICAL_COLS, value);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasNumFeatures.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasNumFeatures.java
new file mode 100644
index 0000000..7d8b699
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasNumFeatures.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared num features param. */
+public interface HasNumFeatures<T> extends WithParams<T> {
+ Param<Integer> NUM_FEATURES =
+ new IntParam(
+ "numFeatures",
+ "The number of features. It will be the length of the output vector.",
+ 262144,
+ ParamValidators.gt(0));
+
+ default int getNumFeatures() {
+ return get(NUM_FEATURES);
+ }
+
+ default T setNumFeatures(int value) {
+ set(NUM_FEATURES, value);
+ return (T) this;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java
new file mode 100644
index 0000000..3ad50b7
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.featurehasher;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+import static org.apache.flink.shaded.guava30.com.google.common.hash.Hashing.murmur3_32;
+
+/**
+ * A Transformer that transforms a set of categorical or numerical features into a sparse vector of
+ * a specified dimension. The rules of hashing categorical columns and numerical columns are as
+ * follows:
+ *
+ * <ul>
+ * <li>For numerical columns, the index of this feature in the output vector is the hash value of
+ * the column name and its correponding value is the same as the input.
+ * <li>For categorical columns, the index of this feature in the output vector is the hash value
+ * of the string "column_name=value" and the corresponding value is 1.0.
+ * </ul>
+ *
+ * <p>If multiple features are projected into the same column, the output values are accumulated.
+ * For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for details.
+ */
+public class FeatureHasher
+ implements Transformer<FeatureHasher>, FeatureHasherParams<FeatureHasher> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private static final org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction HASH =
+ murmur3_32(0);
+
+ public FeatureHasher() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ ResolvedSchema tableSchema = inputs[0].getResolvedSchema();
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(tableSchema);
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+ DataStream<Row> output =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new HashFunction(
+ getInputCols(),
+ generateCategoricalCols(
+ tableSchema, getInputCols(), getCategoricalCols()),
+ getNumFeatures()),
+ outputTypeInfo);
+ Table outputTable = tEnv.fromDataStream(output);
+ return new Table[] {outputTable};
+ }
+
+ /**
+ * The main logic for transforming the categorical and numerical features into a sparse vector.
+ * It uses MurMurHash3 to compute the transformed index in the output vector. If multiple
+ * features are projected to the same column, their values are accumulated.
+ */
+ private static class HashFunction implements MapFunction<Row, Row> {
+ private final String[] categoricalCols;
+ private final int numFeatures;
+ private final String[] numericCols;
+
+ public HashFunction(String[] inputCols, String[] categoricalCols, int numFeatures) {
+ this.categoricalCols = categoricalCols;
+ this.numFeatures = numFeatures;
+ this.numericCols = ArrayUtils.removeElements(inputCols, this.categoricalCols);
+ }
+
+ @Override
+ public Row map(Row row) {
+ TreeMap<Integer, Double> feature = new TreeMap<>();
+ for (String col : numericCols) {
+ if (null != row.getField(col)) {
+ double value = ((Number) row.getFieldAs(col)).doubleValue();
+ updateMap(col, value, feature, numFeatures);
+ }
+ }
+ for (String col : categoricalCols) {
+ if (null != row.getField(col)) {
+ updateMap(col + "=" + row.getField(col), 1.0, feature, numFeatures);
+ }
+ }
+ int nnz = feature.size();
+ int[] indices = new int[nnz];
+ double[] values = new double[nnz];
+ int pos = 0;
+ for (Map.Entry<Integer, Double> entry : feature.entrySet()) {
+ indices[pos] = entry.getKey();
+ values[pos] = entry.getValue();
+ pos++;
+ }
+ return Row.join(row, Row.of(new SparseVector(numFeatures, indices, values)));
+ }
+ }
+
+ private String[] generateCategoricalCols(
+ ResolvedSchema tableSchema, String[] inputCols, String[] categoricalCols) {
+ if (null == inputCols) {
+ return categoricalCols;
+ }
+ List<String> categoricalList = Arrays.asList(categoricalCols);
+ List<String> inputList = Arrays.asList(inputCols);
+ if (categoricalCols.length > 0 && !inputList.containsAll(categoricalList)) {
+ throw new IllegalArgumentException("CategoricalCols must be included in inputCols!");
+ }
+ List<DataType> dataColTypes = tableSchema.getColumnDataTypes();
+ List<String> dataColNames = tableSchema.getColumnNames();
+ List<DataType> inputColTypes = new ArrayList<>();
+ for (String col : inputCols) {
+ for (int i = 0; i < dataColNames.size(); ++i) {
+ if (col.equals(dataColNames.get(i))) {
+ inputColTypes.add(dataColTypes.get(i));
+ break;
+ }
+ }
+ }
+ List<String> resultColList = new ArrayList<>();
+ for (int i = 0; i < inputCols.length; i++) {
+ boolean included = categoricalList.contains(inputCols[i]);
+ if (included
+ || DataTypes.BOOLEAN().equals(inputColTypes.get(i))
+ || DataTypes.STRING().equals(inputColTypes.get(i))) {
+ resultColList.add(inputCols[i]);
+ }
+ }
+ return resultColList.toArray(new String[0]);
+ }
+
+ /**
+ * Updates the treeMap which saves the key-value pair of the final vector, use the hash value of
+ * the string as key and the accumulate the corresponding value.
+ *
+ * @param s the string to hash
+ * @param value the accumulated value
+ */
+ private static void updateMap(
+ String s, double value, TreeMap<Integer, Double> feature, int numFeature) {
+ int hashValue = Math.abs(HASH.hashUnencodedChars(s).asInt());
+
+ int index = Math.floorMod(hashValue, numFeature);
+ if (feature.containsKey(index)) {
+ feature.put(index, feature.get(index) + value);
+ } else {
+ feature.put(index, value);
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static FeatureHasher load(StreamTableEnvironment env, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasherParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasherParams.java
new file mode 100644
index 0000000..997d2d6
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasherParams.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.featurehasher;
+
+import org.apache.flink.ml.common.param.HasCategoricalCols;
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasNumFeatures;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params of {@link FeatureHasher}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FeatureHasherParams<T>
+ extends HasInputCols<T>, HasOutputCol<T>, HasCategoricalCols<T>, HasNumFeatures<T> {}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
new file mode 100644
index 0000000..941b3b1
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link FeatureHasher}. */
+public class FeatureHasherTest extends AbstractTestBase {
+
+ private StreamTableEnvironment tEnv;
+ private Table inputDataTable;
+
+ private static final List<Row> INPUT_DATA =
+ Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false));
+
+ private static final SparseVector EXPECTED_OUTPUT_DATA_1 =
+ Vectors.sparse(1000, new int[] {607, 635, 913}, new double[] {1.0, 1.0, 1.0});
+ private static final SparseVector EXPECTED_OUTPUT_DATA_2 =
+ Vectors.sparse(1000, new int[] {242, 869, 913}, new double[] {1.0, 1.0, 1.0});
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+ inputDataTable = tEnv.fromDataStream(dataStream).as("id", "f0", "f1", "f2");
+ }
+
+ private void verifyOutputResult(Table output, String outputCol) throws Exception {
+ DataStream<Row> dataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+ assertEquals(2, results.size());
+ for (Row result : results) {
+ if (result.getField(0) == (Object) 0) {
+ assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
+ } else if (result.getField(0) == (Object) 1) {
+ assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol));
+ } else {
+ throw new RuntimeException("unknown output value.");
+ }
+ }
+ }
+
+ @Test
+ public void testParam() {
+ FeatureHasher featureHasher = new FeatureHasher();
+ assertEquals("output", featureHasher.getOutputCol());
+ assertArrayEquals(new String[] {}, featureHasher.getCategoricalCols());
+ assertEquals(262144, featureHasher.getNumFeatures());
+ featureHasher
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCol("vec")
+ .setCategoricalCols("f0", "f2")
+ .setNumFeatures(1000);
+ assertArrayEquals(new String[] {"f0", "f1", "f2"}, featureHasher.getInputCols());
+ assertEquals("vec", featureHasher.getOutputCol());
+ assertArrayEquals(new String[] {"f0", "f2"}, featureHasher.getCategoricalCols());
+ assertEquals(1000, featureHasher.getNumFeatures());
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ FeatureHasher featureHash =
+ new FeatureHasher()
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCol("vec")
+ .setCategoricalCols("f0", "f2")
+ .setNumFeatures(1000);
+ FeatureHasher loadedFeatureHasher =
+ TestUtils.saveAndReload(
+ tEnv, featureHash, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ Table output = loadedFeatureHasher.transform(inputDataTable)[0];
+ verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
+ }
+
+ @Test
+ public void testCategoricalColsNotSet() throws Exception {
+ FeatureHasher featureHash =
+ new FeatureHasher()
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCol("vec")
+ .setNumFeatures(1000);
+ FeatureHasher loadedFeatureHasher =
+ TestUtils.saveAndReload(
+ tEnv, featureHash, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ Table output = loadedFeatureHasher.transform(inputDataTable)[0];
+ verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
+ }
+
+ @Test
+ public void testInputTypeConversion() throws Exception {
+ inputDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputDataTable);
+ assertArrayEquals(
+ new Class<?>[] {Integer.class, String.class, Integer.class, Boolean.class},
+ TestUtils.getColumnDataTypes(inputDataTable));
+
+ FeatureHasher featureHash =
+ new FeatureHasher()
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCol("vec")
+ .setCategoricalCols("f0", "f2")
+ .setNumFeatures(1000);
+ FeatureHasher loadedFeatureHasher =
+ TestUtils.saveAndReload(
+ tEnv, featureHash, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ Table output = loadedFeatureHasher.transform(inputDataTable)[0];
+ verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py b/flink-ml-python/pyflink/examples/ml/feature/featurehasher_example.py
similarity index 61%
copy from flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
copy to flink-ml-python/pyflink/examples/ml/feature/featurehasher_example.py
index 8f45b17..582429c 100644
--- a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
+++ b/flink-ml-python/pyflink/examples/ml/feature/featurehasher_example.py
@@ -16,7 +16,7 @@
# limitations under the License.
################################################################################
-# Simple program that creates a VectorAssembler instance and uses it for feature
+# Simple program that creates a FeatureHasher instance and uses it for feature
# engineering.
#
# Before executing this program, please make sure you have followed Flink ML's
@@ -27,8 +27,7 @@
from pyflink.common import Types
from pyflink.datastream import StreamExecutionEnvironment
-from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, SparseVectorTypeInfo
-from pyflink.ml.lib.feature.vectorassembler import VectorAssembler
+from pyflink.ml.lib.feature.featurehasher import FeatureHasher
from pyflink.table import StreamTableEnvironment
# create a new StreamExecutionEnvironment
@@ -40,32 +39,28 @@ t_env = StreamTableEnvironment.create(env)
# generate input data
input_data_table = t_env.from_data_stream(
env.from_collection([
- (Vectors.dense(2.1, 3.1),
- 1.0,
- Vectors.sparse(5, [3], [1.0])),
- (Vectors.dense(2.1, 3.1),
- 1.0,
- Vectors.sparse(5, [1, 2, 3, 4],
- [1.0, 2.0, 3.0, 4.0])),
+ (0, 'a', 1.0, True),
+ (1, 'c', 1.0, False),
],
type_info=Types.ROW_NAMED(
- ['vec', 'num', 'sparse_vec'],
- [DenseVectorTypeInfo(), Types.DOUBLE(), SparseVectorTypeInfo()])))
+ ['id', 'f0', 'f1', 'f2'],
+ [Types.INT(), Types.STRING(), Types.DOUBLE(), Types.BOOLEAN()])))
-# create a vector assembler object and initialize its parameters
-vector_assembler = VectorAssembler() \
- .set_input_cols('vec', 'num', 'sparse_vec') \
- .set_output_col('assembled_vec') \
- .set_handle_invalid('keep')
+# create a feature hasher object and initialize its parameters
+feature_hasher = FeatureHasher() \
+ .set_input_cols('f0', 'f1', 'f2') \
+ .set_categorical_cols('f0', 'f2') \
+ .set_output_col('vec') \
+ .set_num_features(1000)
-# use the vector assembler model for feature engineering
-output = vector_assembler.transform(input_data_table)[0]
+# use the feature hasher for feature engineering
+output = feature_hasher.transform(input_data_table)[0]
# extract and display the results
field_names = output.get_schema().get_field_names()
-input_values = [None for _ in vector_assembler.get_input_cols()]
+input_values = [None for _ in feature_hasher.get_input_cols()]
for result in t_env.to_data_stream(output).execute_and_collect():
- for i in range(len(vector_assembler.get_input_cols())):
- input_values[i] = result[field_names.index(vector_assembler.get_input_cols()[i])]
- output_value = result[field_names.index(vector_assembler.get_output_col())]
+ for i in range(len(feature_hasher.get_input_cols())):
+ input_values[i] = result[field_names.index(feature_hasher.get_input_cols()[i])]
+ output_value = result[field_names.index(feature_hasher.get_output_col())]
print('Input Values: ' + str(input_values) + '\tOutput Value: ' + str(output_value))
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
index 8f45b17..7ae15ce 100644
--- a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
+++ b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
@@ -58,7 +58,7 @@ vector_assembler = VectorAssembler() \
.set_output_col('assembled_vec') \
.set_handle_invalid('keep')
-# use the vector assembler model for feature engineering
+# use the vector assembler for feature engineering
output = vector_assembler.transform(input_data_table)[0]
# extract and display the results
diff --git a/flink-ml-python/pyflink/ml/lib/feature/featurehasher.py b/flink-ml-python/pyflink/ml/lib/feature/featurehasher.py
new file mode 100644
index 0000000..150cf1f
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/featurehasher.py
@@ -0,0 +1,67 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+from pyflink.ml.lib.param import HasInputCols, HasOutputCol, HasCategoricalCols, HasNumFeatures
+
+
+class _FeatureHasherParams(
+ JavaWithParams,
+ HasInputCols,
+ HasCategoricalCols,
+ HasOutputCol,
+ HasNumFeatures
+):
+ """
+ Params for :class:`FeatureHasher`.
+ """
+
+ def __init__(self, java_params):
+ super(_FeatureHasherParams, self).__init__(java_params)
+
+
+class FeatureHasher(JavaFeatureTransformer, _FeatureHasherParams):
+ """
+ A Transformer that transforms a set of categorical or numerical features into
+ a sparse vector of a specified dimension. The rules of hashing categorical
+ columns and numerical columns are as follows:
+
+ For numerical columns, the index of this feature in the output vector is the
+ hash value of the column name and its correponding value is the same as the
+ input.
+
+ For categorical columns, the index of this feature in the output vector is
+ the hash value of the string "column_name=value" and the corresponding
+ value is 1.0.
+
+ If multiple features are projected into the same column, the output values
+ are accumulated. For the hashing trick, see
+ https://en.wikipedia.org/wiki/Feature_hashing for details.
+ """
+
+ def __init__(self, java_model=None):
+ super(FeatureHasher, self).__init__(java_model)
+
+ @classmethod
+ def _java_transformer_package_name(cls) -> str:
+ return "featurehasher"
+
+ @classmethod
+ def _java_transformer_class_name(cls) -> str:
+ return "FeatureHasher"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_feature_hasher.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_feature_hasher.py
new file mode 100644
index 0000000..7f32128
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_feature_hasher.py
@@ -0,0 +1,77 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors
+from pyflink.ml.lib.feature.featurehasher import FeatureHasher
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class FeatureHasherTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(FeatureHasherTest, self).setUp()
+ self.input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (0, 'a', 1.0, True),
+ (1, 'c', 1.0, False)
+ ],
+ type_info=Types.ROW_NAMED(
+ ['id', 'f0', 'f1', 'f2'],
+ [Types.INT(), Types.STRING(), Types.DOUBLE(), Types.BOOLEAN()])))
+
+ self.expected_output_data_1 = Vectors.sparse(1000, [607, 635, 913], [1.0, 1.0, 1.0])
+ self.expected_output_data_2 = Vectors.sparse(1000, [242, 869, 913], [1.0, 1.0, 1.0])
+
+ def test_param(self):
+ feature_hasher = FeatureHasher()
+
+ self.assertEqual('output', feature_hasher.output_col)
+ self.assertEqual(262144, feature_hasher.num_features)
+
+ feature_hasher.set_input_cols('f0', 'f1', 'f2') \
+ .set_categorical_cols('f0', 'f2') \
+ .set_output_col('vec') \
+ .set_num_features(1000)
+
+ self.assertEqual(('f0', 'f1', 'f2'), feature_hasher.input_cols)
+ self.assertEqual(('f0', 'f2'), feature_hasher.categorical_cols)
+ self.assertEqual(1000, feature_hasher.num_features)
+ self.assertEqual('vec', feature_hasher.output_col)
+
+ def test_save_load_transform(self):
+ feature_hasher = FeatureHasher() \
+ .set_input_cols('f0', 'f1', 'f2') \
+ .set_categorical_cols('f0', 'f2') \
+ .set_output_col('vec') \
+ .set_num_features(1000)
+
+ path = os.path.join(self.temp_dir, 'test_save_load_transform_feature_hasher')
+ feature_hasher.save(path)
+ feature_hasher = FeatureHasher.load(self.t_env, path)
+
+ output_table = feature_hasher.transform(self.input_data_table)[0]
+ actual_outputs = [(result[0], result[4]) for result in
+ self.t_env.to_data_stream(output_table).execute_and_collect()]
+ self.assertEqual(2, len(actual_outputs))
+ for actual_output in actual_outputs:
+ if actual_output[0] == 0:
+ self.assertEqual(self.expected_output_data_1, actual_output[1])
+ else:
+ self.assertEqual(self.expected_output_data_2, actual_output[1])
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
index 64cd5ac..39c6277 100644
--- a/flink-ml-python/pyflink/ml/lib/param.py
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -155,6 +155,48 @@ class HasInputCols(WithParams, ABC):
return self.get_input_cols()
+class HasCategoricalCols(WithParams, ABC):
+ """
+ Base class for the shared categorical cols param.
+ """
+ CATEGORICAL_COLS: Param[Tuple[str, ...]] = StringArrayParam(
+ "categorical_cols",
+ "Categorical column names.",
+ [],
+ ParamValidators.not_null())
+
+ def set_categorical_cols(self, *cols: str):
+ return self.set(self.CATEGORICAL_COLS, cols)
+
+ def get_categorical_cols(self) -> Tuple[str, ...]:
+ return self.get(self.CATEGORICAL_COLS)
+
+ @property
+ def categorical_cols(self) -> Tuple[str, ...]:
+ return self.get_categorical_cols()
+
+
+class HasNumFeatures(WithParams, ABC):
+ """
+ Base class for the shared numFeatures param.
+ """
+ NUM_FEATURES: Param[int] = IntParam(
+ "num_features",
+ "Number of features.",
+ 262144,
+ ParamValidators.gt(0))
+
+ def set_num_features(self, num_features: int):
+ return self.set(self.NUM_FEATURES, num_features)
+
+ def get_num_features(self) -> int:
+ return self.get(self.NUM_FEATURES)
+
+ @property
+ def num_features(self) -> int:
+ return self.get_num_features()
+
+
class HasLabelCol(WithParams, ABC):
"""
Base class for the shared label column param.