You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@tomcat.apache.org by ma...@apache.org on 2021/06/04 17:00:23 UTC

[tomcat] branch 8.5.x updated: Fix HEAD for the non-blocking case.

This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch 8.5.x
in repository https://gitbox.apache.org/repos/asf/tomcat.git


The following commit(s) were added to refs/heads/8.5.x by this push:
     new b9d57c0  Fix HEAD for the non-blocking case.
b9d57c0 is described below

commit b9d57c0c413f86e72f908fa47dfe6fab8dea4393
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Fri Jun 4 17:26:37 2021 +0100

    Fix HEAD for the non-blocking case.
---
 java/javax/servlet/http/HttpServlet.java     | 96 +++++++++++++++++++++-------
 test/javax/servlet/http/TestHttpServlet.java | 67 ++++++++++++++++++-
 webapps/docs/changelog.xml                   |  5 ++
 3 files changed, 144 insertions(+), 24 deletions(-)

diff --git a/java/javax/servlet/http/HttpServlet.java b/java/javax/servlet/http/HttpServlet.java
index bcede64..2289b70 100644
--- a/java/javax/servlet/http/HttpServlet.java
+++ b/java/javax/servlet/http/HttpServlet.java
@@ -27,13 +27,15 @@ import java.text.MessageFormat;
 import java.util.Enumeration;
 import java.util.ResourceBundle;
 
+import javax.servlet.AsyncEvent;
+import javax.servlet.AsyncListener;
 import javax.servlet.DispatcherType;
 import javax.servlet.GenericServlet;
 import javax.servlet.ServletException;
 import javax.servlet.ServletOutputStream;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
