You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by gn...@apache.org on 2021/06/14 07:02:52 UTC

[mina-sshd] branch master updated: [SSHD-1181] Fix sftp file downloads when using the server uses the EOF indicator

This is an automated email from the ASF dual-hosted git repository.

gnodet pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git


The following commit(s) were added to refs/heads/master by this push:
     new 3db721d  [SSHD-1181] Fix sftp file downloads when using the server uses the EOF indicator
3db721d is described below

commit 3db721d9a109d9aa80b3a662c43b57458acd99e4
Author: Guillaume Nodet <gn...@gmail.com>
AuthorDate: Mon Jun 14 09:02:45 2021 +0200

    [SSHD-1181] Fix sftp file downloads when using the server uses the EOF indicator
    
    * Use a single method for the read logic in SftpInputStreamAsync
    * Add a small file test
    * Add infrastructure to be able to send the eof indicator
    * Correctly support the eof indicator in all cases, fixes sshd-1181
---
 .../sshd/sftp/client/impl/AbstractSftpClient.java  |  10 --
 .../sftp/client/impl/SftpInputStreamAsync.java     | 127 +++++++++------------
 .../sftp/server/AbstractSftpSubsystemHelper.java   |  15 ++-
 .../org/apache/sshd/sftp/server/FileHandle.java    |  15 ++-
 .../org/apache/sshd/sftp/server/SftpSubsystem.java |   5 +-
 .../java/org/apache/sshd/sftp/client/SftpTest.java |  13 +++
 6 files changed, 96 insertions(+), 89 deletions(-)

