You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tomee.apache.org by rz...@apache.org on 2024/03/25 10:22:04 UTC

(tomee) 02/04: TOMEE-4305 - Port changes from https://github.com/apache/tomcat/commit/0052b374684b613b0c849899b325ebe334ac6501

This is an automated email from the ASF dual-hosted git repository.

rzo1 pushed a commit to branch TOMEE-4305
in repository https://gitbox.apache.org/repos/asf/tomee.git

commit 5420188ed2263e2dec491d29d049f21155d28252
Author: Richard Zowalla <rz...@apache.org>
AuthorDate: Mon Mar 25 11:14:35 2024 +0100

    TOMEE-4305 - Port changes from https://github.com/apache/tomcat/commit/0052b374684b613b0c849899b325ebe334ac6501
---
 .../org/apache/tomcat/websocket/Constants.java     |   6 +
 .../org/apache/tomcat/websocket/WsSession.java     | 487 ++++++++++-----------
 .../tomcat/websocket/WsWebSocketContainer.java     |  10 +-
 .../tomcat/websocket/server/WsServerContainer.java |   2 +-
 4 files changed, 235 insertions(+), 270 deletions(-)

diff --git a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/Constants.java b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/Constants.java
index 8dffa30399..3bb4e00578 100644
--- a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/Constants.java
+++ b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/Constants.java
@@ -19,6 +19,7 @@ package org.apache.tomcat.websocket;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.concurrent.TimeUnit;
 
 import jakarta.websocket.Extension;
 
@@ -118,6 +119,11 @@ public class Constants {
     // Milliseconds so this is 20 seconds
     public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;
 
+    // Configuration for session close timeout
+    public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
+    // Default is 30 seconds - setting is in milliseconds
+    public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30);
+
     // Configuration for read idle timeout on WebSocket session
     public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS";
 
diff --git a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsSession.java b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsSession.java
index 8766d46b1c..6cc5e85668 100644
--- a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsSession.java
+++ b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsSession.java
@@ -27,7 +27,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
 
 import javax.naming.NamingException;
 
@@ -78,8 +80,8 @@ public class WsSession implements Session {
         // be sufficient to pass the validation tests.
         ServerEndpointConfig.Builder builder = ServerEndpointConfig.Builder.create(Object.class, "/");
         ServerEndpointConfig sec = builder.build();
-        SEC_CONFIGURATOR_USES_IMPL_DEFAULT =
-                sec.getConfigurator().getClass().equals(DefaultServerEndpointConfigurator.class);
+        SEC_CONFIGURATOR_USES_IMPL_DEFAULT = sec.getConfigurator().getClass()
+                .equals(DefaultServerEndpointConfigurator.class);
     }
 
     private final Endpoint localEndpoint;
