You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/03/20 02:32:37 UTC

[arrow] branch master updated: ARROW-4871: [Java/Flight] Handle large Flight messages

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 5dc54da  ARROW-4871: [Java/Flight] Handle large Flight messages
5dc54da is described below

commit 5dc54da921ea5b0788c80e2db969940de442b7ac
Author: David Li <Da...@twosigma.com>
AuthorDate: Tue Mar 19 21:32:29 2019 -0500

    ARROW-4871: [Java/Flight] Handle large Flight messages
    
    Equivalent to https://issues.apache.org/jira/browse/ARROW-4421, but for Java.
    
    Author: David Li <Da...@twosigma.com>
    
    Closes #3898 from lihalite/java-large-flight and squashes the following commits:
    
    a74e5c88a <David Li> Move Flight gRPC config values to constants
    209b0a127 <David Li> Accept large messages in Java Flight client/server
    2ce372b5c <David Li> Spawn new Flight server for each integration test
    90f84b955 <David Li> Test large batch sizes in Flight integration
---
 integration/integration_test.py                    |  33 +++--
 .../java/org/apache/arrow/flight/FlightClient.java |  10 +-
 .../java/org/apache/arrow/flight/FlightServer.java |   4 +
 .../org/apache/arrow/flight/TestLargeMessage.java  | 164 +++++++++++++++++++++
 4 files changed, 196 insertions(+), 15 deletions(-)

diff --git a/integration/integration_test.py b/integration/integration_test.py
index 17e7afb..d6cb965 100644
--- a/integration/integration_test.py
+++ b/integration/integration_test.py
@@ -896,7 +896,7 @@ def generate_dictionary_case():
                           dictionaries=[dict1, dict2])
 
 
-def get_generated_json_files(tempdir=None):
+def get_generated_json_files(tempdir=None, flight=False):
     tempdir = tempdir or tempfile.mkdtemp()
 
     def _temp_path():
@@ -911,6 +911,10 @@ def get_generated_json_files(tempdir=None):
         generate_dictionary_case()
     ]
 
+    if flight:
+        file_objs.append(generate_primitive_case([32 * 1024],
+                                                 name='large_batch'))
+
     generated_paths = []
     for file_obj in file_objs:
         out_path = os.path.join(tempdir, 'generated_' +
@@ -951,11 +955,9 @@ class IntegrationRunner(object):
         clients = filter(lambda t: (t.FLIGHT_CLIENT and t.CONSUMER),
                          self.testers)
         for server, client in itertools.product(servers, clients):
-            try:
-                self._compare_flight_implementations(server, client)
-            except Exception:
-                traceback.print_exc()
-                failures.append((server, client, sys.exc_info()))
+            for failure in self._compare_flight_implementations(server,
+                                                                client):
+                failures.append(failure)
         return failures
 
     def _compare_implementations(self, producer, consumer):
@@ -1004,15 +1006,19 @@ class IntegrationRunner(object):
         )
         print('##########################################################')
 
-        with producer.flight_server():
-            for json_path in self.json_files:
-                print('=' * 58)
-                print('Testing file {0}'.format(json_path))
-                print('=' * 58)
+        for json_path in self.json_files:
+            print('=' * 58)
+            print('Testing file {0}'.format(json_path))
+            print('=' * 58)
 
+            with producer.flight_server():
                 # Have the client upload the file, then download and
                 # compare
-                consumer.flight_request(producer.FLIGHT_PORT, json_path)
+                try:
+                    consumer.flight_request(producer.FLIGHT_PORT, json_path)
+                except Exception:
+                    traceback.print_exc()
+                    yield (producer, consumer, sys.exc_info())
 
 
 class Tester(object):
@@ -1297,7 +1303,8 @@ def run_all_tests(run_flight=False, debug=False, tempdir=None):
                JavaTester(debug=debug),
                JSTester(debug=debug)]
     static_json_files = get_static_json_files()
-    generated_json_files = get_generated_json_files(tempdir=tempdir)
+    generated_json_files = get_generated_json_files(tempdir=tempdir,
+                                                    flight=run_flight)
     json_files = static_json_files + generated_json_files
 
     runner = IntegrationRunner(json_files, testers,
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index bd126b5..b7a7e3f 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -56,6 +56,8 @@ import io.grpc.stub.StreamObserver;
 
 public class FlightClient implements AutoCloseable {
   private static final int PENDING_REQUESTS = 5;
+  /** The maximum number of trace events to keep on the gRPC Channel. This value disables channel tracing. */
+  private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
   private final BufferAllocator allocator;
   private final ManagedChannel channel;
   private final FlightServiceBlockingStub blockingStub;
@@ -68,8 +70,12 @@ public class FlightClient implements AutoCloseable {
    * Construct client for accessing RouteGuide server using the existing channel.
    */
   public FlightClient(BufferAllocator incomingAllocator, Location location) {
-    final ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forAddress(location.getHost(),
-        location.getPort()).maxTraceEvents(0).usePlaintext();
+    final ManagedChannelBuilder<?> channelBuilder =
+        ManagedChannelBuilder.forAddress(location.getHost(),
+        location.getPort())
+            .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
+            .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE)
+            .usePlaintext();
     this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
     channel = channelBuilder.build();
     blockingStub = FlightServiceGrpc.newBlockingStub(channel).withInterceptors(authInterceptor);
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index 1681f3b..80c1624 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -36,12 +36,16 @@ public class FlightServer implements AutoCloseable {
 
   private final Server server;
 
+  /** The maximum size of an individual gRPC message. This effectively disables the limit. */
+  static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE;
+
   public FlightServer(
       BufferAllocator allocator,
       int port,
       FlightProducer producer,
       ServerAuthHandler authHandler) {
     this.server = ServerBuilder.forPort(port)
+        .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE)
         .addService(
             ServerInterceptors.intercept(
                 new FlightBindingService(allocator, producer, authHandler),
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
new file mode 100644
index 0000000..a26bd86
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.sql.Types;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.stream.Stream;
+
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestLargeMessage {
+  /**
+   * Make sure a Flight client accepts large message payloads by default.
+   */
+  @Test
+  public void getLargeMessage() throws Exception {
+    try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+         final Producer producer = new Producer(a);
+         final FlightServer s =
+             FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) {
+
+      try (FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()))) {
+        FlightStream stream = client.getStream(new Ticket(new byte[]{}));
+        try (VectorSchemaRoot root = stream.getRoot()) {
+          while (stream.next()) {
+            for (final Field field : root.getSchema().getFields()) {
+              int value = 0;
+              final IntVector iv = (IntVector) root.getVector(field.getName());
+              for (int i = 0; i < root.getRowCount(); i++) {
+                Assert.assertEquals(value, iv.get(i));
+                value++;
+              }
+            }
+          }
+        }
+        stream.close();
+      }
+    }
+  }
+
+  /**
+   * Make sure a Flight server accepts large message payloads by default.
+   */
+  @Test
+  public void putLargeMessage() throws Exception {
+    try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+         final Producer producer = new Producer(a);
+         final FlightServer s =
+             FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) {
+
+      try (FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()));
+           BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE);
+           VectorSchemaRoot root = generateData(testAllocator)) {
+        final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root);
+        listener.putNext();
+        listener.completed();
+        Assert.assertEquals(listener.getResult(), Flight.PutResult.getDefaultInstance());
+      }
+    }
+  }
+
+  private static VectorSchemaRoot generateData(BufferAllocator allocator) {
+    final int size = 128 * 1024;
+    final List<String> fieldNames = Arrays.asList("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10");
+    final Stream<Field> fields = fieldNames
+        .stream()
+        .map(fieldName -> new Field(fieldName, FieldType.nullable(new ArrowType.Int(32, true)), null));
+    final Schema schema = new Schema(fields::iterator, null);
+
+    final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+    root.allocateNew();
+    for (final String fieldName : fieldNames) {
+      final IntVector iv = (IntVector) root.getVector(fieldName);
+      iv.setValueCount(size);
+      for (int i = 0; i < size; i++) {
+        iv.set(i, i);
+      }
+    }
+    root.setRowCount(size);
+    return root;
+  }
+
+  private static class Producer implements FlightProducer, AutoCloseable {
+    private final BufferAllocator allocator;
+
+    Producer(BufferAllocator allocator) {
+      this.allocator = allocator;
+    }
+
+    @Override
+    public void getStream(Ticket ticket, ServerStreamListener listener) {
+      try (VectorSchemaRoot root = generateData(allocator)) {
+        listener.start(root);
+        listener.putNext();
+        listener.completed();
+      }
+    }
+
+    @Override
+    public void listFlights(Criteria criteria, StreamListener<FlightInfo> listener) {
+
+    }
+
+    @Override
+    public FlightInfo getFlightInfo(FlightDescriptor descriptor) {
+      return null;
+    }
+
+    @Override
+    public Callable<Flight.PutResult> acceptPut(FlightStream flightStream) {
+      return () -> {
+        try (VectorSchemaRoot root = flightStream.getRoot()) {
+          while (flightStream.next()) {
+            ;
+          }
+          return Flight.PutResult.getDefaultInstance();
+        }
+      };
+    }
+
+    @Override
+    public Result doAction(Action action) {
+      return null;
+    }
+
+    @Override
+    public void listActions(StreamListener<ActionType> listener) {
+
+    }
+
+    @Override
+    public void close() throws Exception {
+      allocator.close();
+    }
+  }
+}