You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/06/23 06:38:23 UTC

[flink-ml] branch master updated: [FLINK-27096] Optimize VectorAssembler performance

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new ce3e9cb  [FLINK-27096] Optimize VectorAssembler performance
ce3e9cb is described below

commit ce3e9cba4da873dd7b00f4f0f3deb0a4cdb72e0c
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Thu Jun 23 14:38:19 2022 +0800

    [FLINK-27096] Optimize VectorAssembler performance
    
    This closes #114.
---
 .../feature/vectorassembler/VectorAssembler.java   | 132 +++++++++++++--------
 1 file changed, 82 insertions(+), 50 deletions(-)

diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
index ac63eaf..c14e44e 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
@@ -26,6 +26,7 @@ import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.SparseVector;
 import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
@@ -42,7 +43,6 @@ import org.apache.commons.lang3.ArrayUtils;
 
 import java.io.IOException;
 import java.util.HashMap;
-import java.util.LinkedHashMap;
 import java.util.Map;
 
 /**
@@ -90,14 +90,27 @@ public class VectorAssembler
         }
 
         @Override
-        public void flatMap(Row value, Collector<Row> out) throws Exception {
+        public void flatMap(Row value, Collector<Row> out) {
+            int nnz = 0;
+            int vectorSize = 0;
             try {
-                Object[] objects = new Object[inputCols.length];
-                for (int i = 0; i < objects.length; ++i) {
-                    objects[i] = value.getField(inputCols[i]);
+                for (String inputCol : inputCols) {
+                    Object object = value.getField(inputCol);
+                    Preconditions.checkNotNull(object, "Input column value should not be null.");
+                    if (object instanceof Number) {
+                        nnz += 1;
+                        vectorSize += 1;
+                    } else if (object instanceof SparseVector) {
+                        nnz += ((SparseVector) object).indices.length;
+                        vectorSize += ((SparseVector) object).size();
+                    } else if (object instanceof DenseVector) {
+                        nnz += ((DenseVector) object).size();
+                        vectorSize += ((DenseVector) object).size();
+                    } else {
+                        throw new IllegalArgumentException(
+                                "Input type has not been supported yet.");
+                    }
                 }
-                Vector assembledVector = assemble(objects);
-                out.collect(Row.join(value, Row.of(assembledVector)));
             } catch (Exception e) {
                 switch (handleInvalid) {
                     case ERROR_INVALID:
@@ -112,6 +125,13 @@ public class VectorAssembler
                                 "Unsupported " + HANDLE_INVALID + " type: " + handleInvalid);
                 }
             }
+
+            boolean toDense = nnz * RATIO > vectorSize;
+            Vector assembledVec =
+                    toDense
+                            ? assembleDense(inputCols, value, vectorSize)
+                            : assembleSparse(inputCols, value, vectorSize, nnz);
+            out.collect(Row.join(value, Row.of(assembledVec)));
         }
     }
 
@@ -129,57 +149,69 @@ public class VectorAssembler
         return paramMap;
     }
 
-    private static Vector assemble(Object[] objects) {
-        int offset = 0;
-        Map<Integer, Double> map = new LinkedHashMap<>(objects.length);
-        for (Object object : objects) {
-            Preconditions.checkNotNull(object, "Input column value should not be null.");
+    /** Assembles the input columns into a dense vector. */
+    private static Vector assembleDense(String[] inputCols, Row inputRow, int vectorSize) {
+        double[] values = new double[vectorSize];
+        int currentOffset = 0;
+
+        for (String inputCol : inputCols) {
+            Object object = inputRow.getField(inputCol);
             if (object instanceof Number) {
-                map.put(offset++, ((Number) object).doubleValue());
-            } else if (object instanceof Vector) {
-                offset = appendVector((Vector) object, map, offset);
+                values[currentOffset++] = ((Number) object).doubleValue();
+            } else if (object instanceof SparseVector) {
+                SparseVector sparseVector = (SparseVector) object;
+                for (int i = 0; i < sparseVector.indices.length; i++) {
+                    values[currentOffset + sparseVector.indices[i]] = sparseVector.values[i];
+                }
+                currentOffset += sparseVector.size();
+
             } else {
-                throw new IllegalArgumentException("Input type has not been supported yet.");
-            }
-        }
+                DenseVector denseVector = (DenseVector) object;
+                System.arraycopy(
+                        denseVector.values, 0, values, currentOffset, denseVector.values.length);
 
-        if (map.size() * RATIO > offset) {
-            DenseVector assembledVector = new DenseVector(offset);
-            for (int key : map.keySet()) {
-                assembledVector.values[key] = map.get(key);
+                currentOffset += denseVector.size();
             }
-            return assembledVector;
-        } else {
-            return convertMapToSparseVector(offset, map);
         }
+        return Vectors.dense(values);
     }
 
-    private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) {
-        if (vec instanceof SparseVector) {
-            SparseVector sparseVector = (SparseVector) vec;
-            int[] indices = sparseVector.indices;
-            double[] values = sparseVector.values;
-            for (int i = 0; i < indices.length; ++i) {
-                map.put(offset + indices[i], values[i]);
-            }
-            offset += sparseVector.size();
-        } else {
-            DenseVector denseVector = (DenseVector) vec;
-            for (int i = 0; i < denseVector.size(); ++i) {
-                map.put(offset++, denseVector.values[i]);
-            }
-        }
-        return offset;
-    }
+    /** Assembles the input columns into a sparse vector. */
+    private static Vector assembleSparse(
+            String[] inputCols, Row inputRow, int vectorSize, int nnz) {
+        int[] indices = new int[nnz];
+        double[] values = new double[nnz];
 
-    private static SparseVector convertMapToSparseVector(int size, Map<Integer, Double> map) {
-        int[] indices = new int[map.size()];
-        double[] values = new double[map.size()];
-        int offset = 0;
-        for (Map.Entry<Integer, Double> entry : map.entrySet()) {
-            indices[offset] = entry.getKey();
-            values[offset++] = entry.getValue();
+        int currentIndex = 0;
+        int currentOffset = 0;
+
+        for (String inputCol : inputCols) {
+            Object object = inputRow.getField(inputCol);
+            if (object instanceof Number) {
+                indices[currentOffset] = currentIndex;
+                values[currentOffset] = ((Number) object).doubleValue();
+                currentOffset++;
+                currentIndex++;
+            } else if (object instanceof SparseVector) {
+                SparseVector sparseVector = (SparseVector) object;
+                for (int i = 0; i < sparseVector.indices.length; i++) {
+                    indices[currentOffset + i] = sparseVector.indices[i] + currentIndex;
+                }
+                System.arraycopy(
+                        sparseVector.values, 0, values, currentOffset, sparseVector.values.length);
+                currentIndex += sparseVector.size();
+                currentOffset += sparseVector.indices.length;
+            } else {
+                DenseVector denseVector = (DenseVector) object;
+                for (int i = 0; i < denseVector.size(); ++i) {
+                    indices[currentOffset + i] = i + currentIndex;
+                }
+                System.arraycopy(
+                        denseVector.values, 0, values, currentOffset, denseVector.values.length);
+                currentIndex += denseVector.size();
+                currentOffset += denseVector.size();
+            }
         }
-        return new SparseVector(size, indices, values);
+        return new SparseVector(vectorSize, indices, values);
     }
 }