@@ -106,8 +108,7 @@ public class WsSession implements Session {
     // Expected to handle message types of <ByteBuffer> only
     private volatile MessageHandler binaryMessageHandler = null;
     private volatile MessageHandler.Whole<PongMessage> pongMessageHandler = null;
-    private volatile State state = State.OPEN;
-    private final Object stateLock = new Object();
+    private AtomicReference<State> state = new AtomicReference<>(State.OPEN);
     private final Map<String, Object> userProperties = new ConcurrentHashMap<>();
     private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
     private volatile int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
@@ -115,35 +116,30 @@ public class WsSession implements Session {
     private volatile long lastActiveRead = System.currentTimeMillis();
     private volatile long lastActiveWrite = System.currentTimeMillis();
     private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>();
+    private volatile Long sessionCloseTimeoutExpiry;
 
 
     /**
-     * Creates a new WebSocket session for communication between the provided
-     * client and remote end points. The result of
-     * {@link Thread#getContextClassLoader()} at the time this constructor is
-     * called will be used when calling
+     * Creates a new WebSocket session for communication between the provided client and remote end points. The result
+     * of {@link Thread#getContextClassLoader()} at the time this constructor is called will be used when calling
      * {@link Endpoint#onClose(Session, CloseReason)}.
      *
      * @param clientEndpointHolder The end point managed by this code
      * @param wsRemoteEndpoint     The other / remote end point
      * @param wsWebSocketContainer The container that created this session
      * @param negotiatedExtensions The agreed extensions to use for this session
-     * @param subProtocol          The agreed sub-protocol to use for this
-     *                             session
-     * @param pathParameters       The path parameters associated with the
-     *                             request that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param secure               Was this session initiated over a secure
-     *                             connection?
-     * @param clientEndpointConfig The configuration information for the client
-     *                             end point
+     * @param subProtocol          The agreed sub-protocol to use for this session
+     * @param pathParameters       The path parameters associated with the request that initiated this session or
+     *                                 <code>null</code> if this is a client session
+     * @param secure               Was this session initiated over a secure connection?
+     * @param clientEndpointConfig The configuration information for the client end point
+     *
      * @throws DeploymentException if an invalid encode is specified
      */
-    public WsSession(ClientEndpointHolder clientEndpointHolder,
-            WsRemoteEndpointImplBase wsRemoteEndpoint,
-            WsWebSocketContainer wsWebSocketContainer,
-            List<Extension> negotiatedExtensions, String subProtocol, Map<String, String> pathParameters,
-            boolean secure, ClientEndpointConfig clientEndpointConfig) throws DeploymentException {
+    public WsSession(ClientEndpointHolder clientEndpointHolder, WsRemoteEndpointImplBase wsRemoteEndpoint,
+                     WsWebSocketContainer wsWebSocketContainer, List<Extension> negotiatedExtensions, String subProtocol,
+                     Map<String, String> pathParameters, boolean secure, ClientEndpointConfig clientEndpointConfig)
+            throws DeploymentException {
         this.wsRemoteEndpoint = wsRemoteEndpoint;
         this.wsRemoteEndpoint.setSession(this);
         this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint);
@@ -175,53 +171,43 @@ public class WsSession implements Session {
 
         this.localEndpoint = clientEndpointHolder.getInstance(getInstanceManager());
 
-        if (log.isDebugEnabled()) {
-            log.debug(sm.getString("wsSession.created", id));
+        if (log.isTraceEnabled()) {
+            log.trace(sm.getString("wsSession.created", id));
         }
     }
 
 
     /**
-     * Creates a new WebSocket session for communication between the provided
-     * server and remote end points. The result of
-     * {@link Thread#getContextClassLoader()} at the time this constructor is
-     * called will be used when calling
+     * Creates a new WebSocket session for communication between the provided server and remote end points. The result
+     * of {@link Thread#getContextClassLoader()} at the time this constructor is called will be used when calling
      * {@link Endpoint#onClose(Session, CloseReason)}.
      *
      * @param wsRemoteEndpoint     The other / remote end point
      * @param wsWebSocketContainer The container that created this session
-     * @param requestUri           The URI used to connect to this end point or
-     *                             <code>null</code> if this is a client session
-     * @param requestParameterMap  The parameters associated with the request
-     *                             that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param queryString          The query string associated with the request
-     *                             that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param userPrincipal        The principal associated with the request
-     *                             that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param httpSessionId        The HTTP session ID associated with the
-     *                             request that initiated this session or
-     *                             <code>null</code> if this is a client session
+     * @param requestUri           The URI used to connect to this end point or <code>null</code> if this is a client
+     *                                 session
+     * @param requestParameterMap  The parameters associated with the request that initiated this session or
+     *                                 <code>null</code> if this is a client session
+     * @param queryString          The query string associated with the request that initiated this session or
+     *                                 <code>null</code> if this is a client session
+     * @param userPrincipal        The principal associated with the request that initiated this session or
+     *                                 <code>null</code> if this is a client session
+     * @param httpSessionId        The HTTP session ID associated with the request that initiated this session or
+     *                                 <code>null</code> if this is a client session
      * @param negotiatedExtensions The agreed extensions to use for this session
-     * @param subProtocol          The agreed sub-protocol to use for this
-     *                             session
-     * @param pathParameters       The path parameters associated with the
-     *                             request that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param secure               Was this session initiated over a secure
-     *                             connection?
-     * @param serverEndpointConfig The configuration information for the server
-     *                             end point
+     * @param subProtocol          The agreed sub-protocol to use for this session
+     * @param pathParameters       The path parameters associated with the request that initiated this session or
+     *                                 <code>null</code> if this is a client session
+     * @param secure               Was this session initiated over a secure connection?
+     * @param serverEndpointConfig The configuration information for the server end point
+     *
      * @throws DeploymentException if an invalid encode is specified
      */
-    public WsSession(WsRemoteEndpointImplBase wsRemoteEndpoint,
-            WsWebSocketContainer wsWebSocketContainer,
-            URI requestUri, Map<String, List<String>> requestParameterMap,
-            String queryString, Principal userPrincipal, String httpSessionId,
-            List<Extension> negotiatedExtensions, String subProtocol, Map<String, String> pathParameters,
-            boolean secure, ServerEndpointConfig serverEndpointConfig) throws DeploymentException {
+    public WsSession(WsRemoteEndpointImplBase wsRemoteEndpoint, WsWebSocketContainer wsWebSocketContainer,
+                     URI requestUri, Map<String, List<String>> requestParameterMap, String queryString, Principal userPrincipal,
+                     String httpSessionId, List<Extension> negotiatedExtensions, String subProtocol,
+                     Map<String, String> pathParameters, boolean secure, ServerEndpointConfig serverEndpointConfig)
+            throws DeploymentException {
 
         this.wsRemoteEndpoint = wsRemoteEndpoint;
         this.wsRemoteEndpoint.setSession(this);
@@ -284,8 +270,8 @@ public class WsSession implements Session {
             this.localEndpoint = new PojoEndpointServer(pathParameters, endpointInstance);
         }
 
-        if (log.isDebugEnabled()) {
-            log.debug(sm.getString("wsSession.created", id));
+        if (log.isTraceEnabled()) {
+            log.trace(sm.getString("wsSession.created", id));
         }
     }
 
@@ -302,100 +288,6 @@ public class WsSession implements Session {
     }
 
 
-    /**
-     * Creates a new WebSocket session for communication between the two
-     * provided end points. The result of {@link Thread#getContextClassLoader()}
-     * at the time this constructor is called will be used when calling
-     * {@link Endpoint#onClose(Session, CloseReason)}.
-     *
-     * @param localEndpoint        The end point managed by this code
-     * @param wsRemoteEndpoint     The other / remote endpoint
-     * @param wsWebSocketContainer The container that created this session
-     * @param requestUri           The URI used to connect to this endpoint or
-     *                             <code>null</code> is this is a client session
-     * @param requestParameterMap  The parameters associated with the request
-     *                             that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param queryString          The query string associated with the request
-     *                             that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param userPrincipal        The principal associated with the request
-     *                             that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param httpSessionId        The HTTP session ID associated with the
-     *                             request that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param negotiatedExtensions The agreed extensions to use for this session
-     * @param subProtocol          The agreed subprotocol to use for this
-     *                             session
-     * @param pathParameters       The path parameters associated with the
-     *                             request that initiated this session or
-     *                             <code>null</code> if this is a client session
-     * @param secure               Was this session initiated over a secure
-     *                             connection?
-     * @param endpointConfig       The configuration information for the
-     *                             endpoint
-     * @throws DeploymentException if an invalid encode is specified
-     *
-     * @deprecated  Unused. This will be removed in Tomcat 10.1
-     */
-    @Deprecated
-    public WsSession(Endpoint localEndpoint,
-            WsRemoteEndpointImplBase wsRemoteEndpoint,
-            WsWebSocketContainer wsWebSocketContainer,
-            URI requestUri, Map<String, List<String>> requestParameterMap,
-            String queryString, Principal userPrincipal, String httpSessionId,
-            List<Extension> negotiatedExtensions, String subProtocol, Map<String, String> pathParameters,
-            boolean secure, EndpointConfig endpointConfig) throws DeploymentException {
-        this.localEndpoint = localEndpoint;
-        this.wsRemoteEndpoint = wsRemoteEndpoint;
-        this.wsRemoteEndpoint.setSession(this);
-        this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint);
-        this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint);
-        this.webSocketContainer = wsWebSocketContainer;
-        applicationClassLoader = Thread.currentThread().getContextClassLoader();
-        wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout());
-        this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize();
-        this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize();
-        this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout();
-        this.requestUri = requestUri;
-        if (requestParameterMap == null) {
-            this.requestParameterMap = Collections.emptyMap();
-        } else {
-            this.requestParameterMap = requestParameterMap;
-        }
-        this.queryString = queryString;
-        this.userPrincipal = userPrincipal;
-        this.httpSessionId = httpSessionId;
-        this.negotiatedExtensions = negotiatedExtensions;
-        if (subProtocol == null) {
-            this.subProtocol = "";
-        } else {
-            this.subProtocol = subProtocol;
-        }
-        this.pathParameters = pathParameters;
-        this.secure = secure;
-        this.wsRemoteEndpoint.setEncoders(endpointConfig);
-        this.endpointConfig = endpointConfig;
-
-        this.userProperties.putAll(endpointConfig.getUserProperties());
-        this.id = Long.toHexString(ids.getAndIncrement());
-
-        InstanceManager instanceManager = getInstanceManager();
-        if (instanceManager != null) {
-            try {
-                instanceManager.newInstance(localEndpoint);
-            } catch (Exception e) {
-                throw new DeploymentException(sm.getString("wsSession.instanceNew"), e);
-            }
-        }
-
-        if (log.isDebugEnabled()) {
-            log.debug(sm.getString("wsSession.created", id));
-        }
-    }
-
-
     public InstanceManager getInstanceManager() {
         return webSocketContainer.getInstanceManager(applicationClassLoader);
     }
@@ -416,15 +308,13 @@ public class WsSession implements Session {
 
 
     @Override
-    public <T> void addMessageHandler(Class<T> clazz, Partial<T> handler)
-            throws IllegalStateException {
+    public <T> void addMessageHandler(Class<T> clazz, Partial<T> handler) throws IllegalStateException {
         doAddMessageHandler(clazz, handler);
     }
 
 
     @Override
-    public <T> void addMessageHandler(Class<T> clazz, Whole<T> handler)
-            throws IllegalStateException {
+    public <T> void addMessageHandler(Class<T> clazz, Whole<T> handler) throws IllegalStateException {
         doAddMessageHandler(clazz, handler);
     }
 
@@ -443,44 +333,41 @@ public class WsSession implements Session {
         // arbitrary objects with MessageHandlers and can wrap MessageHandlers
         // just as easily.
 
-        Set<MessageHandlerResult> mhResults = Util.getMessageHandlers(target, listener,
-                endpointConfig, this);
+        Set<MessageHandlerResult> mhResults = Util.getMessageHandlers(target, listener, endpointConfig, this);
 
         for (MessageHandlerResult mhResult : mhResults) {
             switch (mhResult.getType()) {
-            case TEXT: {
-                if (textMessageHandler != null) {
-                    throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText"));
+                case TEXT: {
+                    if (textMessageHandler != null) {
+                        throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText"));
+                    }
+                    textMessageHandler = mhResult.getHandler();
+                    break;
                 }
-                textMessageHandler = mhResult.getHandler();
-                break;
-            }
-            case BINARY: {
-                if (binaryMessageHandler != null) {
-                    throw new IllegalStateException(
-                            sm.getString("wsSession.duplicateHandlerBinary"));
+                case BINARY: {
+                    if (binaryMessageHandler != null) {
+                        throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerBinary"));
+                    }
+                    binaryMessageHandler = mhResult.getHandler();
+                    break;
                 }
-                binaryMessageHandler = mhResult.getHandler();
-                break;
-            }
-            case PONG: {
-                if (pongMessageHandler != null) {
-                    throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong"));
+                case PONG: {
+                    if (pongMessageHandler != null) {
+                        throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong"));
+                    }
+                    MessageHandler handler = mhResult.getHandler();
+                    if (handler instanceof MessageHandler.Whole<?>) {
+                        pongMessageHandler = (MessageHandler.Whole<PongMessage>) handler;
+                    } else {
+                        throw new IllegalStateException(sm.getString("wsSession.invalidHandlerTypePong"));
+                    }
+
+                    break;
                 }
-                MessageHandler handler = mhResult.getHandler();
-                if (handler instanceof MessageHandler.Whole<?>) {
-                    pongMessageHandler = (MessageHandler.Whole<PongMessage>) handler;
-                } else {
-                    throw new IllegalStateException(
-                            sm.getString("wsSession.invalidHandlerTypePong"));
+                default: {
+                    throw new IllegalArgumentException(
+                            sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType()));
                 }
-
-                break;
-            }
-            default: {
-                throw new IllegalArgumentException(
-                        sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType()));
-            }
             }
         }
     }
@@ -539,8 +426,7 @@ public class WsSession implements Session {
         if (!removed) {
             // ISE for now. Could swallow this silently / log this if the ISE
             // becomes a problem
-            throw new IllegalStateException(
-                    sm.getString("wsSession.removeHandlerFailed", listener));
+            throw new IllegalStateException(sm.getString("wsSession.removeHandlerFailed", listener));
         }
     }
 
@@ -575,7 +461,12 @@ public class WsSession implements Session {
 
     @Override
     public boolean isOpen() {
-        return state == State.OPEN;
+        return state.get() == State.OPEN || state.get() == State.OUTPUT_CLOSING || state.get() == State.CLOSING;
+    }
+
+
+    public boolean isClosed() {
+        return state.get() == State.CLOSED;
     }
 
 
@@ -655,9 +546,8 @@ public class WsSession implements Session {
 
 
     /**
-     * WebSocket 1.0. Section 2.1.5.
-     * Need internal close method as spec requires that the local endpoint
-     * receives a 1006 on timeout.
+     * WebSocket 1.0. Section 2.1.5. Need internal close method as spec requires that the local endpoint receives a 1006
+     * on timeout.
      *
      * @param closeReasonMessage The close reason to pass to the remote endpoint
      * @param closeReasonLocal   The close reason to pass to the local endpoint
@@ -668,55 +558,54 @@ public class WsSession implements Session {
 
 
     /**
-     * WebSocket 1.0. Section 2.1.5.
-     * Need internal close method as spec requires that the local endpoint
-     * receives a 1006 on timeout.
+     * WebSocket 1.0. Section 2.1.5. Need internal close method as spec requires that the local endpoint receives a 1006
+     * on timeout.
      *
      * @param closeReasonMessage The close reason to pass to the remote endpoint
      * @param closeReasonLocal   The close reason to pass to the local endpoint
-     * @param closeSocket        Should the socket be closed immediately rather than waiting
-     *                           for the server to respond
+     * @param closeSocket        Should the socket be closed immediately rather than waiting for the server to respond
      */
-    public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal,
-            boolean closeSocket) {
-        // Double-checked locking. OK because state is volatile
-        if (state != State.OPEN) {
+    public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal, boolean closeSocket) {
+
+        if (!state.compareAndSet(State.OPEN, State.OUTPUT_CLOSING)) {
+            // Close process has already been started. Don't start it again.
             return;
         }
 
-        synchronized (stateLock) {
-            if (state != State.OPEN) {
-                return;
-            }
-
-            if (log.isDebugEnabled()) {
-                log.debug(sm.getString("wsSession.doClose", id));
-            }
+        if (log.isTraceEnabled()) {
+            log.trace(sm.getString("wsSession.doClose", id));
+        }
 
-            // This will trigger a flush of any batched messages.
-            try {
-                wsRemoteEndpoint.setBatchingAllowed(false);
-            } catch (IOException e) {
-                log.warn(sm.getString("wsSession.flushFailOnClose"), e);
-                fireEndpointOnError(e);
-            }
+        // Flush any batched messages not yet sent.
+        try {
+            wsRemoteEndpoint.setBatchingAllowed(false);
+        } catch (Throwable t) {
+            ExceptionUtils.handleThrowable(t);
+            log.warn(sm.getString("wsSession.flushFailOnClose"), t);
+            fireEndpointOnError(t);
+        }
 
+        // Send the close message to the remote endpoint.
+        sendCloseMessage(closeReasonMessage);
+        fireEndpointOnClose(closeReasonLocal);
+        if (!state.compareAndSet(State.OUTPUT_CLOSING, State.OUTPUT_CLOSED) || closeSocket) {
             /*
-             * If the flush above fails the error handling could call this
-             * method recursively. Without this check, the close message and
-             * notifications could be sent multiple times.
+             * A close message was received in another thread or this is handling an error condition. Either way, no
+             * further close message is expected to be received. Mark the session as fully closed...
              */
-            if (state != State.OUTPUT_CLOSED) {
-                state = State.OUTPUT_CLOSED;
-
-                sendCloseMessage(closeReasonMessage);
-                if (closeSocket) {
-                    wsRemoteEndpoint.close();
-                }
-                fireEndpointOnClose(closeReasonLocal);
-            }
+            state.set(State.CLOSED);
+            // ... and close the network connection.
+            closeConnection();
+        } else {
+            /*
+             * Set close timeout. If the client fails to send a close message response within the timeout, the session
+             * and the connection will be closed when the timeout expires.
+             */
+            sessionCloseTimeoutExpiry =
+                    Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout()));
         }
 
+        // Fail any uncompleted messages.
         IOException ioe = new IOException(sm.getString("wsSession.messageFailed"));
         SendResult sr = new SendResult(ioe);
         for (FutureToSendHandler f2sh : futures.keySet()) {
@@ -726,36 +615,93 @@ public class WsSession implements Session {
 
 
     /**
-     * Called when a close message is received. Should only ever happen once.
-     * Also called after a protocol error when the ProtocolHandler needs to
-     * force the closing of the connection.
+     * Called when a close message is received. Should only ever happen once. Also called after a protocol error when
+     * the ProtocolHandler needs to force the closing of the connection.
      *
-     * @param closeReason The reason contained within the received close
-     *                    message.
+     * @param closeReason The reason contained within the received close message.
      */
     public void onClose(CloseReason closeReason) {
+        if (state.compareAndSet(State.OPEN, State.CLOSING)) {
+            // Standard close.
 
-        synchronized (stateLock) {
-            if (state != State.CLOSED) {
-                try {
-                    wsRemoteEndpoint.setBatchingAllowed(false);
-                } catch (IOException e) {
-                    log.warn(sm.getString("wsSession.flushFailOnClose"), e);
-                    fireEndpointOnError(e);
-                }
-                if (state == State.OPEN) {
-                    state = State.OUTPUT_CLOSED;
-                    sendCloseMessage(closeReason);
-                    fireEndpointOnClose(closeReason);
-                }
-                state = State.CLOSED;
+            // Flush any batched messages not yet sent.
+            try {
+                wsRemoteEndpoint.setBatchingAllowed(false);
+            } catch (Throwable t) {
+                ExceptionUtils.handleThrowable(t);
+                log.warn(sm.getString("wsSession.flushFailOnClose"), t);
+                fireEndpointOnError(t);
+            }
+
+            // Send the close message response to the remote endpoint.
+            sendCloseMessage(closeReason);
+            fireEndpointOnClose(closeReason);
+
+            // Mark the session as fully closed.
+            state.set(State.CLOSED);
 
-                // Close the socket
-                wsRemoteEndpoint.close();
+            // Close the network connection.
+            closeConnection();
+        } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) {
+            /*
+             * The local endpoint sent a close message the the same time as the remote endpoint. The local close is
+             * still being processed. Update the state so the the local close process will also close the network
+             * connection once it has finished sending a close message.
+             */
+        } else if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
+            /*
+             * The local endpoint sent the first close message. The remote endpoint has now responded with its own close
+             * message so mark the session as fully closed and close the network connection.
+             */
+            closeConnection();
+        }
+        // CLOSING and CLOSED are NO-OPs
+    }
+
+
+    private void closeConnection() {
+        /*
+         * Close the network connection.
+         */
+        wsRemoteEndpoint.close();
+        /*
+         * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for
+         * tracking the session close timeout.
+         */
+        webSocketContainer.unregisterSession(getSessionMapKey(), this);
+    }
+
+
+    /*
+     * Returns the session close timeout in milliseconds
+     */
+    protected long getSessionCloseTimeout() {
+        long result = 0;
+        Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY);
+        if (obj instanceof Long) {
+            result = ((Long) obj).intValue();
+        }
+        if (result <= 0) {
+            result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT;
+        }
+        return result;
+    }
+
+
+    protected void checkCloseTimeout() {
+        // Skip the check if no session close timeout has been set.
+        if (sessionCloseTimeoutExpiry != null) {
+            // Check if the timeout has expired.
+            if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) {
+                // Check if the session has been closed in another thread while the timeout was being processed.
+                if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
+                    closeConnection();
+                }
             }
         }
     }
 
+
     private void fireEndpointOnClose(CloseReason closeReason) {
 
         // Fire the onClose event
@@ -789,7 +735,6 @@ public class WsSession implements Session {
     }
 
 
-
     private void fireEndpointOnError(Throwable throwable) {
 
         // Fire the onError event
@@ -829,7 +774,7 @@ public class WsSession implements Session {
             if (log.isDebugEnabled()) {
                 log.debug(sm.getString("wsSession.sendCloseFail", id), e);
             }
-            wsRemoteEndpoint.close();
+            closeConnection();
             // Failure to send a close message is not unexpected in the case of
             // an abnormal closure (usually triggered by a failure to read/write
             // from/to the client. In this case do not trigger the endpoint's
@@ -837,8 +782,6 @@ public class WsSession implements Session {
             if (closeCode != CloseCodes.CLOSED_ABNORMALLY) {
                 localEndpoint.onError(this, e);
             }
-        } finally {
-            webSocketContainer.unregisterSession(getSessionMapKey(), this);
         }
     }
 
@@ -855,7 +798,8 @@ public class WsSession implements Session {
 
     /**
      * Use protected so unit tests can access this method directly.
-     * @param msg The message
+     *
+     * @param msg    The message
      * @param reason The reason
      */
     protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String reason) {
@@ -885,9 +829,9 @@ public class WsSession implements Session {
 
 
     /**
-     * Make the session aware of a {@link FutureToSendHandler} that will need to
-     * be forcibly closed if the session closes before the
-     * {@link FutureToSendHandler} completes.
+     * Make the session aware of a {@link FutureToSendHandler} that will need to be forcibly closed if the session
+     * closes before the {@link FutureToSendHandler} completes.
+     *
      * @param f2sh The handler
      */
     protected void registerFuture(FutureToSendHandler f2sh) {
@@ -900,13 +844,13 @@ public class WsSession implements Session {
         // Always register the future.
         futures.put(f2sh, f2sh);
 
-        if (state == State.OPEN) {
+        if (isOpen()) {
             // The session is open. The future has been registered with the open
             // session. Normal processing continues.
             return;
         }
 
-        // The session is closed. The future may or may not have been registered
+        // The session is closing / closed. The future may or may not have been registered
         // in time for it to be processed during session closure.
 
         if (f2sh.isDone()) {
@@ -916,7 +860,7 @@ public class WsSession implements Session {
             return;
         }
 
-        // The session is closed. The Future had not completed when last checked.
+        // The session is closing / closed. The Future had not completed when last checked.
         // There is a small timing window that means the Future may have been
         // completed since the last check. There is also the possibility that
         // the Future was not registered in time to be cleaned up during session
@@ -928,7 +872,7 @@ public class WsSession implements Session {
         // complete the Future but knowing if this is the case requires the sync
         // on stateLock (see above).
         // Note: If multiple attempts are made to complete the Future, the
-        //       second and subsequent attempts are ignored.
+        // second and subsequent attempts are ignored.
 
         IOException ioe = new IOException(sm.getString("wsSession.messageFailed"));
         SendResult sr = new SendResult(ioe);
@@ -938,6 +882,7 @@ public class WsSession implements Session {
 
     /**
      * Remove a {@link FutureToSendHandler} from the set of tracked instances.
+     *
      * @param f2sh The handler
      */
     protected void unregisterFuture(FutureToSendHandler f2sh) {
@@ -969,6 +914,11 @@ public class WsSession implements Session {
     @Override
     public Principal getUserPrincipal() {
         checkState();
+        return getUserPrincipalInternal();
+    }
+
+
+    public Principal getUserPrincipalInternal() {
         return userPrincipal;
     }
 
@@ -1075,10 +1025,10 @@ public class WsSession implements Session {
 
 
     private void checkState() {
-        if (state == State.CLOSED) {
+        if (isClosed()) {
             /*
-             * As per RFC 6455, a WebSocket connection is considered to be
-             * closed once a peer has sent and received a WebSocket close frame.
+             * As per RFC 6455, a WebSocket connection is considered to be closed once a peer has sent and received a
+             * WebSocket close frame.
              */
             throw new IllegalStateException(sm.getString("wsSession.closed", id));
         }
@@ -1086,12 +1036,15 @@ public class WsSession implements Session {
 
     private enum State {
         OPEN,
+        OUTPUT_CLOSING,
         OUTPUT_CLOSED,
+        CLOSING,
         CLOSED
     }
 
 
     private WsFrameBase wsFrame;
+
     void setWsFrame(WsFrameBase wsFrame) {
         this.wsFrame = wsFrame;
     }
@@ -1111,4 +1064,4 @@ public class WsSession implements Session {
     public void resume() {
         wsFrame.resume();
     }
-}
+}
\ No newline at end of file
diff --git a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsWebSocketContainer.java b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
index 509fabe979..04a6987e2a 100644
--- a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
+++ b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
@@ -646,7 +646,12 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce
         synchronized (endPointSessionMapLock) {
             Set<WsSession> sessions = endpointSessionMap.get(key);
             if (sessions != null) {
-                result.addAll(sessions);
+                // Some sessions may be in the process of closing
+                for (WsSession session : sessions) {
+                    if (session.isOpen()) {
+                        result.add(session);
+                    }
+                }
             }
         }
         return result;
@@ -1111,9 +1116,10 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce
         backgroundProcessCount ++;
         if (backgroundProcessCount >= processPeriod) {
             backgroundProcessCount = 0;
-
+            // Some sessions may be in the process of closing
             for (WsSession wsSession : sessions.keySet()) {
                 wsSession.checkExpiration();
+                wsSession.checkCloseTimeout();
             }
         }
 
diff --git a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/server/WsServerContainer.java b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/server/WsServerContainer.java
index 6d2171989a..c945ee12f4 100644
--- a/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/server/WsServerContainer.java
+++ b/tomee/apache-tomee/src/patch/java/org/apache/tomcat/websocket/server/WsServerContainer.java
@@ -475,7 +475,7 @@ public class WsServerContainer extends WsWebSocketContainer
      */
     @Override
     protected void unregisterSession(Object key, WsSession wsSession) {
-        if (wsSession.getUserPrincipal() != null &&
+        if (wsSession.getUserPrincipalInternal() != null &&
                 wsSession.getHttpSessionId() != null) {
             unregisterAuthenticatedSession(wsSession,
                     wsSession.getHttpSessionId());