You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2021/02/04 13:40:48 UTC

[arrow] branch master updated: ARROW-9586: [FlightRPC][Java] implement per-call allocator

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 53026f9  ARROW-9586: [FlightRPC][Java] implement per-call allocator
53026f9 is described below

commit 53026f9e3582cfcfff82498eda5a66852a8f2db0
Author: David Li <li...@gmail.com>
AuthorDate: Thu Feb 4 08:39:17 2021 -0500

    ARROW-9586: [FlightRPC][Java] implement per-call allocator
    
    This allows gRPC servers to be configured such that every DoGet, DoPut, or DoExchange call gets its own child allocator that is closed when the RPC ends. This includes temporary allocations made by Flight itself, e.g. in serializing and deserializing data. This way, we can more tightly control memory usage and track memory leaks in Flight and in application on the granularity of a particular RPC, instead of for the server as a whole. It also means we can track fine-grained metrics for [...]
    
    Closes #8265 from lidavidm/arrow-9586
    
    Authored-by: David Li <li...@gmail.com>
    Signed-off-by: David Li <li...@gmail.com>
---
 .../apache/arrow/flight/FlightBindingService.java  |   6 +-
 .../org/apache/arrow/flight/FlightService.java     |  18 +-
 java/flight/flight-grpc/pom.xml                    |  14 +
 .../flight/AllocatorClosingServerInterceptor.java  | 150 ++++++
 .../arrow/flight/ArrowMessageMarshaller.java       |  65 +++
 .../org/apache/arrow/flight/FlightGrpcUtils.java   |  28 ++
 .../apache/arrow/flight/FlightHandlerRegistry.java | 174 +++++++
 .../apache/arrow/flight/TestPerCallAllocator.java  | 522 +++++++++++++++++++++
 8 files changed, 972 insertions(+), 5 deletions(-)

diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java
index ba5249b..7ab3127 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java
@@ -43,9 +43,9 @@ import io.grpc.stub.StreamObserver;
  */
 class FlightBindingService implements BindableService {
 
-  private static final String DO_GET = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoGet");
-  private static final String DO_PUT = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoPut");
-  private static final String DO_EXCHANGE = MethodDescriptor.generateFullMethodName(
+  static final String DO_GET = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoGet");
+  static final String DO_PUT = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoPut");
+  static final String DO_EXCHANGE = MethodDescriptor.generateFullMethodName(
       FlightConstants.SERVICE, "DoExchange");
   private static final Set<String> OVERRIDE_METHODS = ImmutableSet.of(DO_GET, DO_PUT, DO_EXCHANGE);
 
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java
index 4fb0dea..073d4de 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java
@@ -42,6 +42,7 @@ import org.slf4j.LoggerFactory;
 
 import com.google.common.base.Strings;
 
+import io.grpc.Context;
 import io.grpc.stub.ServerCallStreamObserver;
 import io.grpc.stub.StreamObserver;
 
@@ -52,6 +53,8 @@ class FlightService extends FlightServiceImplBase {
 
   private static final Logger logger = LoggerFactory.getLogger(FlightService.class);
   private static final int PENDING_REQUESTS = 5;
+  static final Context.Key<BufferAllocator> PER_CALL_ALLOCATOR =
+      Context.key("org.apache.arrow.flight.FlightGrpcUtils.PER_CALL_ALLOCATOR");
 
   private final BufferAllocator allocator;
   private final FlightProducer producer;
@@ -221,7 +224,7 @@ class FlightService extends FlightServiceImplBase {
     final StreamPipe<PutResult, Flight.PutResult> ackStream = StreamPipe
         .wrap(responseObserver, PutResult::toProtocol, this::handleExceptionWithMiddleware);
     final FlightStream fs = new FlightStream(
-        allocator,
+        getCallAllocator(),
         PENDING_REQUESTS,
         /* server-upload streams are not cancellable */null,
         responseObserver::request);
@@ -352,7 +355,7 @@ class FlightService extends FlightServiceImplBase {
         responseObserver,
         this::handleExceptionWithMiddleware);
     final FlightStream fs = new FlightStream(
-        allocator,
+        getCallAllocator(),
         PENDING_REQUESTS,
         /* server-upload streams are not cancellable */null,
         responseObserver::request);
@@ -378,6 +381,17 @@ class FlightService extends FlightServiceImplBase {
   }
 
   /**
+   * Helper method to get either the per-call allocator (if enabled) or the shared allocator.
+   */
+  private BufferAllocator getCallAllocator() {
+    BufferAllocator callAllocator = PER_CALL_ALLOCATOR.get();
+    if (callAllocator == null) {
+      callAllocator = this.allocator;
+    }
+    return callAllocator;
+  }
+
+  /**
    * Call context for the service.
    */
   static class CallContext implements FlightProducer.CallContext {
diff --git a/java/flight/flight-grpc/pom.xml b/java/flight/flight-grpc/pom.xml
index 8399642..0d28d57 100644
--- a/java/flight/flight-grpc/pom.xml
+++ b/java/flight/flight-grpc/pom.xml
@@ -51,6 +51,11 @@
     </dependency>
     <dependency>
       <groupId>io.grpc</groupId>
+      <artifactId>grpc-context</artifactId>
+      <version>${dep.grpc.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>io.grpc</groupId>
       <artifactId>grpc-core</artifactId>
       <version>${dep.grpc.version}</version>
     </dependency>
@@ -71,6 +76,11 @@
        <version>${project.version}</version>
        <scope>runtime</scope>
      </dependency>
+    <dependency>
+       <groupId>org.apache.arrow</groupId>
+       <artifactId>arrow-vector</artifactId>
+       <version>${project.version}</version>
+     </dependency>
      <dependency>
        <groupId>io.grpc</groupId>
        <artifactId>grpc-protobuf</artifactId>
@@ -90,6 +100,10 @@
        <artifactId>grpc-api</artifactId>
        <version>${dep.grpc.version}</version>
      </dependency>
+     <dependency>
+       <groupId>org.slf4j</groupId>
+       <artifactId>slf4j-api</artifactId>
+     </dependency>
   </dependencies>
 
   <build>
diff --git a/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/AllocatorClosingServerInterceptor.java b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/AllocatorClosingServerInterceptor.java
new file mode 100644
index 0000000..0e173ed
--- /dev/null
+++ b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/AllocatorClosingServerInterceptor.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import io.grpc.ForwardingServerCallListener;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+
+/**
+ * A server interceptor that closes {@link ArrowMessageMarshaller} instances.
+ *
+ * <p>This must be the FIRST interceptor run (which is the LAST one registered), as interceptors can choose to
+ * abort a call chain at any point, and in that case, the application would leak a BufferAllocator.
+ */
+public class AllocatorClosingServerInterceptor implements ServerInterceptor {
+  private static final Logger LOGGER = LoggerFactory.getLogger(AllocatorClosingServerInterceptor.class);
+  private final BiConsumer<ServerCall<?, ?>, BufferAllocator> callback;
+  private final AtomicInteger outstandingCalls;
+
+  /**
+   * Create a new interceptor.
+   *
+   * @param callback A callback for right before we close a BufferAllocator.
+   *     Applications can use this to record metrics about the allocator.
+   */
+  public AllocatorClosingServerInterceptor(BiConsumer<ServerCall<?, ?>, BufferAllocator> callback) {
+    this.callback = callback;
+    this.outstandingCalls = new AtomicInteger(0);
+  }
+
+  /** Get the outstanding call count (useful for metrics reporting). */
+  public int getOutstandingCalls() {
+    return outstandingCalls.get();
+  }
+
+  /**
+   * Wait for all tracked calls to finish.
+   *
+   * <p>gRPC does not wait for all onCancel/onComplete callbacks to finish on server shutdown. This method implements a
+   * simple busy-wait so that you can ensure all those callbacks are finished (and hence all child allocators are
+   * closed).
+   *
+   * <p>Should only be called after shutting down the gRPC server, before program exit.
+   *
+   * @throws InterruptedException if interrupted during waiting.
+   * @throws CancellationException if the timeout expires and calls have not yet finished.
+   */
+  public void awaitTermination(long duration, TimeUnit unit) throws InterruptedException, CancellationException {
+    long start = System.nanoTime();
+    long end = start + unit.toNanos(duration);
+    while (outstandingCalls.get() > 0 && System.nanoTime() < end) {
+      Thread.sleep(100);
+    }
+    if (outstandingCalls.get() > 0) {
+      throw new CancellationException("Timed out after " + duration + " " + unit +
+          " with " + outstandingCalls.get() + " outstanding calls");
+    }
+  }
+
+  @Override
+  public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+      ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
+    Set<ArrowMessageMarshaller> arrowMessageMarshallers = new HashSet<>();
+    if (call.getMethodDescriptor().getRequestMarshaller() instanceof ArrowMessageMarshaller) {
+      arrowMessageMarshallers.add((ArrowMessageMarshaller) call.getMethodDescriptor().getRequestMarshaller());
+    }
+    if (call.getMethodDescriptor().getResponseMarshaller() instanceof ArrowMessageMarshaller) {
+      arrowMessageMarshallers.add((ArrowMessageMarshaller) call.getMethodDescriptor().getResponseMarshaller());
+    }
+    if (arrowMessageMarshallers.isEmpty()) {
+      // Bypass our logic
+      return next.startCall(call, headers);
+    }
+    return new AllocatorClosingServerCallListener<>(next.startCall(call, headers), call, arrowMessageMarshallers);
+  }
+
+  private class AllocatorClosingServerCallListener<ReqT, RespT>
+      extends ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT> {
+    private final ServerCall<ReqT, RespT> call;
+    private final Set<ArrowMessageMarshaller> allocators;
+
+    public AllocatorClosingServerCallListener(ServerCall.Listener<ReqT> delegate,
+                                              ServerCall<ReqT, RespT> call, Set<ArrowMessageMarshaller> allocators) {
+      super(delegate);
+      AllocatorClosingServerInterceptor.this.outstandingCalls.getAndIncrement();
+      this.call = call;
+      this.allocators = allocators;
+    }
+
+    private void cleanup(Runnable next) {
+      Throwable t = null;
+      try {
+        allocators.forEach(marshaller -> callback.accept(call, marshaller.getAllocator()));
+      } catch (RuntimeException e) {
+        t = e;
+      }
+      try {
+        if (t != null) {
+          AutoCloseables.close(t, allocators);
+        } else {
+          AutoCloseables.close(allocators);
+        }
+      } catch (Exception e) {
+        LOGGER.warn("Error closing per-call allocators", e);
+      } finally {
+        outstandingCalls.decrementAndGet();
+        next.run();
+      }
+    }
+
+    @Override
+    public void onCancel() {
+      cleanup(super::onCancel);
+    }
+
+    @Override
+    public void onComplete() {
+      cleanup(super::onComplete);
+    }
+  }
+}
diff --git a/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/ArrowMessageMarshaller.java b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/ArrowMessageMarshaller.java
new file mode 100644
index 0000000..a98689f
--- /dev/null
+++ b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/ArrowMessageMarshaller.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.InputStream;
+
+import org.apache.arrow.memory.BufferAllocator;
+
+import io.grpc.MethodDescriptor;
+
+/**
+ * A marshaller for an {@link ArrowMessage} that owns an associated {@link BufferAllocator}.
+ *
+ * Intended to be used with {@link AllocatorClosingServerInterceptor}. While this class is mostly redundant with the
+ * marshaller class defined in {@link ArrowMessage}, the additional type provides additional runtime safety if the
+ * server is not configured properly (e.g. the interceptor is installed without the handler registry).
+ */
+final class ArrowMessageMarshaller implements AutoCloseable, MethodDescriptor.Marshaller<ArrowMessage> {
+  private final MethodDescriptor.Marshaller<ArrowMessage> delegate;
+  private final BufferAllocator allocator;
+
+  ArrowMessageMarshaller(BufferAllocator allocator) {
+    this.delegate = ArrowMessage.createMarshaller(allocator);
+    this.allocator = allocator;
+  }
+
+  /**
+   * Get the internal allocator.
+   *
+   * @see AllocatorClosingServerInterceptor
+   */
+  BufferAllocator getAllocator() {
+    return allocator;
+  }
+
+  @Override
+  public InputStream stream(ArrowMessage value) {
+    return delegate.stream(value);
+  }
+
+  @Override
+  public ArrowMessage parse(InputStream stream) {
+    return delegate.parse(stream);
+  }
+
+  @Override
+  public void close() {
+    allocator.close();
+  }
+}
diff --git a/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java
index eb5e492..019ba2a 100644
--- a/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java
+++ b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java
@@ -29,6 +29,8 @@ import io.grpc.BindableService;
 import io.grpc.CallOptions;
 import io.grpc.ClientCall;
 import io.grpc.ConnectivityState;
+import io.grpc.Context;
+import io.grpc.HandlerRegistry;
 import io.grpc.ManagedChannel;
 import io.grpc.MethodDescriptor;
 
@@ -125,6 +127,13 @@ public class FlightGrpcUtils {
     }
   }
 
+  /**
+   * A gRPC Context key that gives you access to a per-call allocator.
+   *
+   * @see #createHandlerRegistry(BufferAllocator, BindableService)
+   */
+  public static final Context.Key<BufferAllocator> PER_CALL_ALLOCATOR = FlightService.PER_CALL_ALLOCATOR;
+
   private FlightGrpcUtils() {}
 
   /**
@@ -158,4 +167,23 @@ public class FlightGrpcUtils {
       BufferAllocator incomingAllocator, ManagedChannel channel) {
     return new FlightClient(incomingAllocator, new NonClosingProxyManagedChannel(channel), Collections.emptyList());
   }
+
+  /**
+   * Create a gRPC handler registry from a Flight service.
+   *
+   * <p>This handler registry will intercept DoGet/DoPut/DoExchange calls such that a fresh child allocator is created
+   * for each call, allowing finer-grained memory usage tracking and more protection against leaks.
+   *
+   * <p>The per-call allocator can be accessed via {@link #PER_CALL_ALLOCATOR}.
+   *
+   * <p>This must be used with {@link AllocatorClosingServerInterceptor} to close the created child allocator after
+   * each call.
+   *
+   * @param allocator The allocator used to create child allocators.
+   * @param flightService The org.apache.arrow.flight.impl.Flight gRPC service.
+   * @return A gRPC HandlerRegistry that can be passed to a server builder.
+   */
+  public static HandlerRegistry createHandlerRegistry(BufferAllocator allocator, BindableService flightService) {
+    return new FlightHandlerRegistry(allocator, flightService.bindService());
+  }
 }
diff --git a/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightHandlerRegistry.java b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightHandlerRegistry.java
new file mode 100644
index 0000000..1b34548
--- /dev/null
+++ b/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightHandlerRegistry.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.memory.BufferAllocator;
+
+import io.grpc.Context;
+import io.grpc.HandlerRegistry;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerMethodDefinition;
+import io.grpc.ServerServiceDefinition;
+import io.grpc.Status;
+
+/**
+ * A gRPC HandlerRegistry that creates a new Arrow allocator for each call to a Flight "data" method
+ * (DoGet, DoPut, or DoExchange).
+ */
+public final class FlightHandlerRegistry extends HandlerRegistry {
+  private final BufferAllocator allocator;
+  private final ServerServiceDefinition delegate;
+  private final ServerMethodDefinition<Flight.Ticket, ArrowMessage> doGetMethod;
+  private final ServerMethodDefinition<ArrowMessage, Flight.PutResult> doPutMethod;
+  private final ServerMethodDefinition<ArrowMessage, ArrowMessage> doExchangeMethod;
+  private final AtomicInteger counter;
+
+  @SuppressWarnings("unchecked")
+  FlightHandlerRegistry(BufferAllocator allocator, ServerServiceDefinition delegate) {
+    this.allocator = allocator;
+    this.delegate = delegate;
+    // Unchecked cast
+    this.doGetMethod = (ServerMethodDefinition<Flight.Ticket, ArrowMessage>)
+        Objects.requireNonNull(delegate.getMethod(FlightBindingService.DO_GET));
+    this.doPutMethod = (ServerMethodDefinition<ArrowMessage, Flight.PutResult>)
+        Objects.requireNonNull(delegate.getMethod(FlightBindingService.DO_PUT));
+    this.doExchangeMethod = (ServerMethodDefinition<ArrowMessage, ArrowMessage>)
+        Objects.requireNonNull(delegate.getMethod(FlightBindingService.DO_EXCHANGE));
+    this.counter = new AtomicInteger(0);
+  }
+
+  @Override
+  public ServerMethodDefinition<?, ?> lookupMethod(String methodName, String authority) {
+    if (FlightBindingService.DO_GET.equals(methodName)) {
+      final BufferAllocator childAllocator = newChildAllocator("DoGet");
+      final MethodDescriptor.Marshaller<ArrowMessage> marshaller = new ArrowMessageMarshaller(childAllocator);
+      final MethodDescriptor<Flight.Ticket, ArrowMessage> method =
+          doGetMethod.getMethodDescriptor().toBuilder().setResponseMarshaller(marshaller).build();
+      final ServerCallHandler<Flight.Ticket, ArrowMessage> handler =
+          new AllocatorInjectingServerCallHandler<>(doGetMethod.getServerCallHandler(), childAllocator);
+      return ServerMethodDefinition.create(method, handler);
+    } else if (FlightBindingService.DO_PUT.equals(methodName)) {
+      final BufferAllocator childAllocator = newChildAllocator("DoPut");
+      final MethodDescriptor.Marshaller<ArrowMessage> marshaller = new ArrowMessageMarshaller(childAllocator);
+      final MethodDescriptor<ArrowMessage, Flight.PutResult> method =
+          doPutMethod.getMethodDescriptor().toBuilder().setRequestMarshaller(marshaller).build();
+      final ServerCallHandler<ArrowMessage, Flight.PutResult> handler =
+          new AllocatorInjectingServerCallHandler<>(doPutMethod.getServerCallHandler(), childAllocator);
+      return ServerMethodDefinition.create(method, handler);
+    } else if (FlightBindingService.DO_EXCHANGE.equals(methodName)) {
+      final BufferAllocator childAllocator = newChildAllocator("DoExchange");
+      final MethodDescriptor.Marshaller<ArrowMessage> marshaller = new ArrowMessageMarshaller(childAllocator);
+      final MethodDescriptor<ArrowMessage, ArrowMessage> method = doExchangeMethod.getMethodDescriptor()
+          .toBuilder()
+          .setRequestMarshaller(marshaller)
+          .setResponseMarshaller(marshaller)
+          .build();
+      final ServerCallHandler<ArrowMessage, ArrowMessage> handler =
+          new AllocatorInjectingServerCallHandler<>(doExchangeMethod.getServerCallHandler(), childAllocator);
+      return ServerMethodDefinition.create(method, handler);
+    }
+    return delegate.getMethod(methodName);
+  }
+
+  /**
+   * Create a new child allocator for a call to the given method.
+   *
+   * @param methodName The Flight method being called.
+   */
+  private BufferAllocator newChildAllocator(final String methodName) {
+    final String allocatorName = allocator.getName() + "-" + methodName + "-" + counter.getAndIncrement();
+    return allocator.newChildAllocator(allocatorName, 0, allocator.getLimit());
+  }
+
+  /**
+   * A ServerCallHandler that injects the Arrow allocator into the gRPC context.
+   *
+   * @param <ReqT> The request type.
+   * @param <RespT> The response type.
+   */
+  static final class AllocatorInjectingServerCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
+    private final ServerCallHandler<ReqT, RespT> delegate;
+    private final BufferAllocator allocator;
+
+    AllocatorInjectingServerCallHandler(ServerCallHandler<ReqT, RespT> delegate, BufferAllocator allocator) {
+      this.delegate = delegate;
+      this.allocator = allocator;
+    }
+
+    @Override
+    public ServerCall.Listener<ReqT> startCall(ServerCall<ReqT, RespT> call, Metadata headers) {
+      final ServerCall.Listener<ReqT> delegateListener;
+      try {
+        delegateListener = Context.current()
+            .withValue(FlightGrpcUtils.PER_CALL_ALLOCATOR, allocator)
+            .call(() -> delegate.startCall(call, headers));
+      } catch (Exception e) {
+        allocator.close();
+        call.close(Status.INTERNAL.withCause(e).withDescription("Internal error: " + e), new Metadata());
+        return new ServerCall.Listener<ReqT>() {};
+      }
+      return new AllocatorInjectingServerCallListener<>(delegateListener, allocator);
+    }
+  }
+
+  /**
+   * A ServerCallListener that injects the Arrow allocator into the gRPC context.
+   * @param <ReqT> The request type.
+   */
+  static final class AllocatorInjectingServerCallListener<ReqT> extends ServerCall.Listener<ReqT> {
+    private final ServerCall.Listener<ReqT> delegate;
+    private final BufferAllocator allocator;
+
+    AllocatorInjectingServerCallListener(ServerCall.Listener<ReqT> delegate, BufferAllocator allocator) {
+      this.delegate = delegate;
+      this.allocator = allocator;
+    }
+
+    @Override
+    public void onMessage(ReqT message) {
+      Context.current().withValue(FlightGrpcUtils.PER_CALL_ALLOCATOR, allocator).run(() -> delegate.onMessage(message));
+    }
+
+    @Override
+    public void onHalfClose() {
+      Context.current().withValue(FlightGrpcUtils.PER_CALL_ALLOCATOR, allocator).run(delegate::onHalfClose);
+    }
+
+    @Override
+    public void onCancel() {
+      Context.current().withValue(FlightGrpcUtils.PER_CALL_ALLOCATOR, allocator).run(delegate::onCancel);
+    }
+
+    @Override
+    public void onComplete() {
+      Context.current().withValue(FlightGrpcUtils.PER_CALL_ALLOCATOR, allocator).run(delegate::onComplete);
+    }
+
+    @Override
+    public void onReady() {
+      Context.current().withValue(FlightGrpcUtils.PER_CALL_ALLOCATOR, allocator).run(delegate::onReady);
+    }
+  }
+}
diff --git a/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestPerCallAllocator.java b/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestPerCallAllocator.java
new file mode 100644
index 0000000..1a3790a
--- /dev/null
+++ b/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestPerCallAllocator.java
@@ -0,0 +1,522 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.flight.impl.FlightServiceGrpc;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+import com.google.common.collect.Iterables;
+
+import io.grpc.BindableService;
+import io.grpc.HandlerRegistry;
+import io.grpc.MethodDescriptor;
+import io.grpc.Server;
+import io.grpc.ServerBuilder;
+import io.grpc.ServerCall;
+import io.grpc.ServerMethodDefinition;
+
+public class TestPerCallAllocator {
+  @Test
+  public void testDoGet() throws Exception {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        try (final FlightStream stream = client.getStream(new Ticket(Producer.TICKET_EMPTY))) {
+          assertEquals(Producer.SCHEMA_DOUBLES, stream.getRoot().getSchema());
+          assertTrue(stream.next());
+          assertFalse(stream.next());
+        }
+      }, (call, serverAllocator) -> {
+          // gRPC server shutdown does not actually wait for all onCompleted/onCancel callbacks.
+          // Test that we handle this properly by forcing those callbacks to take a while.
+          // This does increase test runtime by the given duration here.
+          try {
+            Thread.sleep(1500);
+          } catch (InterruptedException e) {
+            e.printStackTrace();
+          }
+        });
+    }
+  }
+
+  @Test
+  public void testDoGetError() throws Exception {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        try (final FlightStream stream = client.getStream(new Ticket(Producer.TICKET_ERROR))) {
+          final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream::getRoot);
+          assertEquals("expected", err.status().description());
+        }
+      }, (call, serverAllocator) -> { });
+    }
+  }
+
+  @Test
+  public void testDoGetInfinite() throws Exception {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        try (final FlightStream stream = client.getStream(new Ticket(Producer.TICKET_INFINITE))) {
+          assertEquals(Producer.SCHEMA_DOUBLES, stream.getRoot().getSchema());
+          assertTrue(stream.next());
+          assertTrue(stream.next());
+          assertTrue(stream.next());
+          stream.cancel("", null);
+        }
+      }, (call, serverAllocator) -> {
+          // gRPC server shutdown does not actually wait for all onCompleted/onCancel callbacks.
+          // Test that we handle this properly by forcing those callbacks to take a while.
+          try {
+            Thread.sleep(2500);
+          } catch (InterruptedException e) {
+            e.printStackTrace();
+          }
+        });
+    }
+  }
+
+  @Test
+  public void testDoPut() throws Exception {
+    int batches = 3;
+    int rowsPerBatch = 512;
+    byte[] command = "dataset-doubles".getBytes(StandardCharsets.US_ASCII);
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        FlightDescriptor descriptor = FlightDescriptor.command(command);
+        SyncPutListener reader = new SyncPutListener();
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA_DOUBLES, allocator)) {
+          FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, reader);
+          double counter = 0.0;
+          Float8Vector vector = (Float8Vector) root.getVector(0);
+          for (int batch = 0; batch < batches; batch++) {
+            for (int row = 0; row < rowsPerBatch; row++) {
+              vector.setSafe(row, counter);
+              counter++;
+            }
+            root.setRowCount(rowsPerBatch);
+            writer.putNext();
+          }
+          writer.completed();
+          reader.getResult();
+        }
+
+        try (final FlightStream fs = client.getStream(new Ticket(command))) {
+          double counter = 0.0;
+          assertEquals(Producer.SCHEMA_DOUBLES, fs.getSchema());
+          Float8Vector vector = (Float8Vector) fs.getRoot().getVector(0);
+          for (int batch = 0; batch < batches; batch++) {
+            assertTrue(fs.next());
+            assertEquals(rowsPerBatch, fs.getRoot().getRowCount());
+            for (int row = 0; row < rowsPerBatch; row++) {
+              assertFalse(vector.isNull(row));
+              assertEquals(counter, vector.get(row), /* delta */0.1);
+              counter++;
+            }
+          }
+          assertFalse(fs.next());
+        }
+      }, (call, serverAllocator) -> {
+          // gRPC server shutdown does not actually wait for all onCompleted/onCancel callbacks.
+          // Test that we handle this properly by forcing those callbacks to take a while.
+          try {
+            Thread.sleep(2500);
+          } catch (InterruptedException e) {
+            e.printStackTrace();
+          }
+        });
+    }
+  }
+
+  /** Test that the server cleans up Arrow data even if the producer doesn't drain the stream. */
+  @Test
+  public void testDoPutIgnore() throws Exception {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        SyncPutListener reader = new SyncPutListener();
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA_DOUBLES, allocator)) {
+          FlightClient.ClientStreamListener writer =
+              client.startPut(FlightDescriptor.command(Producer.TICKET_IGNORE), root, reader);
+          root.setRowCount(1024);
+          writer.putNext();
+          writer.putNext();
+          writer.putNext();
+          writer.completed();
+          reader.getResult();
+        }
+
+        reader = new SyncPutListener();
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA_DOUBLES, allocator)) {
+          FlightClient.ClientStreamListener writer =
+              client.startPut(FlightDescriptor.command(Producer.TICKET_CANCEL), root, reader);
+          root.setRowCount(1024);
+          writer.putNext();
+          writer.putNext();
+          writer.putNext();
+          writer.completed();
+          FlightRuntimeException err = assertThrows(FlightRuntimeException.class, reader::getResult);
+          assertEquals(err.toString(), FlightStatusCode.CANCELLED, err.status().code());
+        }
+      }, (call, serverAllocator) -> { });
+    }
+  }
+
+
+  /** Test that the server cleans up Arrow data even if the producer doesn't drain the stream. */
+  @Test
+  public void testDoExchangeIgnore() throws Exception {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        FlightDescriptor descriptor = FlightDescriptor.command(Producer.TICKET_IGNORE);
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA_DOUBLES, allocator);
+             FlightClient.ExchangeReaderWriter stream = client.doExchange(descriptor)) {
+          stream.getWriter().start(root);
+          root.setRowCount(1024);
+          stream.getWriter().putNext();
+          stream.getWriter().putNext();
+          stream.getWriter().putNext();
+          stream.getWriter().completed();
+          while (stream.getReader().next()) {
+            // Drain the stream
+          }
+        }
+      }, (call, serverAllocator) -> { });
+    }
+  }
+
+  @Test
+  public void testDoExchangeInfinite() throws Exception {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        try (final FlightClient.ExchangeReaderWriter stream =
+                 client.doExchange(FlightDescriptor.command(Producer.TICKET_INFINITE))) {
+          assertEquals(Producer.SCHEMA_DOUBLES, stream.getReader().getSchema());
+          assertTrue(stream.getReader().next());
+          assertTrue(stream.getReader().next());
+          assertTrue(stream.getReader().next());
+          stream.getReader().cancel("", null);
+        }
+      }, (call, serverAllocator) -> {
+          // gRPC server shutdown does not actually wait for all onCompleted/onCancel callbacks.
+          // Test that we handle this properly by forcing those callbacks to take a while.
+          try {
+            Thread.sleep(2500);
+          } catch (InterruptedException e) {
+            e.printStackTrace();
+          }
+        });
+    }
+  }
+
+  @Test
+  public void testDoExchangeError() throws Exception {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      test(allocator, (client) -> {
+        try (final FlightClient.ExchangeReaderWriter stream =
+                 client.doExchange(FlightDescriptor.command(Producer.TICKET_ERROR))) {
+          final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream.getReader()::next);
+          assertEquals("expected", err.status().description());
+        }
+
+        try (final FlightClient.ExchangeReaderWriter stream =
+                 client.doExchange(FlightDescriptor.command(Producer.TICKET_CANCEL));
+             VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA_DOUBLES, allocator)) {
+          stream.getWriter().start(root);
+          stream.getWriter().putNext();
+          final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream.getReader()::next);
+          assertEquals(err.toString(), "expected", err.status().description());
+        }
+      }, (call, serverAllocator) -> { });
+    }
+  }
+
+  /** Ensure the custom registry properly hooks up request/response marshallers. */
+  @Test
+  public void testRegistry() {
+    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+      final Producer producer = new Producer(allocator);
+      final ExecutorService executor = Executors.newCachedThreadPool();
+      final BindableService service =
+          FlightGrpcUtils.createFlightService(allocator, producer, ServerAuthHandler.NO_OP, executor);
+      final HandlerRegistry registry = FlightGrpcUtils.createHandlerRegistry(allocator, service);
+
+      checkMethodDefinition(
+           registry.lookupMethod(FlightServiceGrpc.getDoGetMethod().getFullMethodName()),
+           /* requestMarshallerOwnsAllocator */ false,
+           /* responseMarshallerOwnsAllocator */ true);
+      checkMethodDefinition(
+          registry.lookupMethod(FlightServiceGrpc.getDoPutMethod().getFullMethodName()),
+          /* requestMarshallerOwnsAllocator */ true,
+          /* responseMarshallerOwnsAllocator */ false);
+      checkMethodDefinition(
+          registry.lookupMethod(FlightServiceGrpc.getDoExchangeMethod().getFullMethodName()),
+          /* requestMarshallerOwnsAllocator */ true,
+          /* responseMarshallerOwnsAllocator */ true);
+      assertNotNull(registry.lookupMethod(FlightServiceGrpc.getDoActionMethod().getFullMethodName()));
+      assertNull(registry.lookupMethod("/unknown.Service/Unknown"));
+    }
+  }
+
+  /**
+   * Check that the given method definition has the right server call handler/marshallers defined.
+   * @param methodDefinition The method definition to check.
+   * @param requestMarshallerOwnsAllocator If true, ensure the request marshaller is an ArrowMessageMarshaller.
+   * @param responseMarshallerOwnsAllocator If true, ensure the response marshaller is an ArrowMessageMarshaller.
+   */
+  private void checkMethodDefinition(
+      ServerMethodDefinition<?, ?> methodDefinition,
+      boolean requestMarshallerOwnsAllocator,
+      boolean responseMarshallerOwnsAllocator) {
+    assertNotNull(methodDefinition);
+    assertTrue(
+        methodDefinition.getServerCallHandler() instanceof FlightHandlerRegistry.AllocatorInjectingServerCallHandler);
+    MethodDescriptor.Marshaller<?> requestMarshaller = methodDefinition.getMethodDescriptor().getRequestMarshaller();
+    MethodDescriptor.Marshaller<?> responseMarshaller = methodDefinition.getMethodDescriptor().getResponseMarshaller();
+    if (requestMarshallerOwnsAllocator) {
+      assertTrue(requestMarshaller instanceof ArrowMessageMarshaller);
+      ((ArrowMessageMarshaller) requestMarshaller).close();
+    } else {
+      assertFalse(requestMarshaller instanceof ArrowMessageMarshaller);
+    }
+    if (responseMarshallerOwnsAllocator) {
+      assertTrue(responseMarshaller instanceof ArrowMessageMarshaller);
+      ((ArrowMessageMarshaller) responseMarshaller).close();
+    } else {
+      assertFalse(responseMarshaller instanceof ArrowMessageMarshaller);
+    }
+  }
+
+  private void test(
+      BufferAllocator allocator,
+      CheckedConsumer<FlightClient> test,
+      BiConsumer<ServerCall<?, ?>, BufferAllocator> callback)
+      throws Exception {
+    final Producer producer = new Producer(allocator.newChildAllocator("Producer", 0, allocator.getLimit()));
+    final ExecutorService executor = Executors.newCachedThreadPool();
+    final BindableService service =
+        FlightGrpcUtils.createFlightService(allocator, producer, ServerAuthHandler.NO_OP, executor);
+    AllocatorClosingServerInterceptor interceptor = new AllocatorClosingServerInterceptor(callback);
+    Server server = ServerBuilder.forPort(0)
+        .fallbackHandlerRegistry(FlightGrpcUtils.createHandlerRegistry(allocator, service))
+        .intercept(interceptor)
+        .build();
+    server.start();
+    Location location = Location.forGrpcInsecure("localhost", server.getPort());
+    try (FlightClient client = FlightClient.builder(allocator, location).build()) {
+      test.accept(client);
+    } finally {
+      server.shutdown();
+      server.awaitTermination();
+      producer.close();
+      executor.shutdown();
+      executor.awaitTermination(30, TimeUnit.SECONDS);
+      interceptor.awaitTermination(30, TimeUnit.SECONDS);
+    }
+  }
+
+  interface CheckedConsumer<T> {
+    void accept(T t) throws Exception;
+  }
+
+  static final class Producer extends NoOpFlightProducer implements AutoCloseable {
+    static final Schema SCHEMA_DOUBLES =
+        new Schema(Collections.singletonList(
+            Field.nullable("doubles", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))));
+    static final byte[] TICKET_ERROR = "error".getBytes(StandardCharsets.US_ASCII);
+    static final byte[] TICKET_CANCEL = "cancel".getBytes(StandardCharsets.US_ASCII);
+    static final byte[] TICKET_EMPTY = "empty".getBytes(StandardCharsets.US_ASCII);
+    static final byte[] TICKET_INFINITE = "infinite".getBytes(StandardCharsets.US_ASCII);
+    static final byte[] TICKET_IGNORE = "ignore".getBytes(StandardCharsets.US_ASCII);
+
+    final BufferAllocator sharedAllocator;
+    final Map<String, Dataset> datasets;
+
+    Producer(BufferAllocator allocator) {
+      this.sharedAllocator = allocator;
+      this.datasets = new HashMap<>();
+    }
+
+    @Override
+    public void close() throws Exception {
+      AutoCloseables.close(Iterables.concat(datasets.values(), Collections.singleton(sharedAllocator)));
+    }
+
+    BufferAllocator assertHasAllocator(Consumer<Throwable> onError) {
+      final BufferAllocator allocator = FlightGrpcUtils.PER_CALL_ALLOCATOR.get();
+      if (allocator == null) {
+        onError.accept(CallStatus.INTERNAL.withDescription("Per call allocator does not exist").toRuntimeException());
+        return null;
+      }
+      return allocator;
+    }
+
+    @Override
+    public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+      BufferAllocator allocator = assertHasAllocator(listener::error);
+      if (allocator == null) {
+        return;
+      }
+      if (Arrays.equals(ticket.getBytes(), TICKET_ERROR)) {
+        listener.error(CallStatus.INTERNAL.withDescription("expected").toRuntimeException());
+      } else if (Arrays.equals(ticket.getBytes(), TICKET_INFINITE)) {
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA_DOUBLES, allocator)) {
+          listener.start(root);
+          while (!listener.isCancelled()) {
+            listener.putNext();
+            Thread.sleep(500);
+          }
+        } catch (InterruptedException e) {
+          Thread.currentThread().interrupt();
+        } finally {
+          listener.completed();
+        }
+      } else if (Arrays.equals(ticket.getBytes(), TICKET_EMPTY)) {
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA_DOUBLES, allocator)) {
+          listener.start(root);
+          listener.putNext();
+        } finally {
+          // Must call completed() after closing the root, or else we'll tell gRPC the call finished
+          // (causing the allocator to be closed) with outstanding allocations
+          listener.completed();
+        }
+      } else {
+        String key = new String(ticket.getBytes());
+        Dataset dataset = datasets.get(key);
+        if (dataset == null) {
+          listener.error(CallStatus.NOT_FOUND.withDescription(key).toRuntimeException());
+          return;
+        }
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(dataset.schema, allocator)) {
+          VectorLoader loader = new VectorLoader(root);
+          listener.start(root);
+          for (ArrowRecordBatch batch : dataset.batches) {
+            loader.load(batch);
+            listener.putNext();
+          }
+          listener.completed();
+        }
+      }
+    }
+
+    @Override
+    public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
+      return () -> {
+        BufferAllocator allocator = assertHasAllocator(ackStream::onError);
+        if (allocator == null) {
+          return;
+        }
+        if (Arrays.equals(flightStream.getDescriptor().getCommand(), TICKET_IGNORE)) {
+          // Don't drain the stream, but make sure it actually contains data
+          flightStream.getRoot();
+          ackStream.onCompleted();
+        } else if (Arrays.equals(flightStream.getDescriptor().getCommand(), TICKET_CANCEL)) {
+          // Don't drain the stream, but make sure it actually contains data
+          flightStream.getRoot();
+          ackStream.onError(CallStatus.CANCELLED.withDescription("expected").toRuntimeException());
+        } else {
+          String key = new String(flightStream.getDescriptor().getCommand());
+          final VectorUnloader unloader = new VectorUnloader(flightStream.getRoot());
+          final List<ArrowRecordBatch> batches = new ArrayList<>();
+          while (flightStream.next()) {
+            try (ArrowRecordBatch arb = unloader.getRecordBatch()) {
+              batches.add(arb.cloneWithTransfer(sharedAllocator));
+            }
+          }
+          datasets.put(key, new Dataset(flightStream.getRoot().getSchema(), batches));
+          ackStream.onCompleted();
+        }
+      };
+    }
+
+    @Override
+    public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
+      BufferAllocator allocator = assertHasAllocator(writer::error);
+      if (allocator == null) {
+        return;
+      }
+
+      byte[] command = reader.getDescriptor().getCommand();
+      if (Arrays.equals(command, TICKET_IGNORE)) {
+        writer.completed();
+      } else if (Arrays.equals(command, TICKET_ERROR)) {
+        writer.error(CallStatus.INTERNAL.withDescription("expected").toRuntimeException());
+      } else if (Arrays.equals(reader.getDescriptor().getCommand(), TICKET_CANCEL)) {
+        // Don't drain the stream, but make sure it actually contains data
+        reader.getRoot();
+        writer.error(CallStatus.CANCELLED.withDescription("expected").toRuntimeException());
+      } else if (Arrays.equals(command, TICKET_INFINITE)) {
+        try (final VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA_DOUBLES, allocator)) {
+          writer.start(root);
+          while (!writer.isCancelled()) {
+            writer.putNext();
+            Thread.sleep(500);
+          }
+        } catch (InterruptedException e) {
+          Thread.currentThread().interrupt();
+        } finally {
+          writer.completed();
+        }
+      }
+    }
+  }
+
+  static final class Dataset implements AutoCloseable {
+    final Schema schema;
+    final List<ArrowRecordBatch> batches;
+
+    Dataset(Schema schema, List<ArrowRecordBatch> batches) {
+      this.schema = schema;
+      this.batches = batches;
+    }
+
+    @Override
+    public void close() throws Exception {
+      AutoCloseables.close(batches);
+    }
+  }
+}