You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cxf.apache.org by re...@apache.org on 2021/06/27 00:25:50 UTC

[cxf] branch master updated: CXF-8559: SseInterceptor uses PROTOCOL_HEADERS from request (in) message, not from response (out) one (#821)

This is an automated email from the ASF dual-hosted git repository.

reta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/cxf.git


The following commit(s) were added to refs/heads/master by this push:
     new 3abb509  CXF-8559: SseInterceptor uses PROTOCOL_HEADERS from request (in) message, not from response (out) one (#821)
3abb509 is described below

commit 3abb5099003e078f3537051eaa36029ea0131af4
Author: Andriy Redko <dr...@gmail.com>
AuthorDate: Sat Jun 26 20:25:39 2021 -0400

    CXF-8559: SseInterceptor uses PROTOCOL_HEADERS from request (in) message, not from response (out) one (#821)
---
 .../cxf/jaxrs/sse/interceptor/SseInterceptor.java  | 46 +++++++++++++---------
 .../cxf/systest/jaxrs/sse/AbstractSseTest.java     | 10 +++++
 .../apache/cxf/systest/jaxrs/sse/BookStore.java    | 18 ++++++++-
 .../systest/jaxrs/sse/BookStoreResponseFilter.java | 22 ++++++++++-
 4 files changed, 75 insertions(+), 21 deletions(-)

diff --git a/rt/rs/sse/src/main/java/org/apache/cxf/jaxrs/sse/interceptor/SseInterceptor.java b/rt/rs/sse/src/main/java/org/apache/cxf/jaxrs/sse/interceptor/SseInterceptor.java
index 0599ee0..4e510e5 100644
--- a/rt/rs/sse/src/main/java/org/apache/cxf/jaxrs/sse/interceptor/SseInterceptor.java
+++ b/rt/rs/sse/src/main/java/org/apache/cxf/jaxrs/sse/interceptor/SseInterceptor.java
@@ -31,6 +31,8 @@ import javax.ws.rs.core.Response;
 import javax.ws.rs.sse.SseEventSink;
 
 import org.apache.cxf.common.logging.LogUtils;
+import org.apache.cxf.endpoint.Endpoint;
+import org.apache.cxf.helpers.CastUtils;
 import org.apache.cxf.jaxrs.impl.ResponseImpl;
 import org.apache.cxf.jaxrs.model.OperationResourceInfo;
 import org.apache.cxf.jaxrs.provider.ServerProviderFactory;
@@ -38,6 +40,7 @@ import org.apache.cxf.jaxrs.utils.HttpUtils;
 import org.apache.cxf.jaxrs.utils.JAXRSUtils;
 import org.apache.cxf.message.Exchange;
 import org.apache.cxf.message.Message;
+import org.apache.cxf.message.MessageImpl;
 import org.apache.cxf.phase.AbstractPhaseInterceptor;
 import org.apache.cxf.phase.Phase;
 import org.apache.cxf.transport.http.AbstractHTTPDestination;
@@ -72,17 +75,7 @@ public class SseInterceptor extends AbstractPhaseInterceptor<Message> {
                 if (response instanceof HttpServletResponse) {
                     servletResponse = (HttpServletResponse)response;
                     builder = Response.status(servletResponse.getStatus());
-                    
-                    @SuppressWarnings("unchecked")
-                    final Map<String, List<Object>> userHeaders = (Map<String, List<Object>>)message
-                        .get(Message.PROTOCOL_HEADERS);
 
-                    if (userHeaders != null) {
-                        for (Map.Entry<String, List<Object>> entry: userHeaders.entrySet()) {
-                            addHeader(builder, entry);
-                        }
-                    }
-                    
                     for (final String header: servletResponse.getHeaderNames()) {
                         final Collection<String> headers = servletResponse.getHeaders(header);
                         addHeader(builder, header, headers);
@@ -92,13 +85,23 @@ public class SseInterceptor extends AbstractPhaseInterceptor<Message> {
                 // Run the filters
                 try {
                     final ResponseImpl responseImpl = (ResponseImpl)builder.build();
+                    final Message outMessage = getOutMessage(message);
 
                     JAXRSUtils.runContainerResponseFilters(providerFactory, responseImpl, 
-                        message, ori, ori.getAnnotatedMethod());
+                        outMessage, ori, ori.getAnnotatedMethod());
 
                     if (servletResponse != null) {
-                        final MultivaluedMap<String, String> headers = responseImpl.getStringHeaders();
                         servletResponse.setStatus(responseImpl.getStatus());
+                        
+                        final Map<String, List<String>> userHeaders =  CastUtils.cast((Map<?, ?>)outMessage
+                            .get(Message.PROTOCOL_HEADERS));
+                        if (userHeaders != null) {
+                            for (Map.Entry<String, List<String>> entry: userHeaders.entrySet()) {
+                                setHeader(servletResponse, entry);
+                            }
+                        }
+                        
+                        final MultivaluedMap<String, String> headers = responseImpl.getStringHeaders();
                         if (headers != null) {
                             for (Map.Entry<String, List<String>> entry: headers.entrySet()) {
                                 setHeader(servletResponse, entry);
@@ -113,13 +116,20 @@ public class SseInterceptor extends AbstractPhaseInterceptor<Message> {
             }
         }
     }
-  
-    private void addHeader(Response.ResponseBuilder builder, Map.Entry<String, List<Object>> entry) {
-        if (entry.getValue() != null) {
-            for (Object value: entry.getValue()) {
-                builder.header(entry.getKey(), value);
-            }
+    
+    private Message getOutMessage(final Message message) {
+        final Exchange exchange = message.getExchange();
+        Message outMessage = message.getExchange().getOutMessage();
+        
+        if (outMessage == null) {
+            final Endpoint ep = exchange.getEndpoint();
+            outMessage = new MessageImpl();
+            outMessage.setExchange(exchange);
+            outMessage = ep.getBinding().createMessage(outMessage);
+            message.getExchange().setOutMessage(outMessage);
         }
+
+        return outMessage;
     }
     
     private void addHeader(Response.ResponseBuilder builder, final String header, final Collection<String> headers) {
diff --git a/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/AbstractSseTest.java b/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/AbstractSseTest.java
index f6ad024..8a82b21 100644
--- a/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/AbstractSseTest.java
+++ b/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/AbstractSseTest.java
@@ -325,6 +325,16 @@ public abstract class AbstractSseTest extends AbstractSseBaseTest {
         assertTrue(stats.getCompleted() == 2 || stats.getCompleted() == 1);
     }
 
+    @Test
+    public void testBooksSseContainerResponseAddedHeaders() throws InterruptedException {
+        final WebTarget target = createWebTarget("/rest/api/bookstore/headers/sse");
+        try (Response response = target.request(MediaType.SERVER_SENT_EVENTS).get()) {
+            assertThat(response.getStatus(), equalTo(202));
+            assertThat(response.getHeaderString("X-My-Header"), equalTo("headers"));
+            assertThat(response.getHeaderString("X-My-ProtocolHeader"), equalTo("protocol-headers"));
+        }
+    }
+
     /**
      * Jetty / Undertow do not propagate errors from the runnable passed to
      * AsyncContext::start() up to the AsyncEventListener::onError(). Tomcat however
diff --git a/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStore.java b/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStore.java
index de28eef..7fa54c5 100644
--- a/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStore.java
+++ b/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStore.java
@@ -183,7 +183,23 @@ public class BookStore extends BookStoreClientCloseable {
             }
         }.start();
     }
-    
+
+    @GET
+    @Path("/headers/sse")
+    @Produces(MediaType.SERVER_SENT_EVENTS)
+    public void headers(@Context SseEventSink sink) {
+        new Thread() {
+            public void run() {
+                try {
+                    Thread.sleep(200);
+                    sink.close();
+                } catch (final InterruptedException ex) {
+                    LOG.error("Communication error", ex);
+                }
+            }
+        }.start();
+    }
+
     @GET
     @Path("/filtered/stats")
     @Produces(MediaType.TEXT_PLAIN)
diff --git a/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStoreResponseFilter.java b/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStoreResponseFilter.java
index 0dba7d6..0a8270e 100644
--- a/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStoreResponseFilter.java
+++ b/systests/rs-sse/rs-sse-base/src/main/java/org/apache/cxf/systest/jaxrs/sse/BookStoreResponseFilter.java
@@ -20,6 +20,9 @@
 package org.apache.cxf.systest.jaxrs.sse;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import javax.ws.rs.container.ContainerRequestContext;
@@ -29,6 +32,10 @@ import javax.ws.rs.core.Context;
 import javax.ws.rs.core.UriInfo;
 import javax.ws.rs.ext.Provider;
 
+import org.apache.cxf.jaxrs.utils.JAXRSUtils;
+import org.apache.cxf.message.Message;
+import org.apache.cxf.transport.http.Headers;
+
 @Provider
 public class BookStoreResponseFilter implements ContainerResponseFilter {
     private static AtomicInteger counter = new AtomicInteger(0);
@@ -40,9 +47,20 @@ public class BookStoreResponseFilter implements ContainerResponseFilter {
 
     @Override
     public void filter(ContainerRequestContext reqContext, ContainerResponseContext rspContext) throws IOException {
-        if (!uriInfo.getRequestUri().getPath().endsWith("/filtered/stats")) {
+        final String path = uriInfo.getRequestUri().getPath();
+        
+        if (!path.endsWith("/filtered/stats")) {
             counter.incrementAndGet();
-        }
+            
+            if (path.endsWith("/headers/sse")) {
+                rspContext.setStatus(202);
+                rspContext.getHeaders().add("X-My-Header", "headers");
+                
+                final Message message = JAXRSUtils.getCurrentMessage().getExchange().getOutMessage();
+                final Map<String, List<String>> headers = Headers.getSetProtocolHeaders(message);
+                headers.put("X-My-ProtocolHeader", Collections.singletonList("protocol-headers"));
+            }
+        } 
     }
     
     public static int getInvocations() {