You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2022/09/02 16:51:56 UTC

[beam] branch master updated: Adding support for Beam Schema Rows with BQ DIRECT_READ (#22926)

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 948af30a5b6 Adding support for Beam Schema Rows with BQ DIRECT_READ (#22926)
948af30a5b6 is described below

commit 948af30a5b665fe74b7052b673e95ff5f5fc426a
Author: Pablo Estrada <pa...@users.noreply.github.com>
AuthorDate: Fri Sep 2 09:51:50 2022 -0700

    Adding support for Beam Schema Rows with BQ DIRECT_READ (#22926)
    
    * Adding support for Beam Schema Rows with BQ DIRECT_READ
    
    * Fixing for trimmed-out fields
    
    * refactor
    
    * Fix NPE in FakeJobService
    
    * Addressing comments
---
 .../beam/sdk/io/gcp/bigquery/BigQueryIO.java       | 85 ++++++++++++++++------
 .../beam/sdk/io/gcp/bigquery/BigQueryUtils.java    |  2 +-
 .../beam/sdk/io/gcp/testing/FakeJobService.java    |  4 +-
 .../io/gcp/bigquery/BigQueryIOStorageReadIT.java   | 36 +++++++++
 .../io/gcp/bigquery/BigQueryIOStorageReadTest.java | 77 ++++++++++++++++++++
 5 files changed, 178 insertions(+), 26 deletions(-)

diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
index 59db2f95cf9..3a6280c3038 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
@@ -1068,16 +1068,20 @@ public class BigQueryIO {
       checkArgument(getParseFn() != null, "A parseFn is required");
 
       // if both toRowFn and fromRowFn values are set, enable Beam schema support
-      boolean beamSchemaEnabled = false;
+      Pipeline p = input.getPipeline();
+      final BigQuerySourceDef sourceDef = createSourceDef();
+
+      Schema beamSchema = null;
       if (getTypeDescriptor() != null && getToBeamRowFn() != null && getFromBeamRowFn() != null) {
-        beamSchemaEnabled = true;
+        BigQueryOptions bqOptions = p.getOptions().as(BigQueryOptions.class);
+        beamSchema = sourceDef.getBeamSchema(bqOptions);
+        beamSchema = getFinalSchema(beamSchema, getSelectedFields());
       }
 
-      Pipeline p = input.getPipeline();
       final Coder<T> coder = inferCoder(p.getCoderRegistry());
 
       if (getMethod() == TypedRead.Method.DIRECT_READ) {
-        return expandForDirectRead(input, coder);
+        return expandForDirectRead(input, coder, beamSchema);
       }
 
       checkArgument(
@@ -1090,7 +1094,6 @@ public class BigQueryIO {
           "Invalid BigQueryIO.Read: Specifies row restriction, "
               + "which only applies when using Method.DIRECT_READ");
 
-      final BigQuerySourceDef sourceDef = createSourceDef();
       final PCollectionView<String> jobIdTokenView;
       PCollection<String> jobIdTokenCollection;
       PCollection<T> rows;
@@ -1221,33 +1224,60 @@ public class BigQueryIO {
 
       rows = rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView));
 
-      if (beamSchemaEnabled) {
-        BigQueryOptions bqOptions = p.getOptions().as(BigQueryOptions.class);
-        Schema beamSchema = sourceDef.getBeamSchema(bqOptions);
-        SerializableFunction<T, Row> toBeamRow = getToBeamRowFn().apply(beamSchema);
-        SerializableFunction<Row, T> fromBeamRow = getFromBeamRowFn().apply(beamSchema);
-
-        rows.setSchema(beamSchema, getTypeDescriptor(), toBeamRow, fromBeamRow);
+      if (beamSchema != null) {
+        rows.setSchema(
+            beamSchema,
+            getTypeDescriptor(),
+            getToBeamRowFn().apply(beamSchema),
+            getFromBeamRowFn().apply(beamSchema));
       }
       return rows;
     }
 
-    private PCollection<T> expandForDirectRead(PBegin input, Coder<T> outputCoder) {
+    private static Schema getFinalSchema(
+        Schema beamSchema, ValueProvider<List<String>> selectedFields) {
+      List<Schema.Field> flds =
+          beamSchema.getFields().stream()
+              .filter(
+                  field -> {
+                    if (selectedFields != null
+                        && selectedFields.isAccessible()
+                        && selectedFields.get() != null) {
+                      return selectedFields.get().contains(field.getName());
+                    } else {
+                      return true;
+                    }
+                  })
+              .collect(Collectors.toList());
+      return Schema.builder().addFields(flds).build();
+    }
+
+    private PCollection<T> expandForDirectRead(
+        PBegin input, Coder<T> outputCoder, Schema beamSchema) {
       ValueProvider<TableReference> tableProvider = getTableProvider();
       Pipeline p = input.getPipeline();
       if (tableProvider != null) {
         // No job ID is required. Read directly from BigQuery storage.
-        return p.apply(
-            org.apache.beam.sdk.io.Read.from(
-                BigQueryStorageTableSource.create(
-                    tableProvider,
-                    getFormat(),
-                    getSelectedFields(),
-                    getRowRestriction(),
-                    getParseFn(),
-                    outputCoder,
-                    getBigQueryServices(),
-                    getProjectionPushdownApplied())));
+        PCollection<T> rows =
+            p.apply(
+                org.apache.beam.sdk.io.Read.from(
+                    BigQueryStorageTableSource.create(
+                        tableProvider,
+                        getFormat(),
+                        getSelectedFields(),
+                        getRowRestriction(),
+                        getParseFn(),
+                        outputCoder,
+                        getBigQueryServices(),
+                        getProjectionPushdownApplied())));
+        if (beamSchema != null) {
+          rows.setSchema(
+              beamSchema,
+              getTypeDescriptor(),
+              getToBeamRowFn().apply(beamSchema),
+              getFromBeamRowFn().apply(beamSchema));
+        }
+        return rows;
       }
 
       checkArgument(
@@ -1437,6 +1467,13 @@ public class BigQueryIO {
             }
           };
 
+      if (beamSchema != null) {
+        rows.setSchema(
+            beamSchema,
+            getTypeDescriptor(),
+            getToBeamRowFn().apply(beamSchema),
+            getFromBeamRowFn().apply(beamSchema));
+      }
       return rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView));
     }
 
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java
index 06db56234b5..e152beb623d 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java
@@ -645,7 +645,7 @@ public class BigQueryUtils {
         return null;
       } else {
         throw new IllegalArgumentException(
-            "Received null value for non-nullable field " + field.getName());
+            "Received null value for non-nullable field \"" + field.getName() + "\"");
       }
     }
 
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
index 569ea67003e..d2f6816806c 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
@@ -320,7 +320,9 @@ public class FakeJobService implements JobService, Serializable {
                               "Job %s failed: %s", job.job.getConfiguration(), e.toString())));
           List<ResourceId> sourceFiles =
               filesForLoadJobs.get(jobRef.getProjectId(), jobRef.getJobId());
