You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/05/10 09:32:23 UTC

[GitHub] [flink-ml] yunfengzhou-hub opened a new pull request, #97: [FLINK_27096] Improve DataCache and KMeans Performance

yunfengzhou-hub opened a new pull request, #97:
URL: https://github.com/apache/flink-ml/pull/97

   ## What is the purpose of the change
   
   This PR mainly improves `DataCacheReader` and `DataCacheWriter`'s performance with memory caches, and further optimizes KMeans algorithm with DataCache mechanism.
   
   ## Brief change log
   - In DataCache mechanism, provides the option to cache records in memory in deserialized format, instead of directly saving  them to file system, unless checkpoint is invoked or there is not enough memory.
   - Changes `DataStreamUtils.mapParition()`'s behavior from using `ListState` to using `DataCacheWriter`.
   - Utilizes the memory-caching functionality in `KMeans.SelectNearestCentroidOperator` to avoid OOM.
       - All other existing usages of `DataCacheWriter` is untouched, continuing to directly saving records to file system.
   - Adds `DataStreamUtils.sample()` and uses it in `KMeans.selectRandomCentroids()` to reduce memory usage.
   
   
   ## Does this pull request potentially affect one of the following parts:
   - Dependencies (does it add or upgrade a dependency): (yes)
   - The public API, i.e., is any changed class annotated with @public(Evolving): (no)
   - Does this pull request introduce a new feature? (no)
   - If yes, how is the feature documented? (N/A)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890241966


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;
+
+        private StreamTask<?, ?> containingTask;
 
-        public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
+        private DataCacheWriter<IN> dataCacheWriter;
+
+        public MapPartitionOperator(
+                MapPartitionFunction<IN, OUT> mapPartitionFunc, TypeInformation<IN> inputType) {
             super(mapPartitionFunc);
+            this.inputType = inputType;
+        }
+
+        @Override
+        public void setup(
+                StreamTask<?, ?> containingTask,
+                StreamConfig config,
+                Output<StreamRecord<OUT>> output) {
+            super.setup(containingTask, config, output);
+
+            basePath =
+                    OperatorUtils.getDataCachePath(
+                            containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                            containingTask
+                                    .getEnvironment()
+                                    .getIOManager()
+                                    .getSpillingDirectoriesPaths());
+            this.config = config;

Review Comment:
   nits: I think we typically put simple assignment (e.g. `this.config = config`) before non-trivial instantiation (e.g. `basePath = ..`). Could you update the code to follow this convention?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,162 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentSegmentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
 
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)

Review Comment:
   nits: would it be better to keep the previous name `priorFinishedSegments`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;

Review Comment:
   It appears that we only need `config.getOperatorID()` from this config. Would it be simpler to just save the `OperatorID`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -256,4 +392,81 @@ public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> collector) {
             }
         }
     }
+
+    /*
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<T>
+            implements OneInputStreamOperator<T, T>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>(numSamples);
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
+            countState = context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws Exception {
+            T sample = streamRecord.getValue();

Review Comment:
   nits: sample means something that is chosen from a collection. Since it is not sure we will chose this value, would it be simpler to just name it `value`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;
+
+        private StreamTask<?, ?> containingTask;
 
-        public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
+        private DataCacheWriter<IN> dataCacheWriter;
+
+        public MapPartitionOperator(
+                MapPartitionFunction<IN, OUT> mapPartitionFunc, TypeInformation<IN> inputType) {
             super(mapPartitionFunc);
+            this.inputType = inputType;
+        }
+
+        @Override
+        public void setup(
+                StreamTask<?, ?> containingTask,
+                StreamConfig config,
+                Output<StreamRecord<OUT>> output) {
+            super.setup(containingTask, config, output);
+
+            basePath =
+                    OperatorUtils.getDataCachePath(
+                            containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                            containingTask
+                                    .getEnvironment()
+                                    .getIOManager()
+                                    .getSpillingDirectoriesPaths());
+            this.config = config;
+            this.containingTask = containingTask;
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws Exception {
             super.initializeState(context);
-            ListStateDescriptor<IN> descriptor =
-                    new ListStateDescriptor<>(
-                            "inputState",
-                            getOperatorConfig()
-                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
-            valuesState = context.getOperatorStateStore().getListState(descriptor);
+
+            List<StatePartitionStreamProvider> inputs =
+                    IteratorUtils.toList(context.getRawOperatorStateInputs().iterator());
+            Preconditions.checkState(
+                    inputs.size() < 2, "The input from raw operator state should be one or zero.");
+
+            List<Segment> priorFinishedSegments = new ArrayList<>();
+            if (inputs.size() > 0) {
+
+                InputStream inputStream = inputs.get(0).getStream();
+
+                DataCacheSnapshot dataCacheSnapshot =
+                        DataCacheSnapshot.recover(
+                                inputStream,
+                                basePath.getFileSystem(),
+                                OperatorUtils.createDataCacheFileGenerator(
+                                        basePath, "cache", config.getOperatorID()));
+
+                priorFinishedSegments = dataCacheSnapshot.getSegments();
+            }
+
+            dataCacheWriter =
+                    new DataCacheWriter<>(
+                            inputType.createSerializer(containingTask.getExecutionConfig()),
+                            basePath.getFileSystem(),
+                            OperatorUtils.createDataCacheFileGenerator(
+                                    basePath, "cache", config.getOperatorID()),
+                            priorFinishedSegments);
         }
 
         @Override
-        public void endInput() throws Exception {
-            userFunction.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
-            valuesState.clear();
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+
+            dataCacheWriter.writeSegmentsToFiles();
+            DataCacheSnapshot dataCacheSnapshot =
+                    new DataCacheSnapshot(
+                            basePath.getFileSystem(), null, dataCacheWriter.getSegments());
+            context.getRawOperatorStateOutput().startNewPartition();
+            dataCacheSnapshot.writeTo(context.getRawOperatorStateOutput());
         }
 
         @Override
         public void processElement(StreamRecord<IN> input) throws Exception {
-            valuesState.add(input.getValue());
+            dataCacheWriter.addRecord(input.getValue());
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            List<Segment> pendingSegments = dataCacheWriter.getSegments();

Review Comment:
   nits: It is not clear what `pending` means in this context. Would it be simpler to just name it `segments`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;
+
+        private StreamTask<?, ?> containingTask;

Review Comment:
   It appears that we only need `containingTask.getExecutionConfig()`. Would it be simpler to just save the ExecutionConfig?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890784808


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -254,58 +272,150 @@ public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value
                             DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
                     IterationListener<Tuple2<Integer, DenseVector>> {
         private final DistanceMeasure distanceMeasure;
-        private ListState<DenseVector> points;
-        private ListState<DenseVector[]> centroids;
+        private ListState<DenseVector[]> centroidsState;
+        private DenseVector[] centroids;
+
+        private Path basePath;
+        private OperatorID operatorID;
+        private MemorySegmentPool segmentPool;
+        private DataCacheWriter<DenseVector> dataCacheWriter;
 
         public SelectNearestCentroidOperator(DistanceMeasure distanceMeasure) {
+            super();
             this.distanceMeasure = distanceMeasure;
         }
 
+        @Override
+        public void setup(
+                StreamTask<?, ?> containingTask,
+                StreamConfig config,
+                Output<StreamRecord<Tuple2<Integer, DenseVector>>> output) {
+            super.setup(containingTask, config, output);
+
+            operatorID = config.getOperatorID();
+
+            MemoryManager memoryManager = getContainingTask().getEnvironment().getMemoryManager();

Review Comment:
   Could you refactor the code so that algorithm developers won't need to handle so many details?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891817094


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
+            countState = context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws Exception {
+            T sample = streamRecord.getValue();
+            count++;
+
+            // Code below is inspired by the Reservoir Sampling algorithm.
+            if (samples.size() < numSamples) {
+                samples.add(sample);
+            } else {
+                if (random.nextInt(count) < numSamples) {
+                    samples.set(random.nextInt(numSamples), sample);
+                }
+            }
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            Collections.shuffle(samples, random);

Review Comment:
   Got it. As we does not provide any guarantee to the order to the returned samples, it is also acceptable not to shuffle the elements. I'll remove this line.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889837722


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,61 +18,37 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
-import org.apache.flink.core.fs.Path;
+import org.apache.flink.annotation.Internal;
 
-import java.io.Serializable;
-import java.util.Objects;
+/** A segment contains the information about a cache unit. */
+@Internal
+class Segment {
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+    private FileSegment fileSegment;

Review Comment:
   According to offline discussion, I'll make `path` a `final` variable in `Segment` as it always need to be allocated.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on PR #97:
URL: https://github.com/apache/flink-ml/pull/97#issuecomment-1146612587

   Thanks for the comments. I have updated the PR according to the comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r888753223


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -90,18 +90,18 @@ public void writeTo(OutputStream checkpointOutputStream) throws IOException {
             }
 
             dos.writeBoolean(fileSystem.isDistributedFS());
+            for (Segment segment : segments) {
+                persistSegmentToDisk(segment);
+            }
             if (fileSystem.isDistributedFS()) {
                 // We only need to record the segments itself
                 serializeSegments(segments, dos);
             } else {
                 // We have to copy the whole streams.
-                int totalRecords = segments.stream().mapToInt(Segment::getCount).sum();
-                long totalSize = segments.stream().mapToLong(Segment::getSize).sum();
-                checkState(totalRecords >= 0, "overflowed: " + totalRecords);
-                dos.writeInt(totalRecords);
-                dos.writeLong(totalSize);
-
+                dos.writeInt(segments.size());
                 for (Segment segment : segments) {
+                    dos.writeInt(segment.getCount());

Review Comment:
   Because the max size of a segment is limited. For example, limited by the max allowed file size of the underlying filesystem. If we merge all segments into one during snapshot, errors due to such limits might be invoked. Treating each segment separately could help avoid such errors while not adding much overhead to the snapshot process.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r888738791


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {
+                int numPages = (int) ((expectedSize + pageSize - 1) / pageSize);
+                segments.addAll(memoryManager.allocatePages(getKey(), numPages));
+            }
+        }
+
+        public Object getKey() {
+            return key;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(int capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int requiredSegmentNum = (capacity - 1) / pageSize + 2 - segments.size();
+            if (requiredSegmentNum > 0) {
+                segments.addAll(memoryManager.allocatePages(getKey(), requiredSegmentNum));
+            }
+        }
+
+        private void writeRecursive(byte[] b, int off, int len) {

Review Comment:
   Do you mean we should allow `len` to be zero or prohibit it? I think we should allow such usage so as to correspond to lower-level implementation, and the code has already achieved it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r878863091


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final MemoryManager memoryManager;
+
+    private final Path path;
+
+    private final List<T> cache;
+
+    private final TypeSerializer<T> serializer;
+
+    private long inMemorySize;
+
+    private int count;
+
+    private long reservedMemorySize;
+
+    public MemorySegmentWriter(
+            Path path, MemoryManager memoryManager, TypeSerializer<T> serializer, long expectedSize)
+            throws MemoryReservationException {
+        this.serializer = serializer;
+        Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.cache = new ArrayList<>();
+        this.inMemorySize = 0L;
+        this.count = 0;
+        this.memoryManager = memoryManager;
+
+        if (expectedSize > 0) {
+            memoryManager.reserveMemory(this.path, expectedSize);
+        }
+        this.reservedMemorySize = expectedSize;
+    }
+
+    @Override
+    public boolean addRecord(T record) {

Review Comment:
   The way of using `MemoryManager` seems not appropriate to me after digging into the usage of `MemoryManager`. [1][2]
   
   The code snippet here seems to be caching the record in java heap, but trying to reserve memory from off-heap memory. If I am understanding [1] [2] correctly, 
   - When using `MemoryManager` to manipulate managed memory, we are mostly dealing with off-heap memory.
   - The managed memory for each operator should be a fixed one after generating the job graph, i.e., it is not dynamically allocated.
   - The usage of managed memory should be declared to the jobgraph explicitly and then be used by the operator. Otherwise it will lead to OOM if deployed in a container.
   
   As I see, there are basically two options to cache the data:
   - cache it in `task heap` (i.e., cache it in a `list`): It is simple and easy to implement, but the downside is that we cannot control the size of cached element `statically` and the program may not be robust --- `task heap` is shared among the JVM and we have no idea about how others are using the JVM heap memory. 
   - cache it in `off-heap` (for example using the managed memory). In this way, we need to declare the usage of the managed to the job graph via `Transformation#declareManagedMemoryUseCaseAtOperatorScope` or `Transformation#declareManagedMemoryUseCaseAtSlotScope` and get the fraction of the managed memory from [3].
   
   
   I would suggest to go with option-2, but need more discussions with the runtime guys.
   
   [1] https://cwiki.apache.org/confluence/display/FLINK/FLIP-53%3A+Fine+Grained+Operator+Resource+Management
   [2] https://cwiki.apache.org/confluence/display/FLINK/FLIP-141%3A+Intra-Slot+Managed+Memory+Sharing
   [3] https://github.com/apache/flink/blob/18a967f8ad7b22c2942e227fb84f08f552660b5a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/sort/SortOperator.java#L79



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r876654438


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,121 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-
-import javax.annotation.Nullable;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
-    private final TypeSerializer<T> serializer;
+    private final MemoryManager memoryManager;
 
-    private final FileSystem fileSystem;
+    private final TypeSerializer<T> serializer;
 
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    private SegmentReader<T> currentReader;
+
+    private MemorySegmentWriter<T> cacheWriter;
+
+    private int segmentIndex;
 
     public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+            TypeSerializer<T> serializer, MemoryManager memoryManager, List<Segment> segments) {
+        this(serializer, memoryManager, segments, new Tuple2<>(0, 0));
     }
 
     public DataCacheReader(
             TypeSerializer<T> serializer,
-            FileSystem fileSystem,
+            MemoryManager memoryManager,
             List<Segment> segments,
-            Tuple2<Integer, Integer> readerPosition)
-            throws IOException {
-
+            Tuple2<Integer, Integer> readerPosition) {
+        this.memoryManager = memoryManager;
         this.serializer = serializer;
-        this.fileSystem = fileSystem;
         this.segments = segments;
+        this.segmentIndex = readerPosition.f0;
+
+        createSegmentReaderAndCache(readerPosition.f0, readerPosition.f1);
+    }
+
+    private void createSegmentReaderAndCache(int index, int startOffset) {
+        try {
+            cacheWriter = null;
 
-        if (readerPosition.f0 < segments.size()) {
-            this.currentSegmentReader = new SegmentReader(readerPosition.f0, readerPosition.f1);
+            if (index >= segments.size()) {
+                currentReader = null;
+                return;
+            }
+
+            currentReader = SegmentReader.create(serializer, segments.get(index), startOffset);
+
+            boolean shouldCacheInMemory =
+                    startOffset == 0
+                            && currentReader instanceof FsSegmentReader
+                            && MemoryUtils.isMemoryEnoughForCache(memoryManager);
+
+            if (shouldCacheInMemory) {
+                cacheWriter =
+                        new MemorySegmentWriter<>(
+                                segments.get(index).path,
+                                memoryManager,
+                                segments.get(index).inMemorySize);
+            }
+        } catch (MemoryAllocationException e) {
+            cacheWriter = null;
+        } catch (IOException e) {
+            throw new RuntimeException(e);
         }
     }
 
     @Override
     public boolean hasNext() {
-        return currentSegmentReader != null && currentSegmentReader.hasNext();
+        return currentReader != null && currentReader.hasNext();
     }
 
     @Override
     public T next() {
         try {
-            T next = currentSegmentReader.next();
-
-            if (!currentSegmentReader.hasNext()) {
-                currentSegmentReader.close();
-                if (currentSegmentReader.index < segments.size() - 1) {
-                    currentSegmentReader = new SegmentReader(currentSegmentReader.index + 1, 0);
-                } else {
-                    currentSegmentReader = null;
+            T record = currentReader.next();
+
+            if (cacheWriter != null) {
+                if (!cacheWriter.addRecord(record)) {
+                    cacheWriter
+                            .finish()
+                            .ifPresent(x -> memoryManager.releaseMemory(x.path, x.inMemorySize));
+                    cacheWriter = null;
+                }
+            }
+
+            if (!currentReader.hasNext()) {
+                currentReader.close();
+                if (cacheWriter != null) {
+                    cacheWriter
+                            .finish()
+                            .ifPresent(
+                                    x -> {
+                                        x.fsSize = segments.get(segmentIndex).fsSize;
+                                        segments.set(segmentIndex, x);

Review Comment:
   A reader would not modify data stored in file or metadata about that file. It can just create a cache of the data in a segment, and assign the cache metadata to that segment. I'll modify the implementation to better reflect this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r884534511


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   According to offline discussion, I'll adopt buffered stream to improve disk io performance.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r884536006


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemoryUtils.java:
##########
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.memory.MemoryManager;
+
+/** Utility variables and methods for memory operation. */
+@Internal
+class MemoryUtils {
+    // Cache is not suggested if over 80% of memory has been occupied.
+    private static final double CACHE_MEMORY_THRESHOLD = 0.2;

Review Comment:
   According to offline discussion with @zhipeng93 , I'll use `declareManagedMemoryUseCaseAtOperatorScope` to decide how much memory can these operators use.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r876660941


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -89,57 +104,31 @@ public List<Segment> getFinishSegments() {
         return finishSegments;
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
-        }
-
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
-    }
-
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
-
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
-        }
-
-        public void addRecord(T record) throws IOException {
-            serializer.serialize(record, outputView);
-            currentSegmentCount += 1;
-        }
+    private SegmentWriter<T> createSegmentWriter(
+            SupplierWithException<Path, IOException> pathGenerator, MemoryManager memoryManager)
+            throws IOException {
+        boolean shouldCacheInMemory = MemoryUtils.isMemoryEnoughForCache(memoryManager);
 
-        public Optional<Segment> finish() throws IOException {
-            this.outputStream.flush();
-            long size = outputStream.getPos();
-            this.outputStream.close();
-
-            if (currentSegmentCount > 0) {
-                return Optional.of(new Segment(path, currentSegmentCount, size));
-            } else {
-                // If there are no records, we tend to directly delete this file
-                fileSystem.delete(path, false);
-                return Optional.empty();
+        if (shouldCacheInMemory) {
+            try {
+                return new MemorySegmentWriter<>(pathGenerator.get(), memoryManager);
+            } catch (MemoryAllocationException e) {
+                return new FsSegmentWriter<>(serializer, pathGenerator.get());
             }
         }
+        return new FsSegmentWriter<>(serializer, pathGenerator.get());
     }
 
     public void cleanup() throws IOException {
-        finishCurrentSegment();
+        finish();
         for (Segment segment : finishSegments) {
-            fileSystem.delete(segment.getPath(), false);
+            if (segment.isOnDisk()) {
+                fileSystem.delete(segment.path, false);
+            }
+            if (segment.isInMemory()) {
+                memoryManager.releaseMemory(segment.path, segment.inMemorySize);

Review Comment:
   Memory is reserved in `MemorySegmentWriter`. I had set the key to the writer object, rather than segment.path. I'll correct this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890966838


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +151,44 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
 
-        public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
+        private DataCacheListState<IN> dataCacheListState;
+
+        public MapPartitionOperator(
+                MapPartitionFunction<IN, OUT> mapPartitionFunc, TypeInformation<IN> inputType) {
             super(mapPartitionFunc);
+            this.inputType = inputType;

Review Comment:
   Instead of passing inputType to create serializer, would it be simpler to keep the previous approach and create serializer using the following code?
   
   ```
   getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader())
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r874503293


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)

Review Comment:
   In cases when the size of records to cache is known in advance, e.g., when caching a segment from filesystem, `expectedSize` can be used to reserve all memories at once. I'll add corresponding implementations.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r874553395


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   My experiments shows that if we do `outputView = new DataOutputViewStreamWrapper(outputStream)`, the performance of disk IO would degrade obviously. One possible cause is that in `DataOutputViewStreamWrapper` or `FsDataOutputStream`'s subclass, the implementation flushes out bytes every time `write()` is invoked, as I can observe from the update frequency of `outputStream.getPos()`, but I have not dived further into this issue. Shall we leave this implementation as it is for now, and see if @zhipeng93 would offer a more simplified implementation with better performance? I'll leave a TODO here for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889669817


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,61 +18,37 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
-import org.apache.flink.core.fs.Path;
+import org.apache.flink.annotation.Internal;
 
-import java.io.Serializable;
-import java.util.Objects;
+/** A segment contains the information about a cache unit. */
+@Internal
+class Segment {
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+    private FileSegment fileSegment;
 
-    private final Path path;
+    private MemorySegment memorySegment;
 
-    /** The count of the records in the file. */
-    private final int count;
-
-    /** The total length of file. */
-    private final long size;
-
-    public Segment(Path path, int count, long size) {
-        this.path = path;
-        this.count = count;
-        this.size = size;
-    }
-
-    public Path getPath() {
-        return path;
+    Segment(FileSegment fileSegment) {
+        this.fileSegment = fileSegment;
     }
 
-    public int getCount() {
-        return count;
+    Segment(MemorySegment memorySegment) {
+        this.memorySegment = memorySegment;
     }
 
-    public long getSize() {
-        return size;
+    void setFileSegment(FileSegment fileSegment) {
+        this.fileSegment = fileSegment;
     }
 
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) {
-            return true;
-        }
-
-        if (!(o instanceof Segment)) {
-            return false;
-        }
-
-        Segment segment = (Segment) o;
-        return count == segment.count && size == segment.size && Objects.equals(path, segment.path);
+    FileSegment getFileSegment() {
+        return fileSegment;
     }
 
-    @Override
-    public int hashCode() {
-        return Objects.hash(path, count, size);
+    void setMemorySegment(MemorySegment memorySegment) {

Review Comment:
   It is used in `DataCache.tryCacheSegmentToMemory`. I just found that there is a bug in this method that it should call `segment.setMemorySegment(x.getMemorySegment())`. I'll make the modification.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891139035


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -167,26 +183,69 @@ public static DataCacheSnapshot recover(
             if (isDistributedFS) {
                 segments = deserializeSegments(dis);
             } else {
-                int totalRecords = dis.readInt();
-                long totalSize = dis.readLong();
-
-                Path path = pathGenerator.get();
-                try (FSDataOutputStream outputStream =
-                        fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE)) {
-
-                    BoundedInputStream inputStream =
-                            new BoundedInputStream(checkpointInputStream, totalSize);
-                    inputStream.setPropagateClose(false);
-                    IOUtils.copyBytes(inputStream, outputStream, false);
-                    inputStream.close();
+                int segmentNum = dis.readInt();
+                segments = new ArrayList<>(segmentNum);
+                for (int i = 0; i < segmentNum; i++) {
+                    int count = dis.readInt();
+                    long fsSize = dis.readLong();
+                    Path path = pathGenerator.get();
+                    try (FSDataOutputStream outputStream =

Review Comment:
   According to our offline discussion, the original purpose of merging segments into one is to reduce the number of files created. He also agrees that it might cause the size of the file merged from several other files to exceed the max allowed size of the file system, so our current practice to preserve the original segments is acceptable to him. The current use cases of DataCache segments would not generate lots of small segments, so the performance issue is not obvious for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891014714


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing persisted records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /**
+     * The count of records in the file at the path if the file size is not zero, otherwise the
+     * count of records in the cache.
+     */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /**
+     * The total length of file containing persisted records. Its value is 0 iff the segment has not
+     * been written to the given path.
+     */
+    private long fsSize = 0L;
+
+    /**
+     * The memory segments containing cached records. This list is empty iff the segment has not
+     * been cached in memory.
+     */
+    private List<MemorySegment> cache = new ArrayList<>();
+
+    Segment(Path path, int count, long fsSize) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
+        this.count = count;
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
 
-    public Segment(Path path, int count, long size) {
-        this.path = path;
+    Segment(Path path, int count, List<MemorySegment> cache) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
         this.count = count;
-        this.size = size;
+        this.cache = checkNotNull(cache);
+    }
+
+    void setCache(List<MemorySegment> cache) {
+        this.cache = checkNotNull(cache);

Review Comment:
   nits: we typically don't explicitly check whether the input argument is null in such cases. Could you update the code for consistency and simplicity?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -167,26 +183,69 @@ public static DataCacheSnapshot recover(
             if (isDistributedFS) {
                 segments = deserializeSegments(dis);
             } else {
-                int totalRecords = dis.readInt();
-                long totalSize = dis.readLong();
-
-                Path path = pathGenerator.get();
-                try (FSDataOutputStream outputStream =
-                        fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE)) {
-
-                    BoundedInputStream inputStream =
-                            new BoundedInputStream(checkpointInputStream, totalSize);
-                    inputStream.setPropagateClose(false);
-                    IOUtils.copyBytes(inputStream, outputStream, false);
-                    inputStream.close();
+                int segmentNum = dis.readInt();
+                segments = new ArrayList<>(segmentNum);
+                for (int i = 0; i < segmentNum; i++) {
+                    int count = dis.readInt();
+                    long fsSize = dis.readLong();
+                    Path path = pathGenerator.get();
+                    try (FSDataOutputStream outputStream =

Review Comment:
   Prior to this PR, when we recover from a snapshot of multiple smaller segments, we might merge these segments into one segment. We no longer do this after this PR. Could you double check its performance impact with @gaoyunhaii?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing persisted records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /**
+     * The count of records in the file at the path if the file size is not zero, otherwise the
+     * count of records in the cache.
+     */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /**
+     * The total length of file containing persisted records. Its value is 0 iff the segment has not
+     * been written to the given path.
+     */
+    private long fsSize = 0L;
+
+    /**
+     * The memory segments containing cached records. This list is empty iff the segment has not
+     * been cached in memory.
+     */
+    private List<MemorySegment> cache = new ArrayList<>();
+
+    Segment(Path path, int count, long fsSize) {
+        this.path = checkNotNull(path);

Review Comment:
   nits: the code probably looks nicer if we assign all variables before checking their values.
   
   Same for the other constructors.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/ListStateWithCache.java:
##########
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.table.runtime.util.LazyMemorySegmentPool;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A {@link ListState} child class that records data and replays them on required.
+ *
+ * <p>This class basically stores data in file system, and provides the option to cache them in
+ * memory. In order to use the memory caching option, users need to allocate certain managed memory
+ * for the wrapper operator through {@link
+ * org.apache.flink.api.dag.Transformation#declareManagedMemoryUseCaseAtOperatorScope}.
+ *
+ * <p>NOTE: Users need to explicitly invoke this class's {@link
+ * #snapshotState(StateSnapshotContext)} method in order to store the recorded data in snapshot.
+ */
+public class ListStateWithCache<T> implements ListState<T> {
+
+    /** The tool to serialize/deserialize records. */
+    private final TypeSerializer<T> serializer;
+
+    /** The path of the directory that holds the files containing recorded data. */
+    private final Path basePath;
+
+    /** The data cache writer for the received records. */
+    private final DataCacheWriter<T> dataCacheWriter;
+
+    @SuppressWarnings("unchecked")
+    public ListStateWithCache(
+            TypeSerializer<T> serializer,
+            StreamTask<?, ?> containingTask,
+            StreamingRuntimeContext runtimeContext,
+            StateInitializationContext stateInitializationContext,
+            OperatorID operatorID)
+            throws IOException {
+        this.serializer = serializer;
+
+        MemorySegmentPool segmentPool = null;
+        double fraction =
+                containingTask
+                        .getConfiguration()
+                        .getManagedMemoryFractionOperatorUseCaseOfSlot(
+                                ManagedMemoryUseCase.OPERATOR,
+                                runtimeContext.getTaskManagerRuntimeInfo().getConfiguration(),
+                                runtimeContext.getUserCodeClassLoader());
+        if (fraction > 0) {
+            MemoryManager memoryManager = containingTask.getEnvironment().getMemoryManager();
+            segmentPool =
+                    new LazyMemorySegmentPool(
+                            containingTask,
+                            memoryManager,
+                            memoryManager.computeNumberOfPages(fraction));
+        }
+
+        basePath =
+                OperatorUtils.getDataCachePath(
+                        containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                        containingTask
+                                .getEnvironment()
+                                .getIOManager()
+                                .getSpillingDirectoriesPaths());
+
+        List<StatePartitionStreamProvider> inputs =
+                IteratorUtils.toList(
+                        stateInitializationContext.getRawOperatorStateInputs().iterator());
+        Preconditions.checkState(
+                inputs.size() < 2, "The input from raw operator state should be one or zero.");
+
+        List<Segment> priorFinishedSegments = new ArrayList<>();
+        if (inputs.size() > 0) {
+            DataCacheSnapshot dataCacheSnapshot =
+                    DataCacheSnapshot.recover(
+                            inputs.get(0).getStream(),
+                            basePath.getFileSystem(),
+                            OperatorUtils.createDataCacheFileGenerator(
+                                    basePath, "cache", operatorID));
+
+            if (segmentPool != null) {
+                dataCacheSnapshot.tryReadSegmentsToMemory(serializer, segmentPool);
+            }
+
+            priorFinishedSegments = dataCacheSnapshot.getSegments();
+        }
+
+        this.dataCacheWriter =
+                new DataCacheWriter<>(
+                        serializer,
+                        basePath.getFileSystem(),
+                        OperatorUtils.createDataCacheFileGenerator(basePath, "cache", operatorID),
+                        segmentPool,
+                        priorFinishedSegments);
+    }
+
+    public void snapshotState(StateSnapshotContext context) throws Exception {

Review Comment:
   Since snapshot() and add() re-use the same serializer and the serializer is not thread safe, could you double check that snapshot() and add() won't be invoked concurrently?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -253,59 +258,76 @@ public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value
             implements TwoInputStreamOperator<
                             DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
                     IterationListener<Tuple2<Integer, DenseVector>> {
+
         private final DistanceMeasure distanceMeasure;
-        private ListState<DenseVector> points;
-        private ListState<DenseVector[]> centroids;
+
+        private ListState<DenseVector[]> centroidsState;

Review Comment:
   nits: could we keep the original name `centroids` for simplicity? Same for points.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -160,6 +162,9 @@ public IterationBodyResult process(
                                             BasicTypeInfo.INT_TYPE_INFO,
                                             DenseVectorTypeInfo.INSTANCE),
                                     new SelectNearestCentroidOperator(distanceMeasure));
+            centroidIdAndPoints
+                    .getTransformation()
+                    .declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase.OPERATOR, 64);

Review Comment:
   Instead of explicitly calling this API, could this be specified in Flink job configuration according to this documentation https://nightlies.apache.org/flink/flink-docs-master/docs/deployment/memory/mem_setup_tm/#consumer-weights ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890707072


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,73 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing persisted records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /**
+     * The count of records in the file at the path if the file size is not zero, otherwise the
+     * count of records in the cache.
+     */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing persisted records. */
+    private long fsSize = -1L;
+
+    /** The memory segments containing cached records. */

Review Comment:
   Now the actual behavior is that the cache list is null rather than empty when not cached. I think it better satisfies Java's default values. How do you think we should treat null and empty list accordingly?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890665195


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,162 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentSegmentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
 
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);

Review Comment:
   nits: would it be simpler to use `this.finishedSegments = new ArrayList<>(priorFinishedSegments)`, similar to the approach before this PR?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentReader.java:
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.List;
+
+/** A class that reads data cached in memory. */
+@Internal
+class MemorySegmentReader<T> implements SegmentReader<T> {
+
+    /** The tool to deserialize bytes into records. */
+    private final TypeSerializer<T> serializer;
+
+    /** The wrapper view of the input stream of memory segments to be used in TypeSerializer API. */
+    private final DataInputView inputView;
+
+    /** The total number of records contained in the segments. */
+    private final int totalCount;
+
+    /** The number of records that have been read so far. */
+    private int count;
+
+    MemorySegmentReader(TypeSerializer<T> serializer, Segment segment, int startOffset)
+            throws IOException {
+        ManagedMemoryInputStream inputStream = new ManagedMemoryInputStream(segment.getCache());
+        this.inputView = new DataInputViewStreamWrapper(inputStream);
+        this.serializer = serializer;
+        this.totalCount = segment.getCount();
+        this.count = 0;
+
+        for (int ignored = 0; ignored < startOffset; ignored++) {

Review Comment:
   nits: could we replace `ignored` with `i` for consistency with similar code in the `FileSegmentReader` constructor?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FileSegmentWriter.java:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** A class that writes cache data to a target file in given file system. */
+@Internal
+class FileSegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The path to the target file. */
+    private final Path path;
+
+    /** The output stream that writes to the target file. */
+    private final FSDataOutputStream outputStream;
+
+    /** A buffer that wraps the output stream to optimize performance. */
+    private final BufferedOutputStream bufferedOutputStream;

Review Comment:
   Could we remove this class variable, similar to how we handle `bufferedInputStream` in `FileSegmentReader`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentReader.java:
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.List;
+
+/** A class that reads data cached in memory. */
+@Internal
+class MemorySegmentReader<T> implements SegmentReader<T> {
+
+    /** The tool to deserialize bytes into records. */
+    private final TypeSerializer<T> serializer;
+
+    /** The wrapper view of the input stream of memory segments to be used in TypeSerializer API. */
+    private final DataInputView inputView;
+
+    /** The total number of records contained in the segments. */
+    private final int totalCount;
+
+    /** The number of records that have been read so far. */
+    private int count;
+
+    MemorySegmentReader(TypeSerializer<T> serializer, Segment segment, int startOffset)
+            throws IOException {
+        ManagedMemoryInputStream inputStream = new ManagedMemoryInputStream(segment.getCache());
+        this.inputView = new DataInputViewStreamWrapper(inputStream);
+        this.serializer = serializer;
+        this.totalCount = segment.getCount();
+        this.count = 0;
+
+        for (int ignored = 0; ignored < startOffset; ignored++) {
+            next();
+        }
+    }
+
+    @Override
+    public boolean hasNext() {
+        return count < totalCount;
+    }
+
+    @Override
+    public T next() throws IOException {
+        T ret = serializer.deserialize(inputView);

Review Comment:
   nits: could we replace `ret` with `value` for consistency with `FileSegmentReader::next()`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,162 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentSegmentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
 
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentSegmentWriter = createSegmentWriter();
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public void addRecord(T record) throws IOException {
+        if (!currentSegmentWriter.addRecord(record)) {
+            currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+            currentSegmentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentSegmentWriter.addRecord(record));
+        }
     }
 
+    /** Finishes adding records and closes resources occupied for adding records. */
     public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
-    }
+        if (currentSegmentWriter == null) {
+            return finishedSegments;
+        }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
+        currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+        currentSegmentWriter = null;
+        return finishedSegments;
     }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+    /**
+     * Flushes all added records to segments and returns a list of segments containing all cached
+     * records.
+     */
+    public List<Segment> getSegments() throws IOException {
+        finishCurrentSegmentIfAny();
+        return finishedSegments;
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    private void finishCurrentSegmentIfAny() throws IOException {

Review Comment:
   nits: it is not very clear what `IfAny` means. And it is very rare to use `IfAny` in method names. How about renaming it as either `finishCurrentSegment()` or `finishCurrentSegmentIfExists()`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FileSegmentWriter.java:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** A class that writes cache data to a target file in given file system. */
+@Internal
+class FileSegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The path to the target file. */
+    private final Path path;
+
+    /** The output stream that writes to the target file. */
+    private final FSDataOutputStream outputStream;
+
+    /** A buffer that wraps the output stream to optimize performance. */
+    private final BufferedOutputStream bufferedOutputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    FileSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.outputStream = path.getFileSystem().create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.bufferedOutputStream = new BufferedOutputStream(outputStream);
+        this.outputView = new DataOutputViewStreamWrapper(bufferedOutputStream);
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        serializer.serialize(record, outputView);
+        count++;
+        return true;
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        bufferedOutputStream.flush();

Review Comment:
   This line can be removed because `outputStream.flush()` will recursively call `bufferedOutputStream.flush()`.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The pre-allocated path to hold cached records into file system. */
+    private final Path path;
+
+    /** The pool to allocate memory segments from. */
+    private final MemorySegmentPool segmentPool;
+
+    /** The output stream to write serialized content to memory segments. */
+    private final ManagedMemoryOutputStream outputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer,
+            Path path,
+            MemorySegmentPool segmentPool,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.path = path;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            if (e.getCause() instanceof MemoryAllocationException) {
+                return false;
+            }
+            throw e;
+        }
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(new Segment(path, count, outputStream.getSegments()));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    /** An output stream subclass that accepts bytes and writes them to memory segments. */
+    private static class ManagedMemoryOutputStream extends OutputStream {
+
+        /** The pool to allocate memory segments from. */
+        private final MemorySegmentPool segmentPool;
+
+        /** The number of bytes in a memory segment. */
+        private final int pageSize;
+
+        /** The memory segments containing written bytes. */
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        /** The index of the segment that currently accepts written bytes. */
+        private int segmentIndex;
+
+        /** The number of bytes in the current segment that have been written. */
+        private int segmentOffset;
+
+        /** The number of bytes that have been written so far. */
+        private long globalOffset;
+
+        /** The number of bytes that have been allocated so far. */
+        private long allocatedBytes;
+
+        public ManagedMemoryOutputStream(MemorySegmentPool segmentPool, long expectedSize)
+                throws MemoryAllocationException {
+            this.segmentPool = segmentPool;
+            this.pageSize = segmentPool.pageSize();
+
+            Preconditions.checkArgument(expectedSize >= 0);

Review Comment:
   nits: It seems unnecessary to check this and use Math.max(..) below. Since it does not cause any harm to have expectedSize < 0, would it be simpler to remove this line?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The pre-allocated path to hold cached records into file system. */
+    private final Path path;
+
+    /** The pool to allocate memory segments from. */
+    private final MemorySegmentPool segmentPool;
+
+    /** The output stream to write serialized content to memory segments. */
+    private final ManagedMemoryOutputStream outputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer,
+            Path path,
+            MemorySegmentPool segmentPool,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.path = path;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            if (e.getCause() instanceof MemoryAllocationException) {
+                return false;
+            }
+            throw e;
+        }
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(new Segment(path, count, outputStream.getSegments()));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    /** An output stream subclass that accepts bytes and writes them to memory segments. */
+    private static class ManagedMemoryOutputStream extends OutputStream {
+
+        /** The pool to allocate memory segments from. */
+        private final MemorySegmentPool segmentPool;
+
+        /** The number of bytes in a memory segment. */
+        private final int pageSize;
+
+        /** The memory segments containing written bytes. */
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        /** The index of the segment that currently accepts written bytes. */
+        private int segmentIndex;
+
+        /** The number of bytes in the current segment that have been written. */
+        private int segmentOffset;
+
+        /** The number of bytes that have been written so far. */
+        private long globalOffset;
+
+        /** The number of bytes that have been allocated so far. */
+        private long allocatedBytes;
+
+        public ManagedMemoryOutputStream(MemorySegmentPool segmentPool, long expectedSize)
+                throws MemoryAllocationException {
+            this.segmentPool = segmentPool;
+            this.pageSize = segmentPool.pageSize();
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            ensureCapacity(Math.max(expectedSize, 1L));
+        }
+
+        public long getPos() {
+            return globalOffset;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(@Nullable byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+
+            while (len > 0) {

Review Comment:
   It seems simpler to use the following code, which does not need to explicitly call `break`.
   
   ```
   while (len > 0) {
       int currentLen = Math.min(len, pageSize - segmentOffset);
       segments.get(segmentIndex).put(segmentOffset, b, off, currentLen);
       segmentOffset += currentLen;
       globalOffset += currentLen;
       if (segmentOffset >= pageSize) {
           segmentIndex++;
           segmentOffset = 0;
       }
       off += currentLen;
       len -= currentLen;
   }
   ```



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentReader.java:
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.List;
+
+/** A class that reads data cached in memory. */
+@Internal
+class MemorySegmentReader<T> implements SegmentReader<T> {
+
+    /** The tool to deserialize bytes into records. */
+    private final TypeSerializer<T> serializer;
+
+    /** The wrapper view of the input stream of memory segments to be used in TypeSerializer API. */
+    private final DataInputView inputView;
+
+    /** The total number of records contained in the segments. */
+    private final int totalCount;
+
+    /** The number of records that have been read so far. */
+    private int count;
+
+    MemorySegmentReader(TypeSerializer<T> serializer, Segment segment, int startOffset)
+            throws IOException {
+        ManagedMemoryInputStream inputStream = new ManagedMemoryInputStream(segment.getCache());
+        this.inputView = new DataInputViewStreamWrapper(inputStream);
+        this.serializer = serializer;
+        this.totalCount = segment.getCount();
+        this.count = 0;
+
+        for (int ignored = 0; ignored < startOffset; ignored++) {
+            next();
+        }
+    }
+
+    @Override
+    public boolean hasNext() {
+        return count < totalCount;
+    }
+
+    @Override
+    public T next() throws IOException {
+        T ret = serializer.deserialize(inputView);
+        count++;
+        return ret;
+    }
+
+    @Override
+    public void close() {}
+
+    /** An input stream subclass that reads bytes from memory segments. */
+    private static class ManagedMemoryInputStream extends InputStream {
+
+        /** The memory segments to read bytes from. */
+        private final List<MemorySegment> segments;
+
+        /** The index of the segment that is currently being read. */
+        private int segmentIndex;
+
+        /** The number of bytes that have been read from current segment so far. */
+        private int segmentOffset;
+
+        public ManagedMemoryInputStream(List<MemorySegment> segments) {
+            this.segments = segments;
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+        }
+
+        @Override
+        public int read() throws IOException {
+            int ret = segments.get(segmentIndex).get(segmentOffset) & 0xff;
+            segmentOffset += 1;
+            if (segmentOffset >= segments.get(segmentIndex).size()) {
+                segmentIndex++;
+                segmentOffset = 0;
+            }
+            return ret;
+        }
+
+        @Override
+        public int read(byte[] b, int off, int len) throws IOException {
+            int readLen = 0;
+
+            while (len > 0 && segmentIndex < segments.size()) {

Review Comment:
   nits: the following code is probably simpler since it does not need to explicitly write `break`.
   
   ```
   while (len > 0) {
       int currentLen = Math.min(segments.get(segmentIndex).size() - segmentOffset, len);
       segments.get(segmentIndex).get(segmentOffset, b, off, currentLen);
       segmentOffset += currentLen;
       if (segmentOffset >= segments.get(segmentIndex).size()) {
           segmentIndex++;
           segmentOffset = 0;
       }
   
       readLen += currentLen;
       off += currentLen;
       len -= currentLen;
   }
   ```



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -167,26 +183,69 @@ public static DataCacheSnapshot recover(
             if (isDistributedFS) {
                 segments = deserializeSegments(dis);
             } else {
-                int totalRecords = dis.readInt();
-                long totalSize = dis.readLong();
-
-                Path path = pathGenerator.get();
-                try (FSDataOutputStream outputStream =
-                        fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE)) {
-
-                    BoundedInputStream inputStream =
-                            new BoundedInputStream(checkpointInputStream, totalSize);
-                    inputStream.setPropagateClose(false);
-                    IOUtils.copyBytes(inputStream, outputStream, false);
-                    inputStream.close();
+                int segmentNum = dis.readInt();
+                segments = new ArrayList<>(segmentNum);
+                for (int i = 0; i < segmentNum; i++) {
+                    int count = dis.readInt();
+                    long fsSize = dis.readLong();
+                    Path path = pathGenerator.get();
+                    try (FSDataOutputStream outputStream =
+                            fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE)) {
+
+                        BoundedInputStream boundedInputStream =
+                                new BoundedInputStream(checkpointInputStream, fsSize);
+                        boundedInputStream.setPropagateClose(false);
+                        IOUtils.copyBytes(boundedInputStream, outputStream, false);
+                        boundedInputStream.close();
+                    }
+                    segments.add(new Segment(path, count, fsSize));
                 }
-                segments = Collections.singletonList(new Segment(path, totalRecords, totalSize));
             }
 
             return new DataCacheSnapshot(fileSystem, readerPosition, segments);
         }
     }
 
+    /**
+     * Makes an attempt to cache the segments in memory.
+     *
+     * <p>The attempt is made at segment granularity, which means there might be only part of the
+     * segments are cached.
+     *
+     * <p>This method does not throw exception if there is not enough memory space for caching a
+     * segment.
+     */
+    public <T> void tryReadSegmentsToMemory(
+            TypeSerializer<T> serializer, MemorySegmentPool segmentPool) throws IOException {
+        boolean cacheSuccess;
+        for (Segment segment : segments) {
+            if (segment.getCache() != null) {
+                continue;
+            }
+
+            SegmentReader<T> reader = new FileSegmentReader<>(serializer, segment, 0);
+            SegmentWriter<T> writer;
+            try {
+                writer =
+                        new MemorySegmentWriter<>(
+                                serializer, segment.getPath(), segmentPool, segment.getFsSize());
+            } catch (MemoryAllocationException e) {
+                continue;
+            }
+
+            cacheSuccess = true;
+            while (cacheSuccess && reader.hasNext()) {
+                if (!writer.addRecord(reader.next())) {
+                    writer.finish().ifPresent(x -> segmentPool.returnAll(x.getCache()));
+                    cacheSuccess = false;
+                }
+            }
+            if (cacheSuccess) {
+                writer.finish().ifPresent(x -> segment.setCache(x.getCache()));

Review Comment:
   If cacheSuccess == true, then `writer.finish()` must return a non-empty segment. Otherwise there is bug.
   
   How about we just do `segment.setCache(writer.finish().get().getCache())` here to simplify the code and detect bug?
   



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,162 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentSegmentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
 
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentSegmentWriter = createSegmentWriter();
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public void addRecord(T record) throws IOException {
+        if (!currentSegmentWriter.addRecord(record)) {
+            currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+            currentSegmentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentSegmentWriter.addRecord(record));
+        }
     }
 
+    /** Finishes adding records and closes resources occupied for adding records. */
     public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
-    }
+        if (currentSegmentWriter == null) {
+            return finishedSegments;
+        }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
+        currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+        currentSegmentWriter = null;
+        return finishedSegments;
     }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+    /**
+     * Flushes all added records to segments and returns a list of segments containing all cached
+     * records.
+     */
+    public List<Segment> getSegments() throws IOException {
+        finishCurrentSegmentIfAny();
+        return finishedSegments;
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    private void finishCurrentSegmentIfAny() throws IOException {
+        if (currentSegmentWriter == null) {
+            return;
         }
 
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
+        currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+        currentSegmentWriter = createSegmentWriter();
     }
 
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
-
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFsSize() > 0) {
+                fileSystem.delete(segment.getPath(), false);
+            }
+            if (segment.getCache() != null) {
+                segmentPool.returnAll(segment.getCache());
+            }
         }
+        finishedSegments.clear();
+    }
 
-        public void addRecord(T record) throws IOException {
-            serializer.serialize(record, outputView);
-            currentSegmentCount += 1;
-        }
+    /** Write the segments in this writer to files on disk. */
+    public void writeSegmentsToFiles() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFsSize() > 0) {
+                continue;
+            }
 
-        public Optional<Segment> finish() throws IOException {
-            this.outputStream.flush();
-            long size = outputStream.getPos();
-            this.outputStream.close();
-
-            if (currentSegmentCount > 0) {
-                return Optional.of(new Segment(path, currentSegmentCount, size));
-            } else {
-                // If there are no records, we tend to directly delete this file
-                fileSystem.delete(path, false);
-                return Optional.empty();
+            SegmentReader<T> reader = new MemorySegmentReader<>(serializer, segment, 0);
+            SegmentWriter<T> writer = new FileSegmentWriter<>(serializer, segment.getPath());
+            while (reader.hasNext()) {
+                writer.addRecord(reader.next());
             }
+            writer.finish().ifPresent(x -> segment.setFsSize(x.getFsSize()));

Review Comment:
   The writer.finish() is guaranteed to return a non-empty segment here, right?
   
   How about we do `segment.setFsSize(writer.finish().get().getFsSize())` here to simplify the code and detect bug?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The pre-allocated path to hold cached records into file system. */
+    private final Path path;
+
+    /** The pool to allocate memory segments from. */
+    private final MemorySegmentPool segmentPool;
+
+    /** The output stream to write serialized content to memory segments. */
+    private final ManagedMemoryOutputStream outputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer,
+            Path path,
+            MemorySegmentPool segmentPool,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.path = path;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            if (e.getCause() instanceof MemoryAllocationException) {
+                return false;
+            }
+            throw e;
+        }
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(new Segment(path, count, outputStream.getSegments()));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    /** An output stream subclass that accepts bytes and writes them to memory segments. */
+    private static class ManagedMemoryOutputStream extends OutputStream {
+
+        /** The pool to allocate memory segments from. */
+        private final MemorySegmentPool segmentPool;
+
+        /** The number of bytes in a memory segment. */
+        private final int pageSize;
+
+        /** The memory segments containing written bytes. */
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        /** The index of the segment that currently accepts written bytes. */
+        private int segmentIndex;
+
+        /** The number of bytes in the current segment that have been written. */
+        private int segmentOffset;
+
+        /** The number of bytes that have been written so far. */
+        private long globalOffset;
+
+        /** The number of bytes that have been allocated so far. */
+        private long allocatedBytes;
+
+        public ManagedMemoryOutputStream(MemorySegmentPool segmentPool, long expectedSize)
+                throws MemoryAllocationException {
+            this.segmentPool = segmentPool;
+            this.pageSize = segmentPool.pageSize();
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            ensureCapacity(Math.max(expectedSize, 1L));
+        }
+
+        public long getPos() {
+            return globalOffset;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(@Nullable byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);

Review Comment:
   Would it be simpler to make `MemoryAllocationException` subclass of RuntimeException and throw/catch `MemoryAllocationException` directly, instead of wrapping it inside `IOException`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891126951


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
+            countState = context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws Exception {
+            T sample = streamRecord.getValue();
+            count++;
+
+            // Code below is inspired by the Reservoir Sampling algorithm.
+            if (samples.size() < numSamples) {
+                samples.add(sample);
+            } else {
+                if (random.nextInt(count) < numSamples) {
+                    samples.set(random.nextInt(numSamples), sample);
+                }
+            }
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            Collections.shuffle(samples, random);

Review Comment:
   > the first arriving element, if sampled, will always be the first returning element
   
   What is the problem of situation? I think it is ok if the sampled element preserves the order in each worker. If we look at spark#sample [1], the order in each partition also preserves.
   
   [1] https://github.com/apache/spark/blob/6026dd25748fd79caeedc083f99d5c954fb3a19f/core/src/main/scala/org/apache/spark/rdd/RDD.scala#L554



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889790107


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -94,6 +120,26 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
         }
     }
 
+    /**
+     * Takes a randomly sampled subset of elements in a bounded data stream.
+     *
+     * <p>If the number of elements in the stream is smaller than expected number of samples, all
+     * elements will be included in the sample.
+     *
+     * @param input The input data stream.
+     * @param numSamples The number of elements to be sampled.
+     * @param randomSeed The seed to randomly pick elements as sample.
+     * @return A data stream containing a list of the sampled elements.
+     */
+    public static <T> DataStream<List<T>> sample(
+            DataStream<T> input, int numSamples, long randomSeed) {
+        return input.transform(
+                        "samplingOperator",
+                        Types.LIST(input.getType()),
+                        new SamplingOperator<>(numSamples, randomSeed))
+                .setParallelism(1);

Review Comment:
   According to offline discussion, I agree with it that we should not change parallelism to make it more generic. I'll make the change.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891839965


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -160,6 +162,9 @@ public IterationBodyResult process(
                                             BasicTypeInfo.INT_TYPE_INFO,
                                             DenseVectorTypeInfo.INSTANCE),
                                     new SelectNearestCentroidOperator(distanceMeasure));
+            centroidIdAndPoints
+                    .getTransformation()
+                    .declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase.OPERATOR, 64);

Review Comment:
   According to offline discussion, we would create a utility method similar to `ExecNodeUtil.setManagedMemoryWeight` in order to control the usage of `declareManagedMemoryUseCaseAtOperatorScope` in Flink ML. The weight values should reference existing values in Flink, like `StreamExecWindowAggregate.WINDOW_AGG_MEMORY_RATIO`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r876662781


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemoryUtils.java:
##########
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.memory.MemoryManager;
+
+/** Utility variables and methods for memory operation. */
+@Internal
+class MemoryUtils {
+    // Cache is not suggested if over 80% of memory has been occupied.
+    private static final double CACHE_MEMORY_THRESHOLD = 0.2;

Review Comment:
   Spark uses `spark.memory.storageFraction` to configure the proportion of memory used for storage, i.e., caching rdd results. I'll temporarily change this magic number's value to 0.5, the default value of `spark.memory.storageFraction`, to align with Spark's practice. Later we may consider make these magic numbers as Flink ML configuration parameters.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r879052071


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,122 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-
-import javax.annotation.Nullable;
+import org.apache.flink.runtime.memory.MemoryManager;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
-    private final TypeSerializer<T> serializer;
+    private final MemoryManager memoryManager;
 
-    private final FileSystem fileSystem;
+    private final TypeSerializer<T> serializer;
 
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    private SegmentReader<T> currentReader;
+
+    private SegmentWriter<T> cacheWriter;
+
+    private int segmentIndex;
 
     public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+            TypeSerializer<T> serializer, MemoryManager memoryManager, List<Segment> segments) {
+        this(serializer, memoryManager, segments, new Tuple2<>(0, 0));
     }
 
     public DataCacheReader(
             TypeSerializer<T> serializer,
-            FileSystem fileSystem,
+            MemoryManager memoryManager,
             List<Segment> segments,
-            Tuple2<Integer, Integer> readerPosition)
-            throws IOException {
-
+            Tuple2<Integer, Integer> readerPosition) {
+        this.memoryManager = memoryManager;
         this.serializer = serializer;
-        this.fileSystem = fileSystem;
         this.segments = segments;
+        this.segmentIndex = readerPosition.f0;
+
+        createSegmentReaderAndCache(readerPosition.f0, readerPosition.f1);
+    }
+
+    private void createSegmentReaderAndCache(int index, int startOffset) {
+        try {
+            cacheWriter = null;
 
-        if (readerPosition.f0 < segments.size()) {
-            this.currentSegmentReader = new SegmentReader(readerPosition.f0, readerPosition.f1);
+            if (index >= segments.size()) {
+                currentReader = null;
+                return;
+            }
+
+            currentReader = SegmentReader.create(serializer, segments.get(index), startOffset);
+
+            boolean shouldCacheInMemory =
+                    startOffset == 0
+                            && currentReader instanceof FsSegmentReader
+                            && MemoryUtils.isMemoryEnoughForCache(memoryManager);
+
+            if (shouldCacheInMemory) {
+                cacheWriter =
+                        SegmentWriter.create(
+                                segments.get(index).getPath(),
+                                memoryManager,
+                                serializer,
+                                segments.get(index).getFsSize(),
+                                true,
+                                false);
+            }
+
+        } catch (IOException e) {
+            throw new RuntimeException(e);
         }
     }
 
     @Override
     public boolean hasNext() {
-        return currentSegmentReader != null && currentSegmentReader.hasNext();
+        return currentReader != null && currentReader.hasNext();
     }
 
     @Override
     public T next() {
         try {
-            T next = currentSegmentReader.next();
-
-            if (!currentSegmentReader.hasNext()) {
-                currentSegmentReader.close();
-                if (currentSegmentReader.index < segments.size() - 1) {
-                    currentSegmentReader = new SegmentReader(currentSegmentReader.index + 1, 0);
-                } else {
-                    currentSegmentReader = null;
+            T record = currentReader.next();
+
+            if (cacheWriter != null) {
+                if (!cacheWriter.addRecord(record)) {

Review Comment:
   In cases when the flink job is restored from a snapshot, which means previous caches created by the segment writer in `DataCacheWriter` is non-existent, this cache writer helps to re-create the caches during reading process.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r888749310


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,104 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
-/** Records the data received and replayed them on required. */
+/** Records the data received and replays them on required. */
 public class DataCacheWriter<T> {
 
-    private final TypeSerializer<T> serializer;
-
     private final FileSystem fileSystem;
 
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final TypeSerializer<T> serializer;
+
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
-            SupplierWithException<Path, IOException> pathGenerator)
+            SupplierWithException<Path, IOException> pathGenerator,
+            LimitedSizeMemoryManager memoryManager)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, memoryManager, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
+            LimitedSizeMemoryManager memoryManager,
             List<Segment> priorFinishedSegments)
             throws IOException {
         this.serializer = serializer;
         this.fileSystem = fileSystem;
         this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this.memoryManager = memoryManager;
+        this.finishedSegments = new ArrayList<>(priorFinishedSegments);
+        this.currentWriter =
+                SegmentWriter.create(
+                        pathGenerator.get(), this.memoryManager, serializer, 0L, true, true);
     }
 
     public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+        boolean success = currentWriter.addRecord(record);

Review Comment:
   I believe such case is possible. For example, we have written size of a vector to the last bytes of a memory segment, and failed to continue writing the values of the vector because no more segments are available for this operator in memory manager. In this case the code would re-create an `FsSegmentWriter` and re-write the size of the vector to the file. 
   
   Such case exists, but it will do no harm to the program, because the idling `size` value in the segment is not tracked so would not be accessed. It does not waste memory space either since we are allocating space at segments' granularity. It will be released afterwards along with those valid values in the segment.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889646062


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final List<Segment> finishedSegments;
+
+    private SegmentWriter<T> currentWriter;
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        for (Segment segment : finishedSegments) {
+            tryCacheSegmentToMemory(segment);
+        }
+        this.currentWriter = createSegmentWriter();
+    }
+
+    public void addRecord(T record) throws IOException {
+        try {
+            currentWriter.addRecord(record);
+        } catch (SegmentNoVacancyException e) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            currentWriter.addRecord(record);
+        }
+    }
+
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public void finish() throws IOException {
+        if (currentWriter == null) {
+            return;
+        }
+
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+    }
+
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFileSegment() != null) {
+                fileSystem.delete(segment.getFileSegment().getPath(), false);
+            }
+            if (segment.getMemorySegment() != null) {
+                segmentPool.returnAll(segment.getMemorySegment().getCache());
+            }
+        }
+        finishedSegments.clear();
+    }
+
+    private void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {

Review Comment:
   Will `finishCurrentSegmentIfAny()` ever be called after `finish()` has been called? If not, would it be simpler to skip checking `currentWriter == null` here since it should never happen?
   
   We can add `assert(currentWriter != null)` if needed.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheIterator.java:
##########
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+
+/** Reads the cached data from a list of segments. */
+@Internal
+public class DataCacheIterator<T> implements Iterator<T> {

Review Comment:
   It appears that `DataCacheIterator` owns the APIs to read data cache segments. And `DataCache` owns the APIs to write data cache segments, as well as APIs to recover and reply data cache snapshots.
   
   Would it be more consistent with other class names (e.g. SegmentReader, SegmentWrite) to rename DataCacheIterator/DataCache to DataCacheReader/DataCacheWriter respectively?
   
   Is it possible to still decouple `DataCacheWriter` from `DataCacheSnapshot`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final List<Segment> finishedSegments;
+
+    private SegmentWriter<T> currentWriter;
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        for (Segment segment : finishedSegments) {
+            tryCacheSegmentToMemory(segment);
+        }
+        this.currentWriter = createSegmentWriter();
+    }
+
+    public void addRecord(T record) throws IOException {
+        try {
+            currentWriter.addRecord(record);
+        } catch (SegmentNoVacancyException e) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            currentWriter.addRecord(record);
+        }
+    }
+
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public void finish() throws IOException {
+        if (currentWriter == null) {

Review Comment:
   Would `finish()` ever be called twice on the same `DataCache` instance? If not, would it be simpler to skip this check?
   
   We can add `assert(currentWriter != null)` if needed.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    private final TypeSerializer<T> serializer;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer, MemorySegmentPool segmentPool, long expectedSize)
+            throws SegmentNoVacancyException {
+        this.serializer = serializer;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public void addRecord(T record) throws IOException {
+        serializer.serialize(record, outputView);
+        count++;
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            new org.apache.flink.iteration.datacache.nonkeyed.MemorySegment(
+                                    outputStream.getSegments(), count)));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final MemorySegmentPool segmentPool;
+
+        private final int pageSize;
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private long globalOffset;

Review Comment:
   Can you add Java doc for these private variables, similar to what we did in `FileSegment.java`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final List<Segment> finishedSegments;
+
+    private SegmentWriter<T> currentWriter;
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        for (Segment segment : finishedSegments) {
+            tryCacheSegmentToMemory(segment);
+        }
+        this.currentWriter = createSegmentWriter();
+    }
+
+    public void addRecord(T record) throws IOException {
+        try {
+            currentWriter.addRecord(record);
+        } catch (SegmentNoVacancyException e) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            currentWriter.addRecord(record);
+        }
+    }
+
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public void finish() throws IOException {
+        if (currentWriter == null) {
+            return;
+        }
+
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+    }
+
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFileSegment() != null) {
+                fileSystem.delete(segment.getFileSegment().getPath(), false);
+            }
+            if (segment.getMemorySegment() != null) {
+                segmentPool.returnAll(segment.getMemorySegment().getCache());
+            }
+        }
+        finishedSegments.clear();
+    }
+
+    private void finishCurrentSegmentIfAny() throws IOException {

Review Comment:
   Is there any use-case where we need to write to a DataCache instance after having read from it?
   
   If not, would it be simpler to just call `finish()` once before reading from the DataCache?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    private final TypeSerializer<T> serializer;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer, MemorySegmentPool segmentPool, long expectedSize)
+            throws SegmentNoVacancyException {
+        this.serializer = serializer;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public void addRecord(T record) throws IOException {
+        serializer.serialize(record, outputView);
+        count++;
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            new org.apache.flink.iteration.datacache.nonkeyed.MemorySegment(
+                                    outputStream.getSegments(), count)));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final MemorySegmentPool segmentPool;
+
+        private final int pageSize;
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private long globalOffset;
+
+        public ManagedMemoryOutputStream(MemorySegmentPool segmentPool, long expectedSize)
+                throws SegmentNoVacancyException {
+            this.segmentPool = segmentPool;
+            this.pageSize = segmentPool.pageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            ensureCapacity(Math.max(expectedSize, 1L));
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(@Nullable byte[] b, int off, int len) throws IOException {
+            ensureCapacity(globalOffset + len);
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(long capacity) throws SegmentNoVacancyException {
+            Preconditions.checkArgument(capacity > 0);
+            int required =
+                    (int) (capacity % pageSize == 0 ? capacity / pageSize : capacity / pageSize + 1)
+                            - segments.size();
+
+            List<MemorySegment> allocatedSegments = new ArrayList<>();
+            for (int i = 0; i < required; i++) {
+                MemorySegment memorySegment = segmentPool.nextSegment();
+                if (memorySegment == null) {
+                    segmentPool.returnAll(allocatedSegments);
+                    throw new SegmentNoVacancyException(new MemoryAllocationException());

Review Comment:
   In general we don't consider it IO failure if there is not enough memory in the memory pool. Thus MemoryAllocationException is not a subclass of IOException.
   
   It seems better to avoid converting MemoryAllocationException into an IOException.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;

Review Comment:
   It is probably more readable to specify @Nullable here.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,61 +18,37 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
-import org.apache.flink.core.fs.Path;
+import org.apache.flink.annotation.Internal;
 
-import java.io.Serializable;
-import java.util.Objects;
+/** A segment contains the information about a cache unit. */
+@Internal
+class Segment {
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+    private FileSegment fileSegment;
 
-    private final Path path;
+    private MemorySegment memorySegment;
 
-    /** The count of the records in the file. */
-    private final int count;
-
-    /** The total length of file. */
-    private final long size;
-
-    public Segment(Path path, int count, long size) {
-        this.path = path;
-        this.count = count;
-        this.size = size;
-    }
-
-    public Path getPath() {
-        return path;
+    Segment(FileSegment fileSegment) {
+        this.fileSegment = fileSegment;
     }
 
-    public int getCount() {
-        return count;
+    Segment(MemorySegment memorySegment) {
+        this.memorySegment = memorySegment;
     }
 
-    public long getSize() {
-        return size;
+    void setFileSegment(FileSegment fileSegment) {
+        this.fileSegment = fileSegment;
     }
 
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) {
-            return true;
-        }
-
-        if (!(o instanceof Segment)) {
-            return false;
-        }
-
-        Segment segment = (Segment) o;
-        return count == segment.count && size == segment.size && Objects.equals(path, segment.path);
+    FileSegment getFileSegment() {
+        return fileSegment;
     }
 
-    @Override
-    public int hashCode() {
-        return Objects.hash(path, count, size);
+    void setMemorySegment(MemorySegment memorySegment) {

Review Comment:
   Since this method is never called as of now, would it be simpler to remove this method and declare `memorySegment` as final?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FileSegmentWriter.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+class FileSegmentWriter<T> implements SegmentWriter<T> {
+
+    private static final long FILE_SIZE_LIMIT = 1L << 30; // 1GB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FileSystem fileSystem;
+
+    private final FSDataOutputStream outputStream;
+
+    private final BufferedOutputStream bufferedOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    FileSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.bufferedOutputStream = new BufferedOutputStream(outputStream);
+        this.outputView = new DataOutputViewStreamWrapper(bufferedOutputStream);
+    }
+
+    @Override
+    public void addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= FILE_SIZE_LIMIT) {
+            throw new SegmentNoVacancyException();

Review Comment:
   Exception (including IOException) usually indicates that "something is wrong and we need to either recover from it or fail fast". 
   
   Since we expect segment to have limited size, it is probably better not to use an exception to indicate this. How about we have `addRecord(...)` return a boolean, which is false if the record write failed due to size limit?
   
   Then we could also reduce the number of classes by removing `SegmentNoVacancyException`.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,61 +18,37 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
-import org.apache.flink.core.fs.Path;
+import org.apache.flink.annotation.Internal;
 
-import java.io.Serializable;
-import java.util.Objects;
+/** A segment contains the information about a cache unit. */
+@Internal
+class Segment {
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+    private FileSegment fileSegment;

Review Comment:
   Since every segment will eventually be persisted to disk, would it be simpler to declare this variable as final and instantiate it before/when constructing this Segment instance?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,61 +18,37 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
-import org.apache.flink.core.fs.Path;
+import org.apache.flink.annotation.Internal;
 
-import java.io.Serializable;
-import java.util.Objects;
+/** A segment contains the information about a cache unit. */
+@Internal
+class Segment {
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+    private FileSegment fileSegment;
 
-    private final Path path;
+    private MemorySegment memorySegment;

Review Comment:
   Instead of creating a new class `MemorySegment` whose name collide with a class in `org.apache.flink.core.memory.MemorySegment`, would it be simpler to just put `List<org.apache.flink.core.memory.MemorySegment>` inside `Segment`?
   
   The `count` in `MemorySegment` and `FileSegment` can also be moved to `Segment` if we need it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r885138130


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;

Review Comment:
   `countState` seems unnecessary since we can get the count from `samples.size()`



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
+            countState = context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws Exception {
+            T sample = streamRecord.getValue();
+            count++;
+
+            // Code below is inspired by the Reservoir Sampling algorithm.
+            if (samples.size() < numSamples) {
+                samples.add(sample);
+            } else {
+                if (random.nextInt(count) < numSamples) {
+                    samples.set(random.nextInt(numSamples), sample);
+                }
+            }
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            Collections.shuffle(samples, random);

Review Comment:
   Is shuffle necessary here?



##########
flink-ml-dist/pom.xml:
##########
@@ -55,6 +55,15 @@ under the License.
             <scope>compile</scope>
         </dependency>
 
+        <!-- Java Object Layout Dependencies -->
+
+        <dependency>

Review Comment:
   Is this still necessary?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();

Review Comment:
   nit: new ArrayList<>(numSamples).



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.

Review Comment:
   There are many different sampling methods and here we implement one kind of them. So could you update the java doc to explain that we are doing uniform sampling without replacement? 



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java:
##########
@@ -75,9 +76,11 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept
 
         final int len = vector.values.length;
         target.writeInt(len);
+        ByteBuffer buffer = ByteBuffer.allocate(len << 3);

Review Comment:
   We should probably follow the implementation in `ObjectOutputStream#writeInts`. The current implementation introduces some extra issues:
   - Allocating one bytebuffer for each instance may introduces more garbage collection. 
   - It may leads to OOM when the densevector is large.
   - Suppose we are writing a DenseVector to using a memory segment: after writing `len` to memory and the memorySegment runs out of memory. Then the content is written to a FsSegment. Does this still work correctly?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -90,18 +90,18 @@ public void writeTo(OutputStream checkpointOutputStream) throws IOException {
             }
 
             dos.writeBoolean(fileSystem.isDistributedFS());
+            for (Segment segment : segments) {
+                persistSegmentToDisk(segment);
+            }
             if (fileSystem.isDistributedFS()) {
                 // We only need to record the segments itself
                 serializeSegments(segments, dos);
             } else {
                 // We have to copy the whole streams.
-                int totalRecords = segments.stream().mapToInt(Segment::getCount).sum();
-                long totalSize = segments.stream().mapToLong(Segment::getSize).sum();
-                checkState(totalRecords >= 0, "overflowed: " + totalRecords);
-                dos.writeInt(totalRecords);
-                dos.writeLong(totalSize);
-
+                dos.writeInt(segments.size());
                 for (Segment segment : segments) {
+                    dos.writeInt(segment.getCount());

Review Comment:
   Why do you write the size/count of each segment rather than treat them as a whole (the original version)?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {

Review Comment:
   Can we allocate the segments aggressively? Even if the expected size is zero, we still allocate all the segments here since the memory is allocated statically.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {
+                int numPages = (int) ((expectedSize + pageSize - 1) / pageSize);
+                segments.addAll(memoryManager.allocatePages(getKey(), numPages));
+            }
+        }
+
+        public Object getKey() {
+            return key;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(int capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int requiredSegmentNum = (capacity - 1) / pageSize + 2 - segments.size();
+            if (requiredSegmentNum > 0) {
+                segments.addAll(memoryManager.allocatePages(getKey(), requiredSegmentNum));
+            }
+        }
+
+        private void writeRecursive(byte[] b, int off, int len) {

Review Comment:
   If `len` is zero, it seems to throw an exception here.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java:
##########
@@ -75,9 +76,11 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept
 
         final int len = vector.values.length;
         target.writeInt(len);
+        ByteBuffer buffer = ByteBuffer.allocate(len << 3);

Review Comment:
   Moreover, should we also update SparseVectorSerializer? I am also ok to leave to TODO there.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,104 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
-/** Records the data received and replayed them on required. */
+/** Records the data received and replays them on required. */
 public class DataCacheWriter<T> {
 
-    private final TypeSerializer<T> serializer;
-
     private final FileSystem fileSystem;
 
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final TypeSerializer<T> serializer;
+
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
-            SupplierWithException<Path, IOException> pathGenerator)
+            SupplierWithException<Path, IOException> pathGenerator,
+            LimitedSizeMemoryManager memoryManager)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, memoryManager, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
+            LimitedSizeMemoryManager memoryManager,
             List<Segment> priorFinishedSegments)
             throws IOException {
         this.serializer = serializer;
         this.fileSystem = fileSystem;
         this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this.memoryManager = memoryManager;
+        this.finishedSegments = new ArrayList<>(priorFinishedSegments);
+        this.currentWriter =
+                SegmentWriter.create(
+                        pathGenerator.get(), this.memoryManager, serializer, 0L, true, true);
     }
 
     public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+        boolean success = currentWriter.addRecord(record);

Review Comment:
   Since the writer may not be a atomic one, is there a case that `size of  one vector is written to both memory and disk`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {
+                int numPages = (int) ((expectedSize + pageSize - 1) / pageSize);
+                segments.addAll(memoryManager.allocatePages(getKey(), numPages));
+            }
+        }
+
+        public Object getKey() {
+            return key;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(int capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int requiredSegmentNum = (capacity - 1) / pageSize + 2 - segments.size();

Review Comment:
   Is it `(capacity - 1) / pageSize + 1 - segments.size()`?  Moreover, do you think the following code is easier to understand?
   
   ```
   private void ensureCapacity(int capacity) throws MemoryAllocationException {
       if (capacity > pageSize * segments.size()) {
           int requiredSegmentNum = capacity % pageSize == 0 ? capacity / pageSize : capacity / pageSize + 1;
           if (requiredSegmentNum > segments.size()) {
               segments.addAll(memoryManager.allocatePages(getKey(), requiredSegmentNum - segments.size()));
   	}
       }		
   }
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r888724585


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;

Review Comment:
   `count` and `countState` represents the number of elements that have been received by the subtask so far. It might be larger than or equal to `samples.size()`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r869946600


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,65 +19,80 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
-/** Records the data received and replayed them on required. */
+/** Records the data received and replays them on required. */
 public class DataCacheWriter<T> {
 
-    private final TypeSerializer<T> serializer;
-
     private final FileSystem fileSystem;
 
     private final SupplierWithException<Path, IOException> pathGenerator;
 
+    private final MemoryManager memoryManager;
+
+    private final TypeSerializer<T> serializer;
+
     private final List<Segment> finishSegments;
 
-    private SegmentWriter currentSegment;
+    private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
-            SupplierWithException<Path, IOException> pathGenerator)
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemoryManager memoryManager)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, memoryManager, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
+            MemoryManager memoryManager,
             List<Segment> priorFinishedSegments)
             throws IOException {
         this.serializer = serializer;
         this.fileSystem = fileSystem;
         this.pathGenerator = pathGenerator;
-
+        this.memoryManager = memoryManager;
         this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this.currentWriter = createSegmentWriter(pathGenerator, this.memoryManager);
     }
 
     public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+        boolean success = currentWriter.addRecord(record);
+        if (!success) {
+            finishCurrentSegment();
+            success = currentWriter.addRecord(record);
+            Preconditions.checkState(success);
+        }
     }
 
     public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+        if (currentWriter != null) {
+            currentWriter.finish().ifPresent(finishSegments::add);
+            currentWriter = null;

Review Comment:
   nits: this line is redundant.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   Is it possible to remove `objectOutputStream` and `byteArrayOutputStream`, and just use `outputView = new DataOutputViewStreamWrapper(outputStream)`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -253,58 +261,125 @@ public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value
                             DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
                     IterationListener<Tuple2<Integer, DenseVector>> {
         private final DistanceMeasure distanceMeasure;
-        private ListState<DenseVector> points;
-        private ListState<DenseVector[]> centroids;
+        private ListState<DenseVector[]> centroidsState;
+        private DenseVector[] centroids;
+
+        private Path basePath;
+        private StreamConfig config;
+        private DataCacheWriter<DenseVector> dataCacheWriter;
 
         public SelectNearestCentroidOperator(DistanceMeasure distanceMeasure) {
+            super();
             this.distanceMeasure = distanceMeasure;
         }
 
+        @Override
+        public void setup(
+                StreamTask<?, ?> containingTask,
+                StreamConfig config,
+                Output<StreamRecord<Tuple2<Integer, DenseVector>>> output) {
+            super.setup(containingTask, config, output);
+
+            basePath =
+                    OperatorUtils.getDataCachePath(
+                            containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                            containingTask
+                                    .getEnvironment()
+                                    .getIOManager()
+                                    .getSpillingDirectoriesPaths());
+
+            this.config = config;
+        }
+
         @Override
         public void initializeState(StateInitializationContext context) throws Exception {
             super.initializeState(context);
-            points =
-                    context.getOperatorStateStore()
-                            .getListState(new ListStateDescriptor<>("points", DenseVector.class));
 
             TypeInformation<DenseVector[]> type =
                     ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
-            centroids =
+            centroidsState =
                     context.getOperatorStateStore()
                             .getListState(new ListStateDescriptor<>("centroids", type));
+            centroids =
+                    OperatorStateUtils.getUniqueElement(centroidsState, "centroids").orElse(null);
+
+            List<StatePartitionStreamProvider> inputs =
+                    IteratorUtils.toList(context.getRawOperatorStateInputs().iterator());
+            Preconditions.checkState(
+                    inputs.size() < 2, "The input from raw operator state should be one or zero.");
+
+            List<Segment> priorFinishedSegments = new ArrayList<>();
+            if (inputs.size() > 0) {
+                InputStream inputStream = inputs.get(0).getStream();
+
+                DataCacheSnapshot dataCacheSnapshot =
+                        DataCacheSnapshot.recover(
+                                inputStream,
+                                basePath.getFileSystem(),
+                                OperatorUtils.createDataCacheFileGenerator(
+                                        basePath, "cache", config.getOperatorID()));
+
+                priorFinishedSegments = dataCacheSnapshot.getSegments();
+            }
+
+            dataCacheWriter =
+                    new DataCacheWriter<>(
+                            DenseVectorSerializer.INSTANCE,
+                            basePath.getFileSystem(),
+                            OperatorUtils.createDataCacheFileGenerator(
+                                    basePath, "cache", config.getOperatorID()),
+                            getContainingTask().getEnvironment().getMemoryManager(),
+                            priorFinishedSegments);
         }
 
         @Override
         public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception {
-            points.add(streamRecord.getValue());
+            dataCacheWriter.addRecord(streamRecord.getValue());
+            if (centroids != null) {
+                DenseVector point = streamRecord.getValue();
+                int closestCentroidId = findClosestCentroidId(centroids, point, distanceMeasure);
+                output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
+            }
         }
 
         @Override
         public void processElement2(StreamRecord<DenseVector[]> streamRecord) throws Exception {
-            centroids.add(streamRecord.getValue());
+            Preconditions.checkState(centroids == null);
+            centroidsState.add(streamRecord.getValue());
+            centroids = streamRecord.getValue();
+
+            dataCacheWriter.finishCurrentSegment();

Review Comment:
   It seem inefficient to create and delete an empty in each round. How about we optimize `finishCurrentSegment()` (maybe rename it) such that it does not create empty file if `count == 0`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB

Review Comment:
   Why do we choose 100MB as the segment size? How does it compare with Spark ML's approach?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)

Review Comment:
   Do we need the `expectedSize` and `MemoryAllocationException` here?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,121 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-
-import javax.annotation.Nullable;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
-    private final TypeSerializer<T> serializer;
+    private final MemoryManager memoryManager;
 
-    private final FileSystem fileSystem;
+    private final TypeSerializer<T> serializer;
 
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    private SegmentReader<T> currentReader;
+
+    private MemorySegmentWriter<T> cacheWriter;
+
+    private int segmentIndex;
 
     public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+            TypeSerializer<T> serializer, MemoryManager memoryManager, List<Segment> segments) {
+        this(serializer, memoryManager, segments, new Tuple2<>(0, 0));
     }
 
     public DataCacheReader(
             TypeSerializer<T> serializer,
-            FileSystem fileSystem,
+            MemoryManager memoryManager,
             List<Segment> segments,
-            Tuple2<Integer, Integer> readerPosition)
-            throws IOException {
-
+            Tuple2<Integer, Integer> readerPosition) {
+        this.memoryManager = memoryManager;
         this.serializer = serializer;
-        this.fileSystem = fileSystem;
         this.segments = segments;
+        this.segmentIndex = readerPosition.f0;
+
+        createSegmentReaderAndCache(readerPosition.f0, readerPosition.f1);
+    }
+
+    private void createSegmentReaderAndCache(int index, int startOffset) {
+        try {
+            cacheWriter = null;
 
-        if (readerPosition.f0 < segments.size()) {
-            this.currentSegmentReader = new SegmentReader(readerPosition.f0, readerPosition.f1);
+            if (index >= segments.size()) {
+                currentReader = null;
+                return;
+            }
+
+            currentReader = SegmentReader.create(serializer, segments.get(index), startOffset);
+
+            boolean shouldCacheInMemory =
+                    startOffset == 0
+                            && currentReader instanceof FsSegmentReader
+                            && MemoryUtils.isMemoryEnoughForCache(memoryManager);
+
+            if (shouldCacheInMemory) {
+                cacheWriter =
+                        new MemorySegmentWriter<>(
+                                segments.get(index).path,
+                                memoryManager,
+                                segments.get(index).inMemorySize);
+            }
+        } catch (MemoryAllocationException e) {
+            cacheWriter = null;
+        } catch (IOException e) {
+            throw new RuntimeException(e);
         }
     }
 
     @Override
     public boolean hasNext() {
-        return currentSegmentReader != null && currentSegmentReader.hasNext();
+        return currentReader != null && currentReader.hasNext();
     }
 
     @Override
     public T next() {
         try {
-            T next = currentSegmentReader.next();
-
-            if (!currentSegmentReader.hasNext()) {
-                currentSegmentReader.close();
-                if (currentSegmentReader.index < segments.size() - 1) {
-                    currentSegmentReader = new SegmentReader(currentSegmentReader.index + 1, 0);
-                } else {
-                    currentSegmentReader = null;
+            T record = currentReader.next();
+
+            if (cacheWriter != null) {
+                if (!cacheWriter.addRecord(record)) {
+                    cacheWriter
+                            .finish()
+                            .ifPresent(x -> memoryManager.releaseMemory(x.path, x.inMemorySize));
+                    cacheWriter = null;
+                }
+            }
+
+            if (!currentReader.hasNext()) {
+                currentReader.close();
+                if (cacheWriter != null) {
+                    cacheWriter
+                            .finish()
+                            .ifPresent(
+                                    x -> {
+                                        x.fsSize = segments.get(segmentIndex).fsSize;
+                                        segments.set(segmentIndex, x);

Review Comment:
   This method will replace a file-only segment with in-memory segment. It is a bit counter-intuitive to make such changes in a XXXReader class.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemoryUtils.java:
##########
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.memory.MemoryManager;
+
+/** Utility variables and methods for memory operation. */
+@Internal
+class MemoryUtils {
+    // Cache is not suggested if over 80% of memory has been occupied.
+    private static final double CACHE_MEMORY_THRESHOLD = 0.2;

Review Comment:
   How does Spark ML decide when to switch from memory to disk? Does Spark also use such a magic number?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)
+            throws MemoryAllocationException {
+        Preconditions.checkNotNull(memoryManager);
+        this.segment = new Segment();
+        this.segment.path = path;
+        this.segment.cache = new ArrayList<>();
+        this.segment.inMemorySize = 0L;
+        this.memoryManager = memoryManager;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (!MemoryUtils.isMemoryEnoughForCache(memoryManager)) {
+            return false;
+        }
+
+        long recordSize = GraphLayout.parseInstance(record).totalSize();

Review Comment:
   Does Spark ML also estimate the record size in this way? Given that we invoke this method for every record, we probably need to make sure its overhead is low enough even for small records. 



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */
 public class Segment implements Serializable {
 
-    private final Path path;
+    /** The pre-allocated path on disk to persist the records. */
+    Path path;

Review Comment:
   It is not easy to understand the usage of this class after changing its variables to be non-final public. The code looks a bit hacky.
   
   Is there anyway to improve it, e.g. still using private final member variables?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -89,57 +104,31 @@ public List<Segment> getFinishSegments() {
         return finishSegments;
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
-        }
-
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
-    }
-
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
-
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
-        }
-
-        public void addRecord(T record) throws IOException {
-            serializer.serialize(record, outputView);
-            currentSegmentCount += 1;
-        }
+    private SegmentWriter<T> createSegmentWriter(
+            SupplierWithException<Path, IOException> pathGenerator, MemoryManager memoryManager)
+            throws IOException {
+        boolean shouldCacheInMemory = MemoryUtils.isMemoryEnoughForCache(memoryManager);
 
-        public Optional<Segment> finish() throws IOException {
-            this.outputStream.flush();
-            long size = outputStream.getPos();
-            this.outputStream.close();
-
-            if (currentSegmentCount > 0) {
-                return Optional.of(new Segment(path, currentSegmentCount, size));
-            } else {
-                // If there are no records, we tend to directly delete this file
-                fileSystem.delete(path, false);
-                return Optional.empty();
+        if (shouldCacheInMemory) {
+            try {
+                return new MemorySegmentWriter<>(pathGenerator.get(), memoryManager);
+            } catch (MemoryAllocationException e) {
+                return new FsSegmentWriter<>(serializer, pathGenerator.get());
             }
         }
+        return new FsSegmentWriter<>(serializer, pathGenerator.get());
     }
 
     public void cleanup() throws IOException {
-        finishCurrentSegment();
+        finish();
         for (Segment segment : finishSegments) {
-            fileSystem.delete(segment.getPath(), false);
+            if (segment.isOnDisk()) {
+                fileSystem.delete(segment.path, false);
+            }
+            if (segment.isInMemory()) {
+                memoryManager.releaseMemory(segment.path, segment.inMemorySize);

Review Comment:
   This call basically subtracts `segment.inMemorySize` owned by `segment.path` in memoryManager's bookkeeping.
   
   I suppose there should be a corresponding `reserveMemory(...)` with owner = `segment.path`. Where is that call?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/ForwardInputsOfLastRound.java:
##########
@@ -42,19 +41,13 @@ public void flatMap(T value, Collector<T> out) {
 
     @Override
     public void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> out) {
-        valuesInLastEpoch = valuesInCurrentEpoch;
-        valuesInCurrentEpoch = new ArrayList<>();
+        valuesInCurrentEpoch.clear();
     }
 
     @Override
     public void onIterationTerminated(Context context, Collector<T> out) {
-        for (T value : valuesInLastEpoch) {

Review Comment:
   Prior to this PR, when the following sequence of statements are called, two values would be collected. After this PR, no values would be collected. In other words, it changes the behavior of public APIs introduced in FLIP-176. Is it expected?
   
   ```
   flatMap(x1, ...)
   
   flatMap(x2, ...)
   
   onEpochWatermarkIncremented(...)
   
   onIterationTerminated(...)
   ```



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */

Review Comment:
   After this PR, the Java doc says this class represents a cache unit. But the implementation shows that this class could represent a file, which seems confusing. Could you clarify it?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r874532378


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)
+            throws MemoryAllocationException {
+        Preconditions.checkNotNull(memoryManager);
+        this.segment = new Segment();
+        this.segment.path = path;
+        this.segment.cache = new ArrayList<>();
+        this.segment.inMemorySize = 0L;
+        this.memoryManager = memoryManager;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (!MemoryUtils.isMemoryEnoughForCache(memoryManager)) {
+            return false;
+        }
+
+        long recordSize = GraphLayout.parseInstance(record).totalSize();

Review Comment:
   Spark uses its own classes, like `SizeTracker` and `SizeEstimator`, to calculate the size of each object. If we follow Spark's practice, we would also need to add such infra codes, which could add certain workload to this PR.
   
   `GraphLayout` uses `java.lang.instrument.Instrumentation`, which is the default solution for Java programs to get object size, so I think its overhead should be acceptable.
   
   The reason why I did not use `Instrumentation` is because that would require users add extra java options while executing flink ml programs, for example, 
   ```shell
   java -javaagent:target/flink-ml-agent-2.1-SNAPSHOT.jar flink-ml-uber.jar <className>
   ```
   The `-javaagent` option is additionally required. This more complex cmd requirement would decrease user experience while not bringing much performance improvement.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889803526


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final List<Segment> finishedSegments;
+
+    private SegmentWriter<T> currentWriter;
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        for (Segment segment : finishedSegments) {
+            tryCacheSegmentToMemory(segment);
+        }
+        this.currentWriter = createSegmentWriter();
+    }
+
+    public void addRecord(T record) throws IOException {
+        try {
+            currentWriter.addRecord(record);
+        } catch (SegmentNoVacancyException e) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            currentWriter.addRecord(record);
+        }
+    }
+
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public void finish() throws IOException {
+        if (currentWriter == null) {
+            return;
+        }
+
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+    }
+
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFileSegment() != null) {
+                fileSystem.delete(segment.getFileSegment().getPath(), false);
+            }
+            if (segment.getMemorySegment() != null) {
+                segmentPool.returnAll(segment.getMemorySegment().getCache());
+            }
+        }
+        finishedSegments.clear();
+    }
+
+    private void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {

Review Comment:
   It is possible. For example in `ReplayOperator`, `finish()` is invoked at `onEpochWatermarkIncrement` while `finishCurrentSegmentIfAny()` is invoked at `snapshotState`. Since `snapshotState` can happen before or after `onEpochWatermarkIncrement`, `finish()` and `finishCurrentSegmentIfAny()` might be invoked repeatedly in any order, and the `if` condition is unavoidable. If we remove the `if` from DataCacheWriter, users still need to add it in their code.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891145317


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -160,6 +162,9 @@ public IterationBodyResult process(
                                             BasicTypeInfo.INT_TYPE_INFO,
                                             DenseVectorTypeInfo.INSTANCE),
                                     new SelectNearestCentroidOperator(distanceMeasure));
+            centroidIdAndPoints
+                    .getTransformation()
+                    .declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase.OPERATOR, 64);

Review Comment:
   The methods provided in the reference link allow us to specify the proportion of managed memory shared among all usages of `ManagedMemoryUseCase.OPERATOR`, while the `declareManagedMemoryUseCaseAtOperatorScope` method here further specifies the proportion of managed memory this operator can use among all operators that have stated that they need managed memory through invoking this method. These two are configurations at different scopes that cannot replace each other.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889837384


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheIterator.java:
##########
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+
+/** Reads the cached data from a list of segments. */
+@Internal
+public class DataCacheIterator<T> implements Iterator<T> {

Review Comment:
   According to offline discussion, I'll retrieve the original `DataCacheWriter`, `DataCacheSnapshot` and `DataCacheReader`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r876709362


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */
 public class Segment implements Serializable {
 
-    private final Path path;
+    /** The pre-allocated path on disk to persist the records. */
+    Path path;
 
-    /** The count of the records in the file. */
-    private final int count;
+    /** The number of records in the file. */
+    int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The size of the records in file. */
+    long fsSize;
 
-    public Segment(Path path, int count, long size) {
+    /** The size of the records in memory. */
+    transient long inMemorySize;
+
+    /** The cached records in memory. */
+    transient List<Object> cache;
+
+    /** The serializer for the records. */
+    transient TypeSerializer<Object> serializer;
+
+    Segment() {}
+
+    Segment(Path path, int count, long fsSize) {
         this.path = path;
         this.count = count;
-        this.size = size;
-    }
-
-    public Path getPath() {
-        return path;
+        this.fsSize = fsSize;
     }
 
-    public int getCount() {
-        return count;
+    boolean isOnDisk() throws IOException {

Review Comment:
   What about we use MemorySegment and FsSegment?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -94,6 +120,26 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
         }
     }
 
+    /**
+     * Takes a randomly sampled subset of elements in a bounded data stream.
+     *
+     * <p>If the number of elements in the stream is smaller than expected number of samples, all
+     * elements will be included in the sample.
+     *
+     * @param input The input data stream.
+     * @param numSamples The number of elements to be sampled.
+     * @param randomSeed The seed to randomly pick elements as sample.
+     * @return A data stream containing a list of the sampled elements.
+     */
+    public static <T> DataStream<List<T>> sample(
+            DataStream<T> input, int numSamples, long randomSeed) {
+        return input.transform(
+                        "samplingOperator",
+                        Types.LIST(input.getType()),
+                        new SamplingOperator<>(numSamples, randomSeed))
+                .setParallelism(1);

Review Comment:
   The semantic of `Sample` seems not to change the parallelism of the operator? Moreover, we probably should do distributed sampling for better performance.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,106 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.util.Preconditions;
 
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/**
+ * A segment contains the information about a cache unit.
+ *
+ * <p>If the unit is persisted in a file on disk, this class provides the number of records in the
+ * unit, the path to the file, and the size of the file.
+ *
+ * <p>If the unit is cached in memory, this class provides the number of records, the cached
+ * objects, and information to persist them on disk, including the pre-allocated path, and the type
+ * serializer.
+ */
+@Internal
 public class Segment implements Serializable {
 
+    /** The pre-allocated path to persist records on disk. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The number of records in the file. */

Review Comment:
   nit: int the file --> in the segment



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,122 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-
-import javax.annotation.Nullable;
+import org.apache.flink.runtime.memory.MemoryManager;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
-    private final TypeSerializer<T> serializer;
+    private final MemoryManager memoryManager;
 
-    private final FileSystem fileSystem;
+    private final TypeSerializer<T> serializer;
 
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    private SegmentReader<T> currentReader;
+
+    private SegmentWriter<T> cacheWriter;
+
+    private int segmentIndex;
 
     public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+            TypeSerializer<T> serializer, MemoryManager memoryManager, List<Segment> segments) {
+        this(serializer, memoryManager, segments, new Tuple2<>(0, 0));
     }
 
     public DataCacheReader(
             TypeSerializer<T> serializer,
-            FileSystem fileSystem,
+            MemoryManager memoryManager,
             List<Segment> segments,
-            Tuple2<Integer, Integer> readerPosition)
-            throws IOException {
-
+            Tuple2<Integer, Integer> readerPosition) {
+        this.memoryManager = memoryManager;
         this.serializer = serializer;
-        this.fileSystem = fileSystem;
         this.segments = segments;
+        this.segmentIndex = readerPosition.f0;
+
+        createSegmentReaderAndCache(readerPosition.f0, readerPosition.f1);
+    }
+
+    private void createSegmentReaderAndCache(int index, int startOffset) {
+        try {
+            cacheWriter = null;
 
-        if (readerPosition.f0 < segments.size()) {
-            this.currentSegmentReader = new SegmentReader(readerPosition.f0, readerPosition.f1);
+            if (index >= segments.size()) {
+                currentReader = null;
+                return;
+            }
+
+            currentReader = SegmentReader.create(serializer, segments.get(index), startOffset);
+
+            boolean shouldCacheInMemory =
+                    startOffset == 0
+                            && currentReader instanceof FsSegmentReader
+                            && MemoryUtils.isMemoryEnoughForCache(memoryManager);
+
+            if (shouldCacheInMemory) {
+                cacheWriter =
+                        SegmentWriter.create(
+                                segments.get(index).getPath(),
+                                memoryManager,
+                                serializer,
+                                segments.get(index).getFsSize(),
+                                true,
+                                false);
+            }
+
+        } catch (IOException e) {
+            throw new RuntimeException(e);
         }
     }
 
     @Override
     public boolean hasNext() {
-        return currentSegmentReader != null && currentSegmentReader.hasNext();
+        return currentReader != null && currentReader.hasNext();
     }
 
     @Override
     public T next() {
         try {
-            T next = currentSegmentReader.next();
-
-            if (!currentSegmentReader.hasNext()) {
-                currentSegmentReader.close();
-                if (currentSegmentReader.index < segments.size() - 1) {
-                    currentSegmentReader = new SegmentReader(currentSegmentReader.index + 1, 0);
-                } else {
-                    currentSegmentReader = null;
+            T record = currentReader.next();
+
+            if (cacheWriter != null) {
+                if (!cacheWriter.addRecord(record)) {

Review Comment:
   I am a bit confused here about adding a `cacheWriter` here. Could you explain a bit about this?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   Using `BufferedOutputStream` to wrap the `FsDataOutputStream` can improve the performance here, without using the `objectOutputStream`. The intuition is that `objectOutputStream` use a buffer for each record, to reduce the number of calling `DataOutputStream#write`, while `BufferedOutputStream` use a buffer for (possibly) many records.
   
   A code example could be:
   
   ```
       void test1(final int numTries) throws IOException {
           Path path = new Path("/tmp/result1");
           FSDataOutputStream outputStream = path.getFileSystem().create(path, FileSystem.WriteMode.OVERWRITE);
           TypeSerializer serializer = DenseVectorSerializer.INSTANCE;
           DenseVector record = Vectors.dense(new double[100]);
   
           // add the following line comparing with the init implementation.
           BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(outputStream);
   
           DataOutputView outputView = new DataOutputViewStreamWrapper(bufferedOutputStream);
           for (int i = 0; i < numTries; i ++) {
               serializer.serialize(record, outputView);
           }
           bufferedOutputStream.flush();
       }
   
       void test2(final int numTries) throws IOException {
           Path path = new Path("/tmp/result2");
           FSDataOutputStream outputStream = path.getFileSystem().create(path, FileSystem.WriteMode.OVERWRITE);
           TypeSerializer serializer = DenseVectorSerializer.INSTANCE;
           DenseVector record = Vectors.dense(new double[100]);
   
           ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
           DataOutputView outputView = new DataOutputViewStreamWrapper(byteArrayOutputStream);
           for (int i = 0; i < numTries; i ++) {
               serializer.serialize(record, outputView);
               byte[] bytes = byteArrayOutputStream.toByteArray();
               ObjectOutputStream objectOutputStream = new ObjectOutputStream(outputStream);
               objectOutputStream.writeObject(bytes);
               byteArrayOutputStream.reset();
           }
       }
   
       @Test
       public void test() throws IOException {
           int numTries = 1000000;
           long time = System.currentTimeMillis();
           test1(numTries);
           System.out.println("Option-1: " + (System.currentTimeMillis() - time));
   
           time = System.currentTimeMillis();
           test2(numTries);
           System.out.println("Option-2: " + (System.currentTimeMillis() - time));
       }
   ```
   
   The result turns to be 
   ```
   Option-1: 3005
   Option-2: 14898
   ```
   
   



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,106 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.util.Preconditions;
 
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/**
+ * A segment contains the information about a cache unit.
+ *
+ * <p>If the unit is persisted in a file on disk, this class provides the number of records in the
+ * unit, the path to the file, and the size of the file.
+ *
+ * <p>If the unit is cached in memory, this class provides the number of records, the cached
+ * objects, and information to persist them on disk, including the pre-allocated path, and the type
+ * serializer.
+ */
+@Internal

Review Comment:
   Is using `MemorySegment` and `FsSegment` accordingly more clear ?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)
+            throws MemoryAllocationException {
+        Preconditions.checkNotNull(memoryManager);
+        this.segment = new Segment();
+        this.segment.path = path;
+        this.segment.cache = new ArrayList<>();
+        this.segment.inMemorySize = 0L;
+        this.memoryManager = memoryManager;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (!MemoryUtils.isMemoryEnoughForCache(memoryManager)) {
+            return false;
+        }
+
+        long recordSize = GraphLayout.parseInstance(record).totalSize();
+
+        try {
+            memoryManager.reserveMemory(this, recordSize);
+        } catch (MemoryReservationException e) {
+            return false;
+        }
+
+        this.segment.cache.add(record);
+        segment.inMemorySize += recordSize;
+
+        this.segment.count++;
+        return true;
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (segment.count > 0) {
+            return Optional.of(segment);
+        } else {
+            memoryManager.releaseMemory(segment.path, segment.inMemorySize);

Review Comment:
   Is `releaseMemory` needed here? It seems that `inMemorySize` should be always zero here.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final MemoryManager memoryManager;
+
+    private final Path path;
+
+    private final List<T> cache;
+
+    private final TypeSerializer<T> serializer;
+
+    private long inMemorySize;
+
+    private int count;
+
+    private long reservedMemorySize;
+
+    public MemorySegmentWriter(
+            Path path, MemoryManager memoryManager, TypeSerializer<T> serializer, long expectedSize)
+            throws MemoryReservationException {
+        this.serializer = serializer;
+        Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.cache = new ArrayList<>();
+        this.inMemorySize = 0L;
+        this.count = 0;
+        this.memoryManager = memoryManager;
+
+        if (expectedSize > 0) {
+            memoryManager.reserveMemory(this.path, expectedSize);
+        }
+        this.reservedMemorySize = expectedSize;
+    }
+
+    @Override
+    public boolean addRecord(T record) {

Review Comment:
   The way of using `MemoryManager` seems not appropriate to me after digging into the usage of `MemoryManager`. [1][2]
   
   The code snippet here seems to be caching the record in java heap, but trying to reserve memory from off-heap memory. If I am understanding [1] [2] correctly, 
   - When using `MemoryManager` to manipulate managed memory, we are mostly dealing with off-heap memory.
   - The managed memory for each operator should be a fixed one after generating the job graph, i.e., it is not dynamically allocated.
   - The usage of managed memory should be declared to the jobgraph explicitly and then be used by the operator. Otherwise it will lead to OOM if deployed in a container.
   
   As I see, there are basically two options to cache the data:
   - cache it in `task heap` (i.e., cache it in a `list`): It is simple and easy to implement, but the downside is that we cannot control the size of cached element `statically` and the program may not be robust --- `task heap` is shared among the JVM and we have no idea about how others are using the JVM heap memory. Moreover, we need to write the `list` to state for recovery.
   - cache it in `off-heap` (for example using the managed memory). In this way, we need to declare the usage of the managed to the job graph via `Transformation#declareManagedMemoryUseCaseAtOperatorScope` or `Transformation#declareManagedMemoryUseCaseAtSlotScope` and get the fraction of the managed memory from [3].
   
   
   I would suggest to go with option-2, but need more discussions with the runtime guys.
   
   [1] https://cwiki.apache.org/confluence/display/FLINK/FLIP-53%3A+Fine+Grained+Operator+Resource+Management
   [2] https://cwiki.apache.org/confluence/display/FLINK/FLIP-141%3A+Intra-Slot+Managed+Memory+Sharing
   [3] https://github.com/apache/flink/blob/18a967f8ad7b22c2942e227fb84f08f552660b5a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/sort/SortOperator.java#L79



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r876690459


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */

Review Comment:
   This class contains all information about a cache unit, no matter the cache is on disk, in memory, or both. I'll modify the doc and the api to improve its readability.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */

Review Comment:
   This class contains all information about a cache unit, no matter the cache is on disk, in memory, or both. I'll modify the doc and the api to improve its readability.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r879052166


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -94,6 +120,26 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
         }
     }
 
+    /**
+     * Takes a randomly sampled subset of elements in a bounded data stream.
+     *
+     * <p>If the number of elements in the stream is smaller than expected number of samples, all
+     * elements will be included in the sample.
+     *
+     * @param input The input data stream.
+     * @param numSamples The number of elements to be sampled.
+     * @param randomSeed The seed to randomly pick elements as sample.
+     * @return A data stream containing a list of the sampled elements.
+     */
+    public static <T> DataStream<List<T>> sample(
+            DataStream<T> input, int numSamples, long randomSeed) {
+        return input.transform(
+                        "samplingOperator",
+                        Types.LIST(input.getType()),
+                        new SamplingOperator<>(numSamples, randomSeed))
+                .setParallelism(1);

Review Comment:
   The SampleOperator would return only one element, which is a List<T>, so retaining the parallelism of upstream operator seems to be meaningless. 
   
   I agree that distributed sampling can bring better performance. I'll make the change this way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r884536704


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)
+            throws MemoryAllocationException {
+        Preconditions.checkNotNull(memoryManager);
+        this.segment = new Segment();
+        this.segment.path = path;
+        this.segment.cache = new ArrayList<>();
+        this.segment.inMemorySize = 0L;
+        this.memoryManager = memoryManager;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (!MemoryUtils.isMemoryEnoughForCache(memoryManager)) {
+            return false;
+        }
+
+        long recordSize = GraphLayout.parseInstance(record).totalSize();

Review Comment:
   According to offline discussion, I'll store serialized format of the records in memory, so there is no need for size estimation now. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r888730986


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
+            countState = context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws Exception {
+            T sample = streamRecord.getValue();
+            count++;
+
+            // Code below is inspired by the Reservoir Sampling algorithm.
+            if (samples.size() < numSamples) {
+                samples.add(sample);
+            } else {
+                if (random.nextInt(count) < numSamples) {
+                    samples.set(random.nextInt(numSamples), sample);
+                }
+            }
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            Collections.shuffle(samples, random);

Review Comment:
   I think it is necessary. Otherwise the first arriving element, if sampled, will always be the first returning element, while elements that arrived after the first sampled `numSamples` elements will be returned in random order.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889987118


##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java:
##########
@@ -20,29 +20,36 @@
 package org.apache.flink.ml.linalg.typeinfo;
 
 import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
-import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.util.Bits;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Objects;
 
 /** Specialized serializer for {@link DenseVector}. */
-public final class DenseVectorSerializer extends TypeSerializerSingleton<DenseVector> {
+public final class DenseVectorSerializer extends TypeSerializer<DenseVector> {
 
     private static final long serialVersionUID = 1L;
 
     private static final double[] EMPTY = new double[0];
 
-    public static final DenseVectorSerializer INSTANCE = new DenseVectorSerializer();
+    private final byte[] buf = new byte[1024];

Review Comment:
   Instead of add buffer inside serializers and making those serializers non-singleton, would it be simpler to just let caller use `BufferedOutputStream`?
   
   For example, `NaiveBayesModelData::ModelDataEncoder::encode()` can do this:
   
   ```
   DataOutputViewStreamWrapper outputViewStreamWrapper =
           new DataOutputViewStreamWrapper(new BufferedOutputStream(outputStream));
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890701884


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FileSegmentWriter.java:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** A class that writes cache data to a target file in given file system. */
+@Internal
+class FileSegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The path to the target file. */
+    private final Path path;
+
+    /** The output stream that writes to the target file. */
+    private final FSDataOutputStream outputStream;
+
+    /** A buffer that wraps the output stream to optimize performance. */
+    private final BufferedOutputStream bufferedOutputStream;

Review Comment:
   Since we cannot remove the `bufferedOutputStream.flush()` in `finish()` according to the comment above, we cannot remove this variable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890739744


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -59,87 +80,100 @@ public DataCacheWriter(
             SupplierWithException<Path, IOException> pathGenerator,
             List<Segment> priorFinishedSegments)
             throws IOException {
-        this.serializer = serializer;
+        this(serializer, fileSystem, pathGenerator, null, priorFinishedSegments);
+    }
+
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> priorFinishedSegments)
+            throws IOException {
         this.fileSystem = fileSystem;
         this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>(priorFinishedSegments);
+        this.currentSegmentWriter = createSegmentWriter();
     }
 
     public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
-    }
-
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+        if (!currentSegmentWriter.addRecord(record)) {
+            currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+            currentSegmentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentSegmentWriter.addRecord(record));
+        }
     }
 
+    /** Finishes adding records and closes resources occupied for adding records. */
     public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
-    }
+        if (currentSegmentWriter == null) {
+            return finishedSegments;
+        }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
+        currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+        currentSegmentWriter = null;
+        return finishedSegments;
     }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+    /**
+     * Flushes all added records to segments and returns a list of segments containing all cached
+     * records.
+     */
+    public List<Segment> getSegments() throws IOException {
+        finishCurrentSegmentIfExists();
+        return finishedSegments;
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    private void finishCurrentSegmentIfExists() throws IOException {
+        if (currentSegmentWriter == null) {
+            return;
         }
 
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
+        currentSegmentWriter.finish().ifPresent(finishedSegments::add);
+        currentSegmentWriter = createSegmentWriter();
     }
 
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
-
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {

Review Comment:
   Would it be more consistent with `State::clear()` to rename this method as `clear()`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -254,58 +272,150 @@ public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value
                             DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
                     IterationListener<Tuple2<Integer, DenseVector>> {
         private final DistanceMeasure distanceMeasure;
-        private ListState<DenseVector> points;
-        private ListState<DenseVector[]> centroids;
+        private ListState<DenseVector[]> centroidsState;
+        private DenseVector[] centroids;

Review Comment:
   Does this improve performance by using `centroids`? If not, it seems simpler to not adding this variable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891140907


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/ListStateWithCache.java:
##########
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.table.runtime.util.LazyMemorySegmentPool;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A {@link ListState} child class that records data and replays them on required.
+ *
+ * <p>This class basically stores data in file system, and provides the option to cache them in
+ * memory. In order to use the memory caching option, users need to allocate certain managed memory
+ * for the wrapper operator through {@link
+ * org.apache.flink.api.dag.Transformation#declareManagedMemoryUseCaseAtOperatorScope}.
+ *
+ * <p>NOTE: Users need to explicitly invoke this class's {@link
+ * #snapshotState(StateSnapshotContext)} method in order to store the recorded data in snapshot.
+ */
+public class ListStateWithCache<T> implements ListState<T> {
+
+    /** The tool to serialize/deserialize records. */
+    private final TypeSerializer<T> serializer;
+
+    /** The path of the directory that holds the files containing recorded data. */
+    private final Path basePath;
+
+    /** The data cache writer for the received records. */
+    private final DataCacheWriter<T> dataCacheWriter;
+
+    @SuppressWarnings("unchecked")
+    public ListStateWithCache(
+            TypeSerializer<T> serializer,
+            StreamTask<?, ?> containingTask,
+            StreamingRuntimeContext runtimeContext,
+            StateInitializationContext stateInitializationContext,
+            OperatorID operatorID)
+            throws IOException {
+        this.serializer = serializer;
+
+        MemorySegmentPool segmentPool = null;
+        double fraction =
+                containingTask
+                        .getConfiguration()
+                        .getManagedMemoryFractionOperatorUseCaseOfSlot(
+                                ManagedMemoryUseCase.OPERATOR,
+                                runtimeContext.getTaskManagerRuntimeInfo().getConfiguration(),
+                                runtimeContext.getUserCodeClassLoader());
+        if (fraction > 0) {
+            MemoryManager memoryManager = containingTask.getEnvironment().getMemoryManager();
+            segmentPool =
+                    new LazyMemorySegmentPool(
+                            containingTask,
+                            memoryManager,
+                            memoryManager.computeNumberOfPages(fraction));
+        }
+
+        basePath =
+                OperatorUtils.getDataCachePath(
+                        containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                        containingTask
+                                .getEnvironment()
+                                .getIOManager()
+                                .getSpillingDirectoriesPaths());
+
+        List<StatePartitionStreamProvider> inputs =
+                IteratorUtils.toList(
+                        stateInitializationContext.getRawOperatorStateInputs().iterator());
+        Preconditions.checkState(
+                inputs.size() < 2, "The input from raw operator state should be one or zero.");
+
+        List<Segment> priorFinishedSegments = new ArrayList<>();
+        if (inputs.size() > 0) {
+            DataCacheSnapshot dataCacheSnapshot =
+                    DataCacheSnapshot.recover(
+                            inputs.get(0).getStream(),
+                            basePath.getFileSystem(),
+                            OperatorUtils.createDataCacheFileGenerator(
+                                    basePath, "cache", operatorID));
+
+            if (segmentPool != null) {
+                dataCacheSnapshot.tryReadSegmentsToMemory(serializer, segmentPool);
+            }
+
+            priorFinishedSegments = dataCacheSnapshot.getSegments();
+        }
+
+        this.dataCacheWriter =
+                new DataCacheWriter<>(
+                        serializer,
+                        basePath.getFileSystem(),
+                        OperatorUtils.createDataCacheFileGenerator(basePath, "cache", operatorID),
+                        segmentPool,
+                        priorFinishedSegments);
+    }
+
+    public void snapshotState(StateSnapshotContext context) throws Exception {

Review Comment:
   I have also confirmed with @gaoyunhaii that the `processElement` and `snapshotState` methods of an operator would not be invoked at the same time, so there should not be a parallel computing issue. More descriptions of Flink's resource usage can be found at https://nightlies.apache.org/flink/flink-docs-master/docs/concepts/flink-architecture/#task-slots-and-resources.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r884537389


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */
 public class Segment implements Serializable {
 
-    private final Path path;
+    /** The pre-allocated path on disk to persist the records. */
+    Path path;
 
-    /** The count of the records in the file. */
-    private final int count;
+    /** The number of records in the file. */
+    int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The size of the records in file. */
+    long fsSize;
 
-    public Segment(Path path, int count, long size) {
+    /** The size of the records in memory. */
+    transient long inMemorySize;
+
+    /** The cached records in memory. */
+    transient List<Object> cache;
+
+    /** The serializer for the records. */
+    transient TypeSerializer<Object> serializer;
+
+    Segment() {}
+
+    Segment(Path path, int count, long fsSize) {
         this.path = path;
         this.count = count;
-        this.size = size;
-    }
-
-    public Path getPath() {
-        return path;
+        this.fsSize = fsSize;
     }
 
-    public int getCount() {
-        return count;
+    boolean isOnDisk() throws IOException {

Review Comment:
   We have the need to express a segment that is both cached in memory and persisted on disk. If we separate the class into MemorySegment and FsSegment, I'm not sure how to deal with such situation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r884522110


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final MemoryManager memoryManager;
+
+    private final Path path;
+
+    private final List<T> cache;
+
+    private final TypeSerializer<T> serializer;
+
+    private long inMemorySize;
+
+    private int count;
+
+    private long reservedMemorySize;
+
+    public MemorySegmentWriter(
+            Path path, MemoryManager memoryManager, TypeSerializer<T> serializer, long expectedSize)
+            throws MemoryReservationException {
+        this.serializer = serializer;
+        Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.cache = new ArrayList<>();
+        this.inMemorySize = 0L;
+        this.count = 0;
+        this.memoryManager = memoryManager;
+
+        if (expectedSize > 0) {
+            memoryManager.reserveMemory(this.path, expectedSize);
+        }
+        this.reservedMemorySize = expectedSize;
+    }
+
+    @Override
+    public boolean addRecord(T record) {

Review Comment:
   According to offline discussion, I'll adopt the second choice to store records in Flink's managed memory in serialized format.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r888738791


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {
+                int numPages = (int) ((expectedSize + pageSize - 1) / pageSize);
+                segments.addAll(memoryManager.allocatePages(getKey(), numPages));
+            }
+        }
+
+        public Object getKey() {
+            return key;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(int capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int requiredSegmentNum = (capacity - 1) / pageSize + 2 - segments.size();
+            if (requiredSegmentNum > 0) {
+                segments.addAll(memoryManager.allocatePages(getKey(), requiredSegmentNum));
+            }
+        }
+
+        private void writeRecursive(byte[] b, int off, int len) {

Review Comment:
   Do you mean we should allow `len` to be zero or prohibit it? I think we should allow such usage so as to correspond to lower-level implementation, and the code has already achieved it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890701650


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FileSegmentWriter.java:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** A class that writes cache data to a target file in given file system. */
+@Internal
+class FileSegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The path to the target file. */
+    private final Path path;
+
+    /** The output stream that writes to the target file. */
+    private final FSDataOutputStream outputStream;
+
+    /** A buffer that wraps the output stream to optimize performance. */
+    private final BufferedOutputStream bufferedOutputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    FileSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.outputStream = path.getFileSystem().create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.bufferedOutputStream = new BufferedOutputStream(outputStream);
+        this.outputView = new DataOutputViewStreamWrapper(bufferedOutputStream);
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        serializer.serialize(record, outputView);
+        count++;
+        return true;
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        bufferedOutputStream.flush();

Review Comment:
   `bufferedOutputStream` wraps around `outputStream` and is invisible to `outputStream`. When `outputStream.flush()` is invoked, there might be some bytes left in `bufferedOutputStream`. When I tried to remove this line, most data cache tests failed. Thus this line cannot be removed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890992134


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -167,26 +183,69 @@ public static DataCacheSnapshot recover(
             if (isDistributedFS) {
                 segments = deserializeSegments(dis);
             } else {
-                int totalRecords = dis.readInt();
-                long totalSize = dis.readLong();
-
-                Path path = pathGenerator.get();
-                try (FSDataOutputStream outputStream =
-                        fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE)) {
-
-                    BoundedInputStream inputStream =
-                            new BoundedInputStream(checkpointInputStream, totalSize);
-                    inputStream.setPropagateClose(false);
-                    IOUtils.copyBytes(inputStream, outputStream, false);
-                    inputStream.close();
+                int segmentNum = dis.readInt();
+                segments = new ArrayList<>(segmentNum);
+                for (int i = 0; i < segmentNum; i++) {
+                    int count = dis.readInt();
+                    long fsSize = dis.readLong();
+                    Path path = pathGenerator.get();
+                    try (FSDataOutputStream outputStream =
+                            fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE)) {
+
+                        BoundedInputStream boundedInputStream =
+                                new BoundedInputStream(checkpointInputStream, fsSize);
+                        boundedInputStream.setPropagateClose(false);
+                        IOUtils.copyBytes(boundedInputStream, outputStream, false);
+                        boundedInputStream.close();
+                    }
+                    segments.add(new Segment(path, count, fsSize));
                 }
-                segments = Collections.singletonList(new Segment(path, totalRecords, totalSize));
             }
 
             return new DataCacheSnapshot(fileSystem, readerPosition, segments);
         }
     }
 
+    /**
+     * Makes an attempt to cache the segments in memory.
+     *
+     * <p>The attempt is made at segment granularity, which means there might be only part of the
+     * segments are cached.
+     *
+     * <p>This method does not throw exception if there is not enough memory space for caching a
+     * segment.
+     */
+    public <T> void tryReadSegmentsToMemory(
+            TypeSerializer<T> serializer, MemorySegmentPool segmentPool) throws IOException {
+        boolean cacheSuccess;
+        for (Segment segment : segments) {
+            if (!segment.getCache().isEmpty()) {
+                continue;
+            }
+
+            SegmentReader<T> reader = new FileSegmentReader<>(serializer, segment, 0);
+            SegmentWriter<T> writer;
+            try {
+                writer =
+                        new MemorySegmentWriter<>(
+                                serializer, segment.getPath(), segmentPool, segment.getFsSize());
+            } catch (MemoryAllocationException e) {
+                continue;

Review Comment:
   In order to reduce the chance that we repeatedly read part of the segment from disk into memory and then fail due to memory limitation, how about we break out of the loop on the first failure?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889669585


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheIterator.java:
##########
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+
+/** Reads the cached data from a list of segments. */
+@Internal
+public class DataCacheIterator<T> implements Iterator<T> {

Review Comment:
   The original API is coupled together, so whenever the users use `DataCacheWriter`, they would also use `DataCacheSnapshot` and `DataCacheReader`. In this case I think it is better to merge them into one class.
   
   Another advantage of the current API is that it better hides implementation details from developers.  If we still use the original API. we would have to preserve the `Segment` class in order to pass information between `DataCacheWriter`, `DataCacheSnapshot` and `DataCacheReader`, and developers need to learn the notion of `Segment` in order to use the data cache mechanism. While in previous offline discussion, we agreed that there might not be a unified `Segment` concept and might keep `FileSegment` and `MemorySegment` separately in internal implementation. We might also find that there is a better internal representation of the cache unit than `xxxSegment` in future, but it would be hard for us to make such improvement then because `Segment` has become a public API.
   
   Since the `DataCache` holds most functions of previous `DataCacheWriter` and `DataCacheSnapshot`, naming it as `DataCacheWriter` might not be suitable now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891145317


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -160,6 +162,9 @@ public IterationBodyResult process(
                                             BasicTypeInfo.INT_TYPE_INFO,
                                             DenseVectorTypeInfo.INSTANCE),
                                     new SelectNearestCentroidOperator(distanceMeasure));
+            centroidIdAndPoints
+                    .getTransformation()
+                    .declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase.OPERATOR, 64);

Review Comment:
   The methods provided in the reference link allow us to specify the proportion of managed memory shared among all usages of `ManagedMemoryUseCase.OPERATOR`, while the `declareManagedMemoryUseCaseAtOperatorScope` method here further specifies the proportion of managed memory this operator can use among all operators that have stated that they need managed memory through invoking this method. These two are configurations at different scopes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r874553395


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   My experiments show that if we do `outputView = new DataOutputViewStreamWrapper(outputStream)`, the performance of disk IO would degrade obviously. One possible cause is that in `DataOutputViewStreamWrapper` or `FsDataOutputStream`'s subclass, the implementation flushes out bytes every time `write()` is invoked, as I can observe from the update frequency of `outputStream.getPos()`, but I have not dived further into this issue. Shall we leave this implementation as it is for now, and see if @zhipeng93 would offer a more simplified implementation with better performance? I'll leave a TODO here for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r888745233


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {

Review Comment:
   Do you mean we should allocate all segments from the beginning, or we should not do it but the current implementation has forced such practice?
   
   In my opinion, we should not do it so as to improve memory utilization. And the current implementation has achieved it, since it will require memory as more records are added instead of occupying all memory space during initialization. `expectedSize` is only used when recreating cache during reading process, in which case we would have known the size to be occupied from the start.
   
   According to other offline discussions I'll modify the code to use flink's `LazyMemorySegmentPool` instead of `LimitedSizeMemoryManager`, and does the re-caching process when recovering from snapshot instead of during reading. But the statements above remain valid.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889803526


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final List<Segment> finishedSegments;
+
+    private SegmentWriter<T> currentWriter;
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        for (Segment segment : finishedSegments) {
+            tryCacheSegmentToMemory(segment);
+        }
+        this.currentWriter = createSegmentWriter();
+    }
+
+    public void addRecord(T record) throws IOException {
+        try {
+            currentWriter.addRecord(record);
+        } catch (SegmentNoVacancyException e) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            currentWriter.addRecord(record);
+        }
+    }
+
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public void finish() throws IOException {
+        if (currentWriter == null) {
+            return;
+        }
+
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+    }
+
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFileSegment() != null) {
+                fileSystem.delete(segment.getFileSegment().getPath(), false);
+            }
+            if (segment.getMemorySegment() != null) {
+                segmentPool.returnAll(segment.getMemorySegment().getCache());
+            }
+        }
+        finishedSegments.clear();
+    }
+
+    private void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {

Review Comment:
   It is possible. For example in `ReplayOperator`, `finish()` is invoked at `onEpochWatermarkIncrement` while `finishCurrentSegmentIfAny()` is invoked at `snapshotState`. Since `snapshotState` can happen before or after `onEpochWatermarkIncrement`, `finish()` and `finishCurrentSegmentIfAny()` might be invoked repeatedly in any order, and the `if` condition if unavoidable. If we remove the `if` from DataCacheWriter, users still need to add it in their code.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889837092


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FileSegmentWriter.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+class FileSegmentWriter<T> implements SegmentWriter<T> {
+
+    private static final long FILE_SIZE_LIMIT = 1L << 30; // 1GB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FileSystem fileSystem;
+
+    private final FSDataOutputStream outputStream;
+
+    private final BufferedOutputStream bufferedOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    FileSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.bufferedOutputStream = new BufferedOutputStream(outputStream);
+        this.outputView = new DataOutputViewStreamWrapper(bufferedOutputStream);
+    }
+
+    @Override
+    public void addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= FILE_SIZE_LIMIT) {
+            throw new SegmentNoVacancyException();

Review Comment:
   According to offline discussion, we still need to use Exception to pass relative information out from output stream, but in the wrapper method we should cache the exception and convert exceptions related to limited size to proper message, like a boolean denoting success or failure. I'll make the change this way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889905269


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,158 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
+    public void addRecord(T record) throws IOException {
+        assert currentWriter != null;
+        if (!currentWriter.addRecord(record)) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentWriter.addRecord(record));
+        }
     }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
-    }
+    /** Finishes current segment if records has ever been added to this segment. */
+    public void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {
+            return;
+        }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = createSegmentWriter();
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public List<Segment> finish() throws IOException {
+        if (currentWriter == null) {
+            return finishedSegments;
         }
 
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+        return finishedSegments;
     }
 
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
+    public List<Segment> getFinishedSegments() {
+        return finishedSegments;
+    }
 
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.isOnDisk()) {
+                fileSystem.delete(segment.getPath(), false);
+            }
+            if (segment.isCached()) {
+                assert segmentPool != null;
+                segmentPool.returnAll(segment.getCache());
+            }
         }
+        finishedSegments.clear();
+    }
 
-        public void addRecord(T record) throws IOException {
-            serializer.serialize(record, outputView);
-            currentSegmentCount += 1;
-        }
+    public void persistSegmentsToDisk() throws IOException {

Review Comment:
   It appears that we almost always call `finishCurrentSegmentIfAny()` before calling `persistSegmentsToDisk()`. Would it be simpler to just invoke `finishCurrentSegmentIfAny()` inside this method so that algorithm developers don't need to explicitly call `finishCurrentSegmentIfAny()`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,158 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
+    public void addRecord(T record) throws IOException {
+        assert currentWriter != null;
+        if (!currentWriter.addRecord(record)) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentWriter.addRecord(record));
+        }
     }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
-    }
+    /** Finishes current segment if records has ever been added to this segment. */
+    public void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {
+            return;
+        }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = createSegmentWriter();
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public List<Segment> finish() throws IOException {
+        if (currentWriter == null) {
+            return finishedSegments;
         }
 
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+        return finishedSegments;
     }
 
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
+    public List<Segment> getFinishedSegments() {
+        return finishedSegments;

Review Comment:
   It seems that we almost always call `finishCurrentSegmentIfAny()` before calling `getFinishedSegments()`. Would it be simpler to just invoke `finishCurrentSegmentIfAny()` inside this method so that algorithm developers don't need to explicitly call `finishCurrentSegmentIfAny()`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/SegmentWriter.java:
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+
+import java.io.IOException;
+import java.util.Optional;
+
+/** Writer for the data to be cached to a segment. */
+@Internal
+interface SegmentWriter<T> {
+    /** Adds a record to the writer. */

Review Comment:
   Can you add Java doc explaining the return value?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,158 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
+    public void addRecord(T record) throws IOException {
+        assert currentWriter != null;

Review Comment:
   After thinking about this more, I find it more consistent with other code to just remove the assert here. Note that the code will throw NullPointerException if `currentWriter  == null`, which is good enough for us to investigate this bug.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,158 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
+    public void addRecord(T record) throws IOException {
+        assert currentWriter != null;
+        if (!currentWriter.addRecord(record)) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentWriter.addRecord(record));
+        }
     }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
-    }
+    /** Finishes current segment if records has ever been added to this segment. */
+    public void finishCurrentSegmentIfAny() throws IOException {

Review Comment:
   It appears that this method is called when we increment epoch watermark, end input, snapshot state, or inside `processPendingElementsAndWatermarks()`.
   
   The first four cases happens at most regularly. The segment will most likely contain data. And in the rare cases when it does not contain data, the performance impact of re-creating an empty segment should be negligible given the frequency of these cases. 
   
   Due to the use of `hasPendingElements` inside `AbstractBroadcastWrapperOperator`, it appears that `processPendingElementsAndWatermarks()` is called only once for each `AbstractBroadcastWrapperOperator()` instance, which suggests that the overhead to re-create empty segment due to this method will also be pretty small.
   
   How about we remove the optimization related to `getCount()` to simplify the code, and only add it if we have concrete reason to believe it will improve performance?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,158 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
+    public void addRecord(T record) throws IOException {
+        assert currentWriter != null;
+        if (!currentWriter.addRecord(record)) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentWriter.addRecord(record));
+        }
     }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
-    }
+    /** Finishes current segment if records has ever been added to this segment. */
+    public void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {

Review Comment:
   It looks like we can make this method private and let `persistSegmentsToDisk/getFinishedSegments` call this method.
   
   And we can rename `getFinishedSegments()` as `getSegments()`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889666374


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final List<Segment> finishedSegments;
+
+    private SegmentWriter<T> currentWriter;
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        for (Segment segment : finishedSegments) {
+            tryCacheSegmentToMemory(segment);
+        }
+        this.currentWriter = createSegmentWriter();
+    }
+
+    public void addRecord(T record) throws IOException {
+        try {
+            currentWriter.addRecord(record);
+        } catch (SegmentNoVacancyException e) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            currentWriter.addRecord(record);
+        }
+    }
+
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public void finish() throws IOException {
+        if (currentWriter == null) {

Review Comment:
   Like described in the comment below, we might need to call `finish()` before reading from the data cache, which happens every time before we create a iterator. Thus finish() might be invoked multiple times. If we remove this statement, we would need to do like follows.
   ```java
       @Override
       public DataCacheIterator<T> iterator() {
          if (!finished){
               finish();
               finished=true;
         }
           return new DataCacheIterator<>(serializer, finishedSegments);
       }
   ```
   Thus we would still need the `if` condition, and it might be better to directly place it inside `finish()`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889667003


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCache.java:
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Records the data received and replays them on required. */
+@Internal
+public class DataCache<T> implements Iterable<T> {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final TypeSerializer<T> serializer;
+
+    private final FileSystem fileSystem;
+
+    private final SupplierWithException<Path, IOException> pathGenerator;
+
+    private final MemorySegmentPool segmentPool;
+
+    private final List<Segment> finishedSegments;
+
+    private SegmentWriter<T> currentWriter;
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
+
+    public DataCache(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        for (Segment segment : finishedSegments) {
+            tryCacheSegmentToMemory(segment);
+        }
+        this.currentWriter = createSegmentWriter();
+    }
+
+    public void addRecord(T record) throws IOException {
+        try {
+            currentWriter.addRecord(record);
+        } catch (SegmentNoVacancyException e) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            currentWriter.addRecord(record);
+        }
+    }
+
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public void finish() throws IOException {
+        if (currentWriter == null) {
+            return;
+        }
+
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+    }
+
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFileSegment() != null) {
+                fileSystem.delete(segment.getFileSegment().getPath(), false);
+            }
+            if (segment.getMemorySegment() != null) {
+                segmentPool.returnAll(segment.getMemorySegment().getCache());
+            }
+        }
+        finishedSegments.clear();
+    }
+
+    private void finishCurrentSegmentIfAny() throws IOException {

Review Comment:
   There is no such usage for now. I agree that we could put `finish()` an internal method and call it before reading.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890757304


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -254,58 +272,150 @@ public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value
                             DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
                     IterationListener<Tuple2<Integer, DenseVector>> {
         private final DistanceMeasure distanceMeasure;
-        private ListState<DenseVector> points;
-        private ListState<DenseVector[]> centroids;
+        private ListState<DenseVector[]> centroidsState;
+        private DenseVector[] centroids;

Review Comment:
   Does this improve performance by using `centroids`? If not, it seems simpler not to add this variable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890917094


##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java:
##########
@@ -20,29 +20,36 @@
 package org.apache.flink.ml.linalg.typeinfo;
 
 import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
-import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.util.Bits;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Objects;
 
 /** Specialized serializer for {@link DenseVector}. */
-public final class DenseVectorSerializer extends TypeSerializerSingleton<DenseVector> {
+public final class DenseVectorSerializer extends TypeSerializer<DenseVector> {
 
     private static final long serialVersionUID = 1L;
 
     private static final double[] EMPTY = new double[0];
 
-    public static final DenseVectorSerializer INSTANCE = new DenseVectorSerializer();
+    private final byte[] buf = new byte[1024];

Review Comment:
   According to offline discussion, BufferedOutputStream cannot avoid it that there will be a lot of small byte array read/write, and still have performance issues. We'll still use the internal-byte-buffer practice.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r874490030


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB

Review Comment:
   In Spark the size limit of each block (corresponds to the concept "segment" in this PR) seems to be 2GB[1]. Spark does not deliberately limits the size of each block's file, and the 2GB limit seems to be originated from `Integer.MAX_VALUE`.
   
   Flink seems not to have this issue so the possible limitations might just be from file systems (e.g., 4GB in FAT32). I'll remove `CACHE_FILE_SIZE_LIMIT`, while treat IOExceptions caused by file systems as reaching the file system's limit.
   
   [1] https://issues.apache.org/jira/browse/SPARK-6235



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r874553395


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   My experiments show that if we do `outputView = new DataOutputViewStreamWrapper(outputStream)`, the performance of disk IO would decrease obviously. One possible cause is that in `DataOutputViewStreamWrapper` or `FsDataOutputStream`'s subclass, the implementation flushes out bytes every time `write()` is invoked, as I can observe from the update frequency of `outputStream.getPos()`, but I have not dived further into this issue. Shall we leave this implementation as it is for now, and see if @zhipeng93 would offer a more simplified implementation with better performance? I'll leave a TODO here for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889952819


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing cached records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The count of the records in the segment. */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing cached records. */
+    private long fsSize = -1L;
+
+    /** The memory segments containing cached records. */
+    private List<MemorySegment> cache;
+
+    Segment(Path path, int count, long fsSize) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
+        this.count = count;
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
 
-    public Segment(Path path, int count, long size) {
-        this.path = path;
+    Segment(Path path, int count, List<MemorySegment> cache) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
         this.count = count;
-        this.size = size;
+        this.cache = checkNotNull(cache);
+    }
+
+    void setCache(List<MemorySegment> cache) {
+        this.cache = checkNotNull(cache);
     }
 
-    public Path getPath() {
+    void setDiskInfo(long fsSize) {

Review Comment:
   Would it be more intuitive to rename this method as `setFsSize()`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The pre-allocated path to hold cached records into file system. */
+    private final Path path;
+
+    /** The pool to allocate memory segments from. */
+    private final MemorySegmentPool segmentPool;
+
+    /** The output stream to write serialized content to memory segments. */
+    private final ManagedMemoryOutputStream outputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer,
+            Path path,
+            MemorySegmentPool segmentPool,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.path = path;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            if (e.getCause() instanceof MemoryAllocationException) {
+                return false;
+            }
+            throw e;
+        }
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(new Segment(path, count, outputStream.getSegments()));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    /** An output stream subclass that accepts bytes and writes them to memory segments. */
+    private static class ManagedMemoryOutputStream extends OutputStream {
+
+        /** The pool to allocate memory segments from. */
+        private final MemorySegmentPool segmentPool;
+
+        /** The number of bytes in a memory segment. */
+        private final int pageSize;
+
+        /** The memory segments containing written bytes. */
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        /** The index of the segment that currently accepts written bytes. */
+        private int segmentIndex;
+
+        /** THe number of bytes in the current segment that have been written. */
+        private int segmentOffset;
+
+        /** THe number of bytes that have been written so far. */
+        private long globalOffset;
+
+        public ManagedMemoryOutputStream(MemorySegmentPool segmentPool, long expectedSize)
+                throws MemoryAllocationException {
+            this.segmentPool = segmentPool;
+            this.pageSize = segmentPool.pageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            ensureCapacity(Math.max(expectedSize, 1L));
+        }
+
+        public long getPos() {
+            return globalOffset;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(@Nullable byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(long capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int required =
+                    (int) (capacity % pageSize == 0 ? capacity / pageSize : capacity / pageSize + 1)
+                            - segments.size();
+
+            List<MemorySegment> allocatedSegments = new ArrayList<>();
+            for (int i = 0; i < required; i++) {
+                MemorySegment memorySegment = segmentPool.nextSegment();
+                if (memorySegment == null) {
+                    segmentPool.returnAll(allocatedSegments);
+                    throw new MemoryAllocationException();
+                }
+                allocatedSegments.add(memorySegment);
+            }
+
+            segments.addAll(allocatedSegments);
+        }
+
+        private void writeRecursive(byte[] b, int off, int len) {

Review Comment:
   Could we change this function to be iterative to improve performance?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The pre-allocated path to hold cached records into file system. */
+    private final Path path;
+
+    /** The pool to allocate memory segments from. */
+    private final MemorySegmentPool segmentPool;
+
+    /** The output stream to write serialized content to memory segments. */
+    private final ManagedMemoryOutputStream outputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer,
+            Path path,
+            MemorySegmentPool segmentPool,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.path = path;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            if (e.getCause() instanceof MemoryAllocationException) {
+                return false;
+            }
+            throw e;
+        }
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(new Segment(path, count, outputStream.getSegments()));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    /** An output stream subclass that accepts bytes and writes them to memory segments. */
+    private static class ManagedMemoryOutputStream extends OutputStream {
+
+        /** The pool to allocate memory segments from. */
+        private final MemorySegmentPool segmentPool;
+
+        /** The number of bytes in a memory segment. */
+        private final int pageSize;
+
+        /** The memory segments containing written bytes. */
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        /** The index of the segment that currently accepts written bytes. */
+        private int segmentIndex;
+
+        /** THe number of bytes in the current segment that have been written. */
+        private int segmentOffset;
+
+        /** THe number of bytes that have been written so far. */
+        private long globalOffset;
+
+        public ManagedMemoryOutputStream(MemorySegmentPool segmentPool, long expectedSize)
+                throws MemoryAllocationException {
+            this.segmentPool = segmentPool;
+            this.pageSize = segmentPool.pageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            ensureCapacity(Math.max(expectedSize, 1L));
+        }
+
+        public long getPos() {
+            return globalOffset;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(@Nullable byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(long capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int required =
+                    (int) (capacity % pageSize == 0 ? capacity / pageSize : capacity / pageSize + 1)
+                            - segments.size();
+
+            List<MemorySegment> allocatedSegments = new ArrayList<>();
+            for (int i = 0; i < required; i++) {
+                MemorySegment memorySegment = segmentPool.nextSegment();
+                if (memorySegment == null) {
+                    segmentPool.returnAll(allocatedSegments);

Review Comment:
   Should we also return `segments` to the segmentPool? This could be useful when `write(...)` encounters limit.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+class MemorySegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The pre-allocated path to hold cached records into file system. */
+    private final Path path;
+
+    /** The pool to allocate memory segments from. */
+    private final MemorySegmentPool segmentPool;
+
+    /** The output stream to write serialized content to memory segments. */
+    private final ManagedMemoryOutputStream outputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    MemorySegmentWriter(
+            TypeSerializer<T> serializer,
+            Path path,
+            MemorySegmentPool segmentPool,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.path = path;
+        this.segmentPool = segmentPool;
+        this.outputStream = new ManagedMemoryOutputStream(segmentPool, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            if (e.getCause() instanceof MemoryAllocationException) {
+                return false;
+            }
+            throw e;
+        }
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(new Segment(path, count, outputStream.getSegments()));
+        } else {
+            segmentPool.returnAll(outputStream.getSegments());
+            return Optional.empty();
+        }
+    }
+
+    /** An output stream subclass that accepts bytes and writes them to memory segments. */
+    private static class ManagedMemoryOutputStream extends OutputStream {
+
+        /** The pool to allocate memory segments from. */
+        private final MemorySegmentPool segmentPool;
+
+        /** The number of bytes in a memory segment. */
+        private final int pageSize;
+
+        /** The memory segments containing written bytes. */
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        /** The index of the segment that currently accepts written bytes. */
+        private int segmentIndex;
+
+        /** THe number of bytes in the current segment that have been written. */
+        private int segmentOffset;
+
+        /** THe number of bytes that have been written so far. */
+        private long globalOffset;
+
+        public ManagedMemoryOutputStream(MemorySegmentPool segmentPool, long expectedSize)
+                throws MemoryAllocationException {
+            this.segmentPool = segmentPool;
+            this.pageSize = segmentPool.pageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            ensureCapacity(Math.max(expectedSize, 1L));
+        }
+
+        public long getPos() {
+            return globalOffset;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(@Nullable byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);

Review Comment:
   Instead of repeatedly allocate `List<MemorySegment>` instance and do divide operation for each write(...), would it be simpler to just maintain `allocatedBytes` and just check `globalOffset + len <= allocatedBytes` for must write() operation?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FileSegmentWriter.java:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** A class that writes cache data to a target file in given file system. */
+@Internal
+class FileSegmentWriter<T> implements SegmentWriter<T> {
+
+    /** The tool to serialize received records into bytes. */
+    private final TypeSerializer<T> serializer;
+
+    /** The path to the target file. */
+    private final Path path;
+
+    /** The output stream that writes to the target file. */
+    private final FSDataOutputStream outputStream;
+
+    /** A buffer that wraps the output stream to optimize performance. */
+    private final BufferedOutputStream bufferedOutputStream;
+
+    /** The wrapper view of output stream to be used with TypeSerializer API. */
+    private final DataOutputView outputView;
+
+    /** The number of records added so far. */
+    private int count;
+
+    FileSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.outputStream = path.getFileSystem().create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.bufferedOutputStream = new BufferedOutputStream(outputStream);
+        this.outputView = new DataOutputViewStreamWrapper(bufferedOutputStream);
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (outputStream.getPos() >= DataCacheWriter.MAX_SEGMENT_SIZE) {
+            return false;
+        }
+        serializer.serialize(record, outputView);
+        count++;
+        return true;
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        this.bufferedOutputStream.flush();

Review Comment:
   nits: we use `this` when the member variable name collide with the input parameter name. This typically happen in the constructor.  It looks like we don't need to use `this` here.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,91 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 
 import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
+    /** The tool to deserialize bytes into records. */
     private final TypeSerializer<T> serializer;
 
-    private final FileSystem fileSystem;
-
+    /** The segments where to read the records from. */
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    /** The current reader for next records. */
+    @Nullable private SegmentReader<T> currentReader;
 
-    public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+    /** The index of the segment that current reader reads from. */
+    private int segmentIndex;
+
+    /** The number of records that have been read through current reader so far. */
+    private int segmentCount;
+
+    public DataCacheReader(TypeSerializer<T> serializer, List<Segment> segments) {
+        this(serializer, segments, Tuple2.of(0, 0));
     }
 
     public DataCacheReader(
             TypeSerializer<T> serializer,
-            FileSystem fileSystem,
             List<Segment> segments,
-            Tuple2<Integer, Integer> readerPosition)
-            throws IOException {
-
+            Tuple2<Integer, Integer> readerPosition) {
         this.serializer = serializer;
-        this.fileSystem = fileSystem;
         this.segments = segments;
+        this.segmentIndex = readerPosition.f0;
+        this.segmentCount = readerPosition.f1;
 
-        if (readerPosition.f0 < segments.size()) {
-            this.currentSegmentReader = new SegmentReader(readerPosition.f0, readerPosition.f1);
-        }
+        createSegmentReader(readerPosition.f0, readerPosition.f1);

Review Comment:
   The `readerPosition.f1` refers to the number of records that have been read from the `DataCache`, right?
   
   But `createSegmentReader(...)` passes this position directly to the `MemorySegmentReader` constructor, which interprets this value as the number of records that should be skipped `within this segment`.
   
   Should this inconsistency be fixed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890693618


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,73 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing persisted records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /**
+     * The count of records in the file at the path if the file size is not zero, otherwise the
+     * count of records in the cache.
+     */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing persisted records. */
+    private long fsSize = -1L;
+
+    /** The memory segments containing cached records. */

Review Comment:
   Could you update the code explaining that `the cache list is empty iff the segment has not been cached in memory`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,73 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing persisted records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /**
+     * The count of records in the file at the path if the file size is not zero, otherwise the
+     * count of records in the cache.
+     */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing persisted records. */
+    private long fsSize = -1L;
+
+    /** The memory segments containing cached records. */
+    private List<MemorySegment> cache;
+
+    Segment(Path path, int count, long fsSize) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
+        this.count = count;
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
 
-    public Segment(Path path, int count, long size) {
-        this.path = path;
+    Segment(Path path, int count, List<MemorySegment> cache) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
         this.count = count;
-        this.size = size;
+        this.cache = checkNotNull(cache);
     }
 
-    public Path getPath() {
+    void setCache(List<MemorySegment> cache) {
+        this.cache = checkNotNull(cache);
+    }
+
+    void setFsSize(long fsSize) {
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
+
+    Path getPath() {
         return path;
     }
 
-    public int getCount() {
+    int getCount() {
         return count;
     }
 
-    public long getSize() {
-        return size;
+    long getFsSize() {
+        return fsSize;
+    }
+
+    List<MemorySegment> getCache() {
+        return cache;

Review Comment:
   Instead of indicating `the segment has been read into the memory` by using `cache == 0`, would it be simpler to use `cache.isEmpty()`?
   
   This would allow us to handle less special case (e.g. null). After all if the segment has been read into memory, it must have non-empty memory segment list.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,73 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing persisted records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /**
+     * The count of records in the file at the path if the file size is not zero, otherwise the
+     * count of records in the cache.
+     */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing persisted records. */
+    private long fsSize = -1L;

Review Comment:
   Would it be more intuitive to initialize the fsSize to 0?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,73 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing persisted records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /**
+     * The count of records in the file at the path if the file size is not zero, otherwise the
+     * count of records in the cache.
+     */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing persisted records. */

Review Comment:
   Could you update the the comment explaining that `fsSize is 0 iff the segment has not been written to the given path`.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -63,16 +98,18 @@ public boolean equals(Object o) {
         }
 
         Segment segment = (Segment) o;
-        return count == segment.count && size == segment.size && Objects.equals(path, segment.path);
+        return count == segment.count
+                && fsSize == segment.fsSize

Review Comment:
   Should a segment be uniquely identified by its count and path only?
   
   I suppose if we write segment data from memory to file, it should still represent the same segment, right?
   
   Same for `hashCode()`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 merged pull request #97: [FLINK-27096] Improve DataCache and KMeans performance

Posted by GitBox <gi...@apache.org>.
lindong28 merged PR #97:
URL: https://github.com/apache/flink-ml/pull/97


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889669964


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,61 +18,37 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
-import org.apache.flink.core.fs.Path;
+import org.apache.flink.annotation.Internal;
 
-import java.io.Serializable;
-import java.util.Objects;
+/** A segment contains the information about a cache unit. */
+@Internal
+class Segment {
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+    private FileSegment fileSegment;

Review Comment:
   We need to create the `Segment` instance before we persist the segment to disk and acquire enough information to create `FileSegment`. This means that the `fileSegment` field need to be set after the creation of `Segment`, and it cannot be final.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889930583


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,158 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
+    public void addRecord(T record) throws IOException {
+        assert currentWriter != null;
+        if (!currentWriter.addRecord(record)) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentWriter.addRecord(record));
+        }
     }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
-    }
+    /** Finishes current segment if records has ever been added to this segment. */
+    public void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {
+            return;
+        }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = createSegmentWriter();
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public List<Segment> finish() throws IOException {
+        if (currentWriter == null) {
+            return finishedSegments;
         }
 
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+        return finishedSegments;
     }
 
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
+    public List<Segment> getFinishedSegments() {
+        return finishedSegments;
+    }
 
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.isOnDisk()) {

Review Comment:
   Suppose `fsSize == 0`, should we still invoke `fileSystem.delete(...)` to delete the empty file?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing cached records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The count of the records in the segment. */

Review Comment:
   By `in the segment`, do you mean `in the file` or `in the memory`? Could you explain this in the Java doc?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing cached records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The count of the records in the segment. */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing cached records. */
+    private long fsSize = -1L;
+
+    /** The memory segments containing cached records. */
+    private List<MemorySegment> cache;
+
+    Segment(Path path, int count, long fsSize) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
+        this.count = count;
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
 
-    public Segment(Path path, int count, long size) {
-        this.path = path;
+    Segment(Path path, int count, List<MemorySegment> cache) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
         this.count = count;
-        this.size = size;
+        this.cache = checkNotNull(cache);
+    }
+
+    void setCache(List<MemorySegment> cache) {
+        this.cache = checkNotNull(cache);
     }
 
-    public Path getPath() {
+    void setDiskInfo(long fsSize) {
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
+
+    boolean isOnDisk() {

Review Comment:
   Suppose `getFsSize()` can be used with `fsSize == 0`, would it be simpler to remove this method and let caller use `getFsSize() > 0`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,158 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
+    public void addRecord(T record) throws IOException {
+        assert currentWriter != null;
+        if (!currentWriter.addRecord(record)) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentWriter.addRecord(record));
+        }
     }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
-    }
+    /** Finishes current segment if records has ever been added to this segment. */
+    public void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null || currentWriter.getCount() == 0) {
+            return;
+        }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = createSegmentWriter();
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    /** Finishes adding records and closes resources occupied for adding records. */
+    public List<Segment> finish() throws IOException {
+        if (currentWriter == null) {
+            return finishedSegments;
         }
 
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+        return finishedSegments;
     }
 
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
+    public List<Segment> getFinishedSegments() {
+        return finishedSegments;
+    }
 
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.isOnDisk()) {
+                fileSystem.delete(segment.getPath(), false);
+            }
+            if (segment.isCached()) {
+                assert segmentPool != null;

Review Comment:
   After thinking about this more, I find it more consistent with other code to just remove the assert here. Note that the code will throw NullPointerException if segmentPool == null, which is good enough for us to investigate this bug.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing cached records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The count of the records in the segment. */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing cached records. */
+    private long fsSize = -1L;
+
+    /** The memory segments containing cached records. */
+    private List<MemorySegment> cache;
+
+    Segment(Path path, int count, long fsSize) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
+        this.count = count;
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
 
-    public Segment(Path path, int count, long size) {
-        this.path = path;
+    Segment(Path path, int count, List<MemorySegment> cache) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
         this.count = count;
-        this.size = size;
+        this.cache = checkNotNull(cache);
+    }
+
+    void setCache(List<MemorySegment> cache) {
+        this.cache = checkNotNull(cache);
     }
 
-    public Path getPath() {
+    void setDiskInfo(long fsSize) {
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
+
+    boolean isOnDisk() {
+        return fsSize > 0;
+    }
+
+    boolean isCached() {
+        return cache != null;
+    }
+
+    Path getPath() {
         return path;
     }
 
-    public int getCount() {
+    int getCount() {
         return count;
     }
 
-    public long getSize() {
-        return size;
+    long getFsSize() {
+        checkState(fsSize > 0);
+        return fsSize;
+    }
+
+    List<MemorySegment> getCache() {
+        return checkNotNull(cache);

Review Comment:
   For `getXXX()` method, it is in general simpler and more intuitive to just return the fact (e.g. null or empty list) instead of throwing exception.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing cached records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The count of the records in the segment. */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing cached records. */
+    private long fsSize = -1L;
+
+    /** The memory segments containing cached records. */
+    private List<MemorySegment> cache;
+
+    Segment(Path path, int count, long fsSize) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
+        this.count = count;
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
 
-    public Segment(Path path, int count, long size) {
-        this.path = path;
+    Segment(Path path, int count, List<MemorySegment> cache) {
+        this.path = checkNotNull(path);
+        checkArgument(count > 0);
         this.count = count;
-        this.size = size;
+        this.cache = checkNotNull(cache);
+    }
+
+    void setCache(List<MemorySegment> cache) {
+        this.cache = checkNotNull(cache);
     }
 
-    public Path getPath() {
+    void setDiskInfo(long fsSize) {
+        checkArgument(fsSize > 0);
+        this.fsSize = fsSize;
+    }
+
+    boolean isOnDisk() {
+        return fsSize > 0;
+    }
+
+    boolean isCached() {
+        return cache != null;
+    }
+
+    Path getPath() {
         return path;
     }
 
-    public int getCount() {
+    int getCount() {
         return count;
     }
 
-    public long getSize() {
-        return size;
+    long getFsSize() {
+        checkState(fsSize > 0);

Review Comment:
   Suppose the input datastream does not have any data, we should still be able to snapshot/persist this data stream onto disk, and be able to reload this datastream after restarting the Flink job, right?
   
   Then we will need to support a segment with `fsSize = 0`.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
 
-import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+    /** The path to the file containing cached records. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The count of the records in the segment. */
     private final int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The total length of file containing cached records. */

Review Comment:
   nits: the records in the file is not `cached`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r889994816


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,91 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 
 import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
+    /** The tool to deserialize bytes into records. */
     private final TypeSerializer<T> serializer;
 
-    private final FileSystem fileSystem;
-
+    /** The segments where to read the records from. */
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    /** The current reader for next records. */
+    @Nullable private SegmentReader<T> currentReader;

Review Comment:
   nits: since the class name is `DataCacheReader`, could we rename this parameter as `currentSegmentReader` to minimize confusion?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,162 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;

Review Comment:
   nits: could we rename this variable as `currentSegmentWriter`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,91 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 
 import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
+    /** The tool to deserialize bytes into records. */
     private final TypeSerializer<T> serializer;
 
-    private final FileSystem fileSystem;
-
+    /** The segments where to read the records from. */
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    /** The current reader for next records. */
+    @Nullable private SegmentReader<T> currentReader;
 
-    public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+    /** The index of the segment that current reader reads from. */
+    private int segmentIndex;

Review Comment:
   nits: how about renaming it as `currentSegmentIndex` to make the name more consistent with other variables?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,91 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 
 import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
+    /** The tool to deserialize bytes into records. */
     private final TypeSerializer<T> serializer;
 
-    private final FileSystem fileSystem;
-
+    /** The segments where to read the records from. */
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    /** The current reader for next records. */
+    @Nullable private SegmentReader<T> currentReader;
 
-    public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+    /** The index of the segment that current reader reads from. */
+    private int segmentIndex;
+
+    /** The number of records that have been read through current reader so far. */
+    private int segmentCount;

Review Comment:
   nits: to avoid confusing this count with the number of records that have been read by this data cache, how about renaming it as `curentSegmentCount`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,162 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
+        this(serializer, fileSystem, pathGenerator, segmentPool, Collections.emptyList());
+    }
 
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this.fileSystem = fileSystem;
+        this.pathGenerator = pathGenerator;
+        this.segmentPool = segmentPool;
+        this.serializer = serializer;
+        this.finishedSegments = new ArrayList<>();
+        this.finishedSegments.addAll(finishedSegments);
+        this.currentWriter = createSegmentWriter();
     }
 
-    public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+    public void addRecord(T record) throws IOException {
+        if (!currentWriter.addRecord(record)) {
+            currentWriter.finish().ifPresent(finishedSegments::add);
+            currentWriter = new FileSegmentWriter<>(serializer, pathGenerator.get());
+            Preconditions.checkState(currentWriter.addRecord(record));
+        }
     }
 
+    /** Finishes adding records and closes resources occupied for adding records. */
     public List<Segment> finish() throws IOException {
-        finishCurrentSegment(false);
-        return finishSegments;
-    }
+        if (currentWriter == null) {
+            return finishedSegments;
+        }
 
-    public FileSystem getFileSystem() {
-        return fileSystem;
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = null;
+        return finishedSegments;
     }
 
-    public List<Segment> getFinishSegments() {
-        return finishSegments;
+    /**
+     * Flushes all added records to segments and returns a list of segments containing all cached
+     * records.
+     */
+    public List<Segment> getSegments() throws IOException {
+        finishCurrentSegmentIfAny();
+        return finishedSegments;
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
+    private void finishCurrentSegmentIfAny() throws IOException {
+        if (currentWriter == null) {
+            return;
         }
 
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
+        currentWriter.finish().ifPresent(finishedSegments::add);
+        currentWriter = createSegmentWriter();
     }
 
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
-
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
+    /** Cleans up all previously added records. */
+    public void cleanup() throws IOException {
+        finishCurrentSegmentIfAny();
+        for (Segment segment : finishedSegments) {
+            if (segment.getFsSize() > 0) {
+                fileSystem.delete(segment.getPath(), false);
+            }
+            if (segment.getCache() != null) {
+                segmentPool.returnAll(segment.getCache());
+            }
         }
+        finishedSegments.clear();
+    }
 
-        public void addRecord(T record) throws IOException {
-            serializer.serialize(record, outputView);
-            currentSegmentCount += 1;
-        }
+    /** Persists the segments in this writer to disk. */
+    public void persistSegmentsToDisk() throws IOException {

Review Comment:
   nits: given that we use `FileSegmentWriter`, would it be more consistent to rename this method as `writeSegmentsToFiles()`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans performance

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891853412


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -256,4 +325,79 @@ public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> collector) {
             }
         }
     }
+
+    /*
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<T>
+            implements OneInputStreamOperator<T, T>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>(numSamples);
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
+            countState = context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws Exception {
+            T value = streamRecord.getValue();
+            count++;
+
+            if (samples.size() < numSamples) {
+                samples.add(value);
+            } else {
+                if (random.nextInt(count) < numSamples) {

Review Comment:
   nits: could we re-use `random.nextInt(count)` for simplicity?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org