You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shenyu.apache.org by sa...@apache.org on 2022/07/01 13:36:08 UTC

[incubator-shenyu] branch master updated: [issue #3508] support cors whitelist. (#3646)

This is an automated email from the ASF dual-hosted git repository.

sabersola pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-shenyu.git


The following commit(s) were added to refs/heads/master by this push:
     new 8eed0ccc6 [issue #3508] support cors whitelist. (#3646)
8eed0ccc6 is described below

commit 8eed0ccc6e0d839fd5cbdc311a2ef66a2071547a
Author: Qicz <qi...@gmail.com>
AuthorDate: Fri Jul 1 21:36:02 2022 +0800

    [issue #3508] support cors whitelist. (#3646)
    
    * [issue #3508] support cors whitelist.
    
    * [type: refactor] optimize and safety.
    
    * optimize cors
    
    * refactor whitelist logic
    
    * optimize prefix check
    
    * code polish
    
    * code polish
    
    * code polish
    
    * optimize cors logic
---
 .../src/main/resources/application.yml             |  9 +++-
 .../apache/shenyu/common/config/ShenyuConfig.java  | 58 ++++++++++++++++++++--
 .../shenyu/common/config/ShenyuConfigTest.java     |  2 +-
 .../org/apache/shenyu/web/filter/CrossFilter.java  | 21 +++++---
 .../apache/shenyu/web/filter/CrossFilterTest.java  | 27 ++++++++++
 5 files changed, 104 insertions(+), 13 deletions(-)

diff --git a/shenyu-bootstrap/src/main/resources/application.yml b/shenyu-bootstrap/src/main/resources/application.yml
index b5c2ea5a9..4efb30490 100644
--- a/shenyu-bootstrap/src/main/resources/application.yml
+++ b/shenyu-bootstrap/src/main/resources/application.yml
@@ -159,10 +159,15 @@ shenyu:
     enabled: true
     allowedHeaders:
     allowedMethods: "*"
-    allowedOrigin: "*"
-    allowedExpose: "*"
+    allowedOrigin:
+      domain: apache.org
+      prefixes:
+        - a # a.apache.org
+        - b # b.apache.org
+    allowedExpose: ""
     maxAge: "18000"
     allowCredentials: true
+
   switchConfig:
     local: true
   file:
diff --git a/shenyu-common/src/main/java/org/apache/shenyu/common/config/ShenyuConfig.java b/shenyu-common/src/main/java/org/apache/shenyu/common/config/ShenyuConfig.java
index 925e572be..363ed6cb4 100644
--- a/shenyu-common/src/main/java/org/apache/shenyu/common/config/ShenyuConfig.java
+++ b/shenyu-common/src/main/java/org/apache/shenyu/common/config/ShenyuConfig.java
@@ -23,6 +23,7 @@ import org.springframework.util.StringUtils;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Objects;
 import java.util.Properties;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -964,9 +965,9 @@ public class ShenyuConfig {
          */
         private String allowedMethods = "*";
 
-        private String allowedOrigin = "*";
+        private AllowedOriginConfig allowedOrigin = new AllowedOriginConfig();
 
-        private String allowedExpose = "*";
+        private String allowedExpose = "";
 
         private String maxAge = "18000";
 
@@ -1045,7 +1046,7 @@ public class ShenyuConfig {
          *
          * @return the value of allowedOrigin
          */
-        public String getAllowedOrigin() {
+        public AllowedOriginConfig getAllowedOrigin() {
             return allowedOrigin;
         }
     
@@ -1054,7 +1055,7 @@ public class ShenyuConfig {
          *
          * @param allowedOrigin allowedOrigin
          */
-        public void setAllowedOrigin(final String allowedOrigin) {
+        public void setAllowedOrigin(final AllowedOriginConfig allowedOrigin) {
             this.allowedOrigin = allowedOrigin;
         }
     
@@ -1111,6 +1112,55 @@ public class ShenyuConfig {
         public void setAllowCredentials(final boolean allowCredentials) {
             this.allowCredentials = allowCredentials;
         }
+
+        /**
+         * the cors allowedOrigin config.
+         */
+        public static class AllowedOriginConfig {
+
+            private String domain;
+
+            private Set<String> prefixes = new HashSet<>();
+
+            /**
+             * Gets the domain.
+             *
+             * @return the value of domain
+             */
+            public String getDomain() {
+                return domain;
+            }
+
+            /**
+             * Sets the enabled.
+             *
+             * @param domain enabled
+             */
+            public void setDomain(final String domain) {
+                this.domain = domain;
+            }
+
+            /**
+             * Gets the prefixes.
+             *
+             * @return the value of prefixes
+             */
+            public Set<String> getPrefixes() {
+                if (Objects.isNull(prefixes)) {
+                    prefixes = new HashSet<>();
+                }
+                return prefixes;
+            }
+
+            /**
+             * Sets the enabled.
+             *
+             * @param prefixes enabled
+             */
+            public void setPrefixes(final Set<String> prefixes) {
+                this.prefixes = prefixes;
+            }
+        }
     }
     
     /**
diff --git a/shenyu-common/src/test/java/org/apache/shenyu/common/config/ShenyuConfigTest.java b/shenyu-common/src/test/java/org/apache/shenyu/common/config/ShenyuConfigTest.java
index 9f0677d27..ff8da4d3a 100644
--- a/shenyu-common/src/test/java/org/apache/shenyu/common/config/ShenyuConfigTest.java
+++ b/shenyu-common/src/test/java/org/apache/shenyu/common/config/ShenyuConfigTest.java
@@ -258,7 +258,7 @@ public class ShenyuConfigTest {
 
         String allowedExpose = cross.getAllowedExpose();
         String allowedHeaders = cross.getAllowedHeaders();
-        String allowedOrigin = cross.getAllowedOrigin();
+        ShenyuConfig.CrossFilterConfig.AllowedOriginConfig allowedOrigin = cross.getAllowedOrigin();
         Boolean enabled = cross.getEnabled();
         String maxAge = cross.getMaxAge();
         String allowedMethods = cross.getAllowedMethods();
diff --git a/shenyu-web/src/main/java/org/apache/shenyu/web/filter/CrossFilter.java b/shenyu-web/src/main/java/org/apache/shenyu/web/filter/CrossFilter.java
index fd0e2d645..6356b3bf0 100644
--- a/shenyu-web/src/main/java/org/apache/shenyu/web/filter/CrossFilter.java
+++ b/shenyu-web/src/main/java/org/apache/shenyu/web/filter/CrossFilter.java
@@ -33,6 +33,7 @@ import org.springframework.web.server.WebFilterChain;
 import reactor.core.publisher.Mono;
 
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -58,12 +59,20 @@ public class CrossFilter implements WebFilter {
             ServerHttpResponse response = exchange.getResponse();
             HttpHeaders headers = response.getHeaders();
             // "Access-Control-Allow-Origin"
-            // if the allowed origin is empty use the request 's origin
-            if (StringUtils.isBlank(this.filterConfig.getAllowedOrigin())) {
-                headers.set(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, request.getHeaders().getOrigin());
-            } else {
-                this.filterSameHeader(headers, HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN,
-                        this.filterConfig.getAllowedOrigin());
+            if (Objects.nonNull(this.filterConfig.getAllowedOrigin())
+                    && CollectionUtils.isNotEmpty(this.filterConfig.getAllowedOrigin().getPrefixes())) {
+                final String scheme = exchange.getRequest().getURI().getScheme();
+                Set<String> allowedOrigin = this.filterConfig.getAllowedOrigin().getPrefixes()
+                        .stream()
+                        .filter(StringUtils::isNoneBlank)
+                        // scheme://prefix.domain
+                        .map(prefix -> String.format("%s://%s.%s", scheme, prefix.trim(), this.filterConfig.getAllowedOrigin().getDomain()))
+                        .collect(Collectors.toSet());
+                String origin = request.getHeaders().getOrigin();
+                if (allowedOrigin.contains(origin)) {
+                    origin = String.join(",", allowedOrigin);
+                    headers.set(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
+                }
             }
             // "Access-Control-Allow-Methods"
             this.filterSameHeader(headers, HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS,
diff --git a/shenyu-web/src/test/java/org/apache/shenyu/web/filter/CrossFilterTest.java b/shenyu-web/src/test/java/org/apache/shenyu/web/filter/CrossFilterTest.java
index 7d834dfb0..ec942bc1c 100644
--- a/shenyu-web/src/test/java/org/apache/shenyu/web/filter/CrossFilterTest.java
+++ b/shenyu-web/src/test/java/org/apache/shenyu/web/filter/CrossFilterTest.java
@@ -26,6 +26,8 @@ import org.springframework.web.server.WebFilterChain;
 import reactor.core.publisher.Mono;
 import reactor.test.StepVerifier;
 
+import java.util.HashSet;
+
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -72,4 +74,29 @@ public final class CrossFilterTest {
                 .verifyComplete();
     }
 
+    /**
+     * test method for {@link CrossFilter#filter(ServerWebExchange, WebFilterChain)}.
+     */
+    @Test
+    public void testCorsWhitelist() {
+        ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest
+                .get("http://localhost:8080")
+                .header("Origin", "a.apache.org")
+                .build());
+        WebFilterChain chainNormal = mock(WebFilterChain.class);
+        when(chainNormal.filter(exchange)).thenReturn(Mono.empty());
+        final CrossFilterConfig filterConfig = new CrossFilterConfig();
+        CrossFilterConfig.AllowedOriginConfig allowedOriginConfig = new CrossFilterConfig.AllowedOriginConfig();
+        allowedOriginConfig.setDomain("apache.org");
+        allowedOriginConfig.setPrefixes(new HashSet<String>() {
+            {
+                add("a");
+            }
+        });
+        filterConfig.setAllowedOrigin(allowedOriginConfig);
+        CrossFilter filter = new CrossFilter(filterConfig);
+        StepVerifier.create(filter.filter(exchange, chainNormal))
+                .expectSubscription()
+                .verifyComplete();
+    }
 }