-          FileSystems.delete(sourceFiles);
+          if (sourceFiles != null) {
+            FileSystems.delete(sourceFiles);
+          }
         }
         return JSON_FACTORY.fromString(JSON_FACTORY.toString(job.job), Job.class);
       }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java
index b7c50c8054a..bb576be7dd1 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.sdk.io.gcp.bigquery;
 
+import static org.junit.Assert.assertEquals;
+
 import com.google.cloud.bigquery.storage.v1.DataFormat;
 import java.util.Map;
 import org.apache.beam.sdk.Pipeline;
@@ -32,6 +34,7 @@ import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.transforms.Convert;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.TestPipelineOptions;
@@ -123,6 +126,39 @@ public class BigQueryIOStorageReadIT {
     runBigQueryIOStorageReadPipeline();
   }
 
+  @Test
+  public void testBigQueryStorageReadWithAvro() throws Exception {
+    storageReadWithSchema(DataFormat.AVRO);
+  }
+
+  @Test
+  public void testBigQueryStorageReadWithArrow() throws Exception {
+    storageReadWithSchema(DataFormat.ARROW);
+  }
+
+  private void storageReadWithSchema(DataFormat format) {
+    setUpTestEnvironment("multi_field", format);
+
+    Schema multiFieldSchema =
+        Schema.builder()
+            .addNullableField("string_field", FieldType.STRING)
+            .addNullableField("int_field", FieldType.INT64)
+            .build();
+
+    Pipeline p = Pipeline.create(options);
+    PCollection<Row> tableContents =
+        p.apply(
+                "Read",
+                BigQueryIO.readTableRowsWithSchema()
+                    .from(options.getInputTable())
+                    .withMethod(Method.DIRECT_READ)
+                    .withFormat(options.getDataFormat()))
+            .apply(Convert.toRows());
+    PAssert.thatSingleton(tableContents.apply(Count.globally())).isEqualTo(options.getNumRecords());
+    assertEquals(tableContents.getSchema(), multiFieldSchema);
+    p.run().waitUntilFinish();
+  }
+
   /**
    * Tests a pipeline where {@link
    * org.apache.beam.runners.core.construction.graph.ProjectionPushdownOptimizer} may do
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java
index e9a98a52359..518c4a80cdb 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java
@@ -102,6 +102,7 @@ import org.apache.beam.sdk.options.ValueProvider;
 import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
 import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.transforms.Convert;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -1481,6 +1482,82 @@ public class BigQueryIOStorageReadTest {
     p.run();
   }
 
+  @Test
+  public void testReadFromBigQueryIOWithBeamSchema() throws Exception {
+    fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null);
+    TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table");
+    Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA);
+    fakeDatasetService.createTable(table);
+
+    CreateReadSessionRequest expectedCreateReadSessionRequest =
+        CreateReadSessionRequest.newBuilder()
+            .setParent("projects/project-id")
+            .setReadSession(
+                ReadSession.newBuilder()
+                    .setTable("projects/foo.com:project/datasets/dataset/tables/table")
+                    .setReadOptions(
+                        ReadSession.TableReadOptions.newBuilder().addSelectedFields("name"))
+                    .setDataFormat(DataFormat.AVRO))
+            .setMaxStreamCount(10)
+            .build();
+
+    ReadSession readSession =
+        ReadSession.newBuilder()
+            .setName("readSessionName")
+            .setAvroSchema(AvroSchema.newBuilder().setSchema(TRIMMED_AVRO_SCHEMA_STRING))
+            .addStreams(ReadStream.newBuilder().setName("streamName"))
+            .setDataFormat(DataFormat.AVRO)
+            .build();
+
+    ReadRowsRequest expectedReadRowsRequest =
+        ReadRowsRequest.newBuilder().setReadStream("streamName").build();
+
+    List<GenericRecord> records =
+        Lists.newArrayList(
+            createRecord("A", TRIMMED_AVRO_SCHEMA),
+            createRecord("B", TRIMMED_AVRO_SCHEMA),
+            createRecord("C", TRIMMED_AVRO_SCHEMA),
+            createRecord("D", TRIMMED_AVRO_SCHEMA));
+
+    List<ReadRowsResponse> readRowsResponses =
+        Lists.newArrayList(
+            createResponse(TRIMMED_AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50),
+            createResponse(TRIMMED_AVRO_SCHEMA, records.subList(2, 4), 0.5, 0.75));
+
+    StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable());
+    when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest))
+        .thenReturn(readSession);
+    when(fakeStorageClient.readRows(expectedReadRowsRequest, ""))
+        .thenReturn(new FakeBigQueryServerStream<>(readRowsResponses));
+
+    PCollection<Row> output =
+        p.apply(
+                BigQueryIO.readTableRowsWithSchema()
+                    .from("foo.com:project:dataset.table")
+                    .withMethod(Method.DIRECT_READ)
+                    .withSelectedFields(Lists.newArrayList("name"))
+                    .withFormat(DataFormat.AVRO)
+                    .withTestServices(
+                        new FakeBigQueryServices()
+                            .withDatasetService(fakeDatasetService)
+                            .withStorageClient(fakeStorageClient)))
+            .apply(Convert.toRows());
+
+    org.apache.beam.sdk.schemas.Schema beamSchema =
+        org.apache.beam.sdk.schemas.Schema.of(
+            org.apache.beam.sdk.schemas.Schema.Field.of(
+                "name", org.apache.beam.sdk.schemas.Schema.FieldType.STRING));
+    PAssert.that(output)
+        .containsInAnyOrder(
+            ImmutableList.of(
+                Row.withSchema(beamSchema).addValue("A").build(),
+                Row.withSchema(beamSchema).addValue("B").build(),
+                Row.withSchema(beamSchema).addValue("C").build(),
+                Row.withSchema(beamSchema).addValue("D").build()));
+
+    p.run();
+  }
+
   @Test
   public void testReadFromBigQueryIOArrow() throws Exception {
     fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null);