You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2019/06/24 08:06:23 UTC

[arrow] branch master updated: ARROW-5643: [FlightRPC] Add ability to override SSL hostname checking

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new f180a53  ARROW-5643: [FlightRPC] Add ability to override SSL hostname checking
f180a53 is described below

commit f180a53b4da5ed932ecb8aa1e56b961cba899e61
Author: David Li <li...@gmail.com>
AuthorDate: Mon Jun 24 10:06:13 2019 +0200

    ARROW-5643: [FlightRPC] Add ability to override SSL hostname checking
    
    Adds the ability to override hostname checks, so you can connect to localhost over TLS but still verify that the certificate is for some other domain.
    
    Example: when deploying on Kubernetes with headless services, clients connect directly to backend services and do load balancing themselves. Thus all instances of an application must present a certificate for the same hostname. To do health checks in such an environment, you can't connect to the TLS hostname (which may resolve to a different instance); you need to connect to localhost, and override the hostname check.
    
    Also needs https://github.com/apache/arrow-testing/pull/5
    
    Author: David Li <li...@gmail.com>
    
    Closes #4608 from lihalite/flight-tls-java and squashes the following commits:
    
    581fc7582 <David Li> Add ability to override SSL hostname checking
---
 cpp/src/arrow/flight/client.cc                     |   5 +
 cpp/src/arrow/flight/client.h                      |   4 +
 cpp/src/arrow/flight/flight-test.cc                |  18 +++
 java/flight/pom.xml                                |  15 ++-
 .../java/org/apache/arrow/flight/FlightClient.java |  11 ++
 .../java/org/apache/arrow/flight/FlightServer.java |  17 ++-
 .../org/apache/arrow/flight/FlightTestUtil.java    |  44 +++++++
 .../test/java/org/apache/arrow/flight/TestTls.java | 130 +++++++++++++++++++++
 python/pyarrow/_flight.pyx                         |   8 +-
 python/pyarrow/includes/libarrow_flight.pxd        |   1 +
 python/pyarrow/tests/test_flight.py                |  15 +++
 testing                                            |   2 +-
 12 files changed, 259 insertions(+), 11 deletions(-)

diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 2b7c699..1926928 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -259,6 +259,11 @@ class FlightClient::FlightClientImpl {
     args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, 100);
     // Receive messages of any size
     args.SetMaxReceiveMessageSize(-1);
+
+    if (options.override_hostname != "") {
+      args.SetSslTargetNameOverride(options.override_hostname);
+    }
+
     stub_ = pb::FlightService::NewStub(
         grpc::CreateCustomChannel(grpc_uri.str(), creds, args));
     return Status::OK();
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 689c9f8..b8a5d4f 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -59,7 +59,11 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {
 
 class ARROW_FLIGHT_EXPORT FlightClientOptions {
  public:
+  /// \brief Root certificates to use for validating server
+  /// certificates.
   std::string tls_root_certs;
+  /// \brief Override the hostname checked by TLS. Use with caution.
+  std::string override_hostname;
 };
 
 /// \brief Client class for Arrow Flight RPC services (gRPC-based).
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index b295878..3c0b67c 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -675,5 +675,23 @@ TEST_F(TestTls, DoAction) {
   ASSERT_EQ(result->body->ToString(), "Hello, world!");
 }
 
+TEST_F(TestTls, OverrideHostname) {
+  std::unique_ptr<FlightClient> client;
+  auto client_options = FlightClientOptions();
+  client_options.override_hostname = "fakehostname";
+  CertKeyPair root_cert;
+  ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
+  client_options.tls_root_certs = root_cert.pem_cert;
+  ASSERT_OK(FlightClient::Connect(server_->location(), client_options, &client));
+
+  FlightCallOptions options;
+  options.timeout = TimeoutDuration{5.0};
+  Action action;
+  action.type = "test";
+  action.body = Buffer::FromString("");
+  std::unique_ptr<ResultStream> results;
+  ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
+}
+
 }  // namespace flight
 }  // namespace arrow
diff --git a/java/flight/pom.xml b/java/flight/pom.xml
index 7d01a6e..b03fbe6 100644
--- a/java/flight/pom.xml
+++ b/java/flight/pom.xml
@@ -1,10 +1,10 @@
 <?xml version="1.0"?>
-<!-- Copyright (C) 2017-2018 Dremio Corporation Licensed under the Apache 
-  License, Version 2.0 (the "License"); you may not use this file except in 
-  compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 
-  Unless required by applicable law or agreed to in writing, software distributed 
-  under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 
-  OR CONDITIONS OF ANY KIND, either express or implied. See the License for 
+<!-- Copyright (C) 2017-2018 Dremio Corporation Licensed under the Apache
+  License, Version 2.0 (the "License"); you may not use this file except in
+  compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+  Unless required by applicable law or agreed to in writing, software distributed
+  under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+  OR CONDITIONS OF ANY KIND, either express or implied. See the License for
   the specific language governing permissions and limitations under the License. -->
 <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
   <modelVersion>4.0.0</modelVersion>
