You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/06/20 03:43:19 UTC

[flink-ml] 02/04: [FLINK-27877] Reduce the length of the operator chain for generating input table

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 341df450831e4c426ff4f8049af8dc52fc0bb598
Author: zhangzp <zh...@gmail.com>
AuthorDate: Mon Jun 20 09:50:59 2022 +0800

    [FLINK-27877] Reduce the length of the operator chain for generating input table
---
 .../common/DenseVectorArrayGenerator.java          | 114 +++++-------------
 .../datagenerator/common/DenseVectorGenerator.java | 103 +++++-----------
 .../datagenerator/common/InputTableGenerator.java  |  66 ++++++++++
 .../common/LabeledPointWithWeightGenerator.java    | 134 +++++++--------------
 .../datagenerator/common/RowGenerator.java         |  77 ++++++++++++
 .../flink/ml/benchmark/DataGeneratorTest.java      |  45 ++++---
 6 files changed, 275 insertions(+), 264 deletions(-)

diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java
index 0f8b82f..c1b3a21 100644
--- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java
@@ -18,100 +18,50 @@
 
 package org.apache.flink.ml.benchmark.datagenerator.common;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize;
 import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
-import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.DataTypes;
-import org.apache.flink.table.api.Schema;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.util.NumberSequenceIterator;
+import org.apache.flink.types.Row;
 import org.apache.flink.util.Preconditions;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
 /** A DataGenerator which creates a table of DenseVector array. */
