You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nifi.apache.org by ex...@apache.org on 2022/12/30 21:08:53 UTC

[nifi] branch main updated: NIFI-10456 Added Client Authentication Strategy option to OAuth2 Provider

This is an automated email from the ASF dual-hosted git repository.

exceptionfactory pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/main by this push:
     new 4716c8d715 NIFI-10456 Added Client Authentication Strategy option to OAuth2 Provider
4716c8d715 is described below

commit 4716c8d715918352ebd9e7aa897881127aa37c7d
Author: Esa Lindqvist <es...@compile.fi>
AuthorDate: Wed Dec 14 14:19:27 2022 +0200

    NIFI-10456 Added Client Authentication Strategy option to OAuth2 Provider
    
    StandardOauth2AccessTokenProvider has been updated with new property `Client Authentication Strategy` supporting Basic Authentication as recommended in RFC 6749. The changes maintain the current default implementation using Request Body parameters
    
    This closes #6782
    
    Co-authored-by: David Handermann <ex...@apache.org>
    Signed-off-by: David Handermann <ex...@apache.org>
---
 .../nifi/oauth2/ClientAuthenticationStrategy.java  | 45 +++++++++++
 .../oauth2/StandardOauth2AccessTokenProvider.java  | 59 ++++++++++----
 .../StandardOauth2AccessTokenProviderTest.java     | 90 +++++++++++++++++++++-
 3 files changed, 173 insertions(+), 21 deletions(-)

diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/ClientAuthenticationStrategy.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/ClientAuthenticationStrategy.java
new file mode 100644
index 0000000000..66cddc1e07
--- /dev/null
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/ClientAuthenticationStrategy.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.oauth2;
+
+import org.apache.nifi.components.DescribedValue;
+
+public enum ClientAuthenticationStrategy implements DescribedValue {
+    REQUEST_BODY("Send client authentication in request body. RFC 6749 Section 2.3.1 recommends Basic Authentication instead of request body."),
+    BASIC_AUTHENTICATION("Send client authentication using HTTP Basic authentication.");
+
+    private final String description;
+
+    ClientAuthenticationStrategy(final String description) {
+        this.description = description;
+    }
+
+    @Override
+    public String getValue() {
+        return name();
+    }
+
+    @Override
+    public String getDisplayName() {
+        return name();
+    }
+
+    @Override
+    public String getDescription() {
+        return description;
+    }
+}
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
index 22dce3a033..39b74ff88c 100644
--- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
@@ -19,6 +19,7 @@ package org.apache.nifi.oauth2;
 import com.fasterxml.jackson.databind.DeserializationFeature;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.fasterxml.jackson.databind.PropertyNamingStrategies;
+import okhttp3.Credentials;
 import okhttp3.FormBody;
 import okhttp3.OkHttpClient;
 import okhttp3.Request;
@@ -58,7 +59,8 @@ import java.util.concurrent.TimeUnit;
 
 @Tags({"oauth2", "provider", "authorization", "access token", "http"})
 @CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
-    " Uses Resource Owner Password Credentials Grant.")
+    " Can use either Resource Owner Password Credentials Grant or Client Credentials Grant." +
+    " Client authentication can be done with either HTTP Basic authentication or in the request body.")
 public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider, VerifiableControllerService {
     public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder()
         .name("authorization-server-url")
@@ -69,6 +71,15 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
         .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
         .build();
 
+    public static final PropertyDescriptor CLIENT_AUTHENTICATION_STRATEGY = new PropertyDescriptor.Builder()
+        .name("client-authentication-strategy")
+        .displayName("Client Authentication Strategy")
+        .description("Strategy for authenticating the client against the OAuth2 token provider service.")
+        .required(true)
+        .allowableValues(ClientAuthenticationStrategy.class)
+        .defaultValue(ClientAuthenticationStrategy.REQUEST_BODY.getValue())
+        .build();
+
     public static AllowableValue RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE = new AllowableValue(
         "password",
         "User Password",
@@ -136,13 +147,13 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
         .build();
 
     public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder()
-            .name("refresh-window")
-            .displayName("Refresh Window")
-            .description("The service will attempt to refresh tokens expiring within the refresh window, subtracting the configured duration from the token expiration.")
-            .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR)
-            .defaultValue("0 s")
-            .required(true)
-            .build();
+        .name("refresh-window")
+        .displayName("Refresh Window")
+        .description("The service will attempt to refresh tokens expiring within the refresh window, subtracting the configured duration from the token expiration.")
+        .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR)
+        .defaultValue("0 s")
+        .required(true)
+        .build();
 
     public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
         .name("ssl-context-service")
