You are viewing a plain text version of this content. The canonical link for it is here.
Posted to jira@kafka.apache.org by GitBox <gi...@apache.org> on 2021/10/19 21:18:30 UTC

[GitHub] [kafka] kirktrue commented on a change in pull request #11284: KAFKA-13202: KIP-768: Extend SASL/OAUTHBEARER with Support for OIDC

kirktrue commented on a change in pull request #11284:
URL: https://github.com/apache/kafka/pull/11284#discussion_r732251460



##########
File path: clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.oauthbearer.secured;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.UnsupportedEncodingException;
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.SSLSocketFactory;
+import org.apache.kafka.common.config.SaslConfigs;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <code>HttpAccessTokenRetriever</code> is an {@link AccessTokenRetriever} that will
+ * communicate with an OAuth/OIDC provider directly via HTTP to post client credentials
+ * ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG})
+ * to a publicized token endpoint URL
+ * ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URI}).
+ *
+ * @see AccessTokenRetriever
+ * @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG
+ * @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG
+ * @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG
+ * @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URI
+ */
+
+public class HttpAccessTokenRetriever implements AccessTokenRetriever {
+
+    private static final Logger log = LoggerFactory.getLogger(HttpAccessTokenRetriever.class);
+
+    private static final Set<Integer> UNRETRYABLE_HTTP_CODES;
+
+    public static final String AUTHORIZATION_HEADER = "Authorization";
+
+    static {
+        // This does not have to be an exhaustive list. There are other HTTP codes that
+        // are defined in different RFCs (e.g. https://datatracker.ietf.org/doc/html/rfc6585)
+        // that we won't worry about yet. The worst case if a status code is missing from
+        // this set is that the request will be retried.
+        UNRETRYABLE_HTTP_CODES = new HashSet<>();
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_REQUEST);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNAUTHORIZED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PAYMENT_REQUIRED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_FORBIDDEN);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_FOUND);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_METHOD);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_ACCEPTABLE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PROXY_AUTH);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_CONFLICT);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_GONE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_LENGTH_REQUIRED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PRECON_FAILED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_ENTITY_TOO_LARGE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_REQ_TOO_LONG);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNSUPPORTED_TYPE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_IMPLEMENTED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_VERSION);
+    }
+
+    private final String clientId;
+
+    private final String clientSecret;
+
+    private final String scope;
+
+    private final SSLSocketFactory sslSocketFactory;
+
+    private final String tokenEndpointUri;
+
+    private final long loginRetryBackoffMs;
+
+    private final long loginRetryBackoffMaxMs;
+
+    private final Integer loginConnectTimeoutMs;
+
+    private final Integer loginReadTimeoutMs;
+
+    public HttpAccessTokenRetriever(String clientId,
+        String clientSecret,
+        String scope,
+        SSLSocketFactory sslSocketFactory,
+        String tokenEndpointUri,
+        long loginRetryBackoffMs,
+        long loginRetryBackoffMaxMs,
+        Integer loginConnectTimeoutMs,
+        Integer loginReadTimeoutMs) {
+        this.clientId = Objects.requireNonNull(clientId);
+        this.clientSecret = Objects.requireNonNull(clientSecret);
+        this.scope = scope;
+        this.sslSocketFactory = sslSocketFactory;
+        this.tokenEndpointUri = Objects.requireNonNull(tokenEndpointUri);
+        this.loginRetryBackoffMs = loginRetryBackoffMs;
+        this.loginRetryBackoffMaxMs = loginRetryBackoffMaxMs;
+        this.loginConnectTimeoutMs = loginConnectTimeoutMs;
+        this.loginReadTimeoutMs = loginReadTimeoutMs;
+    }
+
+    /**
+     * Retrieves a JWT access token in its serialized three-part form. The implementation
+     * is free to determine how it should be retrieved but should not perform validation
+     * on the result.
+     *
+     * <b>Note</b>: This is a blocking function and callers should be aware that the
+     * implementation communicates over a network. The facility in the
+     * {@link javax.security.auth.spi.LoginModule} from which this is ultimately called
+     * does not provide an asynchronous approach.
+     *
+     * @return Non-<code>null</code> JWT access token string
+     *
+     * @throws IOException Thrown on errors related to IO during retrieval
+     */
+
+    @Override
+    public String retrieve() throws IOException {
+        String authorizationHeader = formatAuthorizationHeader(clientId, clientSecret);
+        String requestBody = formatRequestBody(scope);
+
+        Retry<String> retry = new Retry<>(Time.SYSTEM,
+            loginRetryBackoffMs,
+            loginRetryBackoffMaxMs);
+
+        Map<String, String> headers = Collections.singletonMap(AUTHORIZATION_HEADER, authorizationHeader);
+
+        String responseBody = retry.execute(() -> {
+            HttpURLConnection con = (HttpURLConnection) new URL(tokenEndpointUri).openConnection();
+
+            if (sslSocketFactory != null && con instanceof HttpsURLConnection)
+                ((HttpsURLConnection) con).setSSLSocketFactory(sslSocketFactory);
+
+            try {
+                return post(con, headers, requestBody, loginConnectTimeoutMs, loginReadTimeoutMs);
+            } finally {
+                con.disconnect();
+            }
+        });
+        log.debug("retrieve - responseBody: {}", responseBody);
+
+        return parseAccessToken(responseBody);
+    }
+
+    public static String post(HttpURLConnection con,
+        Map<String, String> headers,
+        String requestBody,
+        Integer connectTimeoutMs,
+        Integer readTimeoutMs)
+        throws IOException, UnretryableException {
+        handleInput(con, headers, requestBody, connectTimeoutMs, readTimeoutMs);
+        return handleOutput(con);
+    }
+
+    private static void handleInput(HttpURLConnection con,
+        Map<String, String> headers,
+        String requestBody,
+        Integer connectTimeoutMs,
+        Integer readTimeoutMs)
+        throws IOException, UnretryableException {
+        log.debug("handleInput - starting post for {}", con.getURL());
+        con.setRequestMethod("POST");
+        con.setRequestProperty("Accept", "application/json");
+
+        if (headers != null) {
+            for (Map.Entry<String, String> header : headers.entrySet())
+                con.setRequestProperty(header.getKey(), header.getValue());
+        }
+
+        con.setRequestProperty("Cache-Control", "no-cache");
+
+        if (requestBody != null) {
+            con.setRequestProperty("Content-Length", String.valueOf(requestBody.length()));
+            con.setDoOutput(true);
+        }
+
+        con.setUseCaches(false);
+
+        if (connectTimeoutMs != null)
+            con.setConnectTimeout(connectTimeoutMs);
+
+        if (readTimeoutMs != null)
+            con.setReadTimeout(readTimeoutMs);
+
+        log.debug("handleInput - preparing to connect to {}", con.getURL());
+        con.connect();
+
+        if (requestBody != null) {
+            try (OutputStream os = con.getOutputStream()) {
+                ByteArrayInputStream is = new ByteArrayInputStream(requestBody.getBytes(
+                    StandardCharsets.UTF_8));
+                log.debug("handleInput - preparing to write request body to {}", con.getURL());
+                copy(is, os);
+            }
+        }
+    }
+
+    static String handleOutput(final HttpURLConnection con) throws IOException {
+        int responseCode = con.getResponseCode();
+        log.debug("handleOutput - responseCode: {}", responseCode);
+
+        String responseBody = null;
+
+        try (InputStream is = con.getInputStream()) {
+            ByteArrayOutputStream os = new ByteArrayOutputStream();
+            log.debug("handleOutput - preparing to read request body to {}", con.getURL());
+            copy(is, os);
+            responseBody = os.toString(StandardCharsets.UTF_8.name());
+        } catch (Exception e) {
+            log.warn("handleOutput - error retrieving data", e);
+        }
+
+        if (responseCode == HttpURLConnection.HTTP_OK || responseCode == HttpURLConnection.HTTP_CREATED) {
+            if (responseBody == null || responseBody.isEmpty())
+                throw new IOException(String.format("The token endpoint response was unexpectedly empty despite response code %s from %s", responseCode, con.getURL()));
+
+            log.debug("handleOutput - responseCode: {}, response: {}", responseCode, responseBody);
+
+            return responseBody;
+        } else {
+            log.warn("handleOutput - error response code: {}, error response body: {}", responseCode, responseBody);
+
+            if (UNRETRYABLE_HTTP_CODES.contains(responseCode)) {
+                // We know that this is a non-transient error, so let's not keep retrying the
+                // request unnecessarily.
+                throw new UnretryableException(new IOException(String.format("The response code %s was encountered reading the token endpoint response; will not attempt further retries", responseCode)));
+            } else {
+                // We don't know if this is a transient (retryable) error or not, so let's assume
+                // it is.
+                throw new IOException(String.format("The unexpected response code %s was encountered reading the token endpoint response", responseCode));
+            }
+        }
+    }
+
+    static void copy(InputStream is, OutputStream os) throws IOException {
+        byte[] buf = new byte[4096];
+        int b;
+
+        while ((b = is.read(buf)) != -1)
+            os.write(buf, 0, b);
+    }
+
+    static String parseAccessToken(String responseBody) throws IOException {
+        ObjectMapper mapper = new ObjectMapper();
+        JsonNode rootNode = mapper.readTree(responseBody);
+        JsonNode accessTokenNode = rootNode.at("/access_token");
+
+        if (accessTokenNode == null)
+            throw new IOException("The token endpoint response did not contain an access_token value");
+
+        return sanitizeString("The token endpoint response access_token", accessTokenNode.textValue());
+    }
+
+    static String formatAuthorizationHeader(String clientId, String clientSecret) throws IOException {
+        clientId = sanitizeString("The token endpoint request clientId", clientId);
+        clientSecret = sanitizeString("The token endpoint request clientId", clientSecret);
+
+        String s = String.format("%s:%s", clientId, clientSecret);
+        String encoded = Base64.getUrlEncoder().encodeToString(Utils.utf8(s));
+        return String.format("Basic %s", encoded);
+    }
+
+    static String formatRequestBody(String scope) throws IOException {
+        try {
+            StringBuilder requestParameters = new StringBuilder();
+            requestParameters.append("grant_type=client_credentials");
+
+            if (scope != null && !scope.trim().isEmpty()) {
+                scope = scope.trim();
+                String encodedScope = URLEncoder.encode(scope, StandardCharsets.UTF_8.name());

Review comment:
       The request body (parameters) are _URL_-encoded, but the `Authorization` header value after `Basic` needs to be _base 64_-encoded.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscribe@kafka.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org