You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/06/01 09:02:54 UTC

[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #97: [FLINK-27096] Improve DataCache and KMeans Performance

zhipeng93 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r885138130


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;

Review Comment:
   `countState` seems unnecessary since we can get the count from `samples.size()`



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
+            countState = context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws Exception {
+            T sample = streamRecord.getValue();
+            count++;
+
+            // Code below is inspired by the Reservoir Sampling algorithm.
+            if (samples.size() < numSamples) {
+                samples.add(sample);
+            } else {
+                if (random.nextInt(count) < numSamples) {
+                    samples.set(random.nextInt(numSamples), sample);
+                }
+            }
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            Collections.shuffle(samples, random);

Review Comment:
   Is shuffle necessary here?



##########
flink-ml-dist/pom.xml:
##########
@@ -55,6 +55,15 @@ under the License.
             <scope>compile</scope>
         </dependency>
 
+        <!-- Java Object Layout Dependencies -->
+
+        <dependency>

Review Comment:
   Is this still necessary?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, getClass().getClassLoader()));
+            samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();

Review Comment:
   nit: new ArrayList<>(numSamples).



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.

Review Comment:
   There are many different sampling methods and here we implement one kind of them. So could you update the java doc to explain that we are doing uniform sampling without replacement? 



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java:
##########
@@ -75,9 +76,11 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept
 
         final int len = vector.values.length;
         target.writeInt(len);
+        ByteBuffer buffer = ByteBuffer.allocate(len << 3);

Review Comment:
   We should probably follow the implementation in `ObjectOutputStream#writeInts`. The current implementation introduces some extra issues:
   - Allocating one bytebuffer for each instance may introduces more garbage collection. 
   - It may leads to OOM when the densevector is large.
   - Suppose we are writing a DenseVector to using a memory segment: after writing `len` to memory and the memorySegment runs out of memory. Then the content is written to a FsSegment. Does this still work correctly?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -90,18 +90,18 @@ public void writeTo(OutputStream checkpointOutputStream) throws IOException {
             }
 
             dos.writeBoolean(fileSystem.isDistributedFS());