-
+import javax.servlet.WriteListener;
 
 /**
  * Provides an abstract class to be subclassed to create
@@ -110,7 +112,7 @@ public abstract class HttpServlet extends GenericServlet {
      * response, only the request header fields.
      *
      * <p>When overriding this method, read the request data,
-     * write the response headers, get the response's writer or
+     * write the response headers, get the response's noBodyWriter or
      * output stream object, and finally, write the response data.
      * It's best to include content type and encoding. When using
      * a <code>PrintWriter</code> object to return the response,
@@ -237,7 +239,11 @@ public abstract class HttpServlet extends GenericServlet {
         } else {
             NoBodyResponse response = new NoBodyResponse(resp);
             doGet(req, response);
-            response.setContentLength();
+            if (req.isAsyncStarted()) {
+                req.getAsyncContext().addListener(new NoBodyAsyncContextListener(response));
+            } else {
+                response.setContentLength();
+            }
         }
     }
 
@@ -252,7 +258,7 @@ public abstract class HttpServlet extends GenericServlet {
      * credit card numbers.
      *
      * <p>When overriding this method, read the request data,
-     * write the response headers, get the response's writer or output
+     * write the response headers, get the response's noBodyWriter or output
      * stream object, and finally, write the response data. It's best
      * to include content type and encoding. When using a
      * <code>PrintWriter</code> object to return the response, set the
@@ -766,21 +772,22 @@ public abstract class HttpServlet extends GenericServlet {
      * wrapped HTTP Servlet Response object.
      */
     private static class NoBodyResponse extends HttpServletResponseWrapper {
-        private final NoBodyOutputStream noBody;
-        private NoBodyPrintWriter writer;
+        private final NoBodyOutputStream noBodyOutputStream;
+        private ServletOutputStream originalOutputStream;
+        private NoBodyPrintWriter noBodyWriter;
         private boolean didSetContentLength;
 
         private NoBodyResponse(HttpServletResponse r) {
             super(r);
-            noBody = new NoBodyOutputStream(this);
+            noBodyOutputStream = new NoBodyOutputStream(this);
         }
 
         private void setContentLength() {
             if (!didSetContentLength) {
-                if (writer != null) {
-                    writer.flush();
+                if (noBodyWriter != null) {
+                    noBodyWriter.flush();
                 }
-                super.setContentLengthLong(noBody.getWrittenByteCount());
+                super.setContentLengthLong(noBodyOutputStream.getWrittenByteCount());
             }
         }
 
@@ -829,29 +836,31 @@ public abstract class HttpServlet extends GenericServlet {
 
         @Override
         public ServletOutputStream getOutputStream() throws IOException {
-            return noBody;
+            originalOutputStream = getResponse().getOutputStream();
+            return noBodyOutputStream;
         }
 
         @Override
         public PrintWriter getWriter() throws UnsupportedEncodingException {
 
-            if (writer == null) {
-                writer = new NoBodyPrintWriter(noBody, getCharacterEncoding());
+            if (noBodyWriter == null) {
+                noBodyWriter = new NoBodyPrintWriter(noBodyOutputStream, getCharacterEncoding());
             }
-            return writer;
+            return noBodyWriter;
         }
 
         @Override
         public void reset() {
             super.reset();
             resetBuffer();
+            originalOutputStream = null;
         }
 
         @Override
         public void resetBuffer() {
-            noBody.resetBuffer();
-            if (writer != null) {
-                writer.resetBuffer();
+            noBodyOutputStream.resetBuffer();
+            if (noBodyWriter != null) {
+                noBodyWriter.resetBuffer();
             }
         }
     }
@@ -865,11 +874,11 @@ public abstract class HttpServlet extends GenericServlet {
         private static final String LSTRING_FILE = "javax.servlet.http.LocalStrings";
         private static final ResourceBundle lStrings = ResourceBundle.getBundle(LSTRING_FILE);
 
-        private final HttpServletResponse response;
+        private final NoBodyResponse response;
         private boolean flushed = false;
         private long writtenByteCount = 0;
 
-        private NoBodyOutputStream(HttpServletResponse response) {
+        private NoBodyOutputStream(NoBodyResponse response) {
             this.response = response;
         }
 
@@ -906,13 +915,13 @@ public abstract class HttpServlet extends GenericServlet {
 
         @Override
         public boolean isReady() {
-            // TODO SERVLET 3.1
-            return false;
+            // Will always be ready as data is swallowed.
+            return true;
         }
 
         @Override
-        public void setWriteListener(javax.servlet.WriteListener listener) {
-            // TODO SERVLET 3.1
+        public void setWriteListener(WriteListener listener) {
+            response.originalOutputStream.setWriteListener(listener);
         }
 
         private void checkCommit() throws IOException {
@@ -931,6 +940,13 @@ public abstract class HttpServlet extends GenericServlet {
     }
 
 
+    /*
+     * On reset() and resetBuffer() need to clear the data buffered in the
+     * OutputStreamWriter. No easy way to do that so NoBodyPrintWriter wraps a
+     * PrintWriter than can be thrown away on reset()/resetBuffer() and a new
+     * one constructed while the application retains a reference to the
+     * NoBodyPrintWriter instance.
+     */
     private static class NoBodyPrintWriter extends PrintWriter {
 
         private final NoBodyOutputStream out;
@@ -1096,4 +1112,38 @@ public abstract class HttpServlet extends GenericServlet {
             pw.println(x);
         }
     }
+
+
+    /*
+     * Calls NoBodyResponse.setContentLength() once the async request is
+     * complete.
+     */
+    private static class NoBodyAsyncContextListener implements AsyncListener {
+
+        private final NoBodyResponse noBodyResponse;
+
+        public NoBodyAsyncContextListener(NoBodyResponse noBodyResponse) {
+            this.noBodyResponse = noBodyResponse;
+        }
+
+        @Override
+        public void onComplete(AsyncEvent event) throws IOException {
+            noBodyResponse.setContentLength();
+        }
+
+        @Override
+        public void onTimeout(AsyncEvent event) throws IOException {
+            // NO-OP
+        }
+
+        @Override
+        public void onError(AsyncEvent event) throws IOException {
+            // NO-OP
+        }
+
+        @Override
+        public void onStartAsync(AsyncEvent event) throws IOException {
+            // NO-OP
+        }
+    }
 }
diff --git a/test/javax/servlet/http/TestHttpServlet.java b/test/javax/servlet/http/TestHttpServlet.java
index bc31e88..fbea9e6 100644
--- a/test/javax/servlet/http/TestHttpServlet.java
+++ b/test/javax/servlet/http/TestHttpServlet.java
@@ -22,14 +22,17 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import javax.servlet.AsyncContext;
 import javax.servlet.Servlet;
 import javax.servlet.ServletException;
 import javax.servlet.ServletOutputStream;
+import javax.servlet.WriteListener;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.catalina.Context;
+import org.apache.catalina.Wrapper;
 import org.apache.catalina.core.StandardContext;
 import org.apache.catalina.startup.SimpleHttpClient;
 import org.apache.catalina.startup.TesterServlet;
@@ -147,13 +150,22 @@ public class TestHttpServlet extends TomcatBaseTest {
     }
 
 
+    @Test
+    public void testHeadWithNonBlocking() throws Exception {
+        // Less than buffer size
+        doTestHead(new NonBlockingWriteServlet(4 * 1024));
+    }
+
+
     private void doTestHead(Servlet servlet) throws Exception {
         Tomcat tomcat = getTomcatInstance();
 
         // No file system docBase required
         StandardContext ctx = (StandardContext) tomcat.addContext("", null);
 
-        Tomcat.addServlet(ctx, "TestServlet", servlet);
+        Wrapper w = Tomcat.addServlet(ctx, "TestServlet", servlet);
+        // Not all need/use this but it is simpler to set it for all
+        w.setAsyncSupported(true);
         ctx.addServletMappingDecoded("/test", "TestServlet");
 
         tomcat.start();
@@ -437,6 +449,59 @@ public class TestHttpServlet extends TomcatBaseTest {
     }
 
 
+    private static class NonBlockingWriteServlet extends HttpServlet {
+
+        private static final long serialVersionUID = 1L;
+
+        private final int bytesToWrite;
+
+        public NonBlockingWriteServlet(int bytesToWrite) {
+            this.bytesToWrite = bytesToWrite;
+        }
+
+        @Override
+        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
+                throws ServletException, IOException {
+            AsyncContext ac = req.startAsync(req, resp);
+            ac.setTimeout(3000);
+            WriteListener wListener = new NonBlockingWriteListener(ac, bytesToWrite);
+            resp.getOutputStream().setWriteListener(wListener);
+        }
+
+        private static class NonBlockingWriteListener implements WriteListener {
+
+            private final AsyncContext ac;
+            private final ServletOutputStream sos;
+            private int bytesToWrite;
+
+            public NonBlockingWriteListener(AsyncContext ac, int bytesToWrite) throws IOException {
+                this.ac = ac;
+                this.sos = ac.getResponse().getOutputStream();
+                this.bytesToWrite = bytesToWrite;
+            }
+
+            @Override
+            public void onWritePossible() throws IOException {
+                do {
+                    // Write up to 1k a time
+                    int bytesThisTime = Math.min(bytesToWrite, 1024);
+                    sos.write(new byte[bytesThisTime]);
+                    bytesToWrite -= bytesThisTime;
+                } while (sos.isReady() && bytesToWrite > 0);
+
+                if (sos.isReady() && bytesToWrite == 0) {
+                    ac.complete();
+                }
+            }
+
+            @Override
+            public void onError(Throwable throwable) {
+                throwable.printStackTrace();
+            }
+        }
+    }
+
+
     private static class OptionsServlet extends HttpServlet {
 
         private static final long serialVersionUID = 1L;
diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml
index 82a0fc3..53be318 100644
--- a/webapps/docs/changelog.xml
+++ b/webapps/docs/changelog.xml
@@ -168,6 +168,11 @@
         calls <code>ServletResponse.reset()</code> and/or
         <code>ServletResponse.resetBuffer()</code>. (markt)
       </fix>
+      <fix>
+        Fix the default <code>doHead())</code> implementation in
+        <code>HttpServlet</code> to correctly handle responses generated using
+        the Servlet non-blocking API. (markt)
+      </fix>
     </changelog>
   </subsection>
   <subsection name="Coyote">

---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@tomcat.apache.org
For additional commands, e-mail: dev-help@tomcat.apache.org