You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/06/20 03:43:19 UTC
[flink-ml] 02/04: [FLINK-27877] Reduce the length of the operator chain for generating input table
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 341df450831e4c426ff4f8049af8dc52fc0bb598
Author: zhangzp <zh...@gmail.com>
AuthorDate: Mon Jun 20 09:50:59 2022 +0800
[FLINK-27877] Reduce the length of the operator chain for generating input table
---
.../common/DenseVectorArrayGenerator.java | 114 +++++-------------
.../datagenerator/common/DenseVectorGenerator.java | 103 +++++-----------
.../datagenerator/common/InputTableGenerator.java | 66 ++++++++++
.../common/LabeledPointWithWeightGenerator.java | 134 +++++++--------------
.../datagenerator/common/RowGenerator.java | 77 ++++++++++++
.../flink/ml/benchmark/DataGeneratorTest.java | 45 ++++---
6 files changed, 275 insertions(+), 264 deletions(-)
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java
index 0f8b82f..c1b3a21 100644
--- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java
@@ -18,100 +18,50 @@
package org.apache.flink.ml.benchmark.datagenerator.common;
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize;
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
-import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.DataTypes;
-import org.apache.flink.table.api.Schema;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.util.NumberSequenceIterator;
+import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
/** A DataGenerator which creates a table of DenseVector array. */
-public class DenseVectorArrayGenerator
- implements InputDataGenerator<DenseVectorArrayGenerator>,
- HasArraySize<DenseVectorArrayGenerator>,
+public class DenseVectorArrayGenerator extends InputTableGenerator<DenseVectorArrayGenerator>
+ implements HasArraySize<DenseVectorArrayGenerator>,
HasVectorDim<DenseVectorArrayGenerator> {
- private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
- public DenseVectorArrayGenerator() {
- ParamUtils.initializeMapWithDefaultValues(paramMap, this);
- }
@Override
- public Table[] getData(StreamTableEnvironment tEnv) {
- StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
-
- DataStream<DenseVector[]> dataStream =
- env.fromParallelCollection(
- new NumberSequenceIterator(1L, getNumValues()),
- BasicTypeInfo.LONG_TYPE_INFO)
- .map(
- new GenerateRandomContinuousVectorArrayFunction(
- getSeed(), getVectorDim(), getArraySize()));
-
- Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector[].class)).build();
- Table dataTable = tEnv.fromDataStream(dataStream, schema);
- if (getColNames() != null) {
- Preconditions.checkState(getColNames().length == 1);
- Preconditions.checkState(getColNames()[0].length == 1);
- dataTable = dataTable.as(getColNames()[0][0]);
- }
-
- return new Table[] {dataTable};
- }
+ protected RowGenerator[] getRowGenerators() {
+ String[][] columnNames = getColNames();
+ Preconditions.checkState(columnNames.length == 1);
+ Preconditions.checkState(columnNames[0].length == 1);
+ int arraySize = getArraySize();
+ int vectorDim = getVectorDim();
- private static class GenerateRandomContinuousVectorArrayFunction
- extends RichMapFunction<Long, DenseVector[]> {
- private final int vectorDim;
- private final long initSeed;
- private final int arraySize;
- private Random random;
-
- private GenerateRandomContinuousVectorArrayFunction(
- long initSeed, int vectorDim, int arraySize) {
- this.vectorDim = vectorDim;
- this.initSeed = initSeed;
- this.arraySize = arraySize;
- }
-
- @Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
- int index = getRuntimeContext().getIndexOfThisSubtask();
- random = new Random(Tuple2.of(initSeed, index).hashCode());
- }
+ return new RowGenerator[] {
+ new RowGenerator(getNumValues(), getSeed()) {
+ @Override
+ protected Row nextRow() {
+ DenseVector[] result = new DenseVector[arraySize];
+ for (int i = 0; i < arraySize; i++) {
+ result[i] = new DenseVector(vectorDim);
+ for (int j = 0; j < vectorDim; j++) {
+ result[i].values[j] = random.nextDouble();
+ }
+ }
+ Row row = new Row(1);
+ row.setField(0, result);
+ return row;
+ }
- @Override
- public DenseVector[] map(Long value) {
- DenseVector[] result = new DenseVector[arraySize];
- for (int i = 0; i < arraySize; i++) {
- result[i] = new DenseVector(vectorDim);
- for (int j = 0; j < vectorDim; j++) {
- result[i].values[j] = random.nextDouble();
+ @Override
+ protected RowTypeInfo getRowTypeInfo() {
+ return new RowTypeInfo(
+ new TypeInformation[] {TypeInformation.of(DenseVector[].class)},
+ columnNames[0]);
}
}
- return result;
- }
- }
-
- @Override
- public Map<Param<?>, Object> getParamMap() {
- return paramMap;
+ };
}
}
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java
index 4117261..10eae84 100644
--- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java
@@ -18,86 +18,43 @@
package org.apache.flink.ml.benchmark.datagenerator.common;
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
-import org.apache.flink.ml.common.datastream.TableUtils;
-import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.util.NumberSequenceIterator;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
/** A DataGenerator which creates a table of DenseVector. */
-public class DenseVectorGenerator
- implements InputDataGenerator<DenseVectorGenerator>, HasVectorDim<DenseVectorGenerator> {
- private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
- public DenseVectorGenerator() {
- ParamUtils.initializeMapWithDefaultValues(paramMap, this);
- }
+public class DenseVectorGenerator extends InputTableGenerator<DenseVectorGenerator>
+ implements HasVectorDim<DenseVectorGenerator> {
@Override
- public Table[] getData(StreamTableEnvironment tEnv) {
- StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
-
- DataStream<DenseVector> dataStream =
- env.fromParallelCollection(
- new NumberSequenceIterator(1L, getNumValues()),
- BasicTypeInfo.LONG_TYPE_INFO)
- .map(new RandomDenseVectorGenerator(getSeed(), getVectorDim()));
-
- Table dataTable = tEnv.fromDataStream(dataStream);
- if (getColNames() != null) {
- Preconditions.checkState(getColNames().length == 1);
- Preconditions.checkState(getColNames()[0].length == 1);
- dataTable = dataTable.as(getColNames()[0][0]);
- }
-
- return new Table[] {dataTable};
- }
-
- private static class RandomDenseVectorGenerator extends RichMapFunction<Long, DenseVector> {
- private final int vectorDim;
- private final long initSeed;
- private Random random;
-
- private RandomDenseVectorGenerator(long initSeed, int vectorDim) {
- this.vectorDim = vectorDim;
- this.initSeed = initSeed;
- }
-
- @Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
- int index = getRuntimeContext().getIndexOfThisSubtask();
- random = new Random(Tuple2.of(initSeed, index).hashCode());
- }
-
- @Override
- public DenseVector map(Long value) {
- double[] values = new double[vectorDim];
- for (int i = 0; i < vectorDim; i++) {
- values[i] = random.nextDouble();
+ public RowGenerator[] getRowGenerators() {
+ String[][] columnNames = getColNames();
+ Preconditions.checkState(columnNames.length == 1);
+ Preconditions.checkState(columnNames[0].length == 1);
+ int vectorDim = getVectorDim();
+
+ return new RowGenerator[] {
+ new RowGenerator(getNumValues(), getSeed()) {
+
+ @Override
+ protected Row nextRow() {
+ double[] values = new double[vectorDim];
+ for (int i = 0; i < values.length; i++) {
+ values[i] = random.nextDouble();
+ }
+ return Row.of(Vectors.dense(values));
+ }
+
+ @Override
+ protected RowTypeInfo getRowTypeInfo() {
+ return new RowTypeInfo(
+ new TypeInformation[] {DenseVectorTypeInfo.INSTANCE}, columnNames[0]);
+ }
}
- return Vectors.dense(values);
- }
- }
-
- @Override
- public Map<Param<?>, Object> getParamMap() {
- return paramMap;
+ };
}
}
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java
new file mode 100644
index 0000000..dd673a7
--- /dev/null
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.benchmark.datagenerator.common;
+
+import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Base class for generating data as input table arrays. */
+public abstract class InputTableGenerator<T extends InputTableGenerator<T>>
+ implements InputDataGenerator<T> {
+ protected final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public InputTableGenerator() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public final Table[] getData(StreamTableEnvironment tEnv) {
+ StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
+
+ RowGenerator[] rowGenerators = getRowGenerators();
+ Table[] dataTables = new Table[rowGenerators.length];
+ for (int i = 0; i < rowGenerators.length; i++) {
+ DataStream<Row> dataStream =
+ env.addSource(rowGenerators[i], "sourceOp-" + i)
+ .returns(rowGenerators[i].getRowTypeInfo());
+ dataTables[i] = tEnv.fromDataStream(dataStream);
+ }
+
+ return dataTables;
+ }
+
+ /** Gets generators for all input tables. */
+ protected abstract RowGenerator[] getRowGenerators();
+
+ @Override
+ public final Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java
index 0e11071..dff9f07 100644
--- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java
@@ -18,34 +18,19 @@
package org.apache.flink.ml.benchmark.datagenerator.common;
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
-import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.IntParam;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.ml.util.ParamUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
-import org.apache.flink.util.NumberSequenceIterator;
import org.apache.flink.util.Preconditions;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
/**
* A DataGenerator which creates a table of features, label and weight.
*
@@ -58,8 +43,8 @@ import java.util.Random;
* </ul>
*/
public class LabeledPointWithWeightGenerator
- implements InputDataGenerator<LabeledPointWithWeightGenerator>,
- HasVectorDim<LabeledPointWithWeightGenerator> {
+ extends InputTableGenerator<LabeledPointWithWeightGenerator>
+ implements HasVectorDim<LabeledPointWithWeightGenerator> {
public static final Param<Integer> FEATURE_ARITY =
new IntParam(
@@ -79,8 +64,6 @@ public class LabeledPointWithWeightGenerator
2,
ParamValidators.gtEq(0));
- private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
public LabeledPointWithWeightGenerator() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@@ -102,79 +85,46 @@ public class LabeledPointWithWeightGenerator
}
@Override
- public Table[] getData(StreamTableEnvironment tEnv) {
- Preconditions.checkState(getColNames().length == 1);
- Preconditions.checkState(getColNames()[0].length == 3);
-
- StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
-
- DataStream<Row> dataStream =
- env.fromParallelCollection(
- new NumberSequenceIterator(1L, getNumValues()),
- BasicTypeInfo.LONG_TYPE_INFO)
- .map(
- new RandomLabeledPointWithWeightGenerator(
- getSeed(),
- getVectorDim(),
- getFeatureArity(),
- getLabelArity()),
- new RowTypeInfo(
- new TypeInformation[] {
- DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE
- },
- getColNames()[0]));
-
- Table dataTable = tEnv.fromDataStream(dataStream);
-
- return new Table[] {dataTable};
- }
-
- @Override
- public Map<Param<?>, Object> getParamMap() {
- return paramMap;
- }
-
- private static class RandomLabeledPointWithWeightGenerator extends RichMapFunction<Long, Row> {
- private final long initSeed;
- private final int vectorDim;
- private final int featureArity;
- private final int labelArity;
- private Random random;
-
- private RandomLabeledPointWithWeightGenerator(
- long initSeed, int vectorDim, int featureArity, int labelArity) {
- this.initSeed = initSeed;
- this.vectorDim = vectorDim;
- this.featureArity = featureArity;
- this.labelArity = labelArity;
- }
-
- @Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
- int index = getRuntimeContext().getIndexOfThisSubtask();
- random = new Random(Tuple2.of(initSeed, index).hashCode());
- }
-
- @Override
- public Row map(Long ignored) {
- double[] features = new double[vectorDim];
- for (int i = 0; i < vectorDim; i++) {
- features[i] = getValue(featureArity);
- }
-
- double label = getValue(labelArity);
-
- double weight = random.nextDouble();
-
- return Row.of(Vectors.dense(features), label, weight);
- }
-
- private double getValue(int arity) {
- if (arity > 0) {
- return random.nextInt(arity);
+ protected RowGenerator[] getRowGenerators() {
+ String[][] colNames = getColNames();
+ Preconditions.checkState(colNames.length == 1);
+ Preconditions.checkState(colNames[0].length == 3);
+ int vectorDim = getVectorDim();
+ int featureArity = getFeatureArity();
+ int labelArity = getLabelArity();
+
+ return new RowGenerator[] {
+ new RowGenerator(getNumValues(), getSeed()) {
+ @Override
+ protected Row nextRow() {
+ double[] features = new double[vectorDim];
+ for (int i = 0; i < vectorDim; i++) {
+ features[i] = getValue(featureArity);
+ }
+
+ double label = getValue(labelArity);
+
+ double weight = random.nextDouble();
+
+ return Row.of(Vectors.dense(features), label, weight);
+ }
+
+ @Override
+ protected RowTypeInfo getRowTypeInfo() {
+ return new RowTypeInfo(
+ new TypeInformation[] {
+ DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE
+ },
+ colNames[0]);
+ }
+
+ private double getValue(int arity) {
+ if (arity > 0) {
+ return random.nextInt(arity);
+ }
+ return random.nextDouble();
+ }
}
- return random.nextDouble();
- }
+ };
}
}
diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java
new file mode 100644
index 0000000..55fe526
--- /dev/null
+++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.benchmark.datagenerator.common;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.types.Row;
+
+import java.util.Random;
+
+/** A parallel source to generate user defined rows. */
+public abstract class RowGenerator extends RichParallelSourceFunction<Row> {
+ /** Random instance to generate data. */
+ protected Random random;
+ /** Number of values to generate in total. */
+ private final long numValues;
+ /** The init seed to generate data. */
+ private final long initSeed;
+ /** Number of tasks to generate in one local task. */
+ private long numValuesOnThisTask;
+ /** Whether this source is still running. */
+ private volatile boolean isRunning = true;
+
+ public RowGenerator(long numValues, long initSeed) {
+ this.numValues = numValues;
+ this.initSeed = initSeed;
+ }
+
+ @Override
+ public final void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ int taskIdx = getRuntimeContext().getIndexOfThisSubtask();
+ int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+ random = new Random(Tuple2.of(initSeed, taskIdx).hashCode());
+ long div = numValues / numTasks;
+ long mod = numValues % numTasks;
+ numValuesOnThisTask = mod > taskIdx ? div + 1 : div;
+ }
+
+ @Override
+ public final void run(SourceContext<Row> ctx) throws Exception {
+ long cnt = 0;
+ while (isRunning && cnt < numValuesOnThisTask) {
+ ctx.collect(nextRow());
+ cnt++;
+ }
+ }
+
+ @Override
+ public final void cancel() {
+ isRunning = false;
+ }
+
+ /** Generates a new data point. */
+ protected abstract Row nextRow();
+
+ /** Returns the output type information for this generator. */
+ protected abstract RowTypeInfo getRowTypeInfo();
+}
diff --git a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java
index 937d25b..7d2883a 100644
--- a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java
+++ b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java
@@ -18,17 +18,22 @@
package org.apache.flink.ml.benchmark;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator;
import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorGenerator;
import org.apache.flink.ml.benchmark.datagenerator.common.LabeledPointWithWeightGenerator;
import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
+import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -36,11 +41,22 @@ import static org.junit.Assert.assertTrue;
/** Tests data generators. */
public class DataGeneratorTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamTableEnvironment tEnv;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ }
+
@Test
public void testDenseVectorGenerator() {
- StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
DenseVectorGenerator generator =
new DenseVectorGenerator()
.setColNames(new String[] {"denseVector"})
@@ -51,7 +67,7 @@ public class DataGeneratorTest {
for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect();
it.hasNext(); ) {
Row row = it.next();
- assertEquals(row.getArity(), 1);
+ assertEquals(1, row.getArity());
DenseVector vector = (DenseVector) row.getField(generator.getColNames()[0][0]);
assertNotNull(vector);
assertEquals(vector.size(), generator.getVectorDim());
@@ -61,10 +77,7 @@ public class DataGeneratorTest {
}
@Test
- public void testDenseVectorArrayGenerator() throws Exception {
- StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
+ public void testDenseVectorArrayGenerator() {
DenseVectorArrayGenerator generator =
new DenseVectorArrayGenerator()
.setColNames(new String[] {"denseVectors"})
@@ -72,12 +85,13 @@ public class DataGeneratorTest {
.setVectorDim(10)
.setArraySize(20);
- DataStream<DenseVector[]> stream =
- tEnv.toDataStream(generator.getData(tEnv)[0], DenseVector[].class);
-
int count = 0;
- for (CloseableIterator<DenseVector[]> it = stream.executeAndCollect(); it.hasNext(); ) {
- DenseVector[] vectors = it.next();
+ for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect();
+ it.hasNext(); ) {
+ Row row = it.next();
+ assertEquals(1, row.getArity());
+ DenseVector[] vectors = (DenseVector[]) row.getField(generator.getColNames()[0][0]);
+ assertNotNull(vectors);
assertEquals(generator.getArraySize(), vectors.length);
for (DenseVector vector : vectors) {
assertEquals(vector.size(), generator.getVectorDim());
@@ -89,9 +103,6 @@ public class DataGeneratorTest {
@Test
public void testLabeledPointWithWeightGenerator() {
- StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
String featuresCol = "features";
String labelCol = "label";
String weightCol = "weight";