You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nifi.apache.org by jo...@apache.org on 2021/02/10 20:35:27 UTC
[nifi] branch main updated: NIFI-8218 This closes #4816. Use proxy
headers when available when getting request values while processing SAML
responses
This is an automated email from the ASF dual-hosted git repository.
joewitt pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git
The following commit(s) were added to refs/heads/main by this push:
new 1d82fb8 NIFI-8218 This closes #4816. Use proxy headers when available when getting request values while processing SAML responses
1d82fb8 is described below
commit 1d82fb8e01f3e8d3b25fcd773eaa7add03aad363
Author: Bryan Bende <bb...@apache.org>
AuthorDate: Wed Feb 10 09:39:27 2021 -0500
NIFI-8218 This closes #4816. Use proxy headers when available when getting request values while processing SAML responses
Signed-off-by: Joe Witt <jo...@apache.org>
---
.../java/org/apache/nifi/web/util/WebUtils.java | 153 +++++++++++++++++++--
...bUtilsTest.groovy => WebUtilsGroovyTest.groovy} | 4 +-
.../org/apache/nifi/web/util/WebUtilsTest.java | 139 +++++++++++++++++++
.../apache/nifi/web/api/ApplicationResource.java | 142 ++++++-------------
.../nifi/web/api/TestDataTransferResource.java | 51 +++----
.../saml/impl/NiFiSAMLContextProviderImpl.java | 59 ++------
.../http/HttpServletRequestWithParameters.java | 64 +++++++++
.../http/ProxyAwareHttpServletRequestWrapper.java | 96 +++++++++++++
.../http/TestHttpServletRequestWithParameters.java | 98 +++++++++++++
.../TestProxyAwareHttpServletRequestWrapper.java | 77 +++++++++++
10 files changed, 696 insertions(+), 187 deletions(-)
diff --git a/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java b/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java
index 578206a..5e508e3 100644
--- a/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java
+++ b/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java
@@ -16,18 +16,6 @@
*/
package org.apache.nifi.web.util;
-import java.net.URI;
-import java.util.Arrays;
-import java.util.List;
-import java.util.concurrent.locks.ReadWriteLock;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
-import java.util.stream.Stream;
-import javax.net.ssl.SSLContext;
-import javax.servlet.ServletRequest;
-import javax.servlet.http.HttpServletRequest;
-import javax.ws.rs.client.Client;
-import javax.ws.rs.client.ClientBuilder;
-import javax.ws.rs.core.UriBuilderException;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.conn.ssl.DefaultHostnameVerifier;
import org.glassfish.jersey.client.ClientConfig;
@@ -35,6 +23,19 @@ import org.glassfish.jersey.jackson.internal.jackson.jaxrs.json.JacksonJaxbJsonP
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.net.ssl.SSLContext;
+import javax.servlet.ServletRequest;
+import javax.servlet.http.HttpServletRequest;
+import javax.ws.rs.client.Client;
+import javax.ws.rs.client.ClientBuilder;
+import javax.ws.rs.core.UriBuilderException;
+import java.net.URI;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.stream.Stream;
+
/**
* Common utilities related to web development.
*/
@@ -44,9 +45,17 @@ public final class WebUtils {
final static ReadWriteLock lock = new ReentrantReadWriteLock();
- private static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
- private static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
- private static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
+ public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme";
+ public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost";
+ public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort";
+
+ public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto";
+ public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host";
+ public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port";
+
+ public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
+ public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
+ public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
private WebUtils() {
}
@@ -248,4 +257,118 @@ public final class WebUtils {
return false;
}
+ /**
+ * Returns the value for the first key discovered when inspecting the current request. Will
+ * return null if there are no keys specified or if none of the specified keys are found.
+ *
+ * @param httpServletRequest request
+ * @param keys http header keys
+ * @return the value for the first key found, or null if no matching keys found
+ */
+ public static String getFirstHeaderValue(final HttpServletRequest httpServletRequest, final String... keys) {
+ if (keys == null) {
+ return null;
+ }
+
+ for (final String key : keys) {
+ final String value = httpServletRequest.getHeader(key);
+
+ // if we found an entry for this key, return the value
+ if (value != null) {
+ return value;
+ }
+ }
+
+ // unable to find any matching keys
+ return null;
+ }
+
+ /**
+ * Determines the scheme based on considering proxy related headers first and then falling back to the scheme of the servlet request.
+ *
+ * @param httpServletRequest the request
+ * @return the determined scheme
+ */
+ public static String determineProxiedScheme(final HttpServletRequest httpServletRequest) {
+ final String schemeHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_SCHEME_HTTP_HEADER, FORWARDED_PROTO_HTTP_HEADER);
+ return StringUtils.isBlank(schemeHeaderValue) ? httpServletRequest.getScheme() : schemeHeaderValue;
+ }
+
+ /**
+ * Determines the host based on considering proxy related headers first and falling back to the host of the servlet request.
+ *
+ * @param httpServletRequest the request
+ * @return the determined host
+ */
+ public static String determineProxiedHost(final HttpServletRequest httpServletRequest) {
+ final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
+ final String proxiedHost = determineProxiedHost(hostHeaderValue);
+ return StringUtils.isBlank(proxiedHost) ? httpServletRequest.getServerName() : proxiedHost;
+ }
+
+ /**
+ * Determines the host from the given header. The header value is intended to come from a header like X-ProxyHost or X-Forwarded-Host.
+ *
+ * @param hostHeaderValue the header value
+ * @return the determined host, or null if a host can't be determined
+ */
+ public static String determineProxiedHost(final String hostHeaderValue) {
+ final String host;
+ // check for a port in the proxied host header
+ String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
+ if (hostSplits.length >= 1 && hostSplits.length <= 2) {
+ // zero or one occurrence of ':', this is an IPv4 address
+ // strip off the port by reassigning host the 0th split
+ host = hostSplits[0];
+ } else if (hostSplits.length == 0) {
+ // hostHeaderValue passed in was null, no splits
+ host = null;
+ } else {
+ // hostHeaderValue has more than one occurrence of ":", IPv6 address
+ host = hostHeaderValue;
+ }
+ return host;
+ }
+
+ /**
+ * Determines the port based on first considering proxy related headers and falling back to the port of the servlet request.
+ *
+ * @param httpServletRequest the request
+ * @return the determined port
+ */
+ public static String determineProxiedPort(final HttpServletRequest httpServletRequest) {
+ final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
+ final String portHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER);
+
+ final String proxiedPort = determineProxiedPort(hostHeaderValue, portHeaderValue);
+ return StringUtils.isBlank(proxiedPort) ? String.valueOf(httpServletRequest.getServerPort()) : proxiedPort;
+ }
+
+ /**
+ * Determines the port based on the header values. The header values are intended to come from headers like X-ProxyHost/X-ProxyPort
+ * or X-Forwarded-Host/X-Forwarded-Port.
+ *
+ * @param hostHeaderValue the host header value
+ * @param portHeaderValue the host port value
+ * @return the determined port, or null if one can't be determined
+ */
+ public static String determineProxiedPort(final String hostHeaderValue, final String portHeaderValue) {
+ final String port;
+ // check for a port in the proxied host header
+ String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
+ // determine the proxied port
+ final String portFromHostHeader;
+ if (hostSplits.length == 2) {
+ // if the port is specified in the proxied host header, it will be overridden by the
+ // port specified in X-ProxyPort or X-Forwarded-Port
+ portFromHostHeader = hostSplits[1];
+ } else {
+ portFromHostHeader = null;
+ }
+ if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) {
+ logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header"));
+ }
+ port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null);
+ return port;
+ }
}
diff --git a/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.groovy b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsGroovyTest.groovy
similarity index 99%
rename from nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.groovy
rename to nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsGroovyTest.groovy
index 5465c2c..c8eb7ff 100644
--- a/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.groovy
+++ b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsGroovyTest.groovy
@@ -39,8 +39,8 @@ import javax.ws.rs.core.UriBuilderException
import java.security.cert.X509Certificate
@RunWith(JUnit4.class)
-class WebUtilsTest extends GroovyTestCase {
- private static final Logger logger = LoggerFactory.getLogger(WebUtilsTest.class)
+class WebUtilsGroovyTest extends GroovyTestCase {
+ private static final Logger logger = LoggerFactory.getLogger(WebUtilsGroovyTest.class)
static final String PCP_HEADER = "X-ProxyContextPath"
static final String FC_HEADER = "X-Forwarded-Context"
diff --git a/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.java b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.java
new file mode 100644
index 0000000..38690a7
--- /dev/null
+++ b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.web.util;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import javax.servlet.http.HttpServletRequest;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.when;
+import static org.junit.Assert.assertEquals;
+
+@RunWith(MockitoJUnitRunner.class)
+public class WebUtilsTest {
+
+ @Mock
+ private HttpServletRequest request;
+
+ // -- scheme tests
+
+ @Test
+ public void testDeterminedProxiedSchemeWhenNoHeaders() {
+ when(request.getHeader(any())).thenReturn(null);
+ when(request.getScheme()).thenReturn("https");
+ assertEquals("https", WebUtils.determineProxiedScheme(request));
+ }
+
+ @Test
+ public void testDeterminedProxiedSchemeWhenXProxySchemeAvailable() {
+ when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("http");
+ assertEquals("http", WebUtils.determineProxiedScheme(request));
+ }
+
+ @Test
+ public void testDeterminedProxiedSchemeWhenXForwardedProtoAvailable() {
+ when(request.getHeader(eq(WebUtils.FORWARDED_PROTO_HTTP_HEADER))).thenReturn("http");
+ assertEquals("http", WebUtils.determineProxiedScheme(request));
+ }
+
+ // -- host tests
+
+ @Test
+ public void testDetermineProxiedHostWhenNoHeaders() {
+ when(request.getHeader(any())).thenReturn(null);
+ when(request.getServerName()).thenReturn("localhost");
+ assertEquals("localhost", WebUtils.determineProxiedHost(request));
+ }
+
+ @Test
+ public void testDetermineProxiedHostWhenXProxyHostAvailable() {
+ when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host");
+ assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request));
+ }
+
+ @Test
+ public void testDetermineProxiedHostWhenXProxyHostAvailableWithPort() {
+ when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:443");
+ assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request));
+ }
+
+ @Test
+ public void testDetermineProxiedHostWhenXForwardedHostAvailable() {
+ when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host");
+ assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request));
+ }
+
+ @Test
+ public void testDetermineProxiedHostWhenXForwardedHostAvailableWithPort() {
+ when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:443");
+ assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request));
+ }
+
+ // -- port tests
+
+ @Test
+ public void testDetermineProxiedPortWhenNoHeaders() {
+ when(request.getServerPort()).thenReturn(443);
+ assertEquals("443", WebUtils.determineProxiedPort(request));
+ }
+
+ @Test
+ public void testDetermineProxiedPortWhenXProxyPortAvailable() {
+ when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host");
+ when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443");
+ assertEquals("8443", WebUtils.determineProxiedPort(request));
+ }
+
+ @Test
+ public void testDetermineProxiedPortWhenPortInXProxyHost() {
+ when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234");
+ assertEquals("1234", WebUtils.determineProxiedPort(request));
+ }
+
+ @Test
+ public void testDetermineProxiedPortWhenXProxyPortOverridesXProxy() {
+ when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234");
+ when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443");
+ assertEquals("8443", WebUtils.determineProxiedPort(request));
+ }
+
+ @Test
+ public void testDetermineProxiedPortWhenXForwardedPortAvailable() {
+ when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host");
+ when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443");
+ assertEquals("8443", WebUtils.determineProxiedPort(request));
+ }
+
+ @Test
+ public void testDetermineProxiedPortWhenPortInXForwardedHost() {
+ when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234");
+ assertEquals("1234", WebUtils.determineProxiedPort(request));
+ }
+
+ @Test
+ public void testDetermineProxiedPortWhenXForwardedPortOverridesXForwardedHost() {
+ when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234");
+ when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443");
+ assertEquals("8443", WebUtils.determineProxiedPort(request));
+ }
+
+}
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java
index 196ba27..b24b9a3 100644
--- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java
@@ -16,40 +16,8 @@
*/
package org.apache.nifi.web.api;
-import static javax.ws.rs.core.Response.Status.NOT_FOUND;
-import static org.apache.commons.lang3.StringUtils.isEmpty;
-import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME;
-import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE;
-
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
-import java.net.URI;
-import java.net.URISyntaxException;
-import java.nio.charset.StandardCharsets;
-import java.util.Collections;
-import java.util.Enumeration;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.TreeMap;
-import java.util.UUID;
-import java.util.concurrent.TimeUnit;
-import java.util.function.BiFunction;
-import java.util.function.Consumer;
-import java.util.function.Function;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import javax.ws.rs.core.CacheControl;
-import javax.ws.rs.core.Context;
-import javax.ws.rs.core.MediaType;
-import javax.ws.rs.core.MultivaluedHashMap;
-import javax.ws.rs.core.MultivaluedMap;
-import javax.ws.rs.core.Response;
-import javax.ws.rs.core.Response.ResponseBuilder;
-import javax.ws.rs.core.UriBuilder;
-import javax.ws.rs.core.UriBuilderException;
-import javax.ws.rs.core.UriInfo;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.authorization.AuthorizableLookup;
import org.apache.nifi.authorization.AuthorizeAccess;
@@ -92,6 +60,45 @@ import org.apache.nifi.web.util.WebUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.ws.rs.core.CacheControl;
+import javax.ws.rs.core.Context;
+import javax.ws.rs.core.MediaType;
+import javax.ws.rs.core.MultivaluedHashMap;
+import javax.ws.rs.core.MultivaluedMap;
+import javax.ws.rs.core.Response;
+import javax.ws.rs.core.Response.ResponseBuilder;
+import javax.ws.rs.core.UriBuilder;
+import javax.ws.rs.core.UriBuilderException;
+import javax.ws.rs.core.UriInfo;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+import static javax.ws.rs.core.Response.Status.NOT_FOUND;
+import static org.apache.commons.lang3.StringUtils.isEmpty;
+import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME;
+import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE;
+import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER;
+import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER;
+import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER;
+import static org.apache.nifi.web.util.WebUtils.FORWARDED_PROTO_HTTP_HEADER;
+import static org.apache.nifi.web.util.WebUtils.FORWARDED_HOST_HTTP_HEADER;
+import static org.apache.nifi.web.util.WebUtils.FORWARDED_PORT_HTTP_HEADER;
+
/**
* Base class for controllers.
*/
@@ -101,19 +108,6 @@ public abstract class ApplicationResource {
public static final String CLIENT_ID = "clientId";
public static final String DISCONNECTED_NODE_ACKNOWLEDGED = "disconnectedNodeAcknowledged";
- public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme";
- public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost";
- public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort";
- public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
-
- public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto";
- public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host";
- public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port";
- public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
-
- // Traefik-specific headers
- public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
-
protected static final String NON_GUARANTEED_ENDPOINT = "Note: This endpoint is subject to change as NiFi and it's REST API evolve.";
private static final Logger logger = LoggerFactory.getLogger(ApplicationResource.class);
@@ -157,8 +151,8 @@ public abstract class ApplicationResource {
final String hostHeaderValue = getFirstHeaderValue(PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
final String portHeaderValue = getFirstHeaderValue(PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER);
- final String host = determineProxiedHost(hostHeaderValue);
- final String port = determineProxiedPort(hostHeaderValue, portHeaderValue);
+ final String host = WebUtils.determineProxiedHost(hostHeaderValue);
+ final String port = WebUtils.determineProxiedPort(hostHeaderValue, portHeaderValue);
// Catch header poisoning
String allowedContextPaths = properties.getAllowedContextPaths();
@@ -194,44 +188,6 @@ public abstract class ApplicationResource {
return uri;
}
- private String determineProxiedHost(String hostHeaderValue) {
- final String host;
- // check for a port in the proxied host header
- String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
- if (hostSplits.length >= 1 && hostSplits.length <= 2) {
- // zero or one occurrence of ':', this is an IPv4 address
- // strip off the port by reassigning host the 0th split
- host = hostSplits[0];
- } else if (hostSplits.length == 0) {
- // hostHeaderValue passed in was null, no splits
- host = null;
- } else {
- // hostHeaderValue has more than one occurrence of ":", IPv6 address
- host = hostHeaderValue;
- }
- return host;
- }
-
- private String determineProxiedPort(String hostHeaderValue, String portHeaderValue) {
- final String port;
- // check for a port in the proxied host header
- String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
- // determine the proxied port
- final String portFromHostHeader;
- if (hostSplits.length == 2) {
- // if the port is specified in the proxied host header, it will be overridden by the
- // port specified in X-ProxyPort or X-Forwarded-Port
- portFromHostHeader = hostSplits[1];
- } else {
- portFromHostHeader = null;
- }
- if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) {
- logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header"));
- }
- port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null);
- return port;
- }
-
/**
* Edit the response headers to indicating no caching.
*
@@ -403,21 +359,7 @@ public abstract class ApplicationResource {
* @return the value for the first key found
*/
private String getFirstHeaderValue(final String... keys) {
- if (keys == null) {
- return null;
- }
-
- for (final String key : keys) {
- final String value = httpServletRequest.getHeader(key);
-
- // if we found an entry for this key, return the value
- if (value != null) {
- return value;
- }
- }
-
- // unable to find any matching keys
- return null;
+ return WebUtils.getFirstHeaderValue(httpServletRequest, keys);
}
/**
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java
index cbe3030..84573b0 100644
--- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java
@@ -16,31 +16,6 @@
*/
package org.apache.nifi.web.api;
-import static org.apache.nifi.web.api.ApplicationResource.PROXY_HOST_HTTP_HEADER;
-import static org.apache.nifi.web.api.ApplicationResource.PROXY_PORT_HTTP_HEADER;
-import static org.apache.nifi.web.api.ApplicationResource.PROXY_SCHEME_HTTP_HEADER;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-import java.io.InputStream;
-import java.lang.reflect.Field;
-import java.net.URI;
-import java.net.URISyntaxException;
-import java.net.URL;
-import javax.servlet.ServletContext;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import javax.ws.rs.core.Response;
-import javax.ws.rs.core.StreamingOutput;
-import javax.ws.rs.core.UriBuilder;
-import javax.ws.rs.core.UriInfo;
import org.apache.nifi.authorization.AuthorizableLookup;
import org.apache.nifi.authorization.resource.ResourceType;
import org.apache.nifi.remote.HttpRemoteSiteListener;
@@ -58,6 +33,32 @@ import org.apache.nifi.web.api.entity.TransactionResultEntity;
import org.junit.BeforeClass;
import org.junit.Test;
+import javax.servlet.ServletContext;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.ws.rs.core.Response;
+import javax.ws.rs.core.StreamingOutput;
+import javax.ws.rs.core.UriBuilder;
+import javax.ws.rs.core.UriInfo;
+import java.io.InputStream;
+import java.lang.reflect.Field;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.URL;
+
+import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER;
+import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER;
+import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
public class TestDataTransferResource {
@BeforeClass
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java
index b85f176..6ac659a 100644
--- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java
@@ -17,17 +17,19 @@
package org.apache.nifi.web.security.saml.impl;
import org.apache.nifi.web.security.saml.NiFiSAMLContextProvider;
+import org.apache.nifi.web.security.saml.impl.http.HttpServletRequestWithParameters;
+import org.apache.nifi.web.security.saml.impl.http.ProxyAwareHttpServletRequestWrapper;
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
import org.opensaml.ws.transport.http.HttpServletResponseAdapter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.springframework.security.saml.context.SAMLContextProviderImpl;
import org.springframework.security.saml.context.SAMLMessageContext;
import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
/**
@@ -35,6 +37,8 @@ import java.util.Map;
*/
public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl implements NiFiSAMLContextProvider {
+ private static final Logger LOGGER = LoggerFactory.getLogger(NiFiSAMLContextProviderImpl.class);
+
@Override
public SAMLMessageContext getLocalEntity(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters)
throws MetadataProviderException {
@@ -60,55 +64,20 @@ public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl impleme
}
protected void populateGenericContext(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters, SAMLMessageContext context) {
- HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(request, parameters);
- HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, request.isSecure());
+ HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
+ LOGGER.debug("Populating SAMLContext - request wrapper URL is [{}]", requestWrapper.getRequestURL().toString());
+
+ HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(requestWrapper, parameters);
+ HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, requestWrapper.isSecure());
// Store attribute which cannot be located from InTransport directly
- request.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, request.getContextPath());
+ requestWrapper.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, requestWrapper.getContextPath());
context.setMetadataProvider(metadata);
context.setInboundMessageTransport(inTransport);
context.setOutboundMessageTransport(outTransport);
- context.setMessageStorage(storageFactory.getMessageStorage(request));
+ context.setMessageStorage(storageFactory.getMessageStorage(requestWrapper));
}
- /**
- * Extends the HttpServletRequestAdapter with a provided set of parameters.
- */
- private static class HttpServletRequestWithParameters extends HttpServletRequestAdapter {
-
- private final Map<String, String> providedParameters;
-
- public HttpServletRequestWithParameters(HttpServletRequest request, Map<String,String> providedParameters) {
- super(request);
- this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters;
- }
-
- @Override
- public String getParameterValue(String name) {
- String value = super.getParameterValue(name);
- if (value == null) {
- value = providedParameters.get(name);
- }
- return value;
- }
-
- @Override
- public List<String> getParameterValues(String name) {
- List<String> combinedValues = new ArrayList<>();
-
- List<String> initialValues = super.getParameterValues(name);
- if (initialValues != null) {
- combinedValues.addAll(initialValues);
- }
-
- String providedValue = providedParameters.get(name);
- if (providedValue != null) {
- combinedValues.add(providedValue);
- }
-
- return combinedValues;
- }
- }
}
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/HttpServletRequestWithParameters.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/HttpServletRequestWithParameters.java
new file mode 100644
index 0000000..0c91597
--- /dev/null
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/HttpServletRequestWithParameters.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.web.security.saml.impl.http;
+
+import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Extends the HttpServletRequestAdapter with a provided set of parameters.
+ */
+public class HttpServletRequestWithParameters extends HttpServletRequestAdapter {
+
+ private final Map<String, String> providedParameters;
+
+ public HttpServletRequestWithParameters(final HttpServletRequest request, final Map<String, String> providedParameters) {
+ super(request);
+ this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters;
+ }
+
+ @Override
+ public String getParameterValue(final String name) {
+ String value = super.getParameterValue(name);
+ if (value == null) {
+ value = providedParameters.get(name);
+ }
+ return value;
+ }
+
+ @Override
+ public List<String> getParameterValues(final String name) {
+ List<String> combinedValues = new ArrayList<>();
+
+ List<String> initialValues = super.getParameterValues(name);
+ if (initialValues != null) {
+ combinedValues.addAll(initialValues);
+ }
+
+ String providedValue = providedParameters.get(name);
+ if (providedValue != null) {
+ combinedValues.add(providedValue);
+ }
+
+ return combinedValues;
+ }
+}
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/ProxyAwareHttpServletRequestWrapper.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/ProxyAwareHttpServletRequestWrapper.java
new file mode 100644
index 0000000..09121ec
--- /dev/null
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/ProxyAwareHttpServletRequestWrapper.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.web.security.saml.impl.http;
+
+import org.apache.nifi.web.util.WebUtils;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
+
+/**
+ * Extension of HttpServletRequestWrapper that respects proxied/forwarded header values for scheme, host, port, and context path.
+ * <p>
+ * If NiFi generates a SAML request using proxied values so that the IDP redirects back through the proxy, then this is needed
+ * so that when Open SAML checks the Destination in the SAML response, it will match with the values here.
+ * <p>
+ * This class is based on SAMLContextProviderLB from spring-security-saml.
+ */
+public class ProxyAwareHttpServletRequestWrapper extends HttpServletRequestWrapper {
+
+ private final String scheme;
+ private final String serverName;
+ private final int serverPort;
+ private final String proxyContextPath;
+ private final String contextPath;
+
+ public ProxyAwareHttpServletRequestWrapper(final HttpServletRequest request) {
+ super(request);
+ this.scheme = WebUtils.determineProxiedScheme(request);
+ this.serverName = WebUtils.determineProxiedHost(request);
+ this.serverPort = Integer.valueOf(WebUtils.determineProxiedPort(request));
+
+ final String tempProxyContextPath = WebUtils.normalizeContextPath(WebUtils.determineContextPath(request));
+ this.proxyContextPath = tempProxyContextPath.equals("/") ? "" : tempProxyContextPath;
+
+ this.contextPath = request.getContextPath();
+ }
+
+ @Override
+ public String getContextPath() {
+ return contextPath;
+ }
+
+ @Override
+ public String getScheme() {
+ return scheme;
+ }
+
+ @Override
+ public String getServerName() {
+ return serverName;
+ }
+
+ @Override
+ public int getServerPort() {
+ return serverPort;
+ }
+
+ @Override
+ public String getRequestURI() {
+ StringBuilder sb = new StringBuilder(contextPath);
+ sb.append(getServletPath());
+ return sb.toString();
+ }
+
+ @Override
+ public StringBuffer getRequestURL() {
+ StringBuffer sb = new StringBuffer();
+ sb.append(scheme).append("://").append(serverName);
+ sb.append(":").append(serverPort);
+ sb.append(proxyContextPath);
+ sb.append(contextPath);
+ sb.append(getServletPath());
+ if (getPathInfo() != null) sb.append(getPathInfo());
+ return sb;
+ }
+
+ @Override
+ public boolean isSecure() {
+ return "https".equalsIgnoreCase(scheme);
+ }
+
+}
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestHttpServletRequestWithParameters.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestHttpServletRequestWithParameters.java
new file mode 100644
index 0000000..694dea2
--- /dev/null
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestHttpServletRequestWithParameters.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.web.security.saml.impl.http;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
+
+import javax.servlet.http.HttpServletRequest;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.when;
+import static org.junit.Assert.assertEquals;
+
+@RunWith(MockitoJUnitRunner.class)
+public class TestHttpServletRequestWithParameters {
+
+ @Mock
+ private HttpServletRequest request;
+
+ @Test
+ public void testGetParameterValueWhenNoExtraParameters() {
+ final String paramName = "fooParam";
+ final String paramValue = "fooValue";
+ when(request.getParameter(eq(paramName))).thenReturn(paramValue);
+
+ final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap());
+ final String result = requestAdapter.getParameterValue(paramName);
+ assertEquals(paramValue, result);
+ }
+
+ @Test
+ public void testGetParameterValueWhenExtraParameters() {
+ final String paramName = "fooParam";
+ final String paramValue = "fooValue";
+
+ final Map<String,String> extraParams = new HashMap<>();
+ extraParams.put(paramName, paramValue);
+
+ when(request.getParameter(any())).thenReturn(null);
+
+ final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams);
+ final String result = requestAdapter.getParameterValue(paramName);
+ assertEquals(paramValue, result);
+ }
+
+ @Test
+ public void testGetParameterValuesWhenNoExtraParameters() {
+ final String paramName = "fooParam";
+ final String paramValue = "fooValue";
+ when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue});
+
+ final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap());
+ final List<String> results = requestAdapter.getParameterValues(paramName);
+ assertEquals(1, results.size());
+ assertEquals(paramValue, results.get(0));
+ }
+
+ @Test
+ public void testGetParameterValuesWhenExtraParameters() {
+ final String paramName = "fooParam";
+ final String paramValue1 = "fooValue1";
+ when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue1});
+
+ final String paramValue2 = "fooValue2";
+ final Map<String,String> extraParams = new HashMap<>();
+ extraParams.put(paramName, paramValue2);
+
+ final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams);
+ final List<String> results = requestAdapter.getParameterValues(paramName);
+ assertEquals(2, results.size());
+ assertTrue(results.contains(paramValue1));
+ assertTrue(results.contains(paramValue2));
+ }
+}
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestProxyAwareHttpServletRequestWrapper.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestProxyAwareHttpServletRequestWrapper.java
new file mode 100644
index 0000000..e303c7e
--- /dev/null
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestProxyAwareHttpServletRequestWrapper.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.web.security.saml.impl.http;
+
+import org.apache.nifi.web.util.WebUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.class)
+public class TestProxyAwareHttpServletRequestWrapper {
+
+ @Mock
+ private HttpServletRequest request;
+
+ @Test
+ public void testWhenNotProxied() {
+ when(request.getScheme()).thenReturn("https");
+ when(request.getServerName()).thenReturn("localhost");
+ when(request.getServerPort()).thenReturn(8443);
+ when(request.getContextPath()).thenReturn("/nifi-api");
+ when(request.getServletPath()).thenReturn("/access/saml/metadata");
+ when(request.getHeader(any())).thenReturn(null);
+
+ final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
+ assertEquals("https://localhost:8443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
+ }
+
+ @Test
+ public void testWhenProxied() {
+ when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https");
+ when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host");
+ when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443");
+ when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/proxy-context");
+ when(request.getContextPath()).thenReturn("/nifi-api");
+ when(request.getServletPath()).thenReturn("/access/saml/metadata");
+
+ final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
+ assertEquals("https://proxy-host:443/proxy-context/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
+ }
+
+ @Test
+ public void testWhenProxiedWithEmptyProxyContextPath() {
+ when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https");
+ when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host");
+ when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443");
+ when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/");
+ when(request.getContextPath()).thenReturn("/nifi-api");
+ when(request.getServletPath()).thenReturn("/access/saml/metadata");
+
+ final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
+ assertEquals("https://proxy-host:443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
+ }
+}