You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/06/23 06:38:23 UTC
[flink-ml] branch master updated: [FLINK-27096] Optimize VectorAssembler performance
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new ce3e9cb [FLINK-27096] Optimize VectorAssembler performance
ce3e9cb is described below
commit ce3e9cba4da873dd7b00f4f0f3deb0a4cdb72e0c
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Thu Jun 23 14:38:19 2022 +0800
[FLINK-27096] Optimize VectorAssembler performance
This closes #114.
---
.../feature/vectorassembler/VectorAssembler.java | 132 +++++++++++++--------
1 file changed, 82 insertions(+), 50 deletions(-)
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
index ac63eaf..c14e44e 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
@@ -26,6 +26,7 @@ import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
@@ -42,7 +43,6 @@ import org.apache.commons.lang3.ArrayUtils;
import java.io.IOException;
import java.util.HashMap;
-import java.util.LinkedHashMap;
import java.util.Map;
/**
@@ -90,14 +90,27 @@ public class VectorAssembler
}
@Override
- public void flatMap(Row value, Collector<Row> out) throws Exception {
+ public void flatMap(Row value, Collector<Row> out) {
+ int nnz = 0;
+ int vectorSize = 0;
try {
- Object[] objects = new Object[inputCols.length];
- for (int i = 0; i < objects.length; ++i) {
- objects[i] = value.getField(inputCols[i]);
+ for (String inputCol : inputCols) {
+ Object object = value.getField(inputCol);
+ Preconditions.checkNotNull(object, "Input column value should not be null.");
+ if (object instanceof Number) {
+ nnz += 1;
+ vectorSize += 1;
+ } else if (object instanceof SparseVector) {
+ nnz += ((SparseVector) object).indices.length;
+ vectorSize += ((SparseVector) object).size();
+ } else if (object instanceof DenseVector) {
+ nnz += ((DenseVector) object).size();
+ vectorSize += ((DenseVector) object).size();
+ } else {
+ throw new IllegalArgumentException(
+ "Input type has not been supported yet.");
+ }
}
- Vector assembledVector = assemble(objects);
- out.collect(Row.join(value, Row.of(assembledVector)));
} catch (Exception e) {
switch (handleInvalid) {
case ERROR_INVALID:
@@ -112,6 +125,13 @@ public class VectorAssembler
"Unsupported " + HANDLE_INVALID + " type: " + handleInvalid);
}
}
+
+ boolean toDense = nnz * RATIO > vectorSize;
+ Vector assembledVec =
+ toDense
+ ? assembleDense(inputCols, value, vectorSize)
+ : assembleSparse(inputCols, value, vectorSize, nnz);
+ out.collect(Row.join(value, Row.of(assembledVec)));
}
}
@@ -129,57 +149,69 @@ public class VectorAssembler
return paramMap;
}
- private static Vector assemble(Object[] objects) {
- int offset = 0;
- Map<Integer, Double> map = new LinkedHashMap<>(objects.length);
- for (Object object : objects) {
- Preconditions.checkNotNull(object, "Input column value should not be null.");
+ /** Assembles the input columns into a dense vector. */
+ private static Vector assembleDense(String[] inputCols, Row inputRow, int vectorSize) {
+ double[] values = new double[vectorSize];
+ int currentOffset = 0;
+
+ for (String inputCol : inputCols) {
+ Object object = inputRow.getField(inputCol);
if (object instanceof Number) {
- map.put(offset++, ((Number) object).doubleValue());
- } else if (object instanceof Vector) {
- offset = appendVector((Vector) object, map, offset);
+ values[currentOffset++] = ((Number) object).doubleValue();
+ } else if (object instanceof SparseVector) {
+ SparseVector sparseVector = (SparseVector) object;
+ for (int i = 0; i < sparseVector.indices.length; i++) {
+ values[currentOffset + sparseVector.indices[i]] = sparseVector.values[i];
+ }
+ currentOffset += sparseVector.size();
+
} else {
- throw new IllegalArgumentException("Input type has not been supported yet.");
- }
- }
+ DenseVector denseVector = (DenseVector) object;
+ System.arraycopy(
+ denseVector.values, 0, values, currentOffset, denseVector.values.length);
- if (map.size() * RATIO > offset) {
- DenseVector assembledVector = new DenseVector(offset);
- for (int key : map.keySet()) {
- assembledVector.values[key] = map.get(key);
+ currentOffset += denseVector.size();
}
- return assembledVector;
- } else {
- return convertMapToSparseVector(offset, map);
}
+ return Vectors.dense(values);
}
- private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) {
- if (vec instanceof SparseVector) {
- SparseVector sparseVector = (SparseVector) vec;
- int[] indices = sparseVector.indices;
- double[] values = sparseVector.values;
- for (int i = 0; i < indices.length; ++i) {
- map.put(offset + indices[i], values[i]);
- }
- offset += sparseVector.size();
- } else {
- DenseVector denseVector = (DenseVector) vec;
- for (int i = 0; i < denseVector.size(); ++i) {
- map.put(offset++, denseVector.values[i]);
- }
- }
- return offset;
- }
+ /** Assembles the input columns into a sparse vector. */
+ private static Vector assembleSparse(
+ String[] inputCols, Row inputRow, int vectorSize, int nnz) {
+ int[] indices = new int[nnz];
+ double[] values = new double[nnz];
- private static SparseVector convertMapToSparseVector(int size, Map<Integer, Double> map) {
- int[] indices = new int[map.size()];
- double[] values = new double[map.size()];
- int offset = 0;
- for (Map.Entry<Integer, Double> entry : map.entrySet()) {
- indices[offset] = entry.getKey();
- values[offset++] = entry.getValue();
+ int currentIndex = 0;
+ int currentOffset = 0;
+
+ for (String inputCol : inputCols) {
+ Object object = inputRow.getField(inputCol);
+ if (object instanceof Number) {
+ indices[currentOffset] = currentIndex;
+ values[currentOffset] = ((Number) object).doubleValue();
+ currentOffset++;
+ currentIndex++;
+ } else if (object instanceof SparseVector) {
+ SparseVector sparseVector = (SparseVector) object;
+ for (int i = 0; i < sparseVector.indices.length; i++) {
+ indices[currentOffset + i] = sparseVector.indices[i] + currentIndex;
+ }
+ System.arraycopy(
+ sparseVector.values, 0, values, currentOffset, sparseVector.values.length);
+ currentIndex += sparseVector.size();
+ currentOffset += sparseVector.indices.length;
+ } else {
+ DenseVector denseVector = (DenseVector) object;
+ for (int i = 0; i < denseVector.size(); ++i) {
+ indices[currentOffset + i] = i + currentIndex;
+ }
+ System.arraycopy(
+ denseVector.values, 0, values, currentOffset, denseVector.values.length);
+ currentIndex += denseVector.size();
+ currentOffset += denseVector.size();
+ }
}
- return new SparseVector(size, indices, values);
+ return new SparseVector(vectorSize, indices, values);
}
}