@@ -137,6 +137,9 @@
         <artifactId>maven-surefire-plugin</artifactId>
         <configuration>
           <enableAssertions>false</enableAssertions>
+          <systemPropertyVariables>
+            <arrow.test.dataRoot>${project.basedir}/../../testing/data</arrow.test.dataRoot>
+          </systemPropertyVariables>
         </configuration>
       </plugin>
       <plugin>
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index c70e1fd..37e4514 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -354,6 +354,7 @@ public class FlightClient implements AutoCloseable {
     private InputStream trustedCertificates = null;
     private InputStream clientCertificate = null;
     private InputStream clientKey = null;
+    private String overrideHostname = null;
 
     private Builder() {
     }
@@ -371,6 +372,12 @@ public class FlightClient implements AutoCloseable {
       return this;
     }
 
+    /** Override the hostname checked for TLS. Use with caution in production. */
+    public Builder overrideHostname(final String hostname) {
+      this.overrideHostname = hostname;
+      return this;
+    }
+
     /** Set the maximum inbound message size. */
     public Builder maxInboundMessageSize(int maxSize) {
       Preconditions.checkArgument(maxSize > 0);
@@ -461,6 +468,10 @@ public class FlightClient implements AutoCloseable {
             throw new RuntimeException(e);
           }
         }
+
+        if (this.overrideHostname != null) {
+          builder.overrideAuthority(this.overrideHostname);
+        }
       } else {
         builder.usePlaintext();
       }
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index eaea044..cd59a75 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -72,10 +72,23 @@ public class FlightServer implements AutoCloseable {
     server.awaitTermination();
   }
 
+  /** Request that the server shut down. */
+  public void shutdown() {
+    server.shutdown();
+  }
+
+  /**
+   * Wait for the server to shut down with a timeout.
+   * @return true if the server shut down successfully.
+   */
+  public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException {
+    return server.awaitTermination(timeout, unit);
+  }
+
   /** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */
   public void close() throws InterruptedException {
-    server.shutdown();
-    final boolean terminated = server.awaitTermination(3000, TimeUnit.MILLISECONDS);
+    shutdown();
+    final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS);
     if (terminated) {
       logger.debug("Server was terminated within 3s");
       return;
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
index f6b9e86..3cb09ef 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
@@ -17,8 +17,14 @@
 
 package org.apache.arrow.flight;
 
+import java.io.File;
 import java.io.IOException;
 import java.lang.reflect.InvocationTargetException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
 import java.util.Random;
 import java.util.function.Function;
 
@@ -26,9 +32,12 @@ import java.util.function.Function;
  * Utility methods and constants for testing flight servers.
  */
 public class FlightTestUtil {
+
   private static final Random RANDOM = new Random();
 
   public static final String LOCALHOST = "localhost";
+  public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA";
+  public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot";
 
   /**
    * Returns a a FlightServer (actually anything that is startable)
@@ -62,6 +71,30 @@ public class FlightTestUtil {
     return server;
   }
 
+  static Path getTestDataRoot() {
+    String path = System.getenv(TEST_DATA_ENV_VAR);
+    if (path == null) {
+      path = System.getProperty(TEST_DATA_PROPERTY);
+    }
+    return Paths.get(Objects.requireNonNull(path,
+        String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.",
+            TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY)));
+  }
+
+  static Path getFlightTestDataRoot() {
+    return getTestDataRoot().resolve("flight");
+  }
+
+  static Path exampleTlsRootCert() {
+    return getFlightTestDataRoot().resolve("root-ca.pem");
+  }
+
+  static List<CertKeyPair> exampleTlsCerts() {
+    final Path root = getFlightTestDataRoot();
+    return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem").toFile(), root.resolve("cert0.pkcs1").toFile()),
+        new CertKeyPair(root.resolve("cert1.pem").toFile(), root.resolve("cert1.pkcs1").toFile()));
+  }
+
   static boolean isEpollAvailable() {
     try {
       Class<?> epoll = Class.forName("io.netty.channel.epoll.Epoll");
@@ -84,6 +117,17 @@ public class FlightTestUtil {
     return isEpollAvailable() || isKqueueAvailable();
   }
 
+  public static class CertKeyPair {
+
+    public final File cert;
+    public final File key;
+
+    public CertKeyPair(File cert, File key) {
+      this.cert = cert;
+      this.key = key;
+    }
+  }
+
   private FlightTestUtil() {
   }
 }
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
new file mode 100644
index 0000000..c22304d
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.FlightClient.Builder;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for TLS in Flight.
+ */
+public class TestTls {
+
+  /**
+   * Test a basic request over TLS.
+   */
+  @Test
+  public void connectTls() {
+    test((builder) -> {
+      try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
+          final FlightClient client = builder.trustedCertificates(roots).build()) {
+        final Iterator<Result> responses = client.doAction(new Action("hello-world"));
+        final byte[] response = responses.next().getBody();
+        Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8));
+        Assert.assertFalse(responses.hasNext());
+      } catch (InterruptedException | IOException e) {
+        throw new RuntimeException(e);
+      }
+    });
+  }
+
+  /**
+   * Make sure that connections are rejected when the root certificate isn't trusted.
+   */
+  @Test(expected = io.grpc.StatusRuntimeException.class)
+  public void rejectInvalidCert() {
+    test((builder) -> {
+      try (final FlightClient client = builder.build()) {
+        final Iterator<Result> responses = client.doAction(new Action("hello-world"));
+        responses.next().getBody();
+        Assert.fail("Call should have failed");
+      } catch (InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    });
+  }
+
+  /**
+   * Make sure that connections are rejected when the hostname doesn't match.
+   */
+  @Test(expected = io.grpc.StatusRuntimeException.class)
+  public void rejectHostname() {
+    test((builder) -> {
+      try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
+          final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname")
+              .build()) {
+        final Iterator<Result> responses = client.doAction(new Action("hello-world"));
+        responses.next().getBody();
+        Assert.fail("Call should have failed");
+      } catch (InterruptedException | IOException e) {
+        throw new RuntimeException(e);
+      }
+    });
+  }
+
+
+  void test(Consumer<Builder> testFn) {
+    final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0);
+    try (
+        BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+        Producer producer = new Producer();
+        FlightServer s =
+            FlightTestUtil.getStartedServer(
+                (port) -> {
+                  try {
+                    return FlightServer.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, port), producer)
+                        .useTls(certKey.cert, certKey.key)
+                        .build();
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                })) {
+      final Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort()));
+      testFn.accept(builder);
+    } catch (InterruptedException | IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  static class Producer extends NoOpFlightProducer implements AutoCloseable {
+
+    @Override
+    public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
+      if (action.getType().equals("hello-world")) {
+        listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
+        listener.onCompleted();
+      }
+      listener.onError(new UnsupportedOperationException("Invalid action " + action.getType()));
+    }
+
+    @Override
+    public void close() {
+    }
+  }
+}
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index c916e6b..7ca83a9 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -419,7 +419,7 @@ cdef class FlightClient:
                         .format(self.__class__.__name__))
 
     @staticmethod
