You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nifi.apache.org by pv...@apache.org on 2022/01/19 12:29:30 UTC
[nifi] branch main updated: NIFI-9065 Add support for OAuth2AccessTokenProvider in InvokeHTTP
This is an automated email from the ASF dual-hosted git repository.
pvillard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git
The following commit(s) were added to refs/heads/main by this push:
new aa61494 NIFI-9065 Add support for OAuth2AccessTokenProvider in InvokeHTTP
aa61494 is described below
commit aa61494fc3a68b4806784f67ad837ee821d26da4
Author: Tamas Palfy <ta...@gmail.com>
AuthorDate: Tue Jul 27 15:19:53 2021 +0200
NIFI-9065 Add support for OAuth2AccessTokenProvider in InvokeHTTP
Signed-off-by: Pierre Villard <pi...@gmail.com>
This closes #5319.
---
.../nifi-standard-processors/pom.xml | 5 +
.../nifi/processors/standard/InvokeHTTP.java | 52 ++-
.../nifi/processors/standard/InvokeHTTPTest.java | 95 +++++
.../java/org/apache/nifi/oauth2/AccessToken.java | 63 +++-
.../nifi/oauth2/OAuth2AccessTokenProvider.java | 30 ++
.../apache/nifi/oauth2/OAuth2TokenProvider.java | 3 +
.../nifi/oauth2/OAuth2TokenProviderImpl.java | 3 +
.../oauth2/StandardOauth2AccessTokenProvider.java | 300 +++++++++++++++
.../org.apache.nifi.controller.ControllerService | 2 +
.../nifi/oauth2/OAuth2TokenProviderImplTest.java | 4 +-
.../StandardOauth2AccessTokenProviderTest.java | 411 +++++++++++++++++++++
11 files changed, 944 insertions(+), 24 deletions(-)
diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml
index 84fffaf..8f6ca2d 100644
--- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml
+++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml
@@ -375,6 +375,11 @@
<version>1.16.0-SNAPSHOT</version>
</dependency>
<dependency>
+ <groupId>org.apache.nifi</groupId>
+ <artifactId>nifi-oauth2-provider-api</artifactId>
+ <version>1.16.0-SNAPSHOT</version>
+ </dependency>
+ <dependency>
<groupId>org.apache.sshd</groupId>
<artifactId>sshd-core</artifactId>
<scope>test</scope>
diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java
index f9977bc..fbfea08 100644
--- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java
+++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java
@@ -42,6 +42,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@@ -90,6 +91,7 @@ import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.flowfile.attributes.CoreAttributes;
import org.apache.nifi.logging.ComponentLog;
+import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.processor.AbstractProcessor;
import org.apache.nifi.processor.DataUnit;
import org.apache.nifi.processor.ProcessContext;
@@ -494,6 +496,13 @@ public class InvokeHTTP extends AbstractProcessor {
.allowableValues("True", "False")
.build();
+ public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder()
+ .name("oauth2-access-token-provider")
+ .displayName("OAuth2 Access Token provider")
+ .identifiesControllerService(OAuth2AccessTokenProvider.class)
+ .required(false)
+ .build();
+
public static final PropertyDescriptor FLOW_FILE_NAMING_STRATEGY = new PropertyDescriptor.Builder()
.name("flow-file-naming-strategy")
.description("Determines the strategy used for setting the filename attribute of the FlowFile.")
@@ -527,6 +536,7 @@ public class InvokeHTTP extends AbstractProcessor {
PROP_USERAGENT,
PROP_BASIC_AUTH_USERNAME,
PROP_BASIC_AUTH_PASSWORD,
+ OAUTH2_ACCESS_TOKEN_PROVIDER,
PROXY_CONFIGURATION_SERVICE,
PROP_PROXY_HOST,
PROP_PROXY_PORT,
@@ -595,6 +605,8 @@ public class InvokeHTTP extends AbstractProcessor {
private volatile boolean useChunked = false;
+ private volatile Optional<OAuth2AccessTokenProvider> oauth2AccessTokenProviderOptional;
+
private final AtomicReference<OkHttpClient> okHttpClientAtomicReference = new AtomicReference<>();
@Override
@@ -728,6 +740,19 @@ public class InvokeHTTP extends AbstractProcessor {
.build());
}
+ boolean usingUserNamePasswordAuthorization = validationContext.getProperty(PROP_BASIC_AUTH_USERNAME).isSet()
+ || validationContext.getProperty(PROP_BASIC_AUTH_PASSWORD).isSet();
+
+ boolean usingOAuth2Authorization = validationContext.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet();
+
+ if (usingUserNamePasswordAuthorization && usingOAuth2Authorization) {
+ results.add(new ValidationResult.Builder()
+ .subject("Authorization properties")
+ .valid(false)
+ .explanation("OAuth2 Authorization cannot be configured together with Username and Password properties")
+ .build());
+ }
+
return results;
}
@@ -806,6 +831,19 @@ public class InvokeHTTP extends AbstractProcessor {
okHttpClientAtomicReference.set(okHttpClientBuilder.build());
}
+ @OnScheduled
+ public void initOauth2AccessTokenProvider(final ProcessContext context) {
+ if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) {
+ OAuth2AccessTokenProvider oauth2AccessTokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class);
+
+ oauth2AccessTokenProvider.getAccessDetails();
+
+ oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider);
+ } else {
+ oauth2AccessTokenProviderOptional = Optional.empty();
+ }
+ }
+
private void setAuthenticator(OkHttpClient.Builder okHttpClientBuilder, ProcessContext context) {
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
@@ -1034,11 +1072,17 @@ public class InvokeHTTP extends AbstractProcessor {
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
// If the username/password properties are set then check if digest auth is being used
- if (!authUser.isEmpty() && "false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
- final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue());
+ if ("false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
+ if (!authUser.isEmpty()) {
+ final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue());
- String credential = Credentials.basic(authUser, authPass);
- requestBuilder.header("Authorization", credential);
+ String credential = Credentials.basic(authUser, authPass);
+ requestBuilder.header("Authorization", credential);
+ } else {
+ oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider ->
+ requestBuilder.addHeader("Authorization", "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken())
+ );
+ }
}
// set the request method
diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java
index 469ba88..ed5d0b8 100644
--- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java
+++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java
@@ -21,6 +21,7 @@ import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.flowfile.attributes.CoreAttributes;
+import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.standard.http.FlowFileNamingStrategy;
import org.apache.nifi.reporting.InitializationException;
@@ -34,6 +35,7 @@ import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.apache.nifi.web.util.ssl.SslContextUtils;
+import org.mockito.Answers;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
@@ -774,6 +776,99 @@ public class InvokeHTTPTest {
);
}
+ @Test
+ public void testValidWhenOAuth2Set() throws Exception {
+ // GIVEN
+ String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
+
+ OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
+ when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
+
+ runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
+ runner.enableControllerService(oauth2AccessTokenProvider);
+
+ setUrlProperty();
+
+ // WHEN
+ runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
+
+ // THEN
+ runner.assertValid();
+ }
+
+ @Test
+ public void testInvalidWhenOAuth2AndUserNameSet() throws Exception {
+ // GIVEN
+ String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
+
+ OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
+ when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
+
+ runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
+ runner.enableControllerService(oauth2AccessTokenProvider);
+
+ setUrlProperty();
+
+ // WHEN
+ runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
+ runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_USERNAME, "userName");
+
+ // THEN
+ runner.assertNotValid();
+ }
+
+ @Test
+ public void testInvalidWhenOAuth2AndPasswordSet() throws Exception {
+ // GIVEN
+ String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
+
+ OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
+ when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
+
+ runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
+ runner.enableControllerService(oauth2AccessTokenProvider);
+
+ setUrlProperty();
+
+ // WHEN
+ runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
+ runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_PASSWORD, "password");
+
+ // THEN
+ runner.assertNotValid();
+ }
+
+ @Test
+ public void testOAuth2AuthorizationHeader() throws Exception {
+ // GIVEN
+ String accessToken = "access_token";
+
+ String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
+
+ OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
+ when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
+ when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken);
+
+ runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
+ runner.enableControllerService(oauth2AccessTokenProvider);
+
+ setUrlProperty();
+
+ mockWebServer.enqueue(new MockResponse());
+
+ // WHEN
+ runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
+ runner.enqueue("unimportant");
+ runner.run();
+
+ // THEN
+ RecordedRequest recordedRequest = mockWebServer.takeRequest();
+
+ String actualAuthorizationHeader = recordedRequest.getHeader("Authorization");
+ assertEquals("Bearer " + accessToken, actualAuthorizationHeader);
+
+ }
+
private void setUrlProperty() {
runner.setProperty(InvokeHTTP.PROP_URL, getMockWebServerUrl());
}
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java
index 2e261dd..622c9b0 100644
--- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java
@@ -17,53 +17,80 @@
package org.apache.nifi.oauth2;
+import java.time.Duration;
+import java.time.Instant;
+
public class AccessToken {
private String accessToken;
private String refreshToken;
private String tokenType;
- private Integer expires;
- private String scope;
+ private long expiresIn;
+ private String scopes;
+
+ private final Instant fetchTime;
- private Long fetchTime;
+ public static final int EXPIRY_MARGIN = 5000;
- public AccessToken(String accessToken,
- String refreshToken,
- String tokenType,
- Integer expires,
- String scope) {
+ public AccessToken() {
+ this.fetchTime = Instant.now();
+ }
+
+ public AccessToken(String accessToken, String refreshToken, String tokenType, long expiresIn, String scopes) {
+ this();
this.accessToken = accessToken;
- this.tokenType = tokenType;
this.refreshToken = refreshToken;
- this.expires = expires;
- this.scope = scope;
- this.fetchTime = System.currentTimeMillis();
+ this.tokenType = tokenType;
+ this.expiresIn = expiresIn;
+ this.scopes = scopes;
}
public String getAccessToken() {
return accessToken;
}
+ public void setAccessToken(String accessToken) {
+ this.accessToken = accessToken;
+ }
+
public String getRefreshToken() {
return refreshToken;
}
+ public void setRefreshToken(String refreshToken) {
+ this.refreshToken = refreshToken;
+ }
+
public String getTokenType() {
return tokenType;
}
- public Integer getExpires() {
- return expires;
+ public void setTokenType(String tokenType) {
+ this.tokenType = tokenType;
+ }
+
+ public long getExpiresIn() {
+ return expiresIn;
}
- public String getScope() {
- return scope;
+ public void setExpiresIn(long expiresIn) {
+ this.expiresIn = expiresIn;
}
- public Long getFetchTime() {
+ public String getScopes() {
+ return scopes;
+ }
+
+ public void setScopes(String scopes) {
+ this.scopes = scopes;
+ }
+
+ public Instant getFetchTime() {
return fetchTime;
}
public boolean isExpired() {
- return System.currentTimeMillis() >= ( fetchTime + (expires * 1000) );
+ boolean expired = Duration.between(Instant.now(), fetchTime.plusSeconds(expiresIn - EXPIRY_MARGIN)).isNegative();
+
+ return expired;
}
}
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2AccessTokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2AccessTokenProvider.java
new file mode 100644
index 0000000..9cfc757
--- /dev/null
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2AccessTokenProvider.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.oauth2;
+
+import org.apache.nifi.controller.ControllerService;
+
+/**
+ * Controller service that provides OAuth2 access details
+ */
+public interface OAuth2AccessTokenProvider extends ControllerService {
+
+ /**
+ * @return A valid access token (refreshed automatically if needed) and additional metadata (provided by the OAuth2 access server)
+ */
+ AccessToken getAccessDetails();
+}
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java
index 964ae43..707ff6a 100644
--- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java
@@ -29,7 +29,10 @@ import java.util.List;
/**
* Interface for defining a credential-providing controller service for oauth2 processes.
+ *
+ * @deprecated use {@link OAuth2AccessTokenProvider} instead
*/
+@Deprecated
public interface OAuth2TokenProvider extends ControllerService {
PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
.name("oauth2-ssl-context")
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java
index 056c4f2..cc06309 100644
--- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java
@@ -30,6 +30,7 @@ import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
+import org.apache.nifi.annotation.documentation.DeprecationNotice;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.annotation.lifecycle.OnEnabled;
import org.apache.nifi.components.PropertyDescriptor;
@@ -39,6 +40,8 @@ import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.ssl.SSLContextService;
import org.apache.nifi.util.StringUtils;
+@Deprecated
+@DeprecationNotice(alternatives = {StandardOauth2AccessTokenProvider.class})
@Tags({"oauth2", "provider", "authorization" })
@CapabilityDescription("This controller service provides a way of working with access and refresh tokens via the " +
"password and client_credential grant flows in the OAuth2 specification. It is meant to provide a way for components " +
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
new file mode 100644
index 0000000..6c1b05e
--- /dev/null
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.oauth2;
+
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.PropertyNamingStrategies;
+import okhttp3.FormBody;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.Response;
+import org.apache.nifi.annotation.documentation.CapabilityDescription;
+import org.apache.nifi.annotation.documentation.Tags;
+import org.apache.nifi.annotation.lifecycle.OnDisabled;
+import org.apache.nifi.annotation.lifecycle.OnEnabled;
+import org.apache.nifi.components.AllowableValue;
+import org.apache.nifi.components.PropertyDescriptor;
+import org.apache.nifi.components.ValidationContext;
+import org.apache.nifi.components.ValidationResult;
+import org.apache.nifi.components.Validator;
+import org.apache.nifi.controller.AbstractControllerService;
+import org.apache.nifi.controller.ConfigurationContext;
+import org.apache.nifi.expression.ExpressionLanguageScope;
+import org.apache.nifi.processor.exception.ProcessException;
+import org.apache.nifi.processor.util.StandardValidators;
+import org.apache.nifi.ssl.SSLContextService;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.X509TrustManager;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+@Tags({"oauth2", "provider", "authorization", "access token", "http"})
+@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
+ " Uses Resource Owner Password Credentials Grant.")
+public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider {
+ public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder()
+ .name("authorization-server-url")
+ .displayName("Authorization Server URL")
+ .description("The URL of the authorization server that issues access tokens.")
+ .required(true)
+ .addValidator(StandardValidators.URL_VALIDATOR)
+ .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
+ .build();
+
+ public static AllowableValue RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE = new AllowableValue(
+ "password",
+ "User Password",
+ "Resource Owner Password Credentials Grant. Used to access resources available to users. Requires username and password and usually Client ID and Client Secret"
+ );
+ public static AllowableValue CLIENT_CREDENTIALS_GRANT_TYPE = new AllowableValue(
+ "client_credentials",
+ "Client Credentials",
+ "Client Credentials Grant. Used to access resources available to clients. Requires Client ID and Client Secret"
+ );
+
+ public static final PropertyDescriptor GRANT_TYPE = new PropertyDescriptor.Builder()
+ .name("grant-type")
+ .displayName("Grant Type")
+ .description("The OAuth2 Grant Type to be used when acquiring an access token.")
+ .required(true)
+ .allowableValues(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE, CLIENT_CREDENTIALS_GRANT_TYPE)
+ .defaultValue(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())
+ .build();
+
+ public static final PropertyDescriptor USERNAME = new PropertyDescriptor.Builder()
+ .name("service-user-name")
+ .displayName("Username")
+ .description("Username on the service that is being accessed.")
+ .dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
+ .required(true)
+ .addValidator(StandardValidators.NON_BLANK_VALIDATOR)
+ .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
+ .build();
+
+ public static final PropertyDescriptor PASSWORD = new PropertyDescriptor.Builder()
+ .name("service-password")
+ .displayName("Password")
+ .description("Password for the username on the service that is being accessed.")
+ .dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
+ .required(true)
+ .sensitive(true)
+ .addValidator(StandardValidators.NON_BLANK_VALIDATOR)
+ .build();
+
+ public static final PropertyDescriptor CLIENT_ID = new PropertyDescriptor.Builder()
+ .name("client-id")
+ .displayName("Client ID")
+ .required(false)
+ .addValidator(StandardValidators.NON_BLANK_VALIDATOR)
+ .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
+ .build();
+
+ public static final PropertyDescriptor CLIENT_SECRET = new PropertyDescriptor.Builder()
+ .name("client-secret")
+ .displayName("Client secret")
+ .dependsOn(CLIENT_ID)
+ .required(true)
+ .sensitive(true)
+ .addValidator(StandardValidators.NON_BLANK_VALIDATOR)
+ .build();
+
+ public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
+ .name("ssl-context-service")
+ .displayName("SSL Context Servuce")
+ .addValidator(Validator.VALID)
+ .identifiesControllerService(SSLContextService.class)
+ .required(false)
+ .build();
+
+ private static final List<PropertyDescriptor> PROPERTIES = Collections.unmodifiableList(Arrays.asList(
+ AUTHORIZATION_SERVER_URL,
+ GRANT_TYPE,
+ USERNAME,
+ PASSWORD,
+ CLIENT_ID,
+ CLIENT_SECRET,
+ SSL_CONTEXT
+ ));
+
+ public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper()
+ .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
+ .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
+
+ private volatile String authorizationServerUrl;
+ private volatile OkHttpClient httpClient;
+
+ private volatile String grantType;
+ private volatile String username;
+ private volatile String password;
+ private volatile String clientId;
+ private volatile String clientSecret;
+
+ private volatile AccessToken accessDetails;
+
+ @Override
+ public List<PropertyDescriptor> getSupportedPropertyDescriptors() {
+ return PROPERTIES;
+ }
+
+ @OnEnabled
+ public void onEnabled(ConfigurationContext context) {
+ authorizationServerUrl = context.getProperty(AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue();
+
+ httpClient = createHttpClient(context);
+
+ grantType = context.getProperty(GRANT_TYPE).getValue();
+ username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue();
+ password = context.getProperty(PASSWORD).getValue();
+ clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
+ clientSecret = context.getProperty(CLIENT_SECRET).getValue();
+ }
+
+ @OnDisabled
+ public void onDisabled() {
+ accessDetails = null;
+ }
+
+ @Override
+ protected Collection<ValidationResult> customValidate(ValidationContext validationContext) {
+ final List<ValidationResult> validationResults = new ArrayList<>(super.customValidate(validationContext));
+
+ if (
+ validationContext.getProperty(GRANT_TYPE).getValue().equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())
+ && !validationContext.getProperty(CLIENT_ID).isSet()
+ ) {
+ validationResults.add(new ValidationResult.Builder().subject(CLIENT_ID.getDisplayName())
+ .valid(false)
+ .explanation(String.format(
+ "When '%s' is set to '%s', '%s' is required",
+ GRANT_TYPE.getDisplayName(),
+ CLIENT_CREDENTIALS_GRANT_TYPE.getDisplayName(),
+ CLIENT_ID.getDisplayName())
+ )
+ .build());
+ }
+
+ return validationResults;
+ }
+
+ protected OkHttpClient createHttpClient(ConfigurationContext context) {
+ OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder();
+
+ SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class);
+ if (sslService != null) {
+ final X509TrustManager trustManager = sslService.createTrustManager();
+ SSLContext sslContext = sslService.createContext();
+ clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustManager);
+ }
+
+ return clientBuilder.build();
+ }
+
+ @Override
+ public AccessToken getAccessDetails() {
+ if (this.accessDetails == null) {
+ acquireAccessDetails();
+ } else if (this.accessDetails.isExpired()) {
+ if (this.accessDetails.getRefreshToken() == null) {
+ acquireAccessDetails();
+ } else {
+ try {
+ refreshAccessDetails();
+ } catch (Exception e) {
+ getLogger().info("Couldn't refresh access token", e);
+ acquireAccessDetails();
+ }
+ }
+ }
+
+ return accessDetails;
+ }
+
+ private void acquireAccessDetails() {
+ getLogger().debug("Getting a new access token");
+
+ FormBody.Builder acquireTokenBuilder = new FormBody.Builder();
+
+ if (grantType.equals(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())) {
+ acquireTokenBuilder.add("grant_type", "password")
+ .add("username", username)
+ .add("password", password);
+ } else if (grantType.equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())) {
+ acquireTokenBuilder.add("grant_type", "client_credentials");
+ }
+
+ if (clientId != null) {
+ acquireTokenBuilder.add("client_id", clientId);
+ acquireTokenBuilder.add("client_secret", clientSecret);
+ }
+
+ RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
+
+ Request acquireTokenRequest = new Request.Builder()
+ .url(authorizationServerUrl)
+ .post(acquireTokenRequestBody)
+ .build();
+
+ this.accessDetails = getAccessDetails(acquireTokenRequest);
+ }
+
+ private void refreshAccessDetails() {
+ getLogger().debug("Refreshing access token");
+
+ FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
+ .add("grant_type", "refresh_token")
+ .add("refresh_token", this.accessDetails.getRefreshToken());
+
+ if (clientId != null) {
+ refreshTokenBuilder.add("client_id", clientId);
+ refreshTokenBuilder.add("client_secret", clientSecret);
+ }
+
+ RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
+
+ Request refreshRequest = new Request.Builder()
+ .url(authorizationServerUrl)
+ .post(refreshTokenRequestBody)
+ .build();
+
+ this.accessDetails = getAccessDetails(refreshRequest);
+ }
+
+ private AccessToken getAccessDetails(Request newRequest) {
+ try {
+ Response response = httpClient.newCall(newRequest).execute();
+ String responseBody = response.body().string();
+ if (response.code() != 200) {
+ getLogger().error(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", response.code(), responseBody));
+ throw new ProcessException(String.format("OAuth2 access token request failed [HTTP %d]", response.code()));
+ }
+
+ AccessToken accessDetails = ACCESS_DETAILS_MAPPER.readValue(responseBody, AccessToken.class);
+
+ return accessDetails;
+ } catch (IOException e) {
+ throw new UncheckedIOException("OAuth2 access token request failed", e);
+ }
+ }
+}
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService
index 75e29d0..b1de8ed 100644
--- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService
@@ -13,3 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
org.apache.nifi.oauth2.OAuth2TokenProviderImpl
+org.apache.nifi.oauth2.StandardOauth2AccessTokenProvider
+
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java
index 25f0449..b8dc62e 100644
--- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java
@@ -107,7 +107,7 @@ public class OAuth2TokenProviderImplTest {
private void assertAccessTokenFound(final AccessToken accessToken) {
assertNotNull(accessToken);
assertEquals("access token", accessToken.getAccessToken());
- assertEquals(300, accessToken.getExpires().intValue());
+ assertEquals(5300, accessToken.getExpiresIn());
assertEquals("BEARER", accessToken.getTokenType());
assertFalse(accessToken.isExpired());
}
@@ -117,7 +117,7 @@ public class OAuth2TokenProviderImplTest {
token.put("access_token", "access token");
token.put("refresh_token", "refresh token");
token.put("token_type", "BEARER");
- token.put("expires_in", 300);
+ token.put("expires_in", 5300);
token.put("scope", "test scope");
final String accessToken = new ObjectMapper().writeValueAsString(token);
mockWebServer.enqueue(new MockResponse().setResponseCode(200).addHeader("Content-Type", "application/json").setBody(accessToken));
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
new file mode 100644
index 0000000..54de7a0
--- /dev/null
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
@@ -0,0 +1,411 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.oauth2;
+
+import okhttp3.MediaType;
+import okhttp3.OkHttpClient;
+import okhttp3.Protocol;
+import okhttp3.Request;
+import okhttp3.Response;
+import okhttp3.ResponseBody;
+import org.apache.nifi.controller.ConfigurationContext;
+import org.apache.nifi.logging.ComponentLog;
+import org.apache.nifi.processor.Processor;
+import org.apache.nifi.processor.exception.ProcessException;
+import org.apache.nifi.util.NoOpProcessor;
+import org.apache.nifi.util.TestRunner;
+import org.apache.nifi.util.TestRunners;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Answers;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class StandardOauth2AccessTokenProviderTest {
+ private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl";
+ private static final String USERNAME = "username";
+ private static final String PASSWORD = "password";
+ private static final String CLIENT_ID = "clientId";
+ private static final String CLIENT_SECRET = "clientSecret";
+
+ private StandardOauth2AccessTokenProvider testSubject;
+
+ @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+ private OkHttpClient mockHttpClient;
+
+ @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+ private ConfigurationContext mockContext;
+
+ @Mock
+ private ComponentLog mockLogger;
+ @Captor
+ private ArgumentCaptor<String> debugCaptor;
+ @Captor
+ private ArgumentCaptor<String> errorCaptor;
+ @Captor
+ private ArgumentCaptor<Throwable> throwableCaptor;
+
+ @Before
+ public void setUp() throws Exception {
+ MockitoAnnotations.initMocks(this);
+
+ testSubject = new StandardOauth2AccessTokenProvider() {
+ @Override
+ protected OkHttpClient createHttpClient(ConfigurationContext context) {
+ return mockHttpClient;
+ }
+
+ @Override
+ protected ComponentLog getLogger() {
+ return mockLogger;
+ }
+ };
+
+ when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue());
+ when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue()).thenReturn(AUTHORIZATION_SERVER_URL);
+ when(mockContext.getProperty(StandardOauth2AccessTokenProvider.USERNAME).evaluateAttributeExpressions().getValue()).thenReturn(USERNAME);
+ when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
+ when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
+ when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
+
+ testSubject.onEnabled(mockContext);
+ }
+
+ @Test
+ public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception {
+ // GIVEN
+ Processor processor = new NoOpProcessor();
+ TestRunner runner = TestRunners.newTestRunner(processor);
+
+ runner.addControllerService("testSubject", testSubject);
+
+ runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
+
+ // WHEN
+ runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
+
+ // THEN
+ runner.assertNotValid(testSubject);
+ }
+
+ @Test
+ public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception {
+ // GIVEN
+ Processor processor = new NoOpProcessor();
+ TestRunner runner = TestRunners.newTestRunner(processor);
+
+ runner.addControllerService("testSubject", testSubject);
+
+ runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
+
+ // WHEN
+ runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
+ runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId");
+ runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret");
+
+ // THEN
+ runner.assertValid(testSubject);
+ }
+
+ @Test
+ public void testAcquireNewToken() throws Exception {
+ String accessTokenValue = "access_token_value";
+
+ // GIVEN
+ Response response = buildResponse(
+ 200,
+ "{ \"access_token\":\"" + accessTokenValue + "\" }"
+ );
+
+ when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
+
+ // WHEN
+ String actual = testSubject.getAccessDetails().getAccessToken();
+
+ // THEN
+ assertEquals(accessTokenValue, actual);
+ }
+
+ @Test
+ public void testRefreshToken() throws Exception {
+ // GIVEN
+ String firstToken = "first_token";
+ String expectedToken = "second_token";
+
+ Response response1 = buildResponse(
+ 200,
+ "{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
+ );
+
+ Response response2 = buildResponse(
+ 200,
+ "{ \"access_token\":\"" + expectedToken + "\" }"
+ );
+
+ when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
+
+ // WHEN
+ testSubject.getAccessDetails();
+ String actualToken = testSubject.getAccessDetails().getAccessToken();
+
+ // THEN
+ assertEquals(expectedToken, actualToken);
+ }
+
+ @Test
+ public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
+ // GIVEN
+ String refreshErrorMessage = "refresh_error";
+ String acquireErrorMessage = "acquire_error";
+
+ AtomicInteger callCounter = new AtomicInteger(0);
+ when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
+ callCounter.incrementAndGet();
+
+ if (callCounter.get() == 1) {
+ return buildSuccessfulInitResponse();
+ } else if (callCounter.get() == 2) {
+ throw new IOException(refreshErrorMessage);
+ } else if (callCounter.get() == 3) {
+ throw new IOException(acquireErrorMessage);
+ }
+
+ throw new IllegalStateException("Test improperly defined mock HTTP responses.");
+ });
+
+ // Get a good accessDetails so we can have a refresh a second time
+ testSubject.getAccessDetails();
+
+ // WHEN
+ UncheckedIOException actualException = assertThrows(
+ UncheckedIOException.class,
+ () -> testSubject.getAccessDetails()
+ );
+
+ // THEN
+ checkLoggedDebugWhenRefreshFails();
+
+ checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
+
+ checkError(new UncheckedIOException("OAuth2 access token request failed", new IOException(acquireErrorMessage)), actualException);
+ }
+
+ @Test
+ public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
+ // GIVEN
+ String refreshErrorMessage = "refresh_error";
+ String expectedToken = "expected_token";
+
+ Response successfulAcquireResponse = buildResponse(
+ 200,
+ "{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
+ );
+
+ AtomicInteger callCounter = new AtomicInteger(0);
+ when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
+ callCounter.incrementAndGet();
+
+ if (callCounter.get() == 1) {
+ return buildSuccessfulInitResponse();
+ } else if (callCounter.get() == 2) {
+ throw new IOException(refreshErrorMessage);
+ } else if (callCounter.get() == 3) {
+ return successfulAcquireResponse;
+ }
+
+ throw new IllegalStateException("Test improperly defined mock HTTP responses.");
+ });
+
+ // Get a good accessDetails so we can have a refresh a second time
+ testSubject.getAccessDetails();
+
+ // WHEN
+ String actualToken = testSubject.getAccessDetails().getAccessToken();
+
+ // THEN
+ checkLoggedDebugWhenRefreshFails();
+
+ checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
+
+ assertEquals(expectedToken, actualToken);
+ }
+
+ @Test
+ public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception {
+ // GIVEN
+ String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }";
+ String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }";
+
+ Response errorRefreshResponse = buildResponse(500, errorRefreshResponseBody);
+ Response errorAcquireResponse = buildResponse(503, errorAcquireResponseBody);
+
+ AtomicInteger callCounter = new AtomicInteger(0);
+ when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
+ callCounter.incrementAndGet();
+
+ if (callCounter.get() == 1) {
+ return buildSuccessfulInitResponse();
+ } else if (callCounter.get() == 2) {
+ return errorRefreshResponse;
+ } else if (callCounter.get() == 3) {
+ return errorAcquireResponse;
+ }
+
+ throw new IllegalStateException("Test improperly defined mock HTTP responses.");
+ });
+
+ List<String> expectedLoggedError = Arrays.asList(
+ String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, errorRefreshResponseBody),
+ String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody)
+ );
+
+ // Get a good accessDetails so we can have a refresh a second time
+ testSubject.getAccessDetails();
+
+ // WHEN
+ ProcessException actualException = assertThrows(
+ ProcessException.class,
+ () -> testSubject.getAccessDetails()
+ );
+
+ // THEN
+ checkLoggedDebugWhenRefreshFails();
+
+ checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
+
+ checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
+
+ checkError(new ProcessException("OAuth2 access token request failed [HTTP 503]"), actualException);
+ }
+
+ @Test
+ public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
+ // GIVEN
+ String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }";
+ String expectedToken = "expected_token";
+
+ Response errorRefreshResponse = buildResponse(500, expectedRefreshErrorResponse);
+ Response successfulAcquireResponse = buildResponse(
+ 200,
+ "{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
+ );
+
+ AtomicInteger callCounter = new AtomicInteger(0);
+ when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
+ callCounter.incrementAndGet();
+
+ if (callCounter.get() == 1) {
+ return buildSuccessfulInitResponse();
+ } else if (callCounter.get() == 2) {
+ return errorRefreshResponse;
+ } else if (callCounter.get() == 3) {
+ return successfulAcquireResponse;
+ }
+
+ throw new IllegalStateException("Test improperly defined mock HTTP responses.");
+ });
+
+ List<String> expectedLoggedError = Arrays.asList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse));
+
+ // Get a good accessDetails so we can have a refresh a second time
+ testSubject.getAccessDetails();
+
+ // WHEN
+ String actualToken = testSubject.getAccessDetails().getAccessToken();
+
+ // THEN
+ checkLoggedDebugWhenRefreshFails();
+
+ checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
+
+ checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
+
+ assertEquals(expectedToken, actualToken);
+ }
+
+ private Response buildSuccessfulInitResponse() {
+ return buildResponse(
+ 200,
+ "{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
+ );
+ }
+
+ private Response buildResponse(int code, String body) {
+ return new Response.Builder()
+ .request(new Request.Builder()
+ .url("http://unimportant_but_required")
+ .build()
+ )
+ .protocol(Protocol.HTTP_2)
+ .message("unimportant_but_required")
+ .code(code)
+ .body(ResponseBody.create(
+ body.getBytes(),
+ MediaType.parse("application/json"))
+ )
+ .build();
+ }
+
+ private void checkLoggedDebugWhenRefreshFails() {
+ verify(mockLogger, times(3)).debug(debugCaptor.capture());
+ List<String> actualDebugMessages = debugCaptor.getAllValues();
+
+ assertEquals(
+ Arrays.asList("Getting a new access token", "Refreshing access token", "Getting a new access token"),
+ actualDebugMessages
+ );
+ }
+
+ private void checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(List<String> expectedLoggedError) {
+ verify(mockLogger, times(expectedLoggedError.size())).error(errorCaptor.capture());
+ List<String> actualLoggedError = errorCaptor.getAllValues();
+
+ assertEquals(expectedLoggedError, actualLoggedError);
+ }
+
+ private void checkLoggedRefreshError(Throwable expectedRefreshError) {
+ verify(mockLogger).info(eq("Couldn't refresh access token"), throwableCaptor.capture());
+ Throwable actualRefreshError = throwableCaptor.getValue();
+
+ checkError(expectedRefreshError, actualRefreshError);
+ }
+
+ private void checkError(Throwable expectedError, Throwable actualError) {
+ assertEquals(expectedError.getClass(), actualError.getClass());
+ assertEquals(expectedError.getMessage(), actualError.getMessage());
+ if (expectedError.getCause() != null || actualError.getCause() != null) {
+ assertEquals(expectedError.getCause().getClass(), actualError.getCause().getClass());
+ assertEquals(expectedError.getCause().getMessage(), actualError.getCause().getMessage());
+ }
+ }
+}