+            for (Segment segment : segments) {
+                persistSegmentToDisk(segment);
+            }
             if (fileSystem.isDistributedFS()) {
                 // We only need to record the segments itself
                 serializeSegments(segments, dos);
             } else {
                 // We have to copy the whole streams.
-                int totalRecords = segments.stream().mapToInt(Segment::getCount).sum();
-                long totalSize = segments.stream().mapToLong(Segment::getSize).sum();
-                checkState(totalRecords >= 0, "overflowed: " + totalRecords);
-                dos.writeInt(totalRecords);
-                dos.writeLong(totalSize);
-
+                dos.writeInt(segments.size());
                 for (Segment segment : segments) {
+                    dos.writeInt(segment.getCount());

Review Comment:
   Why do you write the size/count of each segment rather than treat them as a whole (the original version)?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {

Review Comment:
   Can we allocate the segments aggressively? Even if the expected size is zero, we still allocate all the segments here since the memory is allocated statically.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {
+                int numPages = (int) ((expectedSize + pageSize - 1) / pageSize);
+                segments.addAll(memoryManager.allocatePages(getKey(), numPages));
+            }
+        }
+
+        public Object getKey() {
+            return key;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(int capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int requiredSegmentNum = (capacity - 1) / pageSize + 2 - segments.size();
+            if (requiredSegmentNum > 0) {
+                segments.addAll(memoryManager.allocatePages(getKey(), requiredSegmentNum));
+            }
+        }
+
+        private void writeRecursive(byte[] b, int off, int len) {

Review Comment:
   If `len` is zero, it seems to throw an exception here.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java:
##########
@@ -75,9 +76,11 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept
 
         final int len = vector.values.length;
         target.writeInt(len);
+        ByteBuffer buffer = ByteBuffer.allocate(len << 3);

Review Comment:
   Moreover, should we also update SparseVectorSerializer? I am also ok to leave to TODO there.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,104 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
-/** Records the data received and replayed them on required. */
+/** Records the data received and replays them on required. */
 public class DataCacheWriter<T> {
 
-    private final TypeSerializer<T> serializer;
-
     private final FileSystem fileSystem;
 
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final TypeSerializer<T> serializer;
+
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    private SegmentWriter<T> currentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
-            SupplierWithException<Path, IOException> pathGenerator)
+            SupplierWithException<Path, IOException> pathGenerator,
+            LimitedSizeMemoryManager memoryManager)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, memoryManager, Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
+            LimitedSizeMemoryManager memoryManager,
             List<Segment> priorFinishedSegments)
             throws IOException {
         this.serializer = serializer;
         this.fileSystem = fileSystem;
         this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
-
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+        this.memoryManager = memoryManager;
+        this.finishedSegments = new ArrayList<>(priorFinishedSegments);
+        this.currentWriter =
+                SegmentWriter.create(
+                        pathGenerator.get(), this.memoryManager, serializer, 0L, true, true);
     }
 
     public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+        boolean success = currentWriter.addRecord(record);

Review Comment:
   Since the writer may not be a atomic one, is there a case that `size of  one vector is written to both memory and disk`?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final LimitedSizeMemoryManager memoryManager;
+
+    private final Path path;
+
+    private final TypeSerializer<T> serializer;
+
+    private final ManagedMemoryOutputStream outputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public MemorySegmentWriter(
+            Path path,
+            LimitedSizeMemoryManager memoryManager,
+            TypeSerializer<T> serializer,
+            long expectedSize)
+            throws MemoryAllocationException {
+        this.serializer = serializer;
+        this.memoryManager = Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.outputStream = new ManagedMemoryOutputStream(memoryManager, expectedSize);
+        this.outputView = new DataOutputViewStreamWrapper(outputStream);
+        this.count = 0;
+    }
+
+    @Override
+    public boolean addRecord(T record) {
+        try {
+            serializer.serialize(record, outputView);
+            count++;
+            return true;
+        } catch (IOException e) {
+            return false;
+        }
+    }
+
+    @Override
+    public int getCount() {
+        return this.count;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Optional<Segment> finish() throws IOException {
+        if (count > 0) {
+            return Optional.of(
+                    new Segment(
+                            path,
+                            count,
+                            outputStream.getKey(),
+                            outputStream.getSegments(),
+                            (TypeSerializer<Object>) serializer));
+        } else {
+            memoryManager.releaseAll(outputStream.getKey());
+            return Optional.empty();
+        }
+    }
+
+    private static class ManagedMemoryOutputStream extends OutputStream {
+        private final LimitedSizeMemoryManager memoryManager;
+
+        private final int pageSize;
+
+        private final Object key = new Object();
+
+        private final List<MemorySegment> segments = new ArrayList<>();
+
+        private int segmentIndex;
+
+        private int segmentOffset;
+
+        private int globalOffset;
+
+        public ManagedMemoryOutputStream(LimitedSizeMemoryManager memoryManager, long expectedSize)
+                throws MemoryAllocationException {
+            this.memoryManager = memoryManager;
+            this.pageSize = memoryManager.getPageSize();
+            this.segmentIndex = 0;
+            this.segmentOffset = 0;
+
+            Preconditions.checkArgument(expectedSize >= 0);
+            if (expectedSize > 0) {
+                int numPages = (int) ((expectedSize + pageSize - 1) / pageSize);
+                segments.addAll(memoryManager.allocatePages(getKey(), numPages));
+            }
+        }
+
+        public Object getKey() {
+            return key;
+        }
+
+        public List<MemorySegment> getSegments() {
+            return segments;
+        }
+
+        @Override
+        public void write(int b) throws IOException {
+            write(new byte[] {(byte) b}, 0, 1);
+        }
+
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            try {
+                ensureCapacity(globalOffset + len);
+            } catch (MemoryAllocationException e) {
+                throw new IOException(e);
+            }
+            writeRecursive(b, off, len);
+        }
+
+        private void ensureCapacity(int capacity) throws MemoryAllocationException {
+            Preconditions.checkArgument(capacity > 0);
+            int requiredSegmentNum = (capacity - 1) / pageSize + 2 - segments.size();

Review Comment:
   Is it `(capacity - 1) / pageSize + 1 - segments.size()`?  Moreover, do you think the following code is easier to understand?
   
   ```
   private void ensureCapacity(int capacity) throws MemoryAllocationException {
       if (capacity > pageSize * segments.size()) {
           int requiredSegmentNum = capacity % pageSize == 0 ? capacity / pageSize : capacity / pageSize + 1;
           if (requiredSegmentNum > segments.size()) {
               segments.addAll(memoryManager.allocatePages(getKey(), requiredSegmentNum - segments.size()));
   	}
       }		
   }
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org