-    def connect(location, tls_root_certs=None):
+    def connect(location, tls_root_certs=None, override_hostname=None):
         """
         Connect to a Flight service on the given host and port.
 
@@ -428,8 +428,10 @@ cdef class FlightClient:
         location : Location
             location to connect to
 
-        tls_root_certs : bytes
+        tls_root_certs : bytes or None
             PEM-encoded
+        unsafe_override_hostname : str or None
+            Override the hostname checked by TLS. Insecure, use with caution.
         """
         cdef:
             FlightClient result = FlightClient.__new__(FlightClient)
@@ -439,6 +441,8 @@ cdef class FlightClient:
 
         if tls_root_certs:
             c_options.tls_root_certs = tobytes(tls_root_certs)
+        if override_hostname:
+            c_options.override_hostname = tobytes(override_hostname)
 
         with nogil:
             check_status(CFlightClient.Connect(c_location, c_options,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 14d1ed1..61e9571 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -170,6 +170,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
     cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
         CFlightClientOptions()
         c_string tls_root_certs
+        c_string override_hostname
 
     cdef cppclass CFlightClient" arrow::flight::FlightClient":
         @staticmethod
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index f4c9cc1..3088a7a 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -570,3 +570,18 @@ def test_tls_do_get():
             server_location, tls_root_certs=certs["root_cert"])
         data = client.do_get(flight.Ticket(b'ints')).read_all()
         assert data.equals(table)
+
+
+def test_tls_override_hostname():
+    """Check that incorrectly overriding the hostname fails."""
+    certs = example_tls_certs()
+
+    with flight_server(
+            ConstantFlightServer, tls_certificates=certs["certificates"],
+            connect_args=dict(tls_root_certs=certs["root_cert"]),
+    ) as server_location:
+        client = flight.FlightClient.connect(
+            server_location, tls_root_certs=certs["root_cert"],
+            override_hostname="fakehostname")
+        with pytest.raises(pa.ArrowIOError):
+            client.do_get(flight.Ticket(b'ints'))
diff --git a/testing b/testing
index 12f9dbd..a674dac 160000
--- a/testing
+++ b/testing
@@ -1 +1 @@
-Subproject commit 12f9dbd2a37eea6fa370e108a1d797ee1167724a
+Subproject commit a674dac190c5fc626964c9b611c67552fa2e530d