You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/03/29 12:41:51 UTC

[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #71: [FLINK-26443] Add benchmark framework

yunfengzhou-hub commented on a change in pull request #71:
URL: https://github.com/apache/flink-ml/pull/71#discussion_r837426927



##########
File path: flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/BenchmarkUtils.java
##########
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.benchmark;
+
+import org.apache.flink.api.common.JobExecutionResult;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.benchmark.generator.DataGenerator;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.WithParams;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Pattern;
+
+/** Utility methods for benchmarks. */
+@SuppressWarnings("unchecked")
+public class BenchmarkUtils {
+    private static final String benchmarkNamePattern = "^[A-Za-z0-9][A-Za-z0-9_\\-]*$";
+
+    /**
+     * Loads benchmark paramMaps from the provided json map.
+     *
+     * @return A map whose key is the names of the loaded benchmarks, value is the parameters of the
+     *     benchmarks.
+     */
+    public static Map<String, Map<String, ?>> parseBenchmarkParams(Map<String, ?> jsonMap) {
+        Preconditions.checkArgument(
+                jsonMap.containsKey("_version") && jsonMap.get("_version").equals(1));
+
+        Map<String, Map<String, ?>> result = new HashMap<>();
+        for (String key : jsonMap.keySet()) {
+            if (!isValidBenchmarkName(key)) {
+                continue;
+            }
+            result.put(key, (Map<String, ?>) jsonMap.get(key));
+        }
+        return result;
+    }
+
+    /**
+     * Checks whether a string is a valid benchmark name.
+     *
+     * <p>A valid benchmark name should only contain English letters, numbers, hyphens (-) and
+     * underscores (_). The name should not start with a hyphen or underscore.
+     */
+    public static boolean isValidBenchmarkName(String name) {
+        return Pattern.matches(benchmarkNamePattern, name);
+    }
+
+    /**
+     * Instantiates a benchmark from its parameter map and executes the benchmark in the provided
+     * environment.
+     *
+     * @return Results of the executed benchmark.
+     */
+    public static BenchmarkResult runBenchmark(
+            String name, StreamTableEnvironment tEnv, Map<String, ?> benchmarkParamsMap)
+            throws Exception {
+        Stage<?> stage =
+                (Stage<?>) instantiateWithParams((Map<String, ?>) benchmarkParamsMap.get("stage"));
+
+        BenchmarkResult result;
+        if (benchmarkParamsMap.size() == 2) {
+            DataGenerator<?> inputsGenerator =
+                    (DataGenerator<?>)
+                            instantiateWithParams(
+                                    (Map<String, ?>) benchmarkParamsMap.get("inputs"));
+            result = runBenchmark(name, tEnv, stage, inputsGenerator);
+        } else if (benchmarkParamsMap.size() == 3 && stage instanceof Model) {
+            DataGenerator<?> inputsGenerator =
+                    (DataGenerator<?>)
+                            instantiateWithParams(
+                                    (Map<String, ?>) benchmarkParamsMap.get("inputs"));
+            DataGenerator<?> modelDataGenerator =
+                    (DataGenerator<?>)
+                            instantiateWithParams(
+                                    (Map<String, ?>) benchmarkParamsMap.get("modelData"));
+            result =
+                    runBenchmark(name, tEnv, (Model<?>) stage, modelDataGenerator, inputsGenerator);
+        } else {
+            throw new IllegalArgumentException(
+                    "Unsupported json map with keys " + benchmarkParamsMap.keySet());
+        }
+
+        return result;
+    }
+
+    /**
+     * Executes a benchmark from a stage and an inputsGenerator in the provided environment.
+     *
+     * @return Results of the executed benchmark.
+     */
+    public static BenchmarkResult runBenchmark(
+            String name,
+            StreamTableEnvironment tEnv,
+            Stage<?> stage,
+            DataGenerator<?> inputsGenerator)
+            throws Exception {
+        Table[] inputTables = inputsGenerator.getData(tEnv);
+
+        Table[] outputTables;
+        if (stage instanceof Estimator) {
+            outputTables = ((Estimator<?, ?>) stage).fit(inputTables).getModelData();
+        } else if (stage instanceof AlgoOperator) {
+            outputTables = ((AlgoOperator<?>) stage).transform(inputTables);
+        } else {
+            throw new IllegalArgumentException("Unsupported Stage class " + stage.getClass());
+        }
+
+        for (Table table : outputTables) {
+            tEnv.toDataStream(table).addSink(new DiscardingSink<>());
+        }
+
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+        JobExecutionResult executionResult = env.execute();
+
+        BenchmarkResult result = new BenchmarkResult();
+        result.name = name;
+        result.executionTimeMillis = (double) executionResult.getNetRuntime(TimeUnit.MILLISECONDS);

Review comment:
       I can make `DataGenerator implements GeneratorParams`, which means each generator must have a `getNumData()` method. That should be a simpler solution.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org