@@ -163,6 +174,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
 
     private static final List<PropertyDescriptor> PROPERTIES = Collections.unmodifiableList(Arrays.asList(
         AUTHORIZATION_SERVER_URL,
+        CLIENT_AUTHENTICATION_STRATEGY,
         GRANT_TYPE,
         USERNAME,
         PASSWORD,
@@ -174,6 +186,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
         HTTP_PROTOCOL_STRATEGY
     ));
 
+    private static final String AUTHORIZATION_HEADER = "Authorization";
+
     public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper()
         .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
         .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
@@ -181,6 +195,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
     private volatile String authorizationServerUrl;
     private volatile OkHttpClient httpClient;
 
+    private volatile ClientAuthenticationStrategy clientAuthenticationStrategy;
     private volatile String grantType;
     private volatile String username;
     private volatile String password;
@@ -202,6 +217,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
 
         httpClient = createHttpClient(context);
 
+        clientAuthenticationStrategy = ClientAuthenticationStrategy.valueOf(context.getProperty(CLIENT_AUTHENTICATION_STRATEGY).getValue());
         grantType = context.getProperty(GRANT_TYPE).getValue();
         username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue();
         password = context.getProperty(PASSWORD).getValue();
@@ -288,7 +304,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
             acquireTokenBuilder.add("grant_type", "client_credentials");
         }
 
-        if (clientId != null) {
+        if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) {
             acquireTokenBuilder.add("client_id", clientId);
             acquireTokenBuilder.add("client_secret", clientSecret);
         }
@@ -298,11 +314,15 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
         }
 
         RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
-
-        Request acquireTokenRequest = new Request.Builder()
+        Request.Builder acquireTokenRequestBuilder = new Request.Builder()
             .url(authorizationServerUrl)
-            .post(acquireTokenRequestBody)
-            .build();
+            .post(acquireTokenRequestBody);
+
+        if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
+            acquireTokenRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
+        }
+
+        Request acquireTokenRequest = acquireTokenRequestBuilder.build();
 
         this.accessDetails = getAccessDetails(acquireTokenRequest);
     }
@@ -314,7 +334,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
             .add("grant_type", "refresh_token")
             .add("refresh_token", this.accessDetails.getRefreshToken());
 
-        if (clientId != null) {
+        if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) {
             refreshTokenBuilder.add("client_id", clientId);
             refreshTokenBuilder.add("client_secret", clientSecret);
         }
@@ -325,10 +345,15 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
 
         RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
 
-        Request refreshRequest = new Request.Builder()
+        Request.Builder refreshRequestBuilder = new Request.Builder()
             .url(authorizationServerUrl)
-            .post(refreshTokenRequestBody)
-            .build();
+            .post(refreshTokenRequestBody);
+
+        if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
+            refreshRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
+        }
+
+        Request refreshRequest = refreshRequestBuilder.build();
 
         this.accessDetails = getAccessDetails(refreshRequest);
     }
diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
index 9aa1d5006d..e26fc6bf3f 100644
--- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
+++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
@@ -22,6 +22,8 @@ import okhttp3.Protocol;
 import okhttp3.Request;
 import okhttp3.Response;
 import okhttp3.ResponseBody;
+import okio.Buffer;
+
 import org.apache.nifi.components.ConfigVerificationResult;
 import org.apache.nifi.components.PropertyDescriptor;
 import org.apache.nifi.controller.ConfigurationContext;
@@ -46,7 +48,9 @@ import org.mockito.junit.jupiter.MockitoExtension;
 
 import java.io.IOException;
 import java.io.UncheckedIOException;
+import java.nio.charset.Charset;
 import java.util.Arrays;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -61,6 +65,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.atLeast;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -91,6 +96,8 @@ public class StandardOauth2AccessTokenProviderTest {
     private ArgumentCaptor<String> errorCaptor;
     @Captor
     private ArgumentCaptor<Throwable> throwableCaptor;
+    @Captor
+    private ArgumentCaptor<Request> requestCaptor;
 
     @BeforeEach
     public void setUp() {
@@ -113,6 +120,7 @@ public class StandardOauth2AccessTokenProviderTest {
         when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
         when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
         when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
+        when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
 
         testSubject.onEnabled(mockContext);
     }
@@ -125,7 +133,7 @@ public class StandardOauth2AccessTokenProviderTest {
 
         runner.addControllerService("testSubject", testSubject);
 
-        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL);
 
         // WHEN
         runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
@@ -142,12 +150,50 @@ public class StandardOauth2AccessTokenProviderTest {
 
         runner.addControllerService("testSubject", testSubject);
 
-        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL);
+
+        // WHEN
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, CLIENT_ID);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, CLIENT_SECRET);
+
+        // THEN
+        runner.assertValid(testSubject);
+    }
+
+    @Test
+    public void testInvalidWhenClientAuthenticationStrategyIsInvalid() throws Exception {
+        // GIVEN
+        Processor processor = new NoOpProcessor();
+        TestRunner runner = TestRunners.newTestRunner(processor);
+
+        runner.addControllerService("testSubject", testSubject);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, CLIENT_ID);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, CLIENT_SECRET);
 
         // WHEN
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY, "UNKNOWN");
+
+        // THEN
+        runner.assertNotValid(testSubject);
+    }
+
+    @Test
+    public void testValidWhenClientAuthenticationStrategyIsValid() throws Exception {
+        // GIVEN
+        Processor processor = new NoOpProcessor();
+        TestRunner runner = TestRunners.newTestRunner(processor);
+
+        runner.addControllerService("testSubject", testSubject);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL);
         runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
-        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId");
-        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret");
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, CLIENT_ID);
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, CLIENT_SECRET);
+
+        // WHEN
+        runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY, ClientAuthenticationStrategy.REQUEST_BODY.getValue());
 
         // THEN
         runner.assertValid(testSubject);
@@ -250,6 +296,42 @@ public class StandardOauth2AccessTokenProviderTest {
         assertEquals(expectedToken, actualToken);
     }
 
+    @Test
+    public void testBasicAuthentication() throws Exception {
+        // GIVEN
+        Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}");
+        when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
+        String expected = "Basic " + Base64.getEncoder().withoutPadding().encodeToString((CLIENT_ID + ":" + CLIENT_SECRET).getBytes());
+
+        // WHEN
+        testSubject.getAccessDetails();
+
+        // THEN
+        verify(mockHttpClient, atLeast(1)).newCall(requestCaptor.capture());
+        assertEquals(expected, requestCaptor.getValue().header("Authorization"));
+    }
+
+    @Test
+    public void testRequestBodyAuthentication() throws Exception {
+        when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
+        when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
+        testSubject.onEnabled(mockContext);
+
+        // GIVEN
+        Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}");
+        when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
+        String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID + "&client_secret=" + CLIENT_SECRET;
+
+        // WHEN
+        testSubject.getAccessDetails();
+
+        // THEN
+        Buffer buffer = new Buffer();
+        verify(mockHttpClient, atLeast(1)).newCall(requestCaptor.capture());
+        requestCaptor.getValue().body().writeTo(buffer);
+        assertEquals(expected, buffer.readString(Charset.defaultCharset()));
+    }
+
     @Test
     public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
         // GIVEN