You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/05/11 14:39:42 UTC

[GitHub] [flink-ml] lindong28 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r869946600


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,65 +19,80 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
-/** Records the data received and replayed them on required. */
+/** Records the data received and replays them on required. */
 public class DataCacheWriter<T> {
 
-    private final TypeSerializer<T> serializer;
-
     private final FileSystem fileSystem;
 
     private final SupplierWithException<Path, IOException> pathGenerator;
 
+    private final MemoryManager memoryManager;
+
+    private final TypeSerializer<T> serializer;
+
     private final List<Segment> finishSegments;
 
-    private SegmentWriter currentSegment;
+    private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
-            SupplierWithException<Path, IOException> pathGenerator)
+            SupplierWithException<Path, IOException> pathGenerator,
+            MemoryManager memoryManager)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, memoryManager, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
+            MemoryManager memoryManager,
             List<Segment> priorFinishedSegments)
             throws IOException {
         this.serializer = serializer;
         this.fileSystem = fileSystem;
         this.pathGenerator = pathGenerator;
-
+        this.memoryManager = memoryManager;
         this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this.currentWriter = createSegmentWriter(pathGenerator, this.memoryManager);
     }
 
     public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+        boolean success = currentWriter.addRecord(record);
+        if (!success) {
+            finishCurrentSegment();
+            success = currentWriter.addRecord(record);
+            Preconditions.checkState(success);
+        }
     }
 
     public void finishCurrentSegment() throws IOException {
-        finishCurrentSegment(true);
+        if (currentWriter != null) {
+            currentWriter.finish().ifPresent(finishSegments::add);
+            currentWriter = null;

Review Comment:
   nits: this line is redundant.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   Is it possible to remove `objectOutputStream` and `byteArrayOutputStream`, and just use `outputView = new DataOutputViewStreamWrapper(outputStream)`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -253,58 +261,125 @@ public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> value
                             DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>,
                     IterationListener<Tuple2<Integer, DenseVector>> {
         private final DistanceMeasure distanceMeasure;
-        private ListState<DenseVector> points;
-        private ListState<DenseVector[]> centroids;
+        private ListState<DenseVector[]> centroidsState;
+        private DenseVector[] centroids;
+
+        private Path basePath;
+        private StreamConfig config;
+        private DataCacheWriter<DenseVector> dataCacheWriter;
 
         public SelectNearestCentroidOperator(DistanceMeasure distanceMeasure) {
+            super();
             this.distanceMeasure = distanceMeasure;
         }
 
+        @Override
+        public void setup(
+                StreamTask<?, ?> containingTask,
+                StreamConfig config,
+                Output<StreamRecord<Tuple2<Integer, DenseVector>>> output) {
+            super.setup(containingTask, config, output);
+
+            basePath =
+                    OperatorUtils.getDataCachePath(
+                            containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                            containingTask
+                                    .getEnvironment()
+                                    .getIOManager()
+                                    .getSpillingDirectoriesPaths());
+
+            this.config = config;
+        }
+
         @Override
         public void initializeState(StateInitializationContext context) throws Exception {
             super.initializeState(context);
-            points =
-                    context.getOperatorStateStore()
-                            .getListState(new ListStateDescriptor<>("points", DenseVector.class));
 
             TypeInformation<DenseVector[]> type =
                     ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
-            centroids =
+            centroidsState =
                     context.getOperatorStateStore()
                             .getListState(new ListStateDescriptor<>("centroids", type));
+            centroids =
+                    OperatorStateUtils.getUniqueElement(centroidsState, "centroids").orElse(null);
+
+            List<StatePartitionStreamProvider> inputs =
+                    IteratorUtils.toList(context.getRawOperatorStateInputs().iterator());
+            Preconditions.checkState(
+                    inputs.size() < 2, "The input from raw operator state should be one or zero.");
+
+            List<Segment> priorFinishedSegments = new ArrayList<>();
+            if (inputs.size() > 0) {
+                InputStream inputStream = inputs.get(0).getStream();
+
+                DataCacheSnapshot dataCacheSnapshot =
+                        DataCacheSnapshot.recover(
+                                inputStream,
+                                basePath.getFileSystem(),
+                                OperatorUtils.createDataCacheFileGenerator(
+                                        basePath, "cache", config.getOperatorID()));
+
+                priorFinishedSegments = dataCacheSnapshot.getSegments();
+            }
+
+            dataCacheWriter =
+                    new DataCacheWriter<>(
+                            DenseVectorSerializer.INSTANCE,
+                            basePath.getFileSystem(),
+                            OperatorUtils.createDataCacheFileGenerator(
+                                    basePath, "cache", config.getOperatorID()),
+                            getContainingTask().getEnvironment().getMemoryManager(),
+                            priorFinishedSegments);
         }
 
         @Override
         public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception {
-            points.add(streamRecord.getValue());
+            dataCacheWriter.addRecord(streamRecord.getValue());
+            if (centroids != null) {
+                DenseVector point = streamRecord.getValue();
+                int closestCentroidId = findClosestCentroidId(centroids, point, distanceMeasure);
+                output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
+            }
         }
 
         @Override
         public void processElement2(StreamRecord<DenseVector[]> streamRecord) throws Exception {
-            centroids.add(streamRecord.getValue());
+            Preconditions.checkState(centroids == null);
+            centroidsState.add(streamRecord.getValue());
+            centroids = streamRecord.getValue();
+
+            dataCacheWriter.finishCurrentSegment();

Review Comment:
   It seem inefficient to create and delete an empty in each round. How about we optimize `finishCurrentSegment()` (maybe rename it) such that it does not create empty file if `count == 0`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 100MB

Review Comment:
   Why do we choose 100MB as the segment size? How does it compare with Spark ML's approach?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)

Review Comment:
   Do we need the `expectedSize` and `MemoryAllocationException` here?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,121 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-
-import javax.annotation.Nullable;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
-    private final TypeSerializer<T> serializer;
+    private final MemoryManager memoryManager;
 
-    private final FileSystem fileSystem;
+    private final TypeSerializer<T> serializer;
 
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    private SegmentReader<T> currentReader;
+
+    private MemorySegmentWriter<T> cacheWriter;
+
+    private int segmentIndex;
 
     public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+            TypeSerializer<T> serializer, MemoryManager memoryManager, List<Segment> segments) {
+        this(serializer, memoryManager, segments, new Tuple2<>(0, 0));
     }
 
     public DataCacheReader(
             TypeSerializer<T> serializer,
-            FileSystem fileSystem,
+            MemoryManager memoryManager,
             List<Segment> segments,
-            Tuple2<Integer, Integer> readerPosition)
-            throws IOException {
-
+            Tuple2<Integer, Integer> readerPosition) {
+        this.memoryManager = memoryManager;
         this.serializer = serializer;
-        this.fileSystem = fileSystem;
         this.segments = segments;
+        this.segmentIndex = readerPosition.f0;
+
+        createSegmentReaderAndCache(readerPosition.f0, readerPosition.f1);
+    }
+
+    private void createSegmentReaderAndCache(int index, int startOffset) {
+        try {
+            cacheWriter = null;
 
-        if (readerPosition.f0 < segments.size()) {
-            this.currentSegmentReader = new SegmentReader(readerPosition.f0, readerPosition.f1);
+            if (index >= segments.size()) {
+                currentReader = null;
+                return;
+            }
+
+            currentReader = SegmentReader.create(serializer, segments.get(index), startOffset);
+
+            boolean shouldCacheInMemory =
+                    startOffset == 0
+                            && currentReader instanceof FsSegmentReader
+                            && MemoryUtils.isMemoryEnoughForCache(memoryManager);
+
+            if (shouldCacheInMemory) {
+                cacheWriter =
+                        new MemorySegmentWriter<>(
+                                segments.get(index).path,
+                                memoryManager,
+                                segments.get(index).inMemorySize);
+            }
+        } catch (MemoryAllocationException e) {
+            cacheWriter = null;
+        } catch (IOException e) {
+            throw new RuntimeException(e);
         }
     }
 
     @Override
     public boolean hasNext() {
-        return currentSegmentReader != null && currentSegmentReader.hasNext();
+        return currentReader != null && currentReader.hasNext();
     }
 
     @Override
     public T next() {
         try {
-            T next = currentSegmentReader.next();
-
-            if (!currentSegmentReader.hasNext()) {
-                currentSegmentReader.close();
-                if (currentSegmentReader.index < segments.size() - 1) {
-                    currentSegmentReader = new SegmentReader(currentSegmentReader.index + 1, 0);
-                } else {
-                    currentSegmentReader = null;
+            T record = currentReader.next();
+
+            if (cacheWriter != null) {
+                if (!cacheWriter.addRecord(record)) {
+                    cacheWriter
+                            .finish()
+                            .ifPresent(x -> memoryManager.releaseMemory(x.path, x.inMemorySize));
+                    cacheWriter = null;
+                }
+            }
+
+            if (!currentReader.hasNext()) {
+                currentReader.close();
+                if (cacheWriter != null) {
+                    cacheWriter
+                            .finish()
+                            .ifPresent(
+                                    x -> {
+                                        x.fsSize = segments.get(segmentIndex).fsSize;
+                                        segments.set(segmentIndex, x);

Review Comment:
   This method will replace a file-only segment with in-memory segment. It is a bit counter-intuitive to make such changes in a XXXReader class.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemoryUtils.java:
##########
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.memory.MemoryManager;
+
+/** Utility variables and methods for memory operation. */
+@Internal
+class MemoryUtils {
+    // Cache is not suggested if over 80% of memory has been occupied.
+    private static final double CACHE_MEMORY_THRESHOLD = 0.2;

Review Comment:
   How does Spark ML decide when to switch from memory to disk? Does Spark also use such a magic number?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long expectedSize)
+            throws MemoryAllocationException {
+        Preconditions.checkNotNull(memoryManager);
+        this.segment = new Segment();
+        this.segment.path = path;
+        this.segment.cache = new ArrayList<>();
+        this.segment.inMemorySize = 0L;
+        this.memoryManager = memoryManager;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (!MemoryUtils.isMemoryEnoughForCache(memoryManager)) {
+            return false;
+        }
+
+        long recordSize = GraphLayout.parseInstance(record).totalSize();

Review Comment:
   Does Spark ML also estimate the record size in this way? Given that we invoke this method for every record, we probably need to make sure its overhead is low enough even for small records. 



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */
 public class Segment implements Serializable {
 
-    private final Path path;
+    /** The pre-allocated path on disk to persist the records. */
+    Path path;

Review Comment:
   It is not easy to understand the usage of this class after changing its variables to be non-final public. The code looks a bit hacky.
   
   Is there anyway to improve it, e.g. still using private final member variables?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -89,57 +104,31 @@ public List<Segment> getFinishSegments() {
         return finishSegments;
     }
 
-    private void finishCurrentSegment(boolean newSegment) throws IOException {
-        if (currentSegment != null) {
-            currentSegment.finish().ifPresent(finishSegments::add);
-            currentSegment = null;
-        }
-
-        if (newSegment) {
-            currentSegment = new SegmentWriter(pathGenerator.get());
-        }
-    }
-
-    private class SegmentWriter {
-
-        private final Path path;
-
-        private final FSDataOutputStream outputStream;
-
-        private final DataOutputView outputView;
-
-        private int currentSegmentCount;
-
-        public SegmentWriter(Path path) throws IOException {
-            this.path = path;
-            this.outputStream = fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE);
-            this.outputView = new DataOutputViewStreamWrapper(outputStream);
-        }
-
-        public void addRecord(T record) throws IOException {
-            serializer.serialize(record, outputView);
-            currentSegmentCount += 1;
-        }
+    private SegmentWriter<T> createSegmentWriter(
+            SupplierWithException<Path, IOException> pathGenerator, MemoryManager memoryManager)
+            throws IOException {
+        boolean shouldCacheInMemory = MemoryUtils.isMemoryEnoughForCache(memoryManager);
 
-        public Optional<Segment> finish() throws IOException {
-            this.outputStream.flush();
-            long size = outputStream.getPos();
-            this.outputStream.close();
-
-            if (currentSegmentCount > 0) {
-                return Optional.of(new Segment(path, currentSegmentCount, size));
-            } else {
-                // If there are no records, we tend to directly delete this file
-                fileSystem.delete(path, false);
-                return Optional.empty();
+        if (shouldCacheInMemory) {
+            try {
+                return new MemorySegmentWriter<>(pathGenerator.get(), memoryManager);
+            } catch (MemoryAllocationException e) {
+                return new FsSegmentWriter<>(serializer, pathGenerator.get());
             }
         }
+        return new FsSegmentWriter<>(serializer, pathGenerator.get());
     }
 
     public void cleanup() throws IOException {
-        finishCurrentSegment();
+        finish();
         for (Segment segment : finishSegments) {
-            fileSystem.delete(segment.getPath(), false);
+            if (segment.isOnDisk()) {
+                fileSystem.delete(segment.path, false);
+            }
+            if (segment.isInMemory()) {
+                memoryManager.releaseMemory(segment.path, segment.inMemorySize);

Review Comment:
   This call basically subtracts `segment.inMemorySize` owned by `segment.path` in memoryManager's bookkeeping.
   
   I suppose there should be a corresponding `reserveMemory(...)` with owner = `segment.path`. Where is that call?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/ForwardInputsOfLastRound.java:
##########
@@ -42,19 +41,13 @@ public void flatMap(T value, Collector<T> out) {
 
     @Override
     public void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> out) {
-        valuesInLastEpoch = valuesInCurrentEpoch;
-        valuesInCurrentEpoch = new ArrayList<>();
+        valuesInCurrentEpoch.clear();
     }
 
     @Override
     public void onIterationTerminated(Context context, Collector<T> out) {
-        for (T value : valuesInLastEpoch) {

Review Comment:
   Prior to this PR, when the following sequence of statements are called, two values would be collected. After this PR, no values would be collected. In other words, it changes the behavior of public APIs introduced in FLIP-176. Is it expected?
   
   ```
   flatMap(x1, ...)
   
   flatMap(x2, ...)
   
   onEpochWatermarkIncremented(...)
   
   onIterationTerminated(...)
   ```



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */

Review Comment:
   After this PR, the Java doc says this class represents a cache unit. But the implementation shows that this class could represent a file, which seems confusing. Could you clarify it?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org