You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by "lindong28 (via GitHub)" <gi...@apache.org> on 2023/03/04 11:11:34 UTC

[GitHub] [flink-ml] lindong28 commented on a diff in pull request #218: [FLINK-31306] Add Servable for PipelineModel

lindong28 commented on code in PR #218:
URL: https://github.com/apache/flink-ml/pull/218#discussion_r1125428753


##########
flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java:
##########
@@ -82,6 +85,33 @@ public static PipelineModel load(StreamTableEnvironment tEnv, String path) throw
                 ReadWriteUtils.loadPipeline(tEnv, path, PipelineModel.class.getName()));
     }
 
+    public static PipelineModelServable loadServable(String path) throws IOException {
+        return PipelineModelServable.load(path);
+    }
+
+    /**
+     * Whether all stages in the pipeline have corresponding {@link TransformerServable} so that the
+     * PipelineModel can be turned into a TransformerServable and used in an online inference
+     * program.
+     *
+     * @return true if all stages have corresponding TransformerServable, false if not.
+     */
+    public boolean supportServable() {
+        for (Stage stage : stages) {

Review Comment:
   nits: `Stage stage : stages` -> `Stage<?> stage : stages`



##########
flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/builder/ExampleServables.java:
##########
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.servable.builder;
+
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.api.TransformerServable;
+import org.apache.flink.ml.servable.types.DataTypes;
+import org.apache.flink.ml.util.FileUtils;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ServableReadWriteUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** Defines Servable subclasses to be used in unit tests. */
+public class ExampleServables {
+
+    /**
+     * A {@link TransformerServable} subclass that increments every value in the input dataframe by
+     * `delta` and outputs the resulting values.
+     */
+    public static class SumModelServable implements TransformerServable<SumModelServable> {
+
+        private static final String COL_NAME = "input";
+
+        private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+        private int delta;
+
+        public SumModelServable() {
+            ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        }
+
+        @Override
+        public DataFrame transform(DataFrame input) {
+            List<Row> outputRows = new ArrayList<>();
+            for (Row row : input.collect()) {
+                assert row.size() == 1;
+                int originValue = (Integer) row.get(0);
+                outputRows.add(new Row(Collections.singletonList(originValue + delta)));
+            }
+            return new DataFrame(
+                    Collections.singletonList(COL_NAME),
+                    Collections.singletonList(DataTypes.INT),
+                    outputRows);
+        }
+
+        @Override
+        public Map<Param<?>, Object> getParamMap() {
+            return paramMap;
+        }
+
+        public static SumModelServable load(String path) throws IOException {
+            SumModelServable servable =
+                    ServableReadWriteUtils.loadServableParam(path, SumModelServable.class);
+
+            Path modelDataPath = FileUtils.getDataPath(path);
+            try (FSDataInputStream fsDataInputStream =
+                    FileUtils.getModelDataInputStream(modelDataPath)) {
+                DataInputViewStreamWrapper dataInputViewStreamWrapper =
+                        new DataInputViewStreamWrapper(fsDataInputStream);
+                int delta = IntSerializer.INSTANCE.deserialize(dataInputViewStreamWrapper);
+                servable.setDelta(delta);
+            }
+            return servable;
+        }
+
+        public SumModelServable setDelta(int delta) {

Review Comment:
   Should we replace this method with `SumModelServable#setModelData(...)` so that users can set the model data using the public API designed in FLIP-289?



##########
flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java:
##########
@@ -110,6 +111,10 @@ public static SumModel load(StreamTableEnvironment tEnv, String path) throws IOE
             SumModel model = ReadWriteUtils.loadStageParam(path);
             return model.setModelData(modelDataTable);
         }
+
+        public static SumModelServable loadServable(String path) throws IOException {

Review Comment:
   Would it be useful to add a test similar to the following code snippet to cover this method?
   
   ```
   SumModel model = ...;
   model.save(path);
   SumModelServable servable = SumModelServable.load(path);
   Assert.assertEquals(expected, servable.transform(...));
   ```



##########
flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/FileUtils.java:
##########
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FileStatus;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.util.Preconditions;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.util.Map;
+
+/** Utility methods for file operations. */
+public class FileUtils {
+
+    /** Saves a given string to the specified file. */
+    public static void saveToFile(String pathStr, String content, boolean isOverwrite)
+            throws IOException {
+        Path path = new Path(pathStr);
+
+        // Creates parent directories if not already created.
+        FileSystem fs = mkdirs(path.getParent());
+
+        FileSystem.WriteMode writeMode = FileSystem.WriteMode.OVERWRITE;
+        if (!isOverwrite) {
+            writeMode = FileSystem.WriteMode.NO_OVERWRITE;
+            if (fs.exists(path)) {
+                throw new IOException("File " + path + " already exists.");
+            }
+        }
+        try (BufferedWriter writer =
+                new BufferedWriter(new OutputStreamWriter(fs.create(path, writeMode)))) {
+            writer.write(content);
+        }
+    }
+
+    public static FileSystem mkdirs(Path path) throws IOException {
+        FileSystem fs = path.getFileSystem();
+        fs.mkdirs(path);
+        return fs;
+    }
+
+    /**
+     * Loads the metadata from the metadata file under the given path.
+     *
+     * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not
+     * match the className of the previously saved stage.
+     *
+     * @param path The parent directory of the metadata file to read from.
+     * @param expectedClassName The expected class name of the stage.
+     * @return A map from metadata name to metadata value.
+     */
+    public static Map<String, ?> loadMetadata(String path, String expectedClassName)
+            throws IOException {
+        Path metadataPath = new Path(path, "metadata");
+        FileSystem fs = metadataPath.getFileSystem();
+
+        StringBuilder buffer = new StringBuilder();
+        try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(metadataPath)))) {
+            String line;
+            while ((line = br.readLine()) != null) {
+                if (!line.startsWith("#")) {
+                    buffer.append(line);
+                }
+            }
+        }
+
+        @SuppressWarnings("unchecked")
+        Map<String, ?> result = JsonUtils.OBJECT_MAPPER.readValue(buffer.toString(), Map.class);
+
+        String className = (String) result.get("className");
+        if (!expectedClassName.isEmpty() && !expectedClassName.equals(className)) {
+            throw new RuntimeException(
+                    "Class name "
+                            + className
+                            + " does not match the expected class name "
+                            + expectedClassName
+                            + ".");
+        }
+
+        return result;
+    }
+
+    // Returns a string with value {parentPath}/stages/{stageIdx}, where the stageIdx is prefixed
+    // with zero or more `0` to have the same length as numStages. The resulting string can be
+    // used as the directory to save a stage of the Pipeline or PipelineModel.
+    public static String getPathForPipelineStage(int stageIdx, int numStages, String parentPath) {
+        String format =
+                String.format("stages%s%%0%dd", File.separator, String.valueOf(numStages).length());
+        String fileName = String.format(format, stageIdx);
+        return new Path(parentPath, fileName).toString();
+    }
+
+    /** Returns a subdirectory of the given path for saving/loading model data. */
+    public static Path getDataPath(String path) {
+        return new Path(path, "data");
+    }
+
+    /**
+     * Opens an FSDataInputStream to read the model data file in the directory. Only one model data
+     * file is expected to be in the directory.
+     *
+     * @param path The parent directory of the model data file.
+     * @return A FSDataInputStream to read the model data.
+     */
+    public static FSDataInputStream getModelDataInputStream(Path path) throws IOException {

Review Comment:
   Would the following function signature be simpler to use and more consistent with the existing `ReadWriteUtils#loadModelData`?
   
   ```
   public static InputStream loadModelData(String path) throws IOException
   ```



##########
flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java:
##########
@@ -82,6 +85,33 @@ public static PipelineModel load(StreamTableEnvironment tEnv, String path) throw
                 ReadWriteUtils.loadPipeline(tEnv, path, PipelineModel.class.getName()));
     }
 
+    public static PipelineModelServable loadServable(String path) throws IOException {
+        return PipelineModelServable.load(path);
+    }
+
+    /**
+     * Whether all stages in the pipeline have corresponding {@link TransformerServable} so that the
+     * PipelineModel can be turned into a TransformerServable and used in an online inference
+     * program.
+     *
+     * @return true if all stages have corresponding TransformerServable, false if not.
+     */
+    public boolean supportServable() {
+        for (Stage stage : stages) {
+            if (!(stage instanceof Transformer)) {
+                return false;
+            }
+            Transformer transformer = (Transformer) stage;

Review Comment:
   nits: `Transformer<?> transformer = (Transformer<?>) stage`



##########
flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/types/DataTypes.java:
##########
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.servable.types;
+
+/** This class gives access to the most common types that are used to define DataFrames. */
+public class DataTypes {
+
+    public static final ScalarType BOOLEAN = new ScalarType(BasicType.BOOLEAN);
+
+    public static final ScalarType BYTE = new ScalarType(BasicType.BYTE);
+
+    public static final ScalarType SHORT = new ScalarType(BasicType.SHORT);
+
+    public static final ScalarType INT = new ScalarType(BasicType.INT);
+
+    public static final ScalarType LONG = new ScalarType(BasicType.LONG);
+
+    public static final ScalarType FLOAT = new ScalarType(BasicType.FLOAT);
+
+    public static final ScalarType DOUBLE = new ScalarType(BasicType.DOUBLE);
+
+    public static final ScalarType STRING = new ScalarType(BasicType.STRING);
+
+    public static final ScalarType BYTE_STRING = new ScalarType(BasicType.BYTE_STRING);
+
+    public static VectorType VECTOR(BasicType elementType) {

Review Comment:
   Should we name this method `vector(...)` to follow the existing code style guideline?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org