-public class DenseVectorArrayGenerator
-        implements InputDataGenerator<DenseVectorArrayGenerator>,
-                HasArraySize<DenseVectorArrayGenerator>,
+public class DenseVectorArrayGenerator extends InputTableGenerator<DenseVectorArrayGenerator>
+        implements HasArraySize<DenseVectorArrayGenerator>,
                 HasVectorDim<DenseVectorArrayGenerator> {
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    public DenseVectorArrayGenerator() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
 
     @Override
-    public Table[] getData(StreamTableEnvironment tEnv) {
-        StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
-
-        DataStream<DenseVector[]> dataStream =
-                env.fromParallelCollection(
-                                new NumberSequenceIterator(1L, getNumValues()),
-                                BasicTypeInfo.LONG_TYPE_INFO)
-                        .map(
-                                new GenerateRandomContinuousVectorArrayFunction(
-                                        getSeed(), getVectorDim(), getArraySize()));
-
-        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector[].class)).build();
-        Table dataTable = tEnv.fromDataStream(dataStream, schema);
-        if (getColNames() != null) {
-            Preconditions.checkState(getColNames().length == 1);
-            Preconditions.checkState(getColNames()[0].length == 1);
-            dataTable = dataTable.as(getColNames()[0][0]);
-        }
-
-        return new Table[] {dataTable};
-    }
+    protected RowGenerator[] getRowGenerators() {
+        String[][] columnNames = getColNames();
+        Preconditions.checkState(columnNames.length == 1);
+        Preconditions.checkState(columnNames[0].length == 1);
+        int arraySize = getArraySize();
+        int vectorDim = getVectorDim();
 
-    private static class GenerateRandomContinuousVectorArrayFunction
-            extends RichMapFunction<Long, DenseVector[]> {
-        private final int vectorDim;
-        private final long initSeed;
-        private final int arraySize;
-        private Random random;
-
-        private GenerateRandomContinuousVectorArrayFunction(
-                long initSeed, int vectorDim, int arraySize) {
-            this.vectorDim = vectorDim;
-            this.initSeed = initSeed;
-            this.arraySize = arraySize;
-        }
-
-        @Override
-        public void open(Configuration parameters) throws Exception {
-            super.open(parameters);
-            int index = getRuntimeContext().getIndexOfThisSubtask();
-            random = new Random(Tuple2.of(initSeed, index).hashCode());
-        }
+        return new RowGenerator[] {
+            new RowGenerator(getNumValues(), getSeed()) {
+                @Override
+                protected Row nextRow() {
+                    DenseVector[] result = new DenseVector[arraySize];
+                    for (int i = 0; i < arraySize; i++) {
+                        result[i] = new DenseVector(vectorDim);
+                        for (int j = 0; j < vectorDim; j++) {
+                            result[i].values[j] = random.nextDouble();
+                        }
+                    }
+                    Row row = new Row(1);
+                    row.setField(0, result);
+                    return row;
+                }
 
-        @Override
-        public DenseVector[] map(Long value) {
-            DenseVector[] result = new DenseVector[arraySize];
-            for (int i = 0; i < arraySize; i++) {
-                result[i] = new DenseVector(vectorDim);
-                for (int j = 0; j < vectorDim; j++) {
-                    result[i].values[j] = random.nextDouble();
+                @Override
+                protected RowTypeInfo getRowTypeInfo() {
+                    return new RowTypeInfo(
+                            new TypeInformation[] {TypeInformation.of(DenseVector[].class)},
+                            columnNames[0]);
                 }
             }
-            return result;
-        }
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
+        };
     }
 }
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java
index 4117261..10eae84 100644
--- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java
@@ -18,86 +18,43 @@
 
 package org.apache.flink.ml.benchmark.datagenerator.common;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
-import org.apache.flink.ml.common.datastream.TableUtils;
-import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.util.NumberSequenceIterator;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.types.Row;
 import org.apache.flink.util.Preconditions;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
 /** A DataGenerator which creates a table of DenseVector. */
-public class DenseVectorGenerator
-        implements InputDataGenerator<DenseVectorGenerator>, HasVectorDim<DenseVectorGenerator> {
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    public DenseVectorGenerator() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
+public class DenseVectorGenerator extends InputTableGenerator<DenseVectorGenerator>
+        implements HasVectorDim<DenseVectorGenerator> {
 
     @Override
-    public Table[] getData(StreamTableEnvironment tEnv) {
-        StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
-
-        DataStream<DenseVector> dataStream =
-                env.fromParallelCollection(
-                                new NumberSequenceIterator(1L, getNumValues()),
-                                BasicTypeInfo.LONG_TYPE_INFO)
-                        .map(new RandomDenseVectorGenerator(getSeed(), getVectorDim()));
-
-        Table dataTable = tEnv.fromDataStream(dataStream);
-        if (getColNames() != null) {
-            Preconditions.checkState(getColNames().length == 1);
-            Preconditions.checkState(getColNames()[0].length == 1);
-            dataTable = dataTable.as(getColNames()[0][0]);
-        }
-
-        return new Table[] {dataTable};
-    }
-
-    private static class RandomDenseVectorGenerator extends RichMapFunction<Long, DenseVector> {
-        private final int vectorDim;
-        private final long initSeed;
-        private Random random;
-
-        private RandomDenseVectorGenerator(long initSeed, int vectorDim) {
-            this.vectorDim = vectorDim;
-            this.initSeed = initSeed;
-        }
-
-        @Override
-        public void open(Configuration parameters) throws Exception {
-            super.open(parameters);
-            int index = getRuntimeContext().getIndexOfThisSubtask();
-            random = new Random(Tuple2.of(initSeed, index).hashCode());
-        }
-
-        @Override
-        public DenseVector map(Long value) {
-            double[] values = new double[vectorDim];
-            for (int i = 0; i < vectorDim; i++) {
-                values[i] = random.nextDouble();
+    public RowGenerator[] getRowGenerators() {
+        String[][] columnNames = getColNames();
+        Preconditions.checkState(columnNames.length == 1);
+        Preconditions.checkState(columnNames[0].length == 1);
+        int vectorDim = getVectorDim();
+
+        return new RowGenerator[] {
+            new RowGenerator(getNumValues(), getSeed()) {
+
+                @Override
+                protected Row nextRow() {
+                    double[] values = new double[vectorDim];
+                    for (int i = 0; i < values.length; i++) {
+                        values[i] = random.nextDouble();
+                    }
+                    return Row.of(Vectors.dense(values));
+                }
+
+                @Override
+                protected RowTypeInfo getRowTypeInfo() {
+                    return new RowTypeInfo(
+                            new TypeInformation[] {DenseVectorTypeInfo.INSTANCE}, columnNames[0]);
+                }
             }
-            return Vectors.dense(values);
-        }
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
+        };
     }
 }
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java
new file mode 100644
index 0000000..dd673a7
--- /dev/null
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.benchmark.datagenerator.common;
+
+import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Base class for generating data as input table arrays. */
+public abstract class InputTableGenerator<T extends InputTableGenerator<T>>
+        implements InputDataGenerator<T> {
+    protected final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public InputTableGenerator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public final Table[] getData(StreamTableEnvironment tEnv) {
+        StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
+
+        RowGenerator[] rowGenerators = getRowGenerators();
+        Table[] dataTables = new Table[rowGenerators.length];
+        for (int i = 0; i < rowGenerators.length; i++) {
+            DataStream<Row> dataStream =
+                    env.addSource(rowGenerators[i], "sourceOp-" + i)
+                            .returns(rowGenerators[i].getRowTypeInfo());
+            dataTables[i] = tEnv.fromDataStream(dataStream);
+        }
+
+        return dataTables;
+    }
+
+    /** Gets generators for all input tables. */
+    protected abstract RowGenerator[] getRowGenerators();
+
+    @Override
+    public final Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java
index 0e11071..dff9f07 100644
--- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java
@@ -18,34 +18,19 @@
 
 package org.apache.flink.ml.benchmark.datagenerator.common;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
 import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
-import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
 import org.apache.flink.ml.param.IntParam;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.param.ParamValidators;
 import org.apache.flink.ml.util.ParamUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.NumberSequenceIterator;
 import org.apache.flink.util.Preconditions;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
 /**
  * A DataGenerator which creates a table of features, label and weight.
  *
@@ -58,8 +43,8 @@ import java.util.Random;
  * </ul>
  */
 public class LabeledPointWithWeightGenerator
-        implements InputDataGenerator<LabeledPointWithWeightGenerator>,
-                HasVectorDim<LabeledPointWithWeightGenerator> {
+        extends InputTableGenerator<LabeledPointWithWeightGenerator>
+        implements HasVectorDim<LabeledPointWithWeightGenerator> {
 
     public static final Param<Integer> FEATURE_ARITY =
             new IntParam(
@@ -79,8 +64,6 @@ public class LabeledPointWithWeightGenerator
                     2,
                     ParamValidators.gtEq(0));
 
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
     public LabeledPointWithWeightGenerator() {
         ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
@@ -102,79 +85,46 @@ public class LabeledPointWithWeightGenerator
     }
 
     @Override
-    public Table[] getData(StreamTableEnvironment tEnv) {
-        Preconditions.checkState(getColNames().length == 1);
-        Preconditions.checkState(getColNames()[0].length == 3);
-
-        StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
-
-        DataStream<Row> dataStream =
-                env.fromParallelCollection(
-                                new NumberSequenceIterator(1L, getNumValues()),
-                                BasicTypeInfo.LONG_TYPE_INFO)
-                        .map(
-                                new RandomLabeledPointWithWeightGenerator(
-                                        getSeed(),
-                                        getVectorDim(),
-                                        getFeatureArity(),
-                                        getLabelArity()),
-                                new RowTypeInfo(
-                                        new TypeInformation[] {
-                                            DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE
-                                        },
-                                        getColNames()[0]));
-
-        Table dataTable = tEnv.fromDataStream(dataStream);
-
-        return new Table[] {dataTable};
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
-
-    private static class RandomLabeledPointWithWeightGenerator extends RichMapFunction<Long, Row> {
-        private final long initSeed;
-        private final int vectorDim;
-        private final int featureArity;
-        private final int labelArity;
-        private Random random;
-
-        private RandomLabeledPointWithWeightGenerator(
-                long initSeed, int vectorDim, int featureArity, int labelArity) {
-            this.initSeed = initSeed;
-            this.vectorDim = vectorDim;
-            this.featureArity = featureArity;
-            this.labelArity = labelArity;
-        }
-
-        @Override
-        public void open(Configuration parameters) throws Exception {
-            super.open(parameters);
-            int index = getRuntimeContext().getIndexOfThisSubtask();
-            random = new Random(Tuple2.of(initSeed, index).hashCode());
-        }
-
-        @Override
-        public Row map(Long ignored) {
-            double[] features = new double[vectorDim];
-            for (int i = 0; i < vectorDim; i++) {
-                features[i] = getValue(featureArity);
-            }
-
-            double label = getValue(labelArity);
-
-            double weight = random.nextDouble();
-
-            return Row.of(Vectors.dense(features), label, weight);
-        }
-
-        private double getValue(int arity) {
-            if (arity > 0) {
-                return random.nextInt(arity);
+    protected RowGenerator[] getRowGenerators() {
+        String[][] colNames = getColNames();
+        Preconditions.checkState(colNames.length == 1);
+        Preconditions.checkState(colNames[0].length == 3);
+        int vectorDim = getVectorDim();
+        int featureArity = getFeatureArity();
+        int labelArity = getLabelArity();
+
+        return new RowGenerator[] {
+            new RowGenerator(getNumValues(), getSeed()) {
+                @Override
+                protected Row nextRow() {
+                    double[] features = new double[vectorDim];
+                    for (int i = 0; i < vectorDim; i++) {
+                        features[i] = getValue(featureArity);
+                    }
+
+                    double label = getValue(labelArity);
+
+                    double weight = random.nextDouble();
+
+                    return Row.of(Vectors.dense(features), label, weight);
+                }
+
+                @Override
+                protected RowTypeInfo getRowTypeInfo() {
+                    return new RowTypeInfo(
+                            new TypeInformation[] {
+                                DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE
+                            },
+                            colNames[0]);
+                }
+
+                private double getValue(int arity) {
+                    if (arity > 0) {
+                        return random.nextInt(arity);
+                    }
+                    return random.nextDouble();
+                }
             }
-            return random.nextDouble();
-        }
+        };
     }
 }
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java
new file mode 100644
index 0000000..55fe526
--- /dev/null
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.benchmark.datagenerator.common;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.types.Row;
+
+import java.util.Random;
+
+/** A parallel source to generate user defined rows. */
+public abstract class RowGenerator extends RichParallelSourceFunction<Row> {
+    /** Random instance to generate data. */
+    protected Random random;
+    /** Number of values to generate in total. */
+    private final long numValues;
+    /** The init seed to generate data. */
+    private final long initSeed;
+    /** Number of tasks to generate in one local task. */
+    private long numValuesOnThisTask;
+    /** Whether this source is still running. */
+    private volatile boolean isRunning = true;
+
+    public RowGenerator(long numValues, long initSeed) {
+        this.numValues = numValues;
+        this.initSeed = initSeed;
+    }
+
+    @Override
+    public final void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        int taskIdx = getRuntimeContext().getIndexOfThisSubtask();
+        int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+        random = new Random(Tuple2.of(initSeed, taskIdx).hashCode());
+        long div = numValues / numTasks;
+        long mod = numValues % numTasks;
+        numValuesOnThisTask = mod > taskIdx ? div + 1 : div;
+    }
+
+    @Override
+    public final void run(SourceContext<Row> ctx) throws Exception {
+        long cnt = 0;
+        while (isRunning && cnt < numValuesOnThisTask) {
+            ctx.collect(nextRow());
+            cnt++;
+        }
+    }
+
+    @Override
+    public final void cancel() {
+        isRunning = false;
+    }
+
+    /** Generates a new data point. */
+    protected abstract Row nextRow();
+
+    /** Returns the output type information for this generator. */
+    protected abstract RowTypeInfo getRowTypeInfo();
+}
diff --git a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java
index 937d25b..7d2883a 100644
--- a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java
+++ b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java
@@ -18,17 +18,22 @@
 
 package org.apache.flink.ml.benchmark;
 
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator;
 import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorGenerator;
 import org.apache.flink.ml.benchmark.datagenerator.common.LabeledPointWithWeightGenerator;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.types.Row;
 import org.apache.flink.util.CloseableIterator;
 
+import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
@@ -36,11 +41,22 @@ import static org.junit.Assert.assertTrue;
 
 /** Tests data generators. */
 public class DataGeneratorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+    }
+
     @Test
     public void testDenseVectorGenerator() {
-        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
         DenseVectorGenerator generator =
                 new DenseVectorGenerator()
                         .setColNames(new String[] {"denseVector"})
@@ -51,7 +67,7 @@ public class DataGeneratorTest {
         for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect();
                 it.hasNext(); ) {
             Row row = it.next();
-            assertEquals(row.getArity(), 1);
+            assertEquals(1, row.getArity());
             DenseVector vector = (DenseVector) row.getField(generator.getColNames()[0][0]);
             assertNotNull(vector);
             assertEquals(vector.size(), generator.getVectorDim());
@@ -61,10 +77,7 @@ public class DataGeneratorTest {
     }
 
     @Test
-    public void testDenseVectorArrayGenerator() throws Exception {
-        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
+    public void testDenseVectorArrayGenerator() {
         DenseVectorArrayGenerator generator =
                 new DenseVectorArrayGenerator()
                         .setColNames(new String[] {"denseVectors"})
@@ -72,12 +85,13 @@ public class DataGeneratorTest {
                         .setVectorDim(10)
                         .setArraySize(20);
 
-        DataStream<DenseVector[]> stream =
-                tEnv.toDataStream(generator.getData(tEnv)[0], DenseVector[].class);
-
         int count = 0;
-        for (CloseableIterator<DenseVector[]> it = stream.executeAndCollect(); it.hasNext(); ) {
-            DenseVector[] vectors = it.next();
+        for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect();
+                it.hasNext(); ) {
+            Row row = it.next();
+            assertEquals(1, row.getArity());
+            DenseVector[] vectors = (DenseVector[]) row.getField(generator.getColNames()[0][0]);
+            assertNotNull(vectors);
             assertEquals(generator.getArraySize(), vectors.length);
             for (DenseVector vector : vectors) {
                 assertEquals(vector.size(), generator.getVectorDim());
@@ -89,9 +103,6 @@ public class DataGeneratorTest {
 
     @Test
     public void testLabeledPointWithWeightGenerator() {
-        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
         String featuresCol = "features";
         String labelCol = "label";
         String weightCol = "weight";