diff --git a/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/AbstractSftpClient.java b/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/AbstractSftpClient.java
index 89df6f5..0d29b51 100644
--- a/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/AbstractSftpClient.java
+++ b/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/AbstractSftpClient.java
@@ -772,9 +772,6 @@ public abstract class AbstractSftpClient extends AbstractSubsystemClient impleme
     @Override
     public int read(Handle handle, long fileOffset, byte[] dst, int dstOffset, int len, AtomicReference<Boolean> eofSignalled)
             throws IOException {
-        if (eofSignalled != null) {
-            eofSignalled.set(null);
-        }
         if (!isOpen()) {
             throw new IOException("read(" + handle + "/" + fileOffset + ")[" + dstOffset + "/" + len + "] client is closed");
         }
@@ -790,9 +787,6 @@ public abstract class AbstractSftpClient extends AbstractSubsystemClient impleme
     protected int checkData(
             int cmd, Buffer request, int dstOffset, byte[] dst, AtomicReference<Boolean> eofSignalled)
             throws IOException {
-        if (eofSignalled != null) {
-            eofSignalled.set(null);
-        }
         int reqId = send(cmd, request);
         Buffer response = receive(reqId);
         return checkDataResponse(cmd, response, dstOffset, dst, eofSignalled);
@@ -801,10 +795,6 @@ public abstract class AbstractSftpClient extends AbstractSubsystemClient impleme
     protected int checkDataResponse(
             int cmd, Buffer buffer, int dstoff, byte[] dst, AtomicReference<Boolean> eofSignalled)
             throws IOException {
-        if (eofSignalled != null) {
-            eofSignalled.set(null);
-        }
-
         int length = buffer.getInt();
         int type = buffer.getUByte();
         int id = buffer.getInt();
diff --git a/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/SftpInputStreamAsync.java b/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/SftpInputStreamAsync.java
index eeddb4a..6721c74 100644
--- a/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/SftpInputStreamAsync.java
+++ b/sshd-sftp/src/main/java/org/apache/sshd/sftp/client/impl/SftpInputStreamAsync.java
@@ -26,6 +26,7 @@ import java.util.Collection;
 import java.util.Deque;
 import java.util.LinkedList;
 import java.util.Objects;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 import org.apache.sshd.common.SshConstants;
@@ -125,60 +126,30 @@ public class SftpInputStreamAsync extends InputStreamWithChannel implements Sftp
             throw new IOException("read(" + getPath() + ") stream closed");
         }
 
-        int idx = off;
-        while ((len > 0) && (!eofIndicator)) {
-            if (hasNoData()) {
-                fillData();
-                if (eofIndicator && (hasNoData())) {
-                    break;
-                }
-                sendRequests();
-            } else {
-                int nb = Math.min(buffer.available(), len);
-                buffer.getRawBytes(b, off, nb);
-                off += nb;
-                len -= nb;
-                clientOffset += nb;
-            }
-        }
-
-        int res = off - idx;
-        if ((res == 0) && eofIndicator) {
+        AtomicInteger offset = new AtomicInteger(off);
+        int res = (int) doRead(len, buf -> {
+            int l = buf.available();
+            buf.getRawBytes(b, offset.getAndAdd(l), l);
+        });
+        if (res == 0 && eofIndicator) {
             res = -1;
         }
         return res;
     }
 
-    public long transferTo(long max, WritableByteChannel out) throws IOException {
+    public long transferTo(long len, WritableByteChannel out) throws IOException {
         if (!isOpen()) {
             throw new IOException("transferTo(" + getPath() + ") stream closed");
         }
 
-        long orgOffset = clientOffset;
-        long totalRequested = max;
-        while ((!eofIndicator) && (max > 0L)) {
-            if (hasNoData()) {
-                fillData();
-                if (eofIndicator && hasNoData()) {
-                    break;
-                }
-                sendRequests();
-            } else {
-                int nb = buffer.available();
-                int toRead = (int) Math.min(nb, max);
-                ByteBuffer bb = ByteBuffer.wrap(buffer.array(), buffer.rpos(), toRead);
-                while (bb.hasRemaining()) {
-                    out.write(bb);
-                }
-                buffer.rpos(buffer.rpos() + toRead);
-                clientOffset += toRead;
-                max -= toRead;
+        long numXfered = doRead(len, buf -> {
+            ByteBuffer bb = ByteBuffer.wrap(buf.array(), buf.rpos(), buf.available());
+            while (bb.hasRemaining()) {
+                out.write(bb);
             }
-        }
-
-        long numXfered = clientOffset - orgOffset;
+        });
         if (log.isDebugEnabled()) {
-            log.debug("transferTo({}) transferred {}/{} bytes", numXfered, totalRequested);
+            log.debug("transferTo({}) transferred {}/{} bytes", this, numXfered, len);
         }
         return numXfered;
     }
@@ -189,27 +160,41 @@ public class SftpInputStreamAsync extends InputStreamWithChannel implements Sftp
             throw new IOException("transferTo(" + getPath() + ") stream closed");
         }
 
+        long numXfered = doRead(Long.MAX_VALUE, buf -> {
+            out.write(buf.array(), buf.rpos(), buf.available());
+        });
+        if (log.isDebugEnabled()) {
+            log.debug("transferTo({}) transferred {} bytes", this, numXfered);
+        }
+        return numXfered;
+    }
+
+    interface BufferConsumer {
+        void consume(Buffer buffer) throws IOException;
+    }
+
+    private long doRead(long max, BufferConsumer consumer) throws IOException {
         long orgOffset = clientOffset;
-        while (!eofIndicator) {
+        while (max > 0) {
             if (hasNoData()) {
-                fillData();
-                if (eofIndicator && hasNoData()) {
+                if (eofIndicator) {
                     break;
                 }
-                sendRequests();
+                if (!pendingReads.isEmpty()) {
+                    fillData();
+                }
+                if (!eofIndicator) {
+                    sendRequests();
+                }
             } else {
-                int nb = buffer.available();
-                out.write(buffer.array(), buffer.rpos(), nb);
+                int nb = (int) Math.min(max, buffer.available());
+                consumer.consume(new ByteArrayBuffer(buffer.array(), buffer.rpos(), nb));
                 buffer.rpos(buffer.rpos() + nb);
                 clientOffset += nb;
+                max -= nb;
             }
         }
-
-        long numXfered = clientOffset - orgOffset;
-        if (log.isDebugEnabled()) {
-            log.debug("transferTo({}) transferred {} bytes", this, numXfered);
-        }
-        return numXfered;
+        return clientOffset - orgOffset;
     }
 
     @Override
@@ -234,13 +219,6 @@ public class SftpInputStreamAsync extends InputStreamWithChannel implements Sftp
     }
 
     protected void sendRequests() throws IOException {
-        if (eofIndicator) {
-            if (log.isDebugEnabled()) {
-                log.debug("sendRequests({}) EOF indicator ON", this);
-            }
-            return;
-        }
-
         AbstractSftpClient client = getClient();
         Channel channel = client.getChannel();
         Window localWindow = channel.getLocalWindow();
@@ -248,10 +226,8 @@ public class SftpInputStreamAsync extends InputStreamWithChannel implements Sftp
         Session session = client.getSession();
         byte[] id = handle.getIdentifier();
         boolean traceEnabled = log.isTraceEnabled();
-        for (int ackIndex = 1;
-             (pendingReads.size() < (int) (windowSize / bufferSize)) && (requestOffset < (fileSize + bufferSize))
-                     || pendingReads.isEmpty();
-             ackIndex++) {
+        while (pendingReads.size() < Math.max(1, windowSize / bufferSize)
+                && (fileSize <= 0 || requestOffset < fileSize + bufferSize)) {
             Buffer buf = session.createBuffer(SshConstants.SSH_MSG_CHANNEL_DATA,
                     23 /* sftp packet */ + 16 + id.length);
             buf.rpos(23);
@@ -262,7 +238,7 @@ public class SftpInputStreamAsync extends InputStreamWithChannel implements Sftp
             int reqId = client.send(SftpConstants.SSH_FXP_READ, buf);
             SftpAckData ack = new SftpAckData(reqId, requestOffset, bufferSize);
             if (traceEnabled) {
-                log.trace("sendRequests({}) enqueue pending ack #{}: {}", this, ackIndex, ack);
+                log.trace("sendRequests({}) enqueue pending ack: {}", this, ack);
             }
             pendingReads.add(ack);
             requestOffset += bufferSize;
@@ -282,9 +258,10 @@ public class SftpInputStreamAsync extends InputStreamWithChannel implements Sftp
         if (traceEnabled) {
             log.trace("fillData({}) process ack={}", this, ack);
         }
+        boolean alreadyEof = eofIndicator;
         pollBuffer(ack);
 
-        if ((!eofIndicator) && (clientOffset < ack.offset)) {
+        if (!alreadyEof && clientOffset < ack.offset) {
             // we are actually missing some data
             // so request is synchronously
             byte[] data = new byte[(int) (ack.offset - clientOffset + buffer.available())];
@@ -295,21 +272,27 @@ public class SftpInputStreamAsync extends InputStreamWithChannel implements Sftp
 
             AtomicReference<Boolean> eof = new AtomicReference<>();
             SftpClient client = getClient();
-            for (int cur = 0; cur < nb;) {
+            int cur = 0;
+            while (cur < nb) {
                 int dlen = client.read(handle, clientOffset, data, cur, nb - cur, eof);
                 Boolean eofSignal = eof.getAndSet(null);
                 if ((dlen < 0) || ((eofSignal != null) && eofSignal.booleanValue())) {
                     eofIndicator = true;
+                    break;
                 }
                 cur += dlen;
             }
 
             if (traceEnabled) {
-                log.trace("fillData({}) read {} bytes - EOF={}", this, nb, eofIndicator);
+                log.trace("fillData({}) read {} bytes - EOF={}", this, cur, eofIndicator);
             }
 
-            buffer.getRawBytes(data, nb, buffer.available());
-            buffer = new ByteArrayBuffer(data);
+            if (cur > 0) {
+                buffer.getRawBytes(data, cur, buffer.available());
+                buffer = new ByteArrayBuffer(data);
+            } else {
+                buffer.rpos(buffer.wpos());
+            }
         }
     }
 
diff --git a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/AbstractSftpSubsystemHelper.java b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/AbstractSftpSubsystemHelper.java
index d2d2c5a..dbac259 100644
--- a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/AbstractSftpSubsystemHelper.java
+++ b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/AbstractSftpSubsystemHelper.java
@@ -63,6 +63,7 @@ import java.util.Set;
 import java.util.TreeMap;
 import java.util.TreeSet;
 import java.util.concurrent.CopyOnWriteArraySet;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.IntUnaryOperator;
 
 import org.apache.sshd.common.FactoryManager;
@@ -561,13 +562,23 @@ public abstract class AbstractSftpSubsystemHelper
             int lenPos = buffer.wpos();
             buffer.putInt(0);
 
+            AtomicReference<Boolean> eofRef = new AtomicReference<>();
             int startPos = buffer.wpos();
-            int len = doRead(id, handle, offset, readLen, buffer.array(), startPos);
+            int len = doRead(id, handle, offset, readLen, buffer.array(), startPos, eofRef);
             if (len < 0) {
                 throw new EOFException("Unable to read " + readLen + " bytes from offset=" + offset + " of " + handle);
             }
             buffer.wpos(startPos + len);
             BufferUtils.updateLengthPlaceholder(buffer, lenPos, len);
+            if (len < readLen) {
+                int version = getVersion();
+                if (version >= SftpConstants.SFTP_V6) {
+                    Boolean eof = eofRef.get();
+                    if (eof != null) {
+                        buffer.putBoolean(eof);
+                    }
+                }
+            }
         } catch (IOException | RuntimeException e) {
             sendStatus(prepareReply(buffer), id, e, SftpConstants.SSH_FXP_READ, handle, offset, requestedLength);
             return;
@@ -577,7 +588,7 @@ public abstract class AbstractSftpSubsystemHelper
     }
 
     protected abstract int doRead(
-            int id, String handle, long offset, int length, byte[] data, int doff)
+            int id, String handle, long offset, int length, byte[] data, int doff, AtomicReference<Boolean> eof)
             throws IOException;
 
     protected void doWrite(Buffer buffer, int id) throws IOException {
diff --git a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/FileHandle.java b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/FileHandle.java
index d86ded4..a3dff3a 100644
--- a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/FileHandle.java
+++ b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/FileHandle.java
@@ -34,6 +34,7 @@ import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
 
 import org.apache.sshd.common.util.GenericUtils;
 import org.apache.sshd.common.util.MapEntryUtils;
@@ -108,14 +109,22 @@ public class FileHandle extends Handle {
     }
 
     public int read(byte[] data, long offset) throws IOException {
-        return read(data, 0, data.length, offset);
+        return read(data, 0, data.length, offset, null);
     }
 
-    @SuppressWarnings("resource")
     public int read(byte[] data, int doff, int length, long offset) throws IOException {
+        return read(data, doff, length, offset, null);
+    }
+
+    @SuppressWarnings("resource")
+    public int read(byte[] data, int doff, int length, long offset, AtomicReference<Boolean> eof) throws IOException {
         SeekableByteChannel channel = getFileChannel();
         channel = channel.position(offset);
-        return channel.read(ByteBuffer.wrap(data, doff, length));
+        int l = channel.read(ByteBuffer.wrap(data, doff, length));
+        if (l > 0 && eof != null && l < length) {
+            eof.set(channel.position() >= channel.size());
+        }
+        return l;
     }
 
     public void append(byte[] data) throws IOException {
diff --git a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
index 85c3fab..40964ba 100644
--- a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
+++ b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
@@ -44,6 +44,7 @@ import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
 
 import org.apache.sshd.common.Factory;
 import org.apache.sshd.common.FactoryManager;
@@ -835,7 +836,7 @@ public class SftpSubsystem
 
     @Override
     protected int doRead(
-            int id, String handle, long offset, int length, byte[] data, int doff)
+            int id, String handle, long offset, int length, byte[] data, int doff, AtomicReference<Boolean> eof)
             throws IOException {
         Handle h = handles.get(handle);
         ServerSession session = getServerSession();
@@ -850,7 +851,7 @@ public class SftpSubsystem
         int readLen;
         listener.reading(session, handle, fh, offset, data, doff, length);
         try {
-            readLen = fh.read(data, doff, length, offset);
+            readLen = fh.read(data, doff, length, offset, eof);
         } catch (IOException | RuntimeException | Error e) {
             listener.read(session, handle, fh, offset, data, doff, length, -1, e);
             throw e;
diff --git a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/SftpTest.java b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/SftpTest.java
index c67497e..ed67b0c 100644
--- a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/SftpTest.java
+++ b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/SftpTest.java
@@ -1335,6 +1335,19 @@ public class SftpTest extends AbstractSftpClientTestSupport {
 
         sftp.remove(file);
 
+        byte[] smallBuf = "Hello world".getBytes(StandardCharsets.UTF_8);
+        try (OutputStream os = sftp.write(file)) {
+            os.write(smallBuf);
+        }
+        try (InputStream is = sftp.read(file)) {
+            int readLen = is.read(smallBuf);
+            assertEquals("Mismatched read data length", smallBuf.length, readLen);
+            assertEquals("Hello world", new String(smallBuf, StandardCharsets.UTF_8));
+
+            int i = is.read();
+            assertEquals("Unexpected read past EOF", -1, i);
+        }
+
         final int sizeFactor = Short.SIZE;
         byte[] workBuf = new byte[IoUtils.DEFAULT_COPY_SIZE * Short.SIZE];
         Factory<? extends Random> factory = manager.getRandomFactory();