You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ch...@apache.org on 2019/03/30 01:31:14 UTC

[beam] branch master updated: Read from BQ queries with the storage API.

This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2399c72  Read from BQ queries with the storage API.
     new ce39f17  Merge pull request #8061: [BEAM-6841] Add support for reading query results using the BigQuery storage API.
2399c72 is described below

commit 2399c72041a74cbf30a565bceeb4ecdebd3fd956
Author: Kenneth Jung <km...@google.com>
AuthorDate: Wed Mar 13 12:37:18 2019 -0700

    Read from BQ queries with the storage API.
    
    This change creates a new BoundedSource object capable of executing a
    query in BigQuery and reading the query results using the BigQuery
    storage API (e.g. with Method#DIRECT_READ) rather than export to GCS.
---
 runners/google-cloud-dataflow-java/build.gradle    |   1 +
 sdks/java/io/google-cloud-platform/build.gradle    |   1 +
 .../beam/sdk/io/gcp/bigquery/BigQueryIO.java       | 223 ++++++-
 .../sdk/io/gcp/bigquery/BigQueryQueryHelper.java   | 171 +++++
 .../sdk/io/gcp/bigquery/BigQueryQuerySource.java   | 159 +----
 .../gcp/bigquery/BigQueryStorageQuerySource.java   | 138 ++++
 .../io/gcp/bigquery/BigQueryStorageSourceBase.java | 133 ++++
 .../gcp/bigquery/BigQueryStorageTableSource.java   | 173 ++---
 .../sdk/io/gcp/bigquery/BigQueryIOReadTest.java    | 162 ++---
 .../io/gcp/bigquery/BigQueryIOStorageQueryIT.java  | 102 +++
 .../gcp/bigquery/BigQueryIOStorageQueryTest.java   | 712 +++++++++++++++++++++
 .../io/gcp/bigquery/BigQueryIOStorageReadTest.java |   2 +-
 .../sdk/io/gcp/bigquery/FakeBigQueryServices.java  |  39 +-
 .../beam/sdk/io/gcp/bigquery/FakeJobService.java   |   7 +-
 14 files changed, 1641 insertions(+), 382 deletions(-)

diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle
index 596ad7b..a408ed9 100644
--- a/runners/google-cloud-dataflow-java/build.gradle
+++ b/runners/google-cloud-dataflow-java/build.gradle
@@ -368,6 +368,7 @@ task googleCloudPlatformFnApiWorkerIntegrationTest(type: Test) {
 
   include '**/*IT.class'
   exclude '**/BigQueryIOReadIT.class'
+  exclude '**/BigQueryIOStorageQueryIT.class'
   exclude '**/BigQueryIOStorageReadIT.class'
   exclude '**/BigQueryIOStorageReadTableRowIT.class'
   exclude '**/PubsubReadIT.class'
diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle
index 7961c4cd..d544f6f 100644
--- a/sdks/java/io/google-cloud-platform/build.gradle
+++ b/sdks/java/io/google-cloud-platform/build.gradle
@@ -100,6 +100,7 @@ task integrationTest(type: Test) {
 
   include '**/*IT.class'
   exclude '**/BigQueryIOReadIT.class'
+  exclude '**/BigQueryIOStorageQueryIT.class'
   exclude '**/BigQueryIOStorageReadIT.class'
   exclude '**/BigQueryIOStorageReadTableRowIT.class'
   exclude '**/BigQueryToTableIT.class'
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
index 172bb62..98b114f 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.io.gcp.bigquery;
 
 import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createJobIdToken;
+import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createTempTableReference;
 import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.getExtractJobId;
 import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.resolveTempLocation;
 import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
@@ -28,12 +29,16 @@ import com.google.api.services.bigquery.model.Job;
 import com.google.api.services.bigquery.model.JobConfigurationQuery;
 import com.google.api.services.bigquery.model.JobReference;
 import com.google.api.services.bigquery.model.JobStatistics;
+import com.google.api.services.bigquery.model.Table;
 import com.google.api.services.bigquery.model.TableReference;
 import com.google.api.services.bigquery.model.TableRow;
 import com.google.api.services.bigquery.model.TableSchema;
 import com.google.api.services.bigquery.model.TimePartitioning;
 import com.google.auto.value.AutoValue;
 import com.google.cloud.bigquery.storage.v1beta1.ReadOptions.TableReadOptions;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.CreateReadSessionRequest;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadSession;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.Stream;
 import java.io.IOException;
 import java.util.List;
 import java.util.Map;
@@ -50,8 +55,10 @@ import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
+import org.apache.beam.sdk.extensions.protobuf.ProtoCoder;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.io.fs.MoveOptions;
 import org.apache.beam.sdk.io.fs.ResolveOptions;
 import org.apache.beam.sdk.io.fs.ResourceId;
@@ -63,11 +70,14 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.TimePartitioningToJso
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySourceBase.ExtractResult;
 import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinationsHelpers.ConstantSchemaDestinations;
 import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinationsHelpers.ConstantTimePartitioningDestinations;
 import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinationsHelpers.SchemaFromViewDestinations;
 import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinationsHelpers.TableFunctionDestinations;
+import org.apache.beam.sdk.io.gcp.bigquery.PassThroughThenCleanup.CleanupOperation;
+import org.apache.beam.sdk.io.gcp.bigquery.PassThroughThenCleanup.ContextContainer;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.ValueProvider;
 import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
@@ -686,6 +696,21 @@ public class BigQueryIO {
       return source;
     }
 
+    private BigQueryStorageQuerySource<T> createStorageQuerySource(
+        String stepUuid, Coder<T> outputCoder) {
+      return BigQueryStorageQuerySource.create(
+          stepUuid,
+          getQuery(),
+          getFlattenResults(),
+          getUseLegacySql(),
+          MoreObjects.firstNonNull(getQueryPriority(), QueryPriority.BATCH),
+          getQueryLocation(),
+          getKmsKey(),
+          getParseFn(),
+          outputCoder,
+          getBigQueryServices());
+    }
+
     private static final String QUERY_VALIDATION_FAILURE_ERROR =
         "Validation of query \"%1$s\" failed. If the query depends on an earlier stage of the"
             + " pipeline, This validation can be disabled using #withoutValidation.";
@@ -730,8 +755,6 @@ public class BigQueryIO {
           BigQueryHelpers.verifyTablePresence(datasetService, table.get());
         } else if (getQuery() != null) {
           checkArgument(
-              getMethod() != Method.DIRECT_READ, "Cannot read query results with DIRECT_READ");
-          checkArgument(
               getQuery().isAccessible(), "Cannot call validate if query is dynamically set.");
           JobService jobService = getBigQueryServices().getJobService(bqOptions);
           try {
@@ -774,8 +797,6 @@ public class BigQueryIO {
               BigQueryOptions.class.getSimpleName());
         }
       } else {
-        checkArgument(
-            getMethod() != Method.DIRECT_READ, "Method must not be DIRECT_READ if query is set");
         checkArgument(getQuery() != null, "Either from() or fromQuery() is required");
         checkArgument(
             getFlattenResults() != null, "flattenResults should not be null if query is set");
@@ -787,16 +808,14 @@ public class BigQueryIO {
       final Coder<T> coder = inferCoder(p.getCoderRegistry());
 
       if (getMethod() == Method.DIRECT_READ) {
-        return p.apply(
-            org.apache.beam.sdk.io.Read.from(
-                BigQueryStorageTableSource.create(
-                    getTableProvider(),
-                    getReadOptions(),
-                    getParseFn(),
-                    coder,
-                    getBigQueryServices())));
+        return expandForDirectRead(input, coder);
       }
 
+      checkArgument(
+          getReadOptions() == null,
+          "Invalid BigQueryIO.Read: Specifies table read options, "
+              + "which only applies when using Method.DIRECT_READ");
+
       final PCollectionView<String> jobIdTokenView;
       PCollection<String> jobIdTokenCollection;
       PCollection<T> rows;
@@ -914,6 +933,186 @@ public class BigQueryIO {
       return rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView));
     }
 
+    private PCollection<T> expandForDirectRead(PBegin input, Coder<T> outputCoder) {
+      ValueProvider<TableReference> tableProvider = getTableProvider();
+      Pipeline p = input.getPipeline();
+      if (tableProvider != null) {
+        // No job ID is required. Read directly from BigQuery storage.
+        return p.apply(
+            org.apache.beam.sdk.io.Read.from(
+                BigQueryStorageTableSource.create(
+                    tableProvider,
+                    getReadOptions(),
+                    getParseFn(),
+                    outputCoder,
+                    getBigQueryServices())));
+      }
+
+      checkArgument(
+          getReadOptions() == null,
+          "Invalid BigQueryIO.Read: Specifies table read options, "
+              + "which only applies when reading from a table");
+
+      //
+      // N.B. All of the code below exists because the BigQuery storage API can't (yet) read from
+      // all anonymous tables, so we need the job ID to reason about the name of the destination
+      // table for the query to read the data and subsequently delete the table and dataset. Once
+      // the storage API can handle anonymous tables, the storage source should be modified to use
+      // anonymous tables and all of the code related to job ID generation and table and dataset
+      // cleanup can be removed. [BEAM-6931]
+      //
+
+      PCollectionView<String> jobIdTokenView;
+      PCollection<T> rows;
+
+      if (!getWithTemplateCompatibility()) {
+        // Create a singleton job ID token at pipeline construction time.
+        String staticJobUuid = BigQueryHelpers.randomUUIDString();
+        jobIdTokenView =
+            p.apply("TriggerIdCreation", Create.of(staticJobUuid))
+                .apply("ViewId", View.asSingleton());
+        // Apply the traditional Source model.
+        rows =
+            p.apply(
+                org.apache.beam.sdk.io.Read.from(
+                    createStorageQuerySource(staticJobUuid, outputCoder)));
+      } else {
+        // Create a singleton job ID token at pipeline execution time.
+        PCollection<String> jobIdTokenCollection =
+            p.apply("TriggerIdCreation", Create.of("ignored"))
+                .apply(
+                    "CreateJobId",
+                    MapElements.via(
+                        new SimpleFunction<String, String>() {
+                          @Override
+                          public String apply(String input) {
+                            return BigQueryHelpers.randomUUIDString();
+                          }
+                        }));
+
+        jobIdTokenView = jobIdTokenCollection.apply("ViewId", View.asSingleton());
+
+        TupleTag<Stream> streamsTag = new TupleTag<>();
+        TupleTag<ReadSession> readSessionTag = new TupleTag<>();
+        TupleTag<String> tableSchemaTag = new TupleTag<>();
+
+        PCollectionTuple tuple =
+            jobIdTokenCollection.apply(
+                "RunQueryJob",
+                ParDo.of(
+                        new DoFn<String, Stream>() {
+                          @ProcessElement
+                          public void processElement(ProcessContext c) throws Exception {
+                            BigQueryOptions options =
+                                c.getPipelineOptions().as(BigQueryOptions.class);
+                            String jobUuid = c.element();
+                            // Execute the query and get the destination table holding the results.
+                            // The getTargetTable call runs a new instance of the query and returns
+                            // the destination table created to hold the results.
+                            BigQueryStorageQuerySource<T> querySource =
+                                createStorageQuerySource(jobUuid, outputCoder);
+                            Table queryResultTable = querySource.getTargetTable(options);
+
+                            // Create a read session without specifying a desired stream count and
+                            // let the BigQuery storage server pick the number of streams.
+                            CreateReadSessionRequest request =
+                                CreateReadSessionRequest.newBuilder()
+                                    .setParent("projects/" + options.getProject())
+                                    .setTableReference(
+                                        BigQueryHelpers.toTableRefProto(
+                                            queryResultTable.getTableReference()))
+                                    .setRequestedStreams(0)
+                                    .build();
+
+                            ReadSession readSession;
+                            try (StorageClient storageClient =
+                                getBigQueryServices().getStorageClient(options)) {
+                              readSession = storageClient.createReadSession(request);
+                            }
+
+                            for (Stream stream : readSession.getStreamsList()) {
+                              c.output(stream);
+                            }
+
+                            c.output(readSessionTag, readSession);
+                            c.output(
+                                tableSchemaTag,
+                                BigQueryHelpers.toJsonString(queryResultTable.getSchema()));
+                          }
+                        })
+                    .withOutputTags(
+                        streamsTag, TupleTagList.of(readSessionTag).and(tableSchemaTag)));
+
+        tuple.get(streamsTag).setCoder(ProtoCoder.of(Stream.class));
+        tuple.get(readSessionTag).setCoder(ProtoCoder.of(ReadSession.class));
+        tuple.get(tableSchemaTag).setCoder(StringUtf8Coder.of());
+
+        PCollectionView<ReadSession> readSessionView =
+            tuple.get(readSessionTag).apply("ReadSessionView", View.asSingleton());
+        PCollectionView<String> tableSchemaView =
+            tuple.get(tableSchemaTag).apply("TableSchemaView", View.asSingleton());
+
+        rows =
+            tuple
+                .get(streamsTag)
+                .apply(Reshuffle.viaRandomKey())
+                .apply(
+                    ParDo.of(
+                            new DoFn<Stream, T>() {
+                              @ProcessElement
+                              public void processElement(ProcessContext c) throws Exception {
+                                ReadSession readSession = c.sideInput(readSessionView);
+                                TableSchema tableSchema =
+                                    BigQueryHelpers.fromJsonString(
+                                        c.sideInput(tableSchemaView), TableSchema.class);
+                                Stream stream = c.element();
+
+                                BigQueryStorageStreamSource<T> streamSource =
+                                    BigQueryStorageStreamSource.create(
+                                        readSession,
+                                        stream,
+                                        tableSchema,
+                                        getParseFn(),
+                                        outputCoder,
+                                        getBigQueryServices());
+
+                                // Read all of the data from the stream. In the event that this work
+                                // item fails and is rescheduled, the same rows will be returned in
+                                // the same order.
+                                BoundedSource.BoundedReader<T> reader =
+                                    streamSource.createReader(c.getPipelineOptions());
+                                for (boolean more = reader.start(); more; more = reader.advance()) {
+                                  c.output(reader.getCurrent());
+                                }
+                              }
+                            })
+                        .withSideInputs(readSessionView, tableSchemaView))
+                .setCoder(outputCoder);
+      }
+
+      PassThroughThenCleanup.CleanupOperation cleanupOperation =
+          new CleanupOperation() {
+            @Override
+            void cleanup(ContextContainer c) throws Exception {
+              BigQueryOptions options = c.getPipelineOptions().as(BigQueryOptions.class);
+              String jobUuid = c.getJobId();
+
+              TableReference tempTable =
+                  createTempTableReference(
+                      options.getProject(), createJobIdToken(options.getJobName(), jobUuid));
+
+              DatasetService datasetService = getBigQueryServices().getDatasetService(options);
+              LOG.info("Deleting temporary table with query results {}", tempTable);
+              datasetService.deleteTable(tempTable);
+              LOG.info(
+                  "Deleting temporary dataset with query results {}", tempTable.getDatasetId());
+              datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
+            }
+          };
+
+      return rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView));
+    }
+
     @Override
     public void populateDisplayData(DisplayData.Builder builder) {
       super.populateDisplayData(builder);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
new file mode 100644
index 0000000..a77cf55
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createJobIdToken;
+import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createTempTableReference;
+
+import com.google.api.services.bigquery.model.EncryptionConfiguration;
+import com.google.api.services.bigquery.model.Job;
+import com.google.api.services.bigquery.model.JobConfigurationQuery;
+import com.google.api.services.bigquery.model.JobReference;
+import com.google.api.services.bigquery.model.JobStatistics;
+import com.google.api.services.bigquery.model.TableReference;
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.Status;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.QueryPriority;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Helper object for executing query jobs in BigQuery.
+ *
+ * <p>This object is not serializable, and its state can be safely discarded across serialization
+ * boundaries for any associated source objects.
+ */
+class BigQueryQueryHelper {
+
+  private static final Integer JOB_POLL_MAX_RETRIES = Integer.MAX_VALUE;
+
+  private static final Logger LOG = LoggerFactory.getLogger(BigQueryQueryHelper.class);
+
+  public static JobStatistics dryRunQueryIfNeeded(
+      BigQueryServices bqServices,
+      BigQueryOptions options,
+      AtomicReference<JobStatistics> dryRunJobStats,
+      String query,
+      Boolean flattenResults,
+      Boolean useLegacySql,
+      @Nullable String location)
+      throws InterruptedException, IOException {
+    if (dryRunJobStats.get() == null) {
+      JobStatistics jobStatistics =
+          bqServices
+              .getJobService(options)
+              .dryRunQuery(
+                  options.getProject(),
+                  createBasicQueryConfig(query, flattenResults, useLegacySql),
+                  location);
+      dryRunJobStats.compareAndSet(null, jobStatistics);
+    }
+
+    return dryRunJobStats.get();
+  }
+
+  public static TableReference executeQuery(
+      BigQueryServices bqServices,
+      BigQueryOptions options,
+      AtomicReference<JobStatistics> dryRunJobStats,
+      String stepUuid,
+      String query,
+      Boolean flattenResults,
+      Boolean useLegacySql,
+      QueryPriority priority,
+      @Nullable String location,
+      @Nullable String kmsKey)
+      throws InterruptedException, IOException {
+    // Step 1: Find the effective location of the query.
+    String effectiveLocation = location;
+    DatasetService tableService = bqServices.getDatasetService(options);
+    if (effectiveLocation == null) {
+      List<TableReference> referencedTables =
+          dryRunQueryIfNeeded(
+                  bqServices,
+                  options,
+                  dryRunJobStats,
+                  query,
+                  flattenResults,
+                  useLegacySql,
+                  location)
+              .getQuery()
+              .getReferencedTables();
+      if (referencedTables != null && !referencedTables.isEmpty()) {
+        TableReference referencedTable = referencedTables.get(0);
+        effectiveLocation = tableService.getTable(referencedTable).getLocation();
+      }
+    }
+
+    // Step 2: Create a temporary dataset in the query location.
+    String jobIdToken = createJobIdToken(options.getJobName(), stepUuid);
+    TableReference queryResultTable = createTempTableReference(options.getProject(), jobIdToken);
+    LOG.info("Creating temporary dataset {} for query results", queryResultTable.getDatasetId());
+
+    tableService.createDataset(
+        queryResultTable.getProjectId(),
+        queryResultTable.getDatasetId(),
+        effectiveLocation,
+        "Temporary tables for query results of job " + options.getJobName(),
+        TimeUnit.DAYS.toMillis(1));
+
+    // Step 3: Execute the query. Generate a transient (random) query job ID, because this code may
+    // be retried after the temporary dataset and table have been deleted by a previous attempt --
+    // in that case, we want to regenerate the temporary dataset and table, and we'll need a fresh
+    // query ID to do that.
+    String queryJobId = jobIdToken + "-query-" + BigQueryHelpers.randomUUIDString();
+    LOG.info(
+        "Exporting query results into temporary table {} using job {}",
+        queryResultTable,
+        queryJobId);
+
+    JobReference jobReference =
+        new JobReference()
+            .setProjectId(options.getProject())
+            .setLocation(effectiveLocation)
+            .setJobId(queryJobId);
+
+    JobConfigurationQuery queryConfiguration =
+        createBasicQueryConfig(query, flattenResults, useLegacySql)
+            .setAllowLargeResults(true)
+            .setDestinationTable(queryResultTable)
+            .setCreateDisposition("CREATE_IF_NEEDED")
+            .setWriteDisposition("WRITE_TRUNCATE")
+            .setPriority(priority.name());
+
+    if (kmsKey != null) {
+      queryConfiguration.setDestinationEncryptionConfiguration(
+          new EncryptionConfiguration().setKmsKeyName(kmsKey));
+    }
+
+    JobService jobService = bqServices.getJobService(options);
+    jobService.startQueryJob(jobReference, queryConfiguration);
+    Job job = jobService.pollJob(jobReference, JOB_POLL_MAX_RETRIES);
+    if (BigQueryHelpers.parseStatus(job) != Status.SUCCEEDED) {
+      throw new IOException(
+          String.format(
+              "Query job %s failed, status: %s",
+              queryJobId, BigQueryHelpers.statusToPrettyString(job.getStatus())));
+    }
+
+    LOG.info("Query job {} completed", queryJobId);
+    return queryResultTable;
+  }
+
+  private static JobConfigurationQuery createBasicQueryConfig(
+      String query, Boolean flattenResults, Boolean useLegacySql) {
+    return new JobConfigurationQuery()
+        .setQuery(query)
+        .setFlattenResults(flattenResults)
+        .setUseLegacySql(useLegacySql);
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java
index a937959..375cc4f 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java
@@ -21,21 +21,14 @@ import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createJobIdTok
 import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createTempTableReference;
 import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkNotNull;
 
-import com.google.api.services.bigquery.model.EncryptionConfiguration;
-import com.google.api.services.bigquery.model.Job;
-import com.google.api.services.bigquery.model.JobConfigurationQuery;
-import com.google.api.services.bigquery.model.JobReference;
 import com.google.api.services.bigquery.model.JobStatistics;
 import com.google.api.services.bigquery.model.TableReference;
 import java.io.IOException;
 import java.io.ObjectInputStream;
-import java.util.List;
 import java.util.concurrent.atomic.AtomicReference;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.Status;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.QueryPriority;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService;
-import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.ValueProvider;
 import org.apache.beam.sdk.transforms.SerializableFunction;
@@ -47,6 +40,7 @@ import org.slf4j.LoggerFactory;
 /** A {@link BigQuerySourceBase} for querying BigQuery tables. */
 @VisibleForTesting
 class BigQueryQuerySource<T> extends BigQuerySourceBase<T> {
+
   private static final Logger LOG = LoggerFactory.getLogger(BigQueryQuerySource.class);
 
   static <T> BigQueryQuerySource<T> create(
@@ -76,11 +70,12 @@ class BigQueryQuerySource<T> extends BigQuerySourceBase<T> {
   private final ValueProvider<String> query;
   private final Boolean flattenResults;
   private final Boolean useLegacySql;
-  private transient AtomicReference<JobStatistics> dryRunJobStats;
   private final QueryPriority priority;
   private final String location;
   private final String kmsKey;
 
+  private transient AtomicReference<JobStatistics> dryRunJobStats;
+
   private BigQueryQuerySource(
       String stepUuid,
       ValueProvider<String> query,
@@ -96,62 +91,51 @@ class BigQueryQuerySource<T> extends BigQuerySourceBase<T> {
     this.query = checkNotNull(query, "query");
     this.flattenResults = checkNotNull(flattenResults, "flattenResults");
     this.useLegacySql = checkNotNull(useLegacySql, "useLegacySql");
-    this.dryRunJobStats = new AtomicReference<>();
     this.priority = priority;
     this.location = location;
     this.kmsKey = kmsKey;
+    dryRunJobStats = new AtomicReference<>();
+  }
+
+  /**
+   * Since the query helper reference is declared as transient, neither the AtomicReference nor the
+   * structure it refers to are persisted across serialization boundaries. The code below is
+   * resilient to the QueryHelper object disappearing in between method calls, but the reference
+   * object must be recreated at deserialization time.
+   */
+  private void readObject(ObjectInputStream in) throws ClassNotFoundException, IOException {
+    in.defaultReadObject();
+    dryRunJobStats = new AtomicReference<>();
   }
 
   @Override
   public long getEstimatedSizeBytes(PipelineOptions options) throws Exception {
-    BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
-    return dryRunQueryIfNeeded(bqOptions).getQuery().getTotalBytesProcessed();
+    return BigQueryQueryHelper.dryRunQueryIfNeeded(
+            bqServices,
+            options.as(BigQueryOptions.class),
+            dryRunJobStats,
+            query.get(),
+            flattenResults,
+            useLegacySql,
+            location)
+        .getQuery()
+        .getTotalBytesProcessed();
   }
 
   @Override
   protected TableReference getTableToExtract(BigQueryOptions bqOptions)
       throws IOException, InterruptedException {
-    // 1. Find the location of the query.
-    String location = this.location;
-    DatasetService tableService = bqServices.getDatasetService(bqOptions);
-    if (location == null) {
-      // If location was not provided we try to determine it from the tables referenced by the
-      // Query. This will only work for BQ locations US and EU.
-      List<TableReference> referencedTables =
-          dryRunQueryIfNeeded(bqOptions).getQuery().getReferencedTables();
-      if (referencedTables != null && !referencedTables.isEmpty()) {
-        TableReference queryTable = referencedTables.get(0);
-        location = tableService.getTable(queryTable).getLocation();
-      }
-    }
-
-    String jobIdToken = createJobIdToken(bqOptions.getJobName(), stepUuid);
-
-    // 2. Create the temporary dataset in the query location.
-    TableReference tableToExtract = createTempTableReference(bqOptions.getProject(), jobIdToken);
-
-    LOG.info("Creating temporary dataset {} for query results", tableToExtract.getDatasetId());
-    tableService.createDataset(
-        tableToExtract.getProjectId(),
-        tableToExtract.getDatasetId(),
-        location,
-        "Temporary tables for query results of job " + bqOptions.getJobName(),
-        // Set a TTL of 1 day on the temporary tables, which ought to be enough in all cases:
-        // the temporary tables are used only to immediately extract them into files.
-        // They are normally cleaned up, but in case of job failure the cleanup step may not run,
-        // and then they'll get deleted after the TTL.
-        24 * 3600 * 1000L /* 1 day */);
-
-    // 3. Execute the query.
-    executeQuery(
-        jobIdToken,
-        bqOptions.getProject(),
-        tableToExtract,
-        bqServices.getJobService(bqOptions),
+    return BigQueryQueryHelper.executeQuery(
+        bqServices,
+        bqOptions,
+        dryRunJobStats,
+        stepUuid,
+        query.get(),
+        flattenResults,
+        useLegacySql,
+        priority,
         location,
         kmsKey);
-
-    return tableToExtract;
   }
 
   @Override
@@ -172,79 +156,4 @@ class BigQueryQuerySource<T> extends BigQuerySourceBase<T> {
     super.populateDisplayData(builder);
     builder.add(DisplayData.item("query", query));
   }
-
-  private synchronized JobStatistics dryRunQueryIfNeeded(BigQueryOptions bqOptions)
-      throws InterruptedException, IOException {
-    if (dryRunJobStats.get() == null) {
-      JobStatistics jobStats =
-          bqServices
-              .getJobService(bqOptions)
-              .dryRunQuery(bqOptions.getProject(), createBasicQueryConfig(), this.location);
-      dryRunJobStats.compareAndSet(null, jobStats);
-    }
-    return dryRunJobStats.get();
-  }
-
-  private void executeQuery(
-      String jobIdToken,
-      String executingProject,
-      TableReference destinationTable,
-      JobService jobService,
-      String bqLocation,
-      String kmsKey)
-      throws IOException, InterruptedException {
-    // Generate a transient (random) query job ID, because this code may be retried after the
-    // temporary dataset and table have already been deleted by a previous attempt -
-    // in that case we want to re-generate the temporary dataset and table, and we'll need
-    // a fresh query job to do that.
-    String queryJobId = jobIdToken + "-query-" + BigQueryHelpers.randomUUIDString();
-
-    LOG.info(
-        "Exporting query results into temporary table {} using job {}",
-        destinationTable,
-        queryJobId);
-
-    JobReference jobRef =
-        new JobReference()
-            .setProjectId(executingProject)
-            .setLocation(bqLocation)
-            .setJobId(queryJobId);
-
-    JobConfigurationQuery queryConfig =
-        createBasicQueryConfig()
-            .setAllowLargeResults(true)
-            .setCreateDisposition("CREATE_IF_NEEDED")
-            .setDestinationTable(destinationTable)
-            .setPriority(this.priority.name())
-            // Overwrite contents of the temporary table - it can only already exist if this
-            // is a retry of the splitting task, in which case we must not produce duplicate data.
-            .setWriteDisposition("WRITE_TRUNCATE");
-    if (kmsKey != null) {
-      queryConfig.setDestinationEncryptionConfiguration(
-          new EncryptionConfiguration().setKmsKeyName(kmsKey));
-    }
-
-    jobService.startQueryJob(jobRef, queryConfig);
-    Job job = jobService.pollJob(jobRef, JOB_POLL_MAX_RETRIES);
-    if (BigQueryHelpers.parseStatus(job) != Status.SUCCEEDED) {
-      throw new IOException(
-          String.format(
-              "Query job %s failed, status: %s.",
-              queryJobId, BigQueryHelpers.statusToPrettyString(job.getStatus())));
-    }
-
-    LOG.info("Query job {} completed", queryJobId);
-  }
-
-  private JobConfigurationQuery createBasicQueryConfig() {
-    return new JobConfigurationQuery()
-        .setFlattenResults(flattenResults)
-        .setQuery(query.get())
-        .setUseLegacySql(useLegacySql);
-  }
-
-  private void readObject(ObjectInputStream in) throws ClassNotFoundException, IOException {
-    in.defaultReadObject();
-    dryRunJobStats = new AtomicReference<>();
-  }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageQuerySource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageQuerySource.java
new file mode 100644
index 0000000..9112e56
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageQuerySource.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.api.services.bigquery.model.JobStatistics;
+import com.google.api.services.bigquery.model.Table;
+import com.google.api.services.bigquery.model.TableReference;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.QueryPriority;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+
+/** A {@link org.apache.beam.sdk.io.Source} representing reading the results of a query. */
+@Experimental(Experimental.Kind.SOURCE_SINK)
+public class BigQueryStorageQuerySource<T> extends BigQueryStorageSourceBase<T> {
+
+  public static <T> BigQueryStorageQuerySource<T> create(
+      String stepUuid,
+      ValueProvider<String> queryProvider,
+      Boolean flattenResults,
+      Boolean useLegacySql,
+      QueryPriority priority,
+      @Nullable String location,
+      @Nullable String kmsKey,
+      SerializableFunction<SchemaAndRecord, T> parseFn,
+      Coder<T> outputCoder,
+      BigQueryServices bqServices) {
+    return new BigQueryStorageQuerySource<>(
+        stepUuid,
+        queryProvider,
+        flattenResults,
+        useLegacySql,
+        priority,
+        location,
+        kmsKey,
+        parseFn,
+        outputCoder,
+        bqServices);
+  }
+
+  private final String stepUuid;
+  private final ValueProvider<String> queryProvider;
+  private final Boolean flattenResults;
+  private final Boolean useLegacySql;
+  private final QueryPriority priority;
+  private final String location;
+  private final String kmsKey;
+
+  private transient AtomicReference<JobStatistics> dryRunJobStats;
+
+  private BigQueryStorageQuerySource(
+      String stepUuid,
+      ValueProvider<String> queryProvider,
+      Boolean flattenResults,
+      Boolean useLegacySql,
+      QueryPriority priority,
+      @Nullable String location,
+      @Nullable String kmsKey,
+      SerializableFunction<SchemaAndRecord, T> parseFn,
+      Coder<T> outputCoder,
+      BigQueryServices bqServices) {
+    super(null, parseFn, outputCoder, bqServices);
+    this.stepUuid = checkNotNull(stepUuid, "stepUuid");
+    this.queryProvider = checkNotNull(queryProvider, "queryProvider");
+    this.flattenResults = checkNotNull(flattenResults, "flattenResults");
+    this.useLegacySql = checkNotNull(useLegacySql, "useLegacySql");
+    this.priority = checkNotNull(priority, "priority");
+    this.location = location;
+    this.kmsKey = kmsKey;
+    this.dryRunJobStats = new AtomicReference<>();
+  }
+
+  private void readObject(ObjectInputStream in) throws ClassNotFoundException, IOException {
+    in.defaultReadObject();
+    dryRunJobStats = new AtomicReference<>();
+  }
+
+  @Override
+  public void populateDisplayData(DisplayData.Builder builder) {
+    super.populateDisplayData(builder);
+    builder.add(DisplayData.item("query", queryProvider).withLabel("Query"));
+  }
+
+  @Override
+  public long getEstimatedSizeBytes(PipelineOptions options) throws Exception {
+    return BigQueryQueryHelper.dryRunQueryIfNeeded(
+            bqServices,
+            options.as(BigQueryOptions.class),
+            dryRunJobStats,
+            queryProvider.get(),
+            flattenResults,
+            useLegacySql,
+            location)
+        .getQuery()
+        .getTotalBytesProcessed();
+  }
+
+  @Override
+  protected Table getTargetTable(BigQueryOptions options) throws Exception {
+    TableReference queryResultTable =
+        BigQueryQueryHelper.executeQuery(
+            bqServices,
+            options,
+            dryRunJobStats,
+            stepUuid,
+            queryProvider.get(),
+            flattenResults,
+            useLegacySql,
+            priority,
+            location,
+            kmsKey);
+    return bqServices.getDatasetService(options).getTable(queryResultTable);
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java
new file mode 100644
index 0000000..ccd1f12
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.api.services.bigquery.model.Table;
+import com.google.cloud.bigquery.storage.v1beta1.ReadOptions.TableReadOptions;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.CreateReadSessionRequest;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadSession;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.Stream;
+import java.io.IOException;
+import java.util.List;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
+
+/**
+ * A base class for {@link BoundedSource} implementations which read from BigQuery using the
+ * BigQuery storage API.
+ */
+@Experimental(Experimental.Kind.SOURCE_SINK)
+abstract class BigQueryStorageSourceBase<T> extends BoundedSource<T> {
+
+  /**
+   * The maximum number of streams which will be requested when creating a read session, regardless
+   * of the desired bundle size.
+   */
+  private static final int MAX_SPLIT_COUNT = 10_000;
+
+  /**
+   * The minimum number of streams which will be requested when creating a read session, regardless
+   * of the desired bundle size. Note that the server may still choose to return fewer than ten
+   * streams based on the layout of the table.
+   */
+  private static final int MIN_SPLIT_COUNT = 10;
+
+  protected final TableReadOptions tableReadOptions;
+  protected final SerializableFunction<SchemaAndRecord, T> parseFn;
+  protected final Coder<T> outputCoder;
+  protected final BigQueryServices bqServices;
+
+  BigQueryStorageSourceBase(
+      @Nullable TableReadOptions tableReadOptions,
+      SerializableFunction<SchemaAndRecord, T> parseFn,
+      Coder<T> outputCoder,
+      BigQueryServices bqServices) {
+    this.tableReadOptions = tableReadOptions;
+    this.parseFn = checkNotNull(parseFn, "parseFn");
+    this.outputCoder = checkNotNull(outputCoder, "outputCoder");
+    this.bqServices = checkNotNull(bqServices, "bqServices");
+  }
+
+  /**
+   * Returns the table to read from at split time. This is currently never an anonymous table, but
+   * it can be a named table which was created to hold the results of a query.
+   */
+  protected abstract Table getTargetTable(BigQueryOptions options) throws Exception;
+
+  @Override
+  public Coder<T> getOutputCoder() {
+    return outputCoder;
+  }
+
+  @Override
+  public List<BigQueryStorageStreamSource<T>> split(
+      long desiredBundleSizeBytes, PipelineOptions options) throws Exception {
+    BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
+    Table targetTable = getTargetTable(bqOptions);
+    int streamCount = 0;
+    if (desiredBundleSizeBytes > 0) {
+      long tableSizeBytes = (targetTable != null) ? targetTable.getNumBytes() : 0;
+      streamCount = (int) Math.min(tableSizeBytes / desiredBundleSizeBytes, MAX_SPLIT_COUNT);
+    }
+
+    streamCount = Math.max(streamCount, MIN_SPLIT_COUNT);
+
+    CreateReadSessionRequest.Builder requestBuilder =
+        CreateReadSessionRequest.newBuilder()
+            .setParent("projects/" + bqOptions.getProject())
+            .setTableReference(BigQueryHelpers.toTableRefProto(targetTable.getTableReference()))
+            .setRequestedStreams(streamCount);
+
+    if (tableReadOptions != null) {
+      requestBuilder.setReadOptions(tableReadOptions);
+    }
+
+    ReadSession readSession;
+    try (StorageClient client = bqServices.getStorageClient(bqOptions)) {
+      readSession = client.createReadSession(requestBuilder.build());
+    }
+
+    if (readSession.getStreamsList().isEmpty()) {
+      // The underlying table is empty or all rows have been pruned.
+      return ImmutableList.of();
+    }
+
+    List<BigQueryStorageStreamSource<T>> sources = Lists.newArrayList();
+    for (Stream stream : readSession.getStreamsList()) {
+      sources.add(
+          BigQueryStorageStreamSource.create(
+              readSession, stream, targetTable.getSchema(), parseFn, outputCoder, bqServices));
+    }
+
+    return ImmutableList.copyOf(sources);
+  }
+
+  @Override
+  public BoundedReader<T> createReader(PipelineOptions options) throws IOException {
+    throw new UnsupportedOperationException("BigQuery storage source must be split before reading");
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java
index bdd9037..e48acf7 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java
@@ -23,46 +23,24 @@ import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Precondi
 import com.google.api.services.bigquery.model.Table;
 import com.google.api.services.bigquery.model.TableReference;
 import com.google.cloud.bigquery.storage.v1beta1.ReadOptions.TableReadOptions;
-import com.google.cloud.bigquery.storage.v1beta1.Storage.CreateReadSessionRequest;
-import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadSession;
-import com.google.cloud.bigquery.storage.v1beta1.Storage.Stream;
-import com.google.cloud.bigquery.storage.v1beta1.TableReferenceProto;
 import java.io.IOException;
+import java.io.ObjectInputStream;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicReference;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.TableRefToTableRefProto;
-import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.ValueProvider;
-import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Strings;
-import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
-import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /** A {@link org.apache.beam.sdk.io.Source} representing reading from a table. */
 @Experimental(Experimental.Kind.SOURCE_SINK)
-public class BigQueryStorageTableSource<T> extends BoundedSource<T> {
-
-  /**
-   * The maximum number of streams which will be requested when creating a read session, regardless
-   * of the desired bundle size.
-   */
-  private static final int MAX_SPLIT_COUNT = 10_000;
-
-  /**
-   * The minimum number of streams which will be requested when creating a read session, regardless
-   * of the desired bundle size. Note that the server may still choose to return fewer than ten
-   * streams based on the layout of the table.
-   */
-  private static final int MIN_SPLIT_COUNT = 10;
+public class BigQueryStorageTableSource<T> extends BigQueryStorageSourceBase<T> {
 
   private static final Logger LOG = LoggerFactory.getLogger(BigQueryStorageTableSource.class);
 
@@ -73,142 +51,71 @@ public class BigQueryStorageTableSource<T> extends BoundedSource<T> {
       Coder<T> outputCoder,
       BigQueryServices bqServices) {
     return new BigQueryStorageTableSource<>(
-        NestedValueProvider.of(
-            checkNotNull(tableRefProvider, "tableRefProvider"), new TableRefToTableRefProto()),
-        readOptions,
-        parseFn,
-        outputCoder,
-        bqServices);
+        tableRefProvider, readOptions, parseFn, outputCoder, bqServices);
   }
 
-  private final ValueProvider<TableReferenceProto.TableReference> tableRefProtoProvider;
-  private final TableReadOptions readOptions;
-  private final SerializableFunction<SchemaAndRecord, T> parseFn;
-  private final Coder<T> outputCoder;
-  private final BigQueryServices bqServices;
-  private final AtomicReference<Long> tableSizeBytes;
+  private final ValueProvider<TableReference> tableReferenceProvider;
+
+  private transient AtomicReference<Table> cachedTable;
 
   private BigQueryStorageTableSource(
-      ValueProvider<TableReferenceProto.TableReference> tableRefProtoProvider,
+      ValueProvider<TableReference> tableRefProvider,
       @Nullable TableReadOptions readOptions,
       SerializableFunction<SchemaAndRecord, T> parseFn,
       Coder<T> outputCoder,
       BigQueryServices bqServices) {
-    this.tableRefProtoProvider = checkNotNull(tableRefProtoProvider, "tableRefProtoProvider");
-    this.readOptions = readOptions;
-    this.parseFn = checkNotNull(parseFn, "parseFn");
-    this.outputCoder = checkNotNull(outputCoder, "outputCoder");
-    this.bqServices = checkNotNull(bqServices, "bqServices");
-    this.tableSizeBytes = new AtomicReference<>();
+    super(readOptions, parseFn, outputCoder, bqServices);
+    this.tableReferenceProvider = checkNotNull(tableRefProvider, "tableRefProvider");
+    cachedTable = new AtomicReference<>();
   }
 
-  @Override
-  public Coder<T> getOutputCoder() {
-    return outputCoder;
+  private void readObject(ObjectInputStream in) throws ClassNotFoundException, IOException {
+    in.defaultReadObject();
+    cachedTable = new AtomicReference<>();
   }
 
   @Override
   public void populateDisplayData(DisplayData.Builder builder) {
     super.populateDisplayData(builder);
     builder.addIfNotNull(
-        DisplayData.item("table", BigQueryHelpers.displayTableRefProto(tableRefProtoProvider))
+        DisplayData.item("table", BigQueryHelpers.displayTable(tableReferenceProvider))
             .withLabel("Table"));
   }
 
-  private TableReferenceProto.TableReference getTargetTable(BigQueryOptions bqOptions)
-      throws IOException {
-    TableReferenceProto.TableReference tableReferenceProto = tableRefProtoProvider.get();
-    return setDefaultProjectIfAbsent(bqOptions, tableReferenceProto);
-  }
-
-  private TableReferenceProto.TableReference setDefaultProjectIfAbsent(
-      BigQueryOptions bqOptions, TableReferenceProto.TableReference tableReferenceProto) {
-    if (Strings.isNullOrEmpty(tableReferenceProto.getProjectId())) {
-      checkState(
-          !Strings.isNullOrEmpty(bqOptions.getProject()),
-          "No project ID set in %s or %s, cannot construct a complete %s",
-          TableReferenceProto.TableReference.class.getSimpleName(),
-          BigQueryOptions.class.getSimpleName(),
-          TableReferenceProto.TableReference.class.getSimpleName());
-      LOG.info(
-          "Project ID not set in {}. Using default project from {}.",
-          TableReferenceProto.TableReference.class.getSimpleName(),
-          BigQueryOptions.class.getSimpleName());
-      tableReferenceProto =
-          tableReferenceProto.toBuilder().setProjectId(bqOptions.getProject()).build();
-    }
-    return tableReferenceProto;
-  }
-
-  private List<String> getSelectedFields() {
-    if (readOptions != null && !readOptions.getSelectedFieldsList().isEmpty()) {
-      return readOptions.getSelectedFieldsList();
-    }
-    return null;
-  }
-
   @Override
   public long getEstimatedSizeBytes(PipelineOptions options) throws Exception {
-    if (tableSizeBytes.get() == null) {
-      BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
-      TableReferenceProto.TableReference tableReferenceProto =
-          setDefaultProjectIfAbsent(bqOptions, tableRefProtoProvider.get());
-      TableReference tableReference = BigQueryHelpers.toTableRef(tableReferenceProto);
-      Table table =
-          bqServices.getDatasetService(bqOptions).getTable(tableReference, getSelectedFields());
-      tableSizeBytes.compareAndSet(null, table.getNumBytes());
-    }
-    return tableSizeBytes.get();
+    return getTargetTable(options.as(BigQueryOptions.class)).getNumBytes();
   }
 
   @Override
-  public List<BigQueryStorageStreamSource<T>> split(
-      long desiredBundleSizeBytes, PipelineOptions options) throws Exception {
-    BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
-    TableReferenceProto.TableReference tableReferenceProto =
-        setDefaultProjectIfAbsent(bqOptions, tableRefProtoProvider.get());
-    TableReference tableReference = BigQueryHelpers.toTableRef(tableReferenceProto);
-    Table table =
-        bqServices.getDatasetService(bqOptions).getTable(tableReference, getSelectedFields());
-    long tableSizeBytes = (table != null) ? table.getNumBytes() : 0;
-
-    int streamCount = 0;
-    if (desiredBundleSizeBytes > 0) {
-      streamCount = (int) Math.min(tableSizeBytes / desiredBundleSizeBytes, MAX_SPLIT_COUNT);
-    }
-
-    CreateReadSessionRequest.Builder requestBuilder =
-        CreateReadSessionRequest.newBuilder()
-            .setParent("projects/" + bqOptions.getProject())
-            .setTableReference(tableReferenceProto)
-            .setRequestedStreams(Math.max(streamCount, MIN_SPLIT_COUNT));
-
-    if (readOptions != null) {
-      requestBuilder.setReadOptions(readOptions);
-    }
-
-    ReadSession readSession;
-    try (StorageClient client = bqServices.getStorageClient(bqOptions)) {
-      readSession = client.createReadSession(requestBuilder.build());
-    }
-
-    if (readSession.getStreamsList().isEmpty()) {
-      // The underlying table is empty or has no rows which can be read.
-      return ImmutableList.of();
-    }
-
-    List<BigQueryStorageStreamSource<T>> sources = Lists.newArrayList();
-    for (Stream stream : readSession.getStreamsList()) {
-      sources.add(
-          BigQueryStorageStreamSource.create(
-              readSession, stream, table.getSchema(), parseFn, outputCoder, bqServices));
+  protected Table getTargetTable(BigQueryOptions options) throws Exception {
+    if (cachedTable.get() == null) {
+      TableReference tableReference = tableReferenceProvider.get();
+      if (Strings.isNullOrEmpty(tableReference.getProjectId())) {
+        checkState(
+            !Strings.isNullOrEmpty(options.getProject()),
+            "No project ID set in %s or %s, cannot construct a complete %s",
+            TableReference.class.getSimpleName(),
+            BigQueryOptions.class.getSimpleName(),
+            TableReference.class.getSimpleName());
+        LOG.info(
+            "Project ID not set in {}. Using default project from {}.",
+            TableReference.class.getSimpleName(),
+            BigQueryOptions.class.getSimpleName());
+        tableReference.setProjectId(options.getProject());
+      }
+      Table table =
+          bqServices.getDatasetService(options).getTable(tableReference, getSelectedFields());
+      cachedTable.compareAndSet(null, table);
     }
 
-    return ImmutableList.copyOf(sources);
+    return cachedTable.get();
   }
 
-  @Override
-  public BoundedReader<T> createReader(PipelineOptions options) throws IOException {
-    throw new UnsupportedOperationException("BigQuery table source must be split before reading");
+  private List<String> getSelectedFields() {
+    if (tableReadOptions != null && !tableReadOptions.getSelectedFieldsList().isEmpty()) {
+      return tableReadOptions.getSelectedFieldsList();
+    }
+    return null;
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java
index db8d134..883bbdd 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java
@@ -26,11 +26,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThat;
 
-import com.google.api.services.bigquery.model.Job;
 import com.google.api.services.bigquery.model.JobStatistics;
 import com.google.api.services.bigquery.model.JobStatistics2;
-import com.google.api.services.bigquery.model.JobStatistics4;
-import com.google.api.services.bigquery.model.JobStatus;
 import com.google.api.services.bigquery.model.Streamingbuffer;
 import com.google.api.services.bigquery.model.Table;
 import com.google.api.services.bigquery.model.TableFieldSchema;
@@ -617,22 +614,17 @@ public class BigQueryIOReadTest implements Serializable {
   @Test
   public void testBigQueryQuerySourceEstimatedSize() throws Exception {
 
-    List<TableRow> data =
-        ImmutableList.of(
-            new TableRow().set("name", "A").set("number", 10L),
-            new TableRow().set("name", "B").set("number", 11L),
-            new TableRow().set("name", "C").set("number", 12L));
+    String queryString = "fake query string";
 
     PipelineOptions options = PipelineOptionsFactory.create();
     BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
     bqOptions.setProject("project");
     String stepUuid = "testStepUuid";
 
-    String query = FakeBigQueryServices.encodeQuery(data);
     BigQueryQuerySource<TableRow> bqSource =
         BigQueryQuerySource.create(
             stepUuid,
-            ValueProvider.StaticValueProvider.of(query),
+            ValueProvider.StaticValueProvider.of(queryString),
             true /* flattenResults */,
             true /* useLegacySql */,
             fakeBqServices,
@@ -644,7 +636,7 @@ public class BigQueryIOReadTest implements Serializable {
 
     fakeJobService.expectDryRunQuery(
         bqOptions.getProject(),
-        query,
+        queryString,
         new JobStatistics().setQuery(new JobStatistics2().setTotalBytesProcessed(100L)));
 
     assertEquals(100, bqSource.getEstimatedSizeBytes(bqOptions));
@@ -652,21 +644,31 @@ public class BigQueryIOReadTest implements Serializable {
 
   @Test
   public void testBigQueryQuerySourceInitSplit() throws Exception {
-    TableReference dryRunTable = new TableReference();
-
-    Job queryJob = new Job();
-    JobStatistics queryJobStats = new JobStatistics();
-    JobStatistics2 queryStats = new JobStatistics2();
-    queryStats.setReferencedTables(ImmutableList.of(dryRunTable));
-    queryJobStats.setQuery(queryStats);
-    queryJob.setStatus(new JobStatus()).setStatistics(queryJobStats);
-
-    Job extractJob = new Job();
-    JobStatistics extractJobStats = new JobStatistics();
-    JobStatistics4 extractStats = new JobStatistics4();
-    extractStats.setDestinationUriFileCounts(ImmutableList.of(1L));
-    extractJobStats.setExtract(extractStats);
-    extractJob.setStatus(new JobStatus()).setStatistics(extractJobStats);
+
+    PipelineOptions options = PipelineOptionsFactory.create();
+    BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
+    bqOptions.setProject("project");
+
+    TableReference sourceTableRef = BigQueryHelpers.parseTableSpec("project:dataset.table");
+
+    fakeDatasetService.createDataset(
+        sourceTableRef.getProjectId(),
+        sourceTableRef.getDatasetId(),
+        "asia-northeast1",
+        "Fake plastic tree^H^H^H^Htables",
+        null);
+
+    fakeDatasetService.createTable(
+        new Table().setTableReference(sourceTableRef).setLocation("asia-northeast1"));
+
+    Table queryResultTable =
+        new Table()
+            .setSchema(
+                new TableSchema()
+                    .setFields(
+                        ImmutableList.of(
+                            new TableFieldSchema().setName("name").setType("STRING"),
+                            new TableFieldSchema().setName("number").setType("INTEGER"))));
 
     List<TableRow> expected =
         ImmutableList.of(
@@ -677,31 +679,27 @@ public class BigQueryIOReadTest implements Serializable {
             new TableRow().set("name", "e").set("number", 5L),
             new TableRow().set("name", "f").set("number", 6L));
 
-    PipelineOptions options = PipelineOptionsFactory.create();
-    BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
-    bqOptions.setProject("project");
+    String encodedQuery = FakeBigQueryServices.encodeQueryResult(queryResultTable, expected);
+
     String stepUuid = "testStepUuid";
 
     TableReference tempTableReference =
         createTempTableReference(
-            bqOptions.getProject(), createJobIdToken(bqOptions.getJobName(), stepUuid));
-    fakeDatasetService.createDataset(
-        bqOptions.getProject(), tempTableReference.getDatasetId(), "", "", null);
-    fakeDatasetService.createTable(
-        new Table()
-            .setTableReference(tempTableReference)
-            .setSchema(
-                new TableSchema()
-                    .setFields(
-                        ImmutableList.of(
-                            new TableFieldSchema().setName("name").setType("STRING"),
-                            new TableFieldSchema().setName("number").setType("INTEGER")))));
+            bqOptions.getProject(), createJobIdToken(options.getJobName(), stepUuid));
+
+    fakeJobService.expectDryRunQuery(
+        bqOptions.getProject(),
+        encodedQuery,
+        new JobStatistics()
+            .setQuery(
+                new JobStatistics2()
+                    .setTotalBytesProcessed(100L)
+                    .setReferencedTables(ImmutableList.of(sourceTableRef, tempTableReference))));
 
-    String query = FakeBigQueryServices.encodeQuery(expected);
     BoundedSource<TableRow> bqSource =
         BigQueryQuerySource.create(
             stepUuid,
-            ValueProvider.StaticValueProvider.of(query),
+            ValueProvider.StaticValueProvider.of(encodedQuery),
             true /* flattenResults */,
             true /* useLegacySql */,
             fakeBqServices,
@@ -710,22 +708,8 @@ public class BigQueryIOReadTest implements Serializable {
             QueryPriority.BATCH,
             null,
             null);
-    options.setTempLocation(testFolder.getRoot().getAbsolutePath());
-
-    TableReference queryTable =
-        new TableReference()
-            .setProjectId(bqOptions.getProject())
-            .setDatasetId(tempTableReference.getDatasetId())
-            .setTableId(tempTableReference.getTableId());
 
-    fakeJobService.expectDryRunQuery(
-        bqOptions.getProject(),
-        query,
-        new JobStatistics()
-            .setQuery(
-                new JobStatistics2()
-                    .setTotalBytesProcessed(100L)
-                    .setReferencedTables(ImmutableList.of(queryTable))));
+    options.setTempLocation(testFolder.getRoot().getAbsolutePath());
 
     List<TableRow> read =
         convertStringsToLong(
@@ -736,28 +720,27 @@ public class BigQueryIOReadTest implements Serializable {
     assertEquals(2, sources.size());
   }
 
+  /**
+   * This test simulates the scenario where the SQL text which is executed by the query job doesn't
+   * by itself refer to any tables (e.g. "SELECT 17 AS value"), and thus there are no referenced
+   * tables when the dry run of the query is performed.
+   */
   @Test
-  public void testBigQueryNoTableQuerySourceInitSplit() throws Exception {
-    TableReference dryRunTable = new TableReference();
-
-    Job queryJob = new Job();
-    JobStatistics queryJobStats = new JobStatistics();
-    JobStatistics2 queryStats = new JobStatistics2();
-    queryStats.setReferencedTables(ImmutableList.of(dryRunTable));
-    queryJobStats.setQuery(queryStats);
-    queryJob.setStatus(new JobStatus()).setStatistics(queryJobStats);
-
-    Job extractJob = new Job();
-    JobStatistics extractJobStats = new JobStatistics();
-    JobStatistics4 extractStats = new JobStatistics4();
-    extractStats.setDestinationUriFileCounts(ImmutableList.of(1L));
-    extractJobStats.setExtract(extractStats);
-    extractJob.setStatus(new JobStatus()).setStatistics(extractJobStats);
+  public void testBigQueryQuerySourceInitSplit_NoReferencedTables() throws Exception {
 
-    String stepUuid = "testStepUuid";
+    PipelineOptions options = PipelineOptionsFactory.create();
+    BigQueryOptions bqOptions = options.as(BigQueryOptions.class);
+    bqOptions.setProject("project");
+
+    Table queryResultTable =
+        new Table()
+            .setSchema(
+                new TableSchema()
+                    .setFields(
+                        ImmutableList.of(
+                            new TableFieldSchema().setName("name").setType("STRING"),
+                            new TableFieldSchema().setName("number").setType("INTEGER"))));
 
-    TableReference tempTableReference =
-        createTempTableReference("project-id", createJobIdToken(options.getJobName(), stepUuid));
     List<TableRow> expected =
         ImmutableList.of(
             new TableRow().set("name", "a").set("number", 1L),
@@ -766,33 +749,24 @@ public class BigQueryIOReadTest implements Serializable {
             new TableRow().set("name", "d").set("number", 4L),
             new TableRow().set("name", "e").set("number", 5L),
             new TableRow().set("name", "f").set("number", 6L));
-    fakeDatasetService.createDataset(
-        tempTableReference.getProjectId(), tempTableReference.getDatasetId(), "", "", null);
-    Table table =
-        new Table()
-            .setTableReference(tempTableReference)
-            .setSchema(
-                new TableSchema()
-                    .setFields(
-                        ImmutableList.of(
-                            new TableFieldSchema().setName("name").setType("STRING"),
-                            new TableFieldSchema().setName("number").setType("INTEGER"))));
-    fakeDatasetService.createTable(table);
 
-    String query = FakeBigQueryServices.encodeQuery(expected);
+    String encodedQuery = FakeBigQueryServices.encodeQueryResult(queryResultTable, expected);
+
+    String stepUuid = "testStepUuid";
+
     fakeJobService.expectDryRunQuery(
-        "project-id",
-        query,
+        bqOptions.getProject(),
+        encodedQuery,
         new JobStatistics()
             .setQuery(
                 new JobStatistics2()
                     .setTotalBytesProcessed(100L)
-                    .setReferencedTables(ImmutableList.of(table.getTableReference()))));
+                    .setReferencedTables(ImmutableList.of())));
 
     BoundedSource<TableRow> bqSource =
         BigQueryQuerySource.create(
             stepUuid,
-            ValueProvider.StaticValueProvider.of(query),
+            ValueProvider.StaticValueProvider.of(encodedQuery),
             true /* flattenResults */,
             true /* useLegacySql */,
             fakeBqServices,
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java
new file mode 100644
index 0000000..d619382
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import java.util.Map;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.ExperimentalOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.Validation;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Integration tests for {@link BigQueryIO#read(SerializableFunction)} using {@link
+ * Method#DIRECT_READ} to read query results. This test runs a simple "SELECT *" query over a
+ * pre-defined table and asserts that the number of records read is equal to the expected count.
+ */
+@RunWith(JUnit4.class)
+public class BigQueryIOStorageQueryIT {
+
+  private static final Map<String, Long> EXPECTED_NUM_RECORDS =
+      ImmutableMap.of(
+          "empty", 0L,
+          "1M", 10592L,
+          "1G", 11110839L,
+          "1T", 11110839000L);
+
+  private static final String DATASET_ID = "big_query_storage";
+  private static final String TABLE_PREFIX = "storage_read_";
+
+  private BigQueryIOStorageQueryOptions options;
+
+  /** Customized {@link TestPipelineOptions} for BigQueryIOStorageQuery pipelines. */
+  public interface BigQueryIOStorageQueryOptions extends TestPipelineOptions, ExperimentalOptions {
+    @Description("The table to be queried")
+    @Validation.Required
+    String getInputTable();
+
+    void setInputTable(String table);
+
+    @Description("The expected number of records")
+    @Validation.Required
+    long getNumRecords();
+
+    void setNumRecords(long numRecords);
+  }
+
+  private void setUpTestEnvironment(String tableSize) {
+    PipelineOptionsFactory.register(BigQueryIOStorageQueryOptions.class);
+    options = TestPipeline.testingPipelineOptions().as(BigQueryIOStorageQueryOptions.class);
+    options.setNumRecords(EXPECTED_NUM_RECORDS.get(tableSize));
+    String project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
+    options.setInputTable(project + '.' + DATASET_ID + '.' + TABLE_PREFIX + tableSize);
+  }
+
+  private void runBigQueryIOStorageQueryPipeline() {
+    Pipeline p = Pipeline.create(options);
+    PCollection<Long> count =
+        p.apply(
+                "Query",
+                BigQueryIO.read(TableRowParser.INSTANCE)
+                    .fromQuery("SELECT * FROM `" + options.getInputTable() + "`")
+                    .usingStandardSql()
+                    .withMethod(Method.DIRECT_READ))
+            .apply("Count", Count.globally());
+    PAssert.thatSingleton(count).isEqualTo(options.getNumRecords());
+    p.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testBigQueryStorageQuery1G() throws Exception {
+    setUpTestEnvironment("1G");
+    runBigQueryIOStorageQueryPipeline();
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java
new file mode 100644
index 0000000..847aa9a
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java
@@ -0,0 +1,712 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createJobIdToken;
+import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.createTempTableReference;
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.hasItem;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.withSettings;
+import static org.testng.Assert.assertFalse;
+
+import com.google.api.services.bigquery.model.JobStatistics;
+import com.google.api.services.bigquery.model.JobStatistics2;
+import com.google.api.services.bigquery.model.Table;
+import com.google.api.services.bigquery.model.TableFieldSchema;
+import com.google.api.services.bigquery.model.TableReference;
+import com.google.api.services.bigquery.model.TableRow;
+import com.google.api.services.bigquery.model.TableSchema;
+import com.google.cloud.bigquery.storage.v1beta1.AvroProto.AvroRows;
+import com.google.cloud.bigquery.storage.v1beta1.AvroProto.AvroSchema;
+import com.google.cloud.bigquery.storage.v1beta1.ReadOptions.TableReadOptions;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.CreateReadSessionRequest;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadRowsRequest;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadRowsResponse;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadSession;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.Stream;
+import com.google.cloud.bigquery.storage.v1beta1.Storage.StreamPosition;
+import com.google.protobuf.ByteString;
+import java.io.ByteArrayOutputStream;
+import java.util.Collection;
+import java.util.List;
+import java.util.Set;
+import org.apache.avro.Schema;
+import org.apache.avro.generic.GenericData.Record;
+import org.apache.avro.generic.GenericDatumWriter;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.avro.io.Encoder;
+import org.apache.avro.io.EncoderFactory;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder;
+import org.apache.beam.sdk.extensions.protobuf.ProtoCoder;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.QueryPriority;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.rules.TemporaryFolder;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.junit.runners.model.Statement;
+
+/** Tests for {@link BigQueryIO#readTableRows()} using {@link Method#DIRECT_READ}. */
+@RunWith(JUnit4.class)
+public class BigQueryIOStorageQueryTest {
+
+  private transient BigQueryOptions options;
+  private transient TemporaryFolder testFolder = new TemporaryFolder();
+  private transient TestPipeline p;
+
+  @Rule
+  public final transient TestRule folderThenPipeline =
+      new TestRule() {
+        @Override
+        public Statement apply(Statement base, Description description) {
+          // We need to set up the temporary folder, and then set up the TestPipeline based on the
+          // chosen folder. Unfortunately, since rule evaluation order is unspecified and unrelated
+          // to field order, and is separate from construction, that requires manually creating this
+          // TestRule.
+          Statement withPipeline =
+              new Statement() {
+                @Override
+                public void evaluate() throws Throwable {
+                  options = TestPipeline.testingPipelineOptions().as(BigQueryOptions.class);
+                  options.setProject("project-id");
+                  options.setTempLocation(testFolder.getRoot().getAbsolutePath());
+                  p = TestPipeline.fromOptions(options);
+                  p.apply(base, description).evaluate();
+                }
+              };
+
+          return testFolder.apply(withPipeline, description);
+        }
+      };
+
+  @Rule public transient ExpectedException thrown = ExpectedException.none();
+
+  private FakeDatasetService fakeDatasetService = new FakeDatasetService();
+  private FakeJobService fakeJobService = new FakeJobService();
+
+  private FakeBigQueryServices fakeBigQueryServices =
+      new FakeBigQueryServices()
+          .withDatasetService(fakeDatasetService)
+          .withJobService(fakeJobService);
+
+  @Before
+  public void setUp() throws Exception {
+    FakeDatasetService.setUp();
+  }
+
+  private static final String DEFAULT_QUERY = "SELECT * FROM `dataset.table` LIMIT 1";
+
+  @Test
+  public void testDefaultQueryBasedSource() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead();
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertTrue(typedRead.getValidate());
+    assertTrue(typedRead.getFlattenResults());
+    assertTrue(typedRead.getUseLegacySql());
+    assertNull(typedRead.getQueryPriority());
+    assertNull(typedRead.getQueryLocation());
+    assertNull(typedRead.getKmsKey());
+    assertFalse(typedRead.getWithTemplateCompatibility());
+  }
+
+  @Test
+  public void testQueryBasedSourceWithCustomQuery() throws Exception {
+    TypedRead<TableRow> typedRead =
+        BigQueryIO.read(new TableRowParser())
+            .fromQuery("SELECT * FROM `google.com:project.dataset.table`")
+            .withCoder(TableRowJsonCoder.of());
+    checkTypedReadQueryObject(typedRead, "SELECT * FROM `google.com:project.dataset.table`");
+  }
+
+  @Test
+  public void testQueryBasedSourceWithoutValidation() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead().withoutValidation();
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertFalse(typedRead.getValidate());
+  }
+
+  @Test
+  public void testQueryBasedSourceWithoutResultFlattening() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead().withoutResultFlattening();
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertFalse(typedRead.getFlattenResults());
+  }
+
+  @Test
+  public void testQueryBasedSourceWithStandardSql() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead().usingStandardSql();
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertFalse(typedRead.getUseLegacySql());
+  }
+
+  @Test
+  public void testQueryBasedSourceWithPriority() throws Exception {
+    TypedRead<TableRow> typedRead =
+        getDefaultTypedRead().withQueryPriority(QueryPriority.INTERACTIVE);
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertEquals(QueryPriority.INTERACTIVE, typedRead.getQueryPriority());
+  }
+
+  @Test
+  public void testQueryBasedSourceWithQueryLocation() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead().withQueryLocation("US");
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertEquals("US", typedRead.getQueryLocation());
+  }
+
+  @Test
+  public void testQueryBasedSourceWithKmsKey() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead().withKmsKey("kms_key");
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertEquals("kms_key", typedRead.getKmsKey());
+  }
+
+  @Test
+  public void testQueryBasedSourceWithTemplateCompatibility() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead().withTemplateCompatibility();
+    checkTypedReadQueryObject(typedRead, DEFAULT_QUERY);
+    assertTrue(typedRead.getWithTemplateCompatibility());
+  }
+
+  private TypedRead<TableRow> getDefaultTypedRead() {
+    return BigQueryIO.read(new TableRowParser())
+        .fromQuery(DEFAULT_QUERY)
+        .withCoder(TableRowJsonCoder.of())
+        .withMethod(Method.DIRECT_READ);
+  }
+
+  private void checkTypedReadQueryObject(TypedRead typedRead, String query) {
+    assertNull(typedRead.getTable());
+    assertEquals(query, typedRead.getQuery().get());
+  }
+
+  @Test
+  public void testBuildQueryBasedSourceWithReadOptions() throws Exception {
+    TableReadOptions readOptions = TableReadOptions.newBuilder().setRowRestriction("a > 5").build();
+    TypedRead<TableRow> typedRead = getDefaultTypedRead().withReadOptions(readOptions);
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage(
+        "Invalid BigQueryIO.Read: Specifies table read options, "
+            + "which only applies when reading from a table");
+    p.apply(typedRead);
+    p.run();
+  }
+
+  @Test
+  public void testDisplayData() throws Exception {
+    TypedRead<TableRow> typedRead = getDefaultTypedRead();
+    DisplayData displayData = DisplayData.from(typedRead);
+    assertThat(displayData, hasDisplayItem("query", DEFAULT_QUERY));
+  }
+
+  @Test
+  public void testEvaluatedDisplayData() throws Exception {
+    DisplayDataEvaluator evaluator = DisplayDataEvaluator.create();
+    TypedRead<TableRow> typedRead = getDefaultTypedRead();
+    Set<DisplayData> displayData = evaluator.displayDataForPrimitiveSourceTransforms(typedRead);
+    assertThat(displayData, hasItem(hasDisplayItem("query")));
+  }
+
+  @Test
+  public void testName() {
+    assertEquals("BigQueryIO.TypedRead", getDefaultTypedRead().getName());
+  }
+
+  @Test
+  public void testCoderInference() {
+    SerializableFunction<SchemaAndRecord, KV<ByteString, ReadSession>> parseFn =
+        new SerializableFunction<SchemaAndRecord, KV<ByteString, ReadSession>>() {
+          @Override
+          public KV<ByteString, ReadSession> apply(SchemaAndRecord input) {
+            return null;
+          }
+        };
+
+    assertEquals(
+        KvCoder.of(ByteStringCoder.of(), ProtoCoder.of(ReadSession.class)),
+        BigQueryIO.read(parseFn).inferCoder(CoderRegistry.createDefault()));
+  }
+
+  @Test
+  public void testQuerySourceEstimatedSize() throws Exception {
+
+    String fakeQuery = "fake query text";
+
+    fakeJobService.expectDryRunQuery(
+        options.getProject(),
+        fakeQuery,
+        new JobStatistics().setQuery(new JobStatistics2().setTotalBytesProcessed(125L)));
+
+    BigQueryStorageQuerySource<TableRow> querySource =
+        BigQueryStorageQuerySource.create(
+            /* stepUuid = */ "stepUuid",
+            ValueProvider.StaticValueProvider.of(fakeQuery),
+            /* flattenResults = */ true,
+            /* useLegacySql = */ true,
+            /* priority = */ QueryPriority.INTERACTIVE,
+            /* location = */ null,
+            /* kmsKey = */ null,
+            new TableRowParser(),
+            TableRowJsonCoder.of(),
+            fakeBigQueryServices);
+
+    assertEquals(125L, querySource.getEstimatedSizeBytes(options));
+  }
+
+  @Test
+  public void testQuerySourceInitialSplit() throws Exception {
+    doQuerySourceInitialSplit(1024L, 1024, 50);
+  }
+
+  @Test
+  public void testQuerySourceInitialSplit_MinSplitCount() throws Exception {
+    doQuerySourceInitialSplit(1024L * 1024L, 10, 1);
+  }
+
+  @Test
+  public void testQuerySourceInitialSplit_MaxSplitCount() throws Exception {
+    doQuerySourceInitialSplit(10, 10_000, 200);
+  }
+
+  private void doQuerySourceInitialSplit(
+      long bundleSize, int requestedStreamCount, int expectedStreamCount) throws Exception {
+
+    TableReference sourceTableRef = BigQueryHelpers.parseTableSpec("project:dataset.table");
+
+    fakeDatasetService.createDataset(
+        sourceTableRef.getProjectId(),
+        sourceTableRef.getDatasetId(),
+        "asia-northeast1",
+        "Fake plastic tree^H^H^H^Htables",
+        null);
+
+    fakeDatasetService.createTable(
+        new Table().setTableReference(sourceTableRef).setLocation("asia-northeast1"));
+
+    Table queryResultTable =
+        new Table()
+            .setSchema(
+                new TableSchema()
+                    .setFields(
+                        ImmutableList.of(
+                            new TableFieldSchema().setName("name").setType("STRING"),
+                            new TableFieldSchema().setName("number").setType("INTEGER"))))
+            .setNumBytes(1024L * 1024L);
+
+    String encodedQuery = FakeBigQueryServices.encodeQueryResult(queryResultTable);
+
+    fakeJobService.expectDryRunQuery(
+        options.getProject(),
+        encodedQuery,
+        new JobStatistics()
+            .setQuery(
+                new JobStatistics2()
+                    .setTotalBytesProcessed(1024L * 1024L)
+                    .setReferencedTables(ImmutableList.of(sourceTableRef))));
+
+    String stepUuid = "testStepUuid";
+
+    TableReference tempTableReference =
+        createTempTableReference(
+            options.getProject(), createJobIdToken(options.getJobName(), stepUuid));
+
+    CreateReadSessionRequest expectedRequest =
+        CreateReadSessionRequest.newBuilder()
+            .setParent("projects/" + options.getProject())
+            .setTableReference(BigQueryHelpers.toTableRefProto(tempTableReference))
+            .setRequestedStreams(requestedStreamCount)
+            .build();
+
+    ReadSession.Builder builder = ReadSession.newBuilder();
+    for (int i = 0; i < expectedStreamCount; i++) {
+      builder.addStreams(Stream.newBuilder().setName("stream-" + i));
+    }
+
+    StorageClient fakeStorageClient = mock(StorageClient.class);
+    when(fakeStorageClient.createReadSession(expectedRequest)).thenReturn(builder.build());
+
+    BigQueryStorageQuerySource<TableRow> querySource =
+        BigQueryStorageQuerySource.create(
+            stepUuid,
+            ValueProvider.StaticValueProvider.of(encodedQuery),
+            /* flattenResults = */ true,
+            /* useLegacySql = */ true,
+            /* priority = */ QueryPriority.BATCH,
+            /* location = */ null,
+            /* kmsKey = */ null,
+            new TableRowParser(),
+            TableRowJsonCoder.of(),
+            new FakeBigQueryServices()
+                .withDatasetService(fakeDatasetService)
+                .withJobService(fakeJobService)
+                .withStorageClient(fakeStorageClient));
+
+    List<? extends BoundedSource<TableRow>> sources = querySource.split(bundleSize, options);
+    assertEquals(expectedStreamCount, sources.size());
+  }
+
+  /**
+   * This test simulates the scenario where the SQL text which is executed by the query job doesn't
+   * by itself refer to any tables (e.g. "SELECT 17 AS value"), and thus there are no referenced
+   * tables when the dry run of the query is performed.
+   */
+  @Test
+  public void testQuerySourceInitialSplit_NoReferencedTables() throws Exception {
+
+    Table queryResultTable =
+        new Table()
+            .setSchema(
+                new TableSchema()
+                    .setFields(
+                        ImmutableList.of(
+                            new TableFieldSchema().setName("name").setType("STRING"),
+                            new TableFieldSchema().setName("number").setType("INTEGER"))))
+            .setNumBytes(1024L * 1024L);
+
+    String encodedQuery = FakeBigQueryServices.encodeQueryResult(queryResultTable);
+
+    fakeJobService.expectDryRunQuery(
+        options.getProject(),
+        encodedQuery,
+        new JobStatistics()
+            .setQuery(
+                new JobStatistics2()
+                    .setTotalBytesProcessed(1024L * 1024L)
+                    .setReferencedTables(ImmutableList.of())));
+
+    String stepUuid = "testStepUuid";
+
+    TableReference tempTableReference =
+        createTempTableReference(
+            options.getProject(), createJobIdToken(options.getJobName(), stepUuid));
+
+    CreateReadSessionRequest expectedRequest =
+        CreateReadSessionRequest.newBuilder()
+            .setParent("projects/" + options.getProject())
+            .setTableReference(BigQueryHelpers.toTableRefProto(tempTableReference))
+            .setRequestedStreams(1024)
+            .build();
+
+    ReadSession.Builder builder = ReadSession.newBuilder();
+    for (int i = 0; i < 1024; i++) {
+      builder.addStreams(Stream.newBuilder().setName("stream-" + i));
+    }
+
+    StorageClient fakeStorageClient = mock(StorageClient.class);
+    when(fakeStorageClient.createReadSession(expectedRequest)).thenReturn(builder.build());
+
+    BigQueryStorageQuerySource<TableRow> querySource =
+        BigQueryStorageQuerySource.create(
+            stepUuid,
+            ValueProvider.StaticValueProvider.of(encodedQuery),
+            /* flattenResults = */ true,
+            /* useLegacySql = */ true,
+            /* priority = */ QueryPriority.BATCH,
+            /* location = */ null,
+            /* kmsKey = */ null,
+            new TableRowParser(),
+            TableRowJsonCoder.of(),
+            new FakeBigQueryServices()
+                .withDatasetService(fakeDatasetService)
+                .withJobService(fakeJobService)
+                .withStorageClient(fakeStorageClient));
+
+    List<? extends BoundedSource<TableRow>> sources = querySource.split(1024, options);
+    assertEquals(1024, sources.size());
+  }
+
+  private static final String AVRO_SCHEMA_STRING =
+      "{\"namespace\": \"example.avro\",\n"
+          + " \"type\": \"record\",\n"
+          + " \"name\": \"RowRecord\",\n"
+          + " \"fields\": [\n"
+          + "     {\"name\": \"name\", \"type\": \"string\"},\n"
+          + "     {\"name\": \"number\", \"type\": \"long\"}\n"
+          + " ]\n"
+          + "}";
+
+  private static final Schema AVRO_SCHEMA = new Schema.Parser().parse(AVRO_SCHEMA_STRING);
+
+  private static final TableSchema TABLE_SCHEMA =
+      new TableSchema()
+          .setFields(
+              ImmutableList.of(
+                  new TableFieldSchema().setName("name").setType("STRING").setMode("REQUIRED"),
+                  new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED")));
+
+  private static GenericRecord createRecord(String name, long number, Schema schema) {
+    GenericRecord genericRecord = new Record(schema);
+    genericRecord.put("name", name);
+    genericRecord.put("number", number);
+    return genericRecord;
+  }
+
+  private static final EncoderFactory ENCODER_FACTORY = EncoderFactory.get();
+
+  private static ReadRowsResponse createResponse(
+      Schema schema, Collection<GenericRecord> genericRecords) throws Exception {
+    GenericDatumWriter<GenericRecord> writer = new GenericDatumWriter<>(schema);
+    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+    Encoder binaryEncoder = ENCODER_FACTORY.binaryEncoder(outputStream, null);
+    for (GenericRecord genericRecord : genericRecords) {
+      writer.write(genericRecord, binaryEncoder);
+    }
+
+    binaryEncoder.flush();
+
+    return ReadRowsResponse.newBuilder()
+        .setAvroRows(
+            AvroRows.newBuilder()
+                .setSerializedBinaryRows(ByteString.copyFrom(outputStream.toByteArray()))
+                .setRowCount(genericRecords.size()))
+        .build();
+  }
+
+  private static final class ParseKeyValue
+      implements SerializableFunction<SchemaAndRecord, KV<String, Long>> {
+    @Override
+    public KV<String, Long> apply(SchemaAndRecord input) {
+      return KV.of(
+          input.getRecord().get("name").toString(), (Long) input.getRecord().get("number"));
+    }
+  }
+
+  @Test
+  public void testQuerySourceInitialSplit_EmptyResult() throws Exception {
+
+    TableReference sourceTableRef = BigQueryHelpers.parseTableSpec("project:dataset.table");
+
+    fakeDatasetService.createDataset(
+        sourceTableRef.getProjectId(),
+        sourceTableRef.getDatasetId(),
+        "asia-northeast1",
+        "Fake plastic tree^H^H^H^Htables",
+        null);
+
+    fakeDatasetService.createTable(
+        new Table().setTableReference(sourceTableRef).setLocation("asia-northeast1"));
+
+    Table queryResultTable =
+        new Table()
+            .setSchema(
+                new TableSchema()
+                    .setFields(
+                        ImmutableList.of(
+                            new TableFieldSchema().setName("name").setType("STRING"),
+                            new TableFieldSchema().setName("number").setType("INTEGER"))))
+            .setNumBytes(0L);
+
+    String encodedQuery = FakeBigQueryServices.encodeQueryResult(queryResultTable);
+
+    fakeJobService.expectDryRunQuery(
+        options.getProject(),
+        encodedQuery,
+        new JobStatistics()
+            .setQuery(
+                new JobStatistics2()
+                    .setTotalBytesProcessed(1024L * 1024L)
+                    .setReferencedTables(ImmutableList.of(sourceTableRef))));
+
+    String stepUuid = "testStepUuid";
+
+    TableReference tempTableReference =
+        createTempTableReference(
+            options.getProject(), createJobIdToken(options.getJobName(), stepUuid));
+
+    CreateReadSessionRequest expectedRequest =
+        CreateReadSessionRequest.newBuilder()
+            .setParent("projects/" + options.getProject())
+            .setTableReference(BigQueryHelpers.toTableRefProto(tempTableReference))
+            .setRequestedStreams(10)
+            .build();
+
+    ReadSession emptyReadSession = ReadSession.newBuilder().build();
+    StorageClient fakeStorageClient = mock(StorageClient.class);
+    when(fakeStorageClient.createReadSession(expectedRequest)).thenReturn(emptyReadSession);
+
+    BigQueryStorageQuerySource<TableRow> querySource =
+        BigQueryStorageQuerySource.create(
+            stepUuid,
+            ValueProvider.StaticValueProvider.of(encodedQuery),
+            /* flattenResults = */ true,
+            /* useLegacySql = */ true,
+            /* priority = */ QueryPriority.BATCH,
+            /* location = */ null,
+            /* kmsKey = */ null,
+            new TableRowParser(),
+            TableRowJsonCoder.of(),
+            new FakeBigQueryServices()
+                .withDatasetService(fakeDatasetService)
+                .withJobService(fakeJobService)
+                .withStorageClient(fakeStorageClient));
+
+    List<? extends BoundedSource<TableRow>> sources = querySource.split(1024L, options);
+    assertTrue(sources.isEmpty());
+  }
+
+  @Test
+  public void testQuerySourceCreateReader() throws Exception {
+    BigQueryStorageQuerySource<TableRow> querySource =
+        BigQueryStorageQuerySource.create(
+            /* stepUuid = */ "testStepUuid",
+            ValueProvider.StaticValueProvider.of("SELECT * FROM `dataset.table`"),
+            /* flattenResults = */ false,
+            /* useLegacySql = */ false,
+            /* priority = */ QueryPriority.INTERACTIVE,
+            /* location = */ "asia-northeast1",
+            /* kmsKey = */ null,
+            new TableRowParser(),
+            TableRowJsonCoder.of(),
+            fakeBigQueryServices);
+
+    thrown.expect(UnsupportedOperationException.class);
+    thrown.expectMessage("BigQuery storage source must be split before reading");
+    querySource.createReader(options);
+  }
+
+  @Test
+  public void testReadFromBigQueryIO() throws Exception {
+    doReadFromBigQueryIO(false);
+  }
+
+  @Test
+  public void testReadFromBigQueryIOWithTemplateCompatibility() throws Exception {
+    doReadFromBigQueryIO(true);
+  }
+
+  private void doReadFromBigQueryIO(boolean templateCompatibility) throws Exception {
+
+    TableReference sourceTableRef = BigQueryHelpers.parseTableSpec("project:dataset.table");
+
+    fakeDatasetService.createDataset(
+        sourceTableRef.getProjectId(),
+        sourceTableRef.getDatasetId(),
+        "asia-northeast1",
+        "Fake plastic tree^H^H^H^Htables",
+        null);
+
+    fakeDatasetService.createTable(
+        new Table().setTableReference(sourceTableRef).setLocation("asia-northeast1"));
+
+    Table queryResultTable =
+        new Table()
+            .setSchema(
+                new TableSchema()
+                    .setFields(
+                        ImmutableList.of(
+                            new TableFieldSchema().setName("name").setType("STRING"),
+                            new TableFieldSchema().setName("number").setType("INTEGER"))))
+            .setNumBytes(0L);
+
+    String encodedQuery = FakeBigQueryServices.encodeQueryResult(queryResultTable);
+
+    fakeJobService.expectDryRunQuery(
+        options.getProject(),
+        encodedQuery,
+        new JobStatistics()
+            .setQuery(
+                new JobStatistics2()
+                    .setTotalBytesProcessed(1024L * 1024L)
+                    .setReferencedTables(ImmutableList.of(sourceTableRef))));
+
+    ReadSession readSession =
+        ReadSession.newBuilder()
+            .setName("readSessionName")
+            .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING))
+            .addStreams(Stream.newBuilder().setName("streamName"))
+            .build();
+
+    ReadRowsRequest expectedReadRowsRequest =
+        ReadRowsRequest.newBuilder()
+            .setReadPosition(
+                StreamPosition.newBuilder().setStream(Stream.newBuilder().setName("streamName")))
+            .build();
+
+    List<GenericRecord> records =
+        Lists.newArrayList(
+            createRecord("A", 1, AVRO_SCHEMA),
+            createRecord("B", 2, AVRO_SCHEMA),
+            createRecord("C", 3, AVRO_SCHEMA),
+            createRecord("D", 4, AVRO_SCHEMA));
+
+    List<ReadRowsResponse> readRowsResponses =
+        Lists.newArrayList(
+            createResponse(AVRO_SCHEMA, records.subList(0, 2)),
+            createResponse(AVRO_SCHEMA, records.subList(2, 4)));
+
+    //
+    // Note that since the temporary table name is generated by the pipeline, we can't match the
+    // expected create read session request exactly. For now, match against any appropriately typed
+    // proto object.
+    //
+
+    StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable());
+    when(fakeStorageClient.createReadSession(any())).thenReturn(readSession);
+    when(fakeStorageClient.readRows(expectedReadRowsRequest)).thenReturn(readRowsResponses);
+
+    BigQueryIO.TypedRead<KV<String, Long>> typedRead =
+        BigQueryIO.read(new ParseKeyValue())
+            .fromQuery(encodedQuery)
+            .withMethod(Method.DIRECT_READ)
+            .withTestServices(
+                new FakeBigQueryServices()
+                    .withDatasetService(fakeDatasetService)
+                    .withJobService(fakeJobService)
+                    .withStorageClient(fakeStorageClient));
+
+    if (templateCompatibility) {
+      typedRead = typedRead.withTemplateCompatibility();
+    }
+
+    PCollection<KV<String, Long>> output = p.apply(typedRead);
+
+    PAssert.that(output)
+        .containsInAnyOrder(
+            ImmutableList.of(KV.of("A", 1L), KV.of("B", 2L), KV.of("C", 3L), KV.of("D", 4L)));
+
+    p.run();
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java
index 2762329..3de15f3 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java
@@ -541,7 +541,7 @@ public class BigQueryIOStorageReadTest {
             new FakeBigQueryServices().withDatasetService(fakeDatasetService));
 
     thrown.expect(UnsupportedOperationException.class);
-    thrown.expectMessage("BigQuery table source must be split before reading");
+    thrown.expectMessage("BigQuery storage source must be split before reading");
     tableSource.createReader(options);
   }
 
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeBigQueryServices.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeBigQueryServices.java
index 1e60c1e..b10f960 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeBigQueryServices.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeBigQueryServices.java
@@ -18,14 +18,18 @@
 package org.apache.beam.sdk.io.gcp.bigquery;
 
 import com.google.api.client.util.Base64;
+import com.google.api.services.bigquery.model.Table;
 import com.google.api.services.bigquery.model.TableRow;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.util.List;
 import org.apache.beam.sdk.annotations.Experimental;
-import org.apache.beam.sdk.coders.Coder.Context;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
 
 /** A fake implementation of BigQuery's query service.. */
 @Experimental(Experimental.Kind.SOURCE_SINK)
@@ -64,21 +68,28 @@ public class FakeBigQueryServices implements BigQueryServices {
     return storageClient;
   }
 
-  static List<TableRow> rowsFromEncodedQuery(String query) throws IOException {
-    ListCoder<TableRow> listCoder = ListCoder.of(TableRowJsonCoder.of());
-    ByteArrayInputStream input = new ByteArrayInputStream(Base64.decodeBase64(query));
-    List<TableRow> rows = listCoder.decode(input, Context.OUTER);
-    for (TableRow row : rows) {
-      convertNumbers(row);
-    }
-    return rows;
+  static String encodeQueryResult(Table table) throws IOException {
+    return encodeQueryResult(table, ImmutableList.of());
+  }
+
+  static String encodeQueryResult(Table table, List<TableRow> rows) throws IOException {
+    KvCoder<String, List<TableRow>> coder =
+        KvCoder.of(StringUtf8Coder.of(), ListCoder.of(TableRowJsonCoder.of()));
+    KV<String, List<TableRow>> kv = KV.of(BigQueryHelpers.toJsonString(table), rows);
+    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+    coder.encode(kv, outputStream);
+    return Base64.encodeBase64String(outputStream.toByteArray());
   }
 
-  static String encodeQuery(List<TableRow> rows) throws IOException {
-    ListCoder<TableRow> listCoder = ListCoder.of(TableRowJsonCoder.of());
-    ByteArrayOutputStream output = new ByteArrayOutputStream();
-    listCoder.encode(rows, output, Context.OUTER);
-    return Base64.encodeBase64String(output.toByteArray());
+  static KV<Table, List<TableRow>> decodeQueryResult(String queryResult) throws IOException {
+    KvCoder<String, List<TableRow>> coder =
+        KvCoder.of(StringUtf8Coder.of(), ListCoder.of(TableRowJsonCoder.of()));
+    ByteArrayInputStream inputStream = new ByteArrayInputStream(Base64.decodeBase64(queryResult));
+    KV<String, List<TableRow>> kv = coder.decode(inputStream);
+    Table table = BigQueryHelpers.fromJsonString(kv.getKey(), Table.class);
+    List<TableRow> rows = kv.getValue();
+    rows.forEach(FakeBigQueryServices::convertNumbers);
+    return KV.of(table, rows);
   }
 
   // Longs tend to get converted back to Integers due to JSON serialization. Convert them back.
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeJobService.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeJobService.java
index d03033e..c5ce93b 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeJobService.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FakeJobService.java
@@ -70,6 +70,7 @@ import org.apache.beam.sdk.util.BackOffAdapter;
 import org.apache.beam.sdk.util.FluentBackoff;
 import org.apache.beam.sdk.util.MimeTypes;
 import org.apache.beam.sdk.util.Transport;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.HashBasedTable;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
@@ -433,9 +434,9 @@ public class FakeJobService implements JobService, Serializable {
 
   private JobStatus runQueryJob(JobConfigurationQuery query)
       throws IOException, InterruptedException {
-    List<TableRow> rows = FakeBigQueryServices.rowsFromEncodedQuery(query.getQuery());
-    datasetService.createTable(new Table().setTableReference(query.getDestinationTable()));
-    datasetService.insertAll(query.getDestinationTable(), rows, null);
+    KV<Table, List<TableRow>> result = FakeBigQueryServices.decodeQueryResult(query.getQuery());
+    datasetService.createTable(result.getKey().setTableReference(query.getDestinationTable()));
+    datasetService.insertAll(query.getDestinationTable(), result.getValue(), null);
     return new JobStatus().setState("DONE");
   }