You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jk...@apache.org on 2017/06/28 03:16:52 UTC
[3/4] beam git commit: Read api with naive implementation
Read api with naive implementation
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a21a6d79
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a21a6d79
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a21a6d79
Branch: refs/heads/master
Commit: a21a6d797777ad38b927bd5d44c63306d85a752a
Parents: 454f1c4
Author: Mairbek Khadikov <ma...@google.com>
Authored: Mon Jun 19 13:28:52 2017 -0700
Committer: Eugene Kirpichov <ki...@google.com>
Committed: Tue Jun 27 18:38:21 2017 -0700
----------------------------------------------------------------------
pom.xml | 12 +
sdks/java/io/google-cloud-platform/pom.xml | 16 +-
.../sdk/io/gcp/spanner/AbstractSpannerFn.java | 17 +
.../sdk/io/gcp/spanner/CreateTransactionFn.java | 51 ++
.../sdk/io/gcp/spanner/NaiveSpannerReadFn.java | 65 +++
.../beam/sdk/io/gcp/spanner/SpannerConfig.java | 29 +-
.../beam/sdk/io/gcp/spanner/SpannerIO.java | 479 ++++++++++++++++---
.../sdk/io/gcp/spanner/SpannerWriteGroupFn.java | 17 +
.../beam/sdk/io/gcp/spanner/Transaction.java | 33 ++
.../beam/sdk/io/gcp/GcpApiSurfaceTest.java | 10 +
.../sdk/io/gcp/spanner/FakeServiceFactory.java | 82 ++++
.../sdk/io/gcp/spanner/SpannerIOReadTest.java | 275 +++++++++++
.../beam/sdk/io/gcp/spanner/SpannerIOTest.java | 314 ------------
.../sdk/io/gcp/spanner/SpannerIOWriteTest.java | 258 ++++++++++
.../beam/sdk/io/gcp/spanner/SpannerReadIT.java | 169 +++++++
15 files changed, 1432 insertions(+), 395 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index f06568b..069191c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -161,6 +161,7 @@
<compiler.error.flag>-Werror</compiler.error.flag>
<compiler.default.pkginfo.flag>-Xpkginfo:always</compiler.default.pkginfo.flag>
<compiler.default.exclude>nothing</compiler.default.exclude>
+ <gax-grpc.version>0.20.0</gax-grpc.version>
</properties>
<packaging>pom</packaging>
@@ -638,6 +639,12 @@
</dependency>
<dependency>
+ <groupId>com.google.api</groupId>
+ <artifactId>gax-grpc</artifactId>
+ <version>${gax-grpc.version}</version>
+ </dependency>
+
+ <dependency>
<groupId>com.google.api-client</groupId>
<artifactId>google-api-client</artifactId>
<version>${google-clients.version}</version>
@@ -852,6 +859,11 @@
</dependency>
<dependency>
+ <groupId>com.google.cloud</groupId>
+ <artifactId>google-cloud-core-grpc</artifactId>
+ <version>${grpc.version}</version>
+ </dependency>
+ <dependency>
<groupId>com.google.cloud.bigtable</groupId>
<artifactId>bigtable-protos</artifactId>
<version>${bigtable.version}</version>
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/pom.xml
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml
index 6737eea..94066c7 100644
--- a/sdks/java/io/google-cloud-platform/pom.xml
+++ b/sdks/java/io/google-cloud-platform/pom.xml
@@ -93,7 +93,12 @@
<dependency>
<groupId>com.google.api</groupId>
- <artifactId>api-common</artifactId>
+ <artifactId>gax-grpc</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.google.cloud</groupId>
+ <artifactId>google-cloud-core-grpc</artifactId>
</dependency>
<dependency>
@@ -255,12 +260,17 @@
<dependency>
<groupId>org.apache.commons</groupId>
- <artifactId>commons-text</artifactId>
- <scope>test</scope>
+ <artifactId>commons-lang3</artifactId>
+ <scope>provided</scope>
</dependency>
<!-- Test dependencies -->
<dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-text</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.beam</groupId>
<artifactId>beam-sdks-java-core</artifactId>
<classifier>tests</classifier>
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java
index 08f7fa9..00008f1 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package org.apache.beam.sdk.io.gcp.spanner;
import com.google.cloud.spanner.DatabaseClient;
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java
new file mode 100644
index 0000000..da8e8b1
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import com.google.cloud.spanner.ReadOnlyTransaction;
+import com.google.cloud.spanner.ResultSet;
+import com.google.cloud.spanner.Statement;
+
+/** Creates a batch transaction. */
+class CreateTransactionFn extends AbstractSpannerFn<Object, Transaction> {
+
+ private final SpannerIO.CreateTransaction config;
+
+ CreateTransactionFn(SpannerIO.CreateTransaction config) {
+ this.config = config;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ try (ReadOnlyTransaction readOnlyTransaction =
+ databaseClient().readOnlyTransaction(config.getTimestampBound())) {
+ // Run a dummy sql statement to force the RPC and obtain the timestamp from the server.
+ ResultSet resultSet = readOnlyTransaction.executeQuery(Statement.of("SELECT 1"));
+ while (resultSet.next()) {
+ // do nothing
+ }
+ Transaction tx = Transaction.create(readOnlyTransaction.getReadTimestamp());
+ c.output(tx);
+ }
+ }
+
+ @Override
+ SpannerConfig getSpannerConfig() {
+ return config.getSpannerConfig();
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java
new file mode 100644
index 0000000..d193b95
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import com.google.cloud.spanner.ReadOnlyTransaction;
+import com.google.cloud.spanner.ResultSet;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.TimestampBound;
+import com.google.common.annotations.VisibleForTesting;
+
+/** A simplest read function implementation. Parallelism support is coming. */
+@VisibleForTesting
+class NaiveSpannerReadFn extends AbstractSpannerFn<Object, Struct> {
+ private final SpannerIO.Read config;
+
+ NaiveSpannerReadFn(SpannerIO.Read config) {
+ this.config = config;
+ }
+
+ SpannerConfig getSpannerConfig() {
+ return config.getSpannerConfig();
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ TimestampBound timestampBound = TimestampBound.strong();
+ if (config.getTransaction() != null) {
+ Transaction transaction = c.sideInput(config.getTransaction());
+ timestampBound = TimestampBound.ofReadTimestamp(transaction.timestamp());
+ }
+ try (ReadOnlyTransaction readOnlyTransaction =
+ databaseClient().readOnlyTransaction(timestampBound)) {
+ ResultSet resultSet = execute(readOnlyTransaction);
+ while (resultSet.next()) {
+ c.output(resultSet.getCurrentRowAsStruct());
+ }
+ }
+ }
+
+ private ResultSet execute(ReadOnlyTransaction readOnlyTransaction) {
+ if (config.getQuery() != null) {
+ return readOnlyTransaction.executeQuery(config.getQuery());
+ }
+ if (config.getIndex() != null) {
+ return readOnlyTransaction.readUsingIndex(
+ config.getTable(), config.getIndex(), config.getKeySet(), config.getColumns());
+ }
+ return readOnlyTransaction.read(config.getTable(), config.getKeySet(), config.getColumns());
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
index 4cb8aa2..02716fb 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package org.apache.beam.sdk.io.gcp.spanner;
import static com.google.common.base.Preconditions.checkNotNull;
@@ -17,8 +34,6 @@ import org.apache.beam.sdk.transforms.display.DisplayData;
@AutoValue
public abstract class SpannerConfig implements Serializable {
- private static final long serialVersionUID = -5680874609304170301L;
-
@Nullable
abstract ValueProvider<String> getProjectId();
@@ -49,7 +64,7 @@ public abstract class SpannerConfig implements Serializable {
return builder().build();
}
- public static Builder builder() {
+ static Builder builder() {
return new AutoValue_SpannerConfig.Builder();
}
@@ -79,14 +94,12 @@ public abstract class SpannerConfig implements Serializable {
@AutoValue.Builder
public abstract static class Builder {
-
abstract Builder setProjectId(ValueProvider<String> projectId);
abstract Builder setInstanceId(ValueProvider<String> instanceId);
abstract Builder setDatabaseId(ValueProvider<String> databaseId);
-
abstract Builder setServiceFactory(ServiceFactory<Spanner, SpannerOptions> serviceFactory);
public abstract SpannerConfig build();
@@ -115,4 +128,10 @@ public abstract class SpannerConfig implements Serializable {
public SpannerConfig withDatabaseId(String databaseId) {
return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId));
}
+
+ @VisibleForTesting
+ SpannerConfig withServiceFactory(ServiceFactory<Spanner, SpannerOptions> serviceFactory) {
+ return toBuilder().setServiceFactory(serviceFactory).build();
+ }
+
}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
index 791c7e7..acf9285 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
@@ -17,23 +17,38 @@
*/
package org.apache.beam.sdk.io.gcp.spanner;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+
import com.google.auto.value.AutoValue;
import com.google.cloud.ServiceFactory;
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.KeySet;
import com.google.cloud.spanner.Mutation;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
+import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.TimestampBound;
import com.google.common.annotations.VisibleForTesting;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PDone;
/**
@@ -42,7 +57,69 @@ import org.apache.beam.sdk.values.PDone;
*
* <h3>Reading from Cloud Spanner</h3>
*
- * <p>This functionality is not yet implemented.
+ * <p>To read from Cloud Spanner, apply {@link SpannerIO.Read} transformation. It will return a
+ * {@link PCollection} of {@link Struct Structs}, where each element represents
+ * an individual row returned from the read operation. Both Query and Read APIs are supported.
+ * See more information about <a href="https://cloud.google.com/spanner/docs/reads">reading from
+ * Cloud Spanner</a>
+ *
+ * <p>To execute a <strong>query</strong>, specify a {@link SpannerIO.Read#withQuery(Statement)} or
+ * {@link SpannerIO.Read#withQuery(String)} during the construction of the transform.
+ *
+ * <pre>{@code
+ * PCollection<Struct> rows = p.apply(
+ * SpannerIO.read()
+ * .withInstanceId(instanceId)
+ * .withDatabaseId(dbId)
+ * .withQuery("SELECT id, name, email FROM users"));
+ * }</pre>
+ *
+ * <p>To use the Read API, specify a {@link SpannerIO.Read#withTable(String) table name} and
+ * a {@link SpannerIO.Read#withColumns(List) list of columns}.
+ *
+ * <pre>{@code
+ * PCollection<Struct> rows = p.apply(
+ * SpannerIO.read()
+ * .withInstanceId(instanceId)
+ * .withDatabaseId(dbId)
+ * .withTable("users")
+ * .withColumns("id", "name", "email"));
+ * }</pre>
+ *
+ * <p>To optimally read using index, specify the index name using {@link SpannerIO.Read#withIndex}.
+ *
+ * <p>The transform is guaranteed to be executed on a consistent snapshot of data, utilizing the
+ * power of read only transactions. Staleness of data can be controlled using
+ * {@link SpannerIO.Read#withTimestampBound} or {@link SpannerIO.Read#withTimestamp(Timestamp)}
+ * methods. <a href="https://cloud.google.com/spanner/docs/transactions">Read more</a> about
+ * transactions in Cloud Spanner.
+ *
+ * <p>It is possible to read several {@link PCollection PCollections} within a single transaction.
+ * Apply {@link SpannerIO#createTransaction()} transform, that lazily creates a transaction. The
+ * result of this transformation can be passed to read operation using
+ * {@link SpannerIO.Read#withTransaction(PCollectionView)}.
+ *
+ * <pre>{@code
+ * SpannerConfig spannerConfig = ...
+ *
+ * PCollectionView<Transaction> tx =
+ * p.apply(
+ * SpannerIO.createTransaction()
+ * .withSpannerConfig(spannerConfig)
+ * .withTimestampBound(TimestampBound.strong()));
+ *
+ * PCollection<Struct> users = p.apply(
+ * SpannerIO.read()
+ * .withSpannerConfig(spannerConfig)
+ * .withQuery("SELECT name, email FROM users")
+ * .withTransaction(tx));
+ *
+ * PCollection<Struct> tweets = p.apply(
+ * SpannerIO.read()
+ * .withSpannerConfig(spannerConfig)
+ * .withQuery("SELECT user, tweet, date FROM tweets")
+ * .withTransaction(tx));
+ * }</pre>
*
* <h3>Writing to Cloud Spanner</h3>
*
@@ -86,6 +163,33 @@ public class SpannerIO {
private static final long DEFAULT_BATCH_SIZE_BYTES = 1024 * 1024; // 1 MB
/**
+ * Creates an uninitialized instance of {@link Read}. Before use, the {@link Read} must be
+ * configured with a {@link Read#withInstanceId} and {@link Read#withDatabaseId} that identify the
+ * Cloud Spanner database.
+ */
+ @Experimental
+ public static Read read() {
+ return new AutoValue_SpannerIO_Read.Builder()
+ .setSpannerConfig(SpannerConfig.create())
+ .setTimestampBound(TimestampBound.strong())
+ .setKeySet(KeySet.all())
+ .build();
+ }
+
+ /**
+ * Returns a transform that creates a batch transaction. By default,
+ * {@link TimestampBound#strong()} transaction is created, to override this use
+ * {@link CreateTransaction#withTimestampBound(TimestampBound)}.
+ */
+ @Experimental
+ public static CreateTransaction createTransaction() {
+ return new AutoValue_SpannerIO_CreateTransaction.Builder()
+ .setSpannerConfig(SpannerConfig.create())
+ .setTimestampBound(TimestampBound.strong())
+ .build();
+ }
+
+ /**
* Creates an uninitialized instance of {@link Write}. Before use, the {@link Write} must be
* configured with a {@link Write#withInstanceId} and {@link Write#withDatabaseId} that identify
* the Cloud Spanner database being written.
@@ -93,11 +197,286 @@ public class SpannerIO {
@Experimental
public static Write write() {
return new AutoValue_SpannerIO_Write.Builder()
+ .setSpannerConfig(SpannerConfig.create())
.setBatchSizeBytes(DEFAULT_BATCH_SIZE_BYTES)
.build();
}
/**
+ * A {@link PTransform} that reads data from Google Cloud Spanner.
+ *
+ * @see SpannerIO
+ */
+ @Experimental(Experimental.Kind.SOURCE_SINK)
+ @AutoValue
+ public abstract static class Read extends PTransform<PBegin, PCollection<Struct>> {
+
+ abstract SpannerConfig getSpannerConfig();
+
+ @Nullable
+ abstract TimestampBound getTimestampBound();
+
+ @Nullable
+ abstract Statement getQuery();
+
+ @Nullable
+ abstract String getTable();
+
+ @Nullable
+ abstract String getIndex();
+
+ @Nullable
+ abstract List<String> getColumns();
+
+ @Nullable
+ abstract KeySet getKeySet();
+
+ @Nullable
+ abstract PCollectionView<Transaction> getTransaction();
+
+ abstract Builder toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder {
+
+ abstract Builder setSpannerConfig(SpannerConfig spannerConfig);
+
+ abstract Builder setTimestampBound(TimestampBound timestampBound);
+
+ abstract Builder setQuery(Statement statement);
+
+ abstract Builder setTable(String table);
+
+ abstract Builder setIndex(String index);
+
+ abstract Builder setColumns(List<String> columns);
+
+ abstract Builder setKeySet(KeySet keySet);
+
+ abstract Builder setTransaction(PCollectionView<Transaction> transaction);
+
+ abstract Read build();
+ }
+
+ /** Specifies the Cloud Spanner configuration. */
+ public Read withSpannerConfig(SpannerConfig spannerConfig) {
+ return toBuilder().setSpannerConfig(spannerConfig).build();
+ }
+
+ /** Specifies the Cloud Spanner project. */
+ public Read withProjectId(String projectId) {
+ return withProjectId(ValueProvider.StaticValueProvider.of(projectId));
+ }
+
+ /** Specifies the Cloud Spanner project. */
+ public Read withProjectId(ValueProvider<String> projectId) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withProjectId(projectId));
+ }
+
+ /** Specifies the Cloud Spanner instance. */
+ public Read withInstanceId(String instanceId) {
+ return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId));
+ }
+
+ /** Specifies the Cloud Spanner instance. */
+ public Read withInstanceId(ValueProvider<String> instanceId) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withInstanceId(instanceId));
+ }
+
+ /** Specifies the Cloud Spanner database. */
+ public Read withDatabaseId(String databaseId) {
+ return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId));
+ }
+
+ /** Specifies the Cloud Spanner database. */
+ public Read withDatabaseId(ValueProvider<String> databaseId) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withDatabaseId(databaseId));
+ }
+
+ @VisibleForTesting
+ Read withServiceFactory(ServiceFactory<Spanner, SpannerOptions> serviceFactory) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withServiceFactory(serviceFactory));
+ }
+
+ public Read withTransaction(PCollectionView<Transaction> transaction) {
+ return toBuilder().setTransaction(transaction).build();
+ }
+
+ public Read withTimestamp(Timestamp timestamp) {
+ return withTimestampBound(TimestampBound.ofReadTimestamp(timestamp));
+ }
+
+ public Read withTimestampBound(TimestampBound timestampBound) {
+ return toBuilder().setTimestampBound(timestampBound).build();
+ }
+
+ public Read withTable(String table) {
+ return toBuilder().setTable(table).build();
+ }
+
+ public Read withColumns(String... columns) {
+ return withColumns(Arrays.asList(columns));
+ }
+
+ public Read withColumns(List<String> columns) {
+ return toBuilder().setColumns(columns).build();
+ }
+
+ public Read withQuery(Statement statement) {
+ return toBuilder().setQuery(statement).build();
+ }
+
+ public Read withQuery(String sql) {
+ return withQuery(Statement.of(sql));
+ }
+
+ public Read withKeySet(KeySet keySet) {
+ return toBuilder().setKeySet(keySet).build();
+ }
+
+ public Read withIndex(String index) {
+ return toBuilder().setIndex(index).build();
+ }
+
+
+ @Override
+ public void validate(PipelineOptions options) {
+ getSpannerConfig().validate(options);
+ checkNotNull(
+ getTimestampBound(),
+ "SpannerIO.read() runs in a read only transaction and requires timestamp to be set "
+ + "with withTimestampBound or withTimestamp method");
+
+ if (getQuery() != null) {
+ // TODO: validate query?
+ } else if (getTable() != null) {
+ // Assume read
+ checkNotNull(
+ getColumns(),
+ "For a read operation SpannerIO.read() requires a list of "
+ + "columns to set with withColumns method");
+ checkArgument(
+ !getColumns().isEmpty(),
+ "For a read operation SpannerIO.read() requires a"
+ + " list of columns to set with withColumns method");
+ } else {
+ throw new IllegalArgumentException(
+ "SpannerIO.read() requires configuring query or read operation.");
+ }
+ }
+
+ @Override
+ public PCollection<Struct> expand(PBegin input) {
+ Read config = this;
+ List<PCollectionView<Transaction>> sideInputs = Collections.emptyList();
+ if (getTimestampBound() != null) {
+ PCollectionView<Transaction> transaction =
+ input.apply(createTransaction().withSpannerConfig(getSpannerConfig()));
+ config = config.withTransaction(transaction);
+ sideInputs = Collections.singletonList(transaction);
+ }
+ return input
+ .apply(Create.of(1))
+ .apply(
+ "Execute query", ParDo.of(new NaiveSpannerReadFn(config)).withSideInputs(sideInputs));
+ }
+ }
+
+ /**
+ * A {@link PTransform} that create a transaction.
+ *
+ * @see SpannerIO
+ */
+ @Experimental(Experimental.Kind.SOURCE_SINK)
+ @AutoValue
+ public abstract static class CreateTransaction
+ extends PTransform<PBegin, PCollectionView<Transaction>> {
+
+ abstract SpannerConfig getSpannerConfig();
+
+ @Nullable
+ abstract TimestampBound getTimestampBound();
+
+ abstract Builder toBuilder();
+
+ @Override
+ public PCollectionView<Transaction> expand(PBegin input) {
+ return input.apply(Create.of(1))
+ .apply("Create transaction", ParDo.of(new CreateTransactionFn(this)))
+ .apply("As PCollectionView", View.<Transaction>asSingleton());
+ }
+
+ /** Specifies the Cloud Spanner configuration. */
+ public CreateTransaction withSpannerConfig(SpannerConfig spannerConfig) {
+ return toBuilder().setSpannerConfig(spannerConfig).build();
+ }
+
+ /** Specifies the Cloud Spanner project. */
+ public CreateTransaction withProjectId(String projectId) {
+ return withProjectId(ValueProvider.StaticValueProvider.of(projectId));
+ }
+
+ /** Specifies the Cloud Spanner project. */
+ public CreateTransaction withProjectId(ValueProvider<String> projectId) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withProjectId(projectId));
+ }
+
+ /** Specifies the Cloud Spanner instance. */
+ public CreateTransaction withInstanceId(String instanceId) {
+ return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId));
+ }
+
+ /** Specifies the Cloud Spanner instance. */
+ public CreateTransaction withInstanceId(ValueProvider<String> instanceId) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withInstanceId(instanceId));
+ }
+
+ /** Specifies the Cloud Spanner database. */
+ public CreateTransaction withDatabaseId(String databaseId) {
+ return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId));
+ }
+
+ /** Specifies the Cloud Spanner database. */
+ public CreateTransaction withDatabaseId(ValueProvider<String> databaseId) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withDatabaseId(databaseId));
+ }
+
+ @VisibleForTesting
+ CreateTransaction withServiceFactory(
+ ServiceFactory<Spanner, SpannerOptions> serviceFactory) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withServiceFactory(serviceFactory));
+ }
+
+ public CreateTransaction withTimestampBound(TimestampBound timestampBound) {
+ return toBuilder().setTimestampBound(timestampBound).build();
+ }
+
+ @Override
+ public void validate(PipelineOptions options) {
+ getSpannerConfig().validate(options);
+ }
+
+ /** A builder for {@link CreateTransaction}. */
+ @AutoValue.Builder public abstract static class Builder {
+
+ public abstract Builder setSpannerConfig(SpannerConfig spannerConfig);
+
+ public abstract Builder setTimestampBound(TimestampBound newTimestampBound);
+
+ public abstract CreateTransaction build();
+ }
+ }
+
+
+ /**
* A {@link PTransform} that writes {@link Mutation} objects to Google Cloud Spanner.
*
* @see SpannerIO
@@ -106,8 +485,6 @@ public class SpannerIO {
@AutoValue
public abstract static class Write extends PTransform<PCollection<Mutation>, PDone> {
- private static final long serialVersionUID = 1920175411827980145L;
-
abstract SpannerConfig getSpannerConfig();
abstract long getBatchSizeBytes();
@@ -119,95 +496,53 @@ public class SpannerIO {
abstract Builder setSpannerConfig(SpannerConfig spannerConfig);
- abstract SpannerConfig.Builder spannerConfigBuilder();
-
abstract Builder setBatchSizeBytes(long batchSizeBytes);
abstract Write build();
}
- /**
- * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner project.
- *
- * <p>Does not modify this object.
- */
+ /** Specifies the Cloud Spanner configuration. */
+ public Write withSpannerConfig(SpannerConfig spannerConfig) {
+ return toBuilder().setSpannerConfig(spannerConfig).build();
+ }
+
+ /** Specifies the Cloud Spanner project. */
public Write withProjectId(String projectId) {
return withProjectId(ValueProvider.StaticValueProvider.of(projectId));
}
- /**
- * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner project.
- *
- * <p>Does not modify this object.
- */
+ /** Specifies the Cloud Spanner project. */
public Write withProjectId(ValueProvider<String> projectId) {
- Write.Builder builder = toBuilder();
- builder.spannerConfigBuilder().setProjectId(projectId);
- return builder.build();
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withProjectId(projectId));
}
- /**
- * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner
- * instance.
- *
- * <p>Does not modify this object.
- */
+ /** Specifies the Cloud Spanner instance. */
public Write withInstanceId(String instanceId) {
return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId));
}
- /**
- * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner
- * instance.
- *
- * <p>Does not modify this object.
- */
+ /** Specifies the Cloud Spanner instance. */
public Write withInstanceId(ValueProvider<String> instanceId) {
- Write.Builder builder = toBuilder();
- builder.spannerConfigBuilder().setInstanceId(instanceId);
- return builder.build();
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withInstanceId(instanceId));
}
- /**
- * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner
- * config.
- *
- * <p>Does not modify this object.
- */
- public Write withSpannerConfig(SpannerConfig spannerConfig) {
- return toBuilder().setSpannerConfig(spannerConfig).build();
- }
-
-
- /**
- * Returns a new {@link SpannerIO.Write} with a new batch size limit.
- *
- * <p>Does not modify this object.
- */
- public Write withBatchSizeBytes(long batchSizeBytes) {
- return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
- }
-
- /**
- * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner
- * database.
- *
- * <p>Does not modify this object.
- */
+ /** Specifies the Cloud Spanner database. */
public Write withDatabaseId(String databaseId) {
return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId));
}
- /**
- * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner
- * database.
- *
- * <p>Does not modify this object.
- */
+ /** Specifies the Cloud Spanner database. */
public Write withDatabaseId(ValueProvider<String> databaseId) {
- Write.Builder builder = toBuilder();
- builder.spannerConfigBuilder().setDatabaseId(databaseId);
- return builder.build();
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withDatabaseId(databaseId));
+ }
+
+ @VisibleForTesting
+ Write withServiceFactory(ServiceFactory<Spanner, SpannerOptions> serviceFactory) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withServiceFactory(serviceFactory));
}
/**
@@ -217,11 +552,9 @@ public class SpannerIO {
return new WriteGrouped(this);
}
- @VisibleForTesting
- Write withServiceFactory(ServiceFactory<Spanner, SpannerOptions> serviceFactory) {
- Write.Builder builder = toBuilder();
- builder.spannerConfigBuilder().setServiceFactory(serviceFactory);
- return builder.build();
+ /** Specifies the batch size limit. */
+ public Write withBatchSizeBytes(long batchSizeBytes) {
+ return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
}
@Override
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java
index aed4832..34a11da 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package org.apache.beam.sdk.io.gcp.spanner;
import com.google.cloud.spanner.AbortedException;
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java
new file mode 100644
index 0000000..22af3b8
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import com.google.auto.value.AutoValue;
+import com.google.cloud.Timestamp;
+import java.io.Serializable;
+
+/** A transaction object. */
+@AutoValue
+public abstract class Transaction implements Serializable {
+
+ abstract Timestamp timestamp();
+
+ public static Transaction create(Timestamp timestamp) {
+ return new AutoValue_Transaction(timestamp);
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java
index 91caded..8aac417 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java
@@ -52,6 +52,7 @@ public class GcpApiSurfaceTest {
@SuppressWarnings("unchecked")
final Set<Matcher<Class<?>>> allowedClasses =
ImmutableSet.of(
+ classesInPackage("com.google.api.core"),
classesInPackage("com.google.api.client.googleapis"),
classesInPackage("com.google.api.client.http"),
classesInPackage("com.google.api.client.json"),
@@ -60,9 +61,18 @@ public class GcpApiSurfaceTest {
classesInPackage("com.google.auth"),
classesInPackage("com.google.bigtable.v2"),
classesInPackage("com.google.cloud.bigtable.config"),
+ classesInPackage("com.google.spanner.v1"),
+ Matchers.<Class<?>>equalTo(com.google.api.gax.grpc.ApiException.class),
Matchers.<Class<?>>equalTo(com.google.cloud.bigtable.grpc.BigtableClusterName.class),
Matchers.<Class<?>>equalTo(com.google.cloud.bigtable.grpc.BigtableInstanceName.class),
Matchers.<Class<?>>equalTo(com.google.cloud.bigtable.grpc.BigtableTableName.class),
+ Matchers.<Class<?>>equalTo(com.google.cloud.BaseServiceException.class),
+ Matchers.<Class<?>>equalTo(com.google.cloud.BaseServiceException.Error.class),
+ Matchers.<Class<?>>equalTo(com.google.cloud.BaseServiceException.ExceptionData.class),
+ Matchers.<Class<?>>equalTo(com.google.cloud.BaseServiceException.ExceptionData.Builder
+ .class),
+ Matchers.<Class<?>>equalTo(com.google.cloud.RetryHelper.RetryHelperException.class),
+ Matchers.<Class<?>>equalTo(com.google.cloud.grpc.BaseGrpcServiceException.class),
Matchers.<Class<?>>equalTo(com.google.cloud.ByteArray.class),
Matchers.<Class<?>>equalTo(com.google.cloud.Date.class),
Matchers.<Class<?>>equalTo(com.google.cloud.Timestamp.class),
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java
new file mode 100644
index 0000000..753d807
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.withSettings;
+
+import com.google.cloud.ServiceFactory;
+import com.google.cloud.spanner.DatabaseClient;
+import com.google.cloud.spanner.DatabaseId;
+import com.google.cloud.spanner.Spanner;
+import com.google.cloud.spanner.SpannerOptions;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import javax.annotation.concurrent.GuardedBy;
+import org.mockito.Matchers;
+
+/**
+ * A serialization friendly type service factory that maintains a mock {@link Spanner} and
+ * {@link DatabaseClient}.
+ * */
+class FakeServiceFactory
+ implements ServiceFactory<Spanner, SpannerOptions>, Serializable {
+
+ // Marked as static so they could be returned by serviceFactory, which is serializable.
+ private static final Object lock = new Object();
+
+ @GuardedBy("lock")
+ private static final List<Spanner> mockSpanners = new ArrayList<>();
+
+ @GuardedBy("lock")
+ private static final List<DatabaseClient> mockDatabaseClients = new ArrayList<>();
+
+ @GuardedBy("lock")
+ private static int count = 0;
+
+ private final int index;
+
+ public FakeServiceFactory() {
+ synchronized (lock) {
+ index = count++;
+ mockSpanners.add(mock(Spanner.class, withSettings().serializable()));
+ mockDatabaseClients.add(mock(DatabaseClient.class, withSettings().serializable()));
+ }
+ when(mockSpanner().getDatabaseClient(Matchers.any(DatabaseId.class)))
+ .thenReturn(mockDatabaseClient());
+ }
+
+ DatabaseClient mockDatabaseClient() {
+ synchronized (lock) {
+ return mockDatabaseClients.get(index);
+ }
+ }
+
+ Spanner mockSpanner() {
+ synchronized (lock) {
+ return mockSpanners.get(index);
+ }
+ }
+
+ @Override
+ public Spanner create(SpannerOptions serviceOptions) {
+ return mockSpanner();
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java
new file mode 100644
index 0000000..e5d4e72
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static org.junit.Assert.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.KeySet;
+import com.google.cloud.spanner.ReadOnlyTransaction;
+import com.google.cloud.spanner.ResultSets;
+import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.TimestampBound;
+import com.google.cloud.spanner.Type;
+import com.google.cloud.spanner.Value;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.DoFnTester;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+/** Unit tests for {@link SpannerIO}. */
+@RunWith(JUnit4.class)
+public class SpannerIOReadTest implements Serializable {
+ @Rule
+ public final transient TestPipeline pipeline = TestPipeline.create();
+ @Rule
+ public final transient ExpectedException thrown = ExpectedException.none();
+
+ private FakeServiceFactory serviceFactory;
+ private ReadOnlyTransaction mockTx;
+
+ private Type fakeType = Type.struct(Type.StructField.of("id", Type.int64()),
+ Type.StructField.of("name", Type.string()));
+
+ private List<Struct> fakeRows = Arrays.asList(
+ Struct.newBuilder().add("id", Value.int64(1)).add("name", Value.string("Alice")).build(),
+ Struct.newBuilder().add("id", Value.int64(2)).add("name", Value.string("Bob")).build());
+
+ @Before
+ @SuppressWarnings("unchecked")
+ public void setUp() throws Exception {
+ serviceFactory = new FakeServiceFactory();
+ mockTx = Mockito.mock(ReadOnlyTransaction.class);
+ }
+
+ @Test
+ public void emptyTransform() throws Exception {
+ SpannerIO.Read read = SpannerIO.read();
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires instance id to be set with");
+ read.validate(null);
+ }
+
+ @Test
+ public void emptyInstanceId() throws Exception {
+ SpannerIO.Read read = SpannerIO.read().withDatabaseId("123");
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires instance id to be set with");
+ read.validate(null);
+ }
+
+ @Test
+ public void emptyDatabaseId() throws Exception {
+ SpannerIO.Read read = SpannerIO.read().withInstanceId("123");
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires database id to be set with");
+ read.validate(null);
+ }
+
+ @Test
+ public void emptyQuery() throws Exception {
+ SpannerIO.Read read =
+ SpannerIO.read().withInstanceId("123").withDatabaseId("aaa").withTimestamp(Timestamp.now());
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("requires configuring query or read operation");
+ read.validate(null);
+ }
+
+ @Test
+ public void emptyColumns() throws Exception {
+ SpannerIO.Read read =
+ SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withTable("users");
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires a list of columns");
+ read.validate(null);
+ }
+
+ @Test
+ public void validRead() throws Exception {
+ SpannerIO.Read read =
+ SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withTable("users")
+ .withColumns("id", "name", "email");
+ read.validate(null);
+ }
+
+ @Test
+ public void validQuery() throws Exception {
+ SpannerIO.Read read =
+ SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withQuery("SELECT * FROM users");
+ read.validate(null);
+ }
+
+ @Test
+ public void runQuery() throws Exception {
+ SpannerIO.Read read =
+ SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withQuery("SELECT * FROM users")
+ .withServiceFactory(serviceFactory);
+
+ NaiveSpannerReadFn readFn = new NaiveSpannerReadFn(read);
+ DoFnTester<Object, Struct> fnTester = DoFnTester.of(readFn);
+
+ when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class)))
+ .thenReturn(mockTx);
+ when(mockTx.executeQuery(any(Statement.class)))
+ .thenReturn(ResultSets.forRows(fakeType, fakeRows));
+
+ List<Struct> result = fnTester.processBundle(1);
+ assertThat(result, Matchers.<Struct>iterableWithSize(2));
+
+ verify(serviceFactory.mockDatabaseClient()).readOnlyTransaction(TimestampBound
+ .strong());
+ verify(mockTx).executeQuery(Statement.of("SELECT * FROM users"));
+ }
+
+ @Test
+ public void runRead() throws Exception {
+ SpannerIO.Read read =
+ SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withTable("users")
+ .withColumns("id", "name")
+ .withServiceFactory(serviceFactory);
+
+ NaiveSpannerReadFn readFn = new NaiveSpannerReadFn(read);
+ DoFnTester<Object, Struct> fnTester = DoFnTester.of(readFn);
+
+ when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class)))
+ .thenReturn(mockTx);
+ when(mockTx.read("users", KeySet.all(), Arrays.asList("id", "name")))
+ .thenReturn(ResultSets.forRows(fakeType, fakeRows));
+
+ List<Struct> result = fnTester.processBundle(1);
+ assertThat(result, Matchers.<Struct>iterableWithSize(2));
+
+ verify(serviceFactory.mockDatabaseClient()).readOnlyTransaction(TimestampBound.strong());
+ verify(mockTx).read("users", KeySet.all(), Arrays.asList("id", "name"));
+ }
+
+ @Test
+ public void runReadUsingIndex() throws Exception {
+ SpannerIO.Read read =
+ SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withTable("users")
+ .withColumns("id", "name")
+ .withIndex("theindex")
+ .withServiceFactory(serviceFactory);
+
+ NaiveSpannerReadFn readFn = new NaiveSpannerReadFn(read);
+ DoFnTester<Object, Struct> fnTester = DoFnTester.of(readFn);
+
+ when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class)))
+ .thenReturn(mockTx);
+ when(mockTx.readUsingIndex("users", "theindex", KeySet.all(), Arrays.asList("id", "name")))
+ .thenReturn(ResultSets.forRows(fakeType, fakeRows));
+
+ List<Struct> result = fnTester.processBundle(1);
+ assertThat(result, Matchers.<Struct>iterableWithSize(2));
+
+ verify(serviceFactory.mockDatabaseClient()).readOnlyTransaction(TimestampBound.strong());
+ verify(mockTx).readUsingIndex("users", "theindex", KeySet.all(), Arrays.asList("id", "name"));
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void readPipeline() throws Exception {
+ Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
+
+ PCollectionView<Transaction> tx = pipeline
+ .apply("tx", SpannerIO.createTransaction()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withServiceFactory(serviceFactory));
+
+ PCollection<Struct> one = pipeline.apply("read q", SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withQuery("SELECT * FROM users")
+ .withServiceFactory(serviceFactory)
+ .withTransaction(tx));
+ PCollection<Struct> two = pipeline.apply("read r", SpannerIO.read()
+ .withInstanceId("123")
+ .withDatabaseId("aaa")
+ .withTimestamp(Timestamp.now())
+ .withTable("users")
+ .withColumns("id", "name")
+ .withServiceFactory(serviceFactory)
+ .withTransaction(tx));
+
+ when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class)))
+ .thenReturn(mockTx);
+
+ when(mockTx.executeQuery(Statement.of("SELECT 1"))).thenReturn(ResultSets.forRows(Type.struct(),
+ Collections.<Struct>emptyList()));
+
+ when(mockTx.executeQuery(Statement.of("SELECT * FROM users")))
+ .thenReturn(ResultSets.forRows(fakeType, fakeRows));
+ when(mockTx.read("users", KeySet.all(), Arrays.asList("id", "name")))
+ .thenReturn(ResultSets.forRows(fakeType, fakeRows));
+ when(mockTx.getReadTimestamp()).thenReturn(timestamp);
+
+ PAssert.that(one).containsInAnyOrder(fakeRows);
+ PAssert.that(two).containsInAnyOrder(fakeRows);
+
+ pipeline.run();
+
+ verify(serviceFactory.mockDatabaseClient(), times(2))
+ .readOnlyTransaction(TimestampBound.ofReadTimestamp(timestamp));
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java
deleted file mode 100644
index abeac0a..0000000
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java
+++ /dev/null
@@ -1,314 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.gcp.spanner;
-
-import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
-import static org.hamcrest.Matchers.hasSize;
-import static org.junit.Assert.assertThat;
-import static org.mockito.Mockito.argThat;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-import static org.mockito.Mockito.withSettings;
-
-import com.google.cloud.ServiceFactory;
-import com.google.cloud.spanner.DatabaseClient;
-import com.google.cloud.spanner.DatabaseId;
-import com.google.cloud.spanner.Mutation;
-import com.google.cloud.spanner.Spanner;
-import com.google.cloud.spanner.SpannerOptions;
-import com.google.common.collect.Iterables;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import javax.annotation.concurrent.GuardedBy;
-
-import org.apache.beam.sdk.testing.NeedsRunner;
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.DoFnTester;
-import org.apache.beam.sdk.transforms.display.DisplayData;
-import org.apache.beam.sdk.values.PCollection;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.experimental.categories.Category;
-import org.junit.rules.ExpectedException;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-import org.mockito.ArgumentMatcher;
-import org.mockito.Matchers;
-
-
-/**
- * Unit tests for {@link SpannerIO}.
- */
-@RunWith(JUnit4.class)
-public class SpannerIOTest implements Serializable {
- @Rule public final transient TestPipeline pipeline = TestPipeline.create();
- @Rule public transient ExpectedException thrown = ExpectedException.none();
-
- private FakeServiceFactory serviceFactory;
-
- @Before
- @SuppressWarnings("unchecked")
- public void setUp() throws Exception {
- serviceFactory = new FakeServiceFactory();
- }
-
- @Test
- public void emptyTransform() throws Exception {
- SpannerIO.Write write = SpannerIO.write();
- thrown.expect(NullPointerException.class);
- thrown.expectMessage("requires instance id to be set with");
- write.validate(null);
- }
-
- @Test
- public void emptyInstanceId() throws Exception {
- SpannerIO.Write write = SpannerIO.write().withDatabaseId("123");
- thrown.expect(NullPointerException.class);
- thrown.expectMessage("requires instance id to be set with");
- write.validate(null);
- }
-
- @Test
- public void emptyDatabaseId() throws Exception {
- SpannerIO.Write write = SpannerIO.write().withInstanceId("123");
- thrown.expect(NullPointerException.class);
- thrown.expectMessage("requires database id to be set with");
- write.validate(null);
- }
-
- @Test
- @Category(NeedsRunner.class)
- public void singleMutationPipeline() throws Exception {
- Mutation mutation = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2).build();
- PCollection<Mutation> mutations = pipeline.apply(Create.of(mutation));
-
- mutations.apply(
- SpannerIO.write()
- .withProjectId("test-project")
- .withInstanceId("test-instance")
- .withDatabaseId("test-database")
- .withServiceFactory(serviceFactory));
- pipeline.run();
- verify(serviceFactory.mockSpanner())
- .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
- verify(serviceFactory.mockDatabaseClient(), times(1))
- .writeAtLeastOnce(argThat(new IterableOfSize(1)));
- }
-
- @Test
- @Category(NeedsRunner.class)
- public void singleMutationGroupPipeline() throws Exception {
- Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
- Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
- Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build();
- PCollection<MutationGroup> mutations = pipeline
- .apply(Create.<MutationGroup>of(g(one, two, three)));
- mutations.apply(
- SpannerIO.write()
- .withProjectId("test-project")
- .withInstanceId("test-instance")
- .withDatabaseId("test-database")
- .withServiceFactory(serviceFactory)
- .grouped());
- pipeline.run();
- verify(serviceFactory.mockSpanner())
- .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
- verify(serviceFactory.mockDatabaseClient(), times(1))
- .writeAtLeastOnce(argThat(new IterableOfSize(3)));
- }
-
- @Test
- public void batching() throws Exception {
- MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
- MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
- SpannerIO.Write write =
- SpannerIO.write()
- .withProjectId("test-project")
- .withInstanceId("test-instance")
- .withDatabaseId("test-database")
- .withBatchSizeBytes(1000000000)
- .withServiceFactory(serviceFactory);
- SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
- DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
- fnTester.processBundle(Arrays.asList(one, two));
-
- verify(serviceFactory.mockSpanner())
- .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
- verify(serviceFactory.mockDatabaseClient(), times(1))
- .writeAtLeastOnce(argThat(new IterableOfSize(2)));
- }
-
- @Test
- public void batchingGroups() throws Exception {
- MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
- MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
- MutationGroup three = g(Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build());
-
- // Have a room to accumulate one more item.
- long batchSize = MutationSizeEstimator.sizeOf(one) + 1;
-
- SpannerIO.Write write =
- SpannerIO.write()
- .withProjectId("test-project")
- .withInstanceId("test-instance")
- .withDatabaseId("test-database")
- .withBatchSizeBytes(batchSize)
- .withServiceFactory(serviceFactory);
- SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
- DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
- fnTester.processBundle(Arrays.asList(one, two, three));
-
- verify(serviceFactory.mockSpanner())
- .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
- verify(serviceFactory.mockDatabaseClient(), times(1))
- .writeAtLeastOnce(argThat(new IterableOfSize(2)));
- verify(serviceFactory.mockDatabaseClient(), times(1))
- .writeAtLeastOnce(argThat(new IterableOfSize(1)));
- }
-
- @Test
- public void noBatching() throws Exception {
- MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
- MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
- SpannerIO.Write write =
- SpannerIO.write()
- .withProjectId("test-project")
- .withInstanceId("test-instance")
- .withDatabaseId("test-database")
- .withBatchSizeBytes(0) // turn off batching.
- .withServiceFactory(serviceFactory);
- SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
- DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
- fnTester.processBundle(Arrays.asList(one, two));
-
- verify(serviceFactory.mockSpanner())
- .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
- verify(serviceFactory.mockDatabaseClient(), times(2))
- .writeAtLeastOnce(argThat(new IterableOfSize(1)));
- }
-
- @Test
- public void groups() throws Exception {
- Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
- Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
- Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build();
-
- // Smallest batch size
- long batchSize = 1;
-
- SpannerIO.Write write =
- SpannerIO.write()
- .withProjectId("test-project")
- .withInstanceId("test-instance")
- .withDatabaseId("test-database")
- .withBatchSizeBytes(batchSize)
- .withServiceFactory(serviceFactory);
- SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
- DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
- fnTester.processBundle(Arrays.asList(g(one, two, three)));
-
- verify(serviceFactory.mockSpanner())
- .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
- verify(serviceFactory.mockDatabaseClient(), times(1))
- .writeAtLeastOnce(argThat(new IterableOfSize(3)));
- }
-
- @Test
- public void displayData() throws Exception {
- SpannerIO.Write write =
- SpannerIO.write()
- .withProjectId("test-project")
- .withInstanceId("test-instance")
- .withDatabaseId("test-database")
- .withBatchSizeBytes(123);
-
- DisplayData data = DisplayData.from(write);
- assertThat(data.items(), hasSize(4));
- assertThat(data, hasDisplayItem("projectId", "test-project"));
- assertThat(data, hasDisplayItem("instanceId", "test-instance"));
- assertThat(data, hasDisplayItem("databaseId", "test-database"));
- assertThat(data, hasDisplayItem("batchSizeBytes", 123));
- }
-
- private static class FakeServiceFactory
- implements ServiceFactory<Spanner, SpannerOptions>, Serializable {
- // Marked as static so they could be returned by serviceFactory, which is serializable.
- private static final Object lock = new Object();
-
- @GuardedBy("lock")
- private static final List<Spanner> mockSpanners = new ArrayList<>();
-
- @GuardedBy("lock")
- private static final List<DatabaseClient> mockDatabaseClients = new ArrayList<>();
-
- @GuardedBy("lock")
- private static int count = 0;
-
- private final int index;
-
- public FakeServiceFactory() {
- synchronized (lock) {
- index = count++;
- mockSpanners.add(mock(Spanner.class, withSettings().serializable()));
- mockDatabaseClients.add(mock(DatabaseClient.class, withSettings().serializable()));
- }
- when(mockSpanner().getDatabaseClient(Matchers.any(DatabaseId.class)))
- .thenReturn(mockDatabaseClient());
- }
-
- DatabaseClient mockDatabaseClient() {
- synchronized (lock) {
- return mockDatabaseClients.get(index);
- }
- }
-
- Spanner mockSpanner() {
- synchronized (lock) {
- return mockSpanners.get(index);
- }
- }
-
- @Override
- public Spanner create(SpannerOptions serviceOptions) {
- return mockSpanner();
- }
- }
-
- private static class IterableOfSize extends ArgumentMatcher<Iterable<Mutation>> {
- private final int size;
-
- private IterableOfSize(int size) {
- this.size = size;
- }
-
- @Override
- public boolean matches(Object argument) {
- return argument instanceof Iterable && Iterables.size((Iterable<?>) argument) == size;
- }
- }
-
- private static MutationGroup g(Mutation m, Mutation... other) {
- return MutationGroup.create(m, other);
- }
-}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
new file mode 100644
index 0000000..09cdb8e
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+import static org.hamcrest.Matchers.hasSize;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.argThat;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import com.google.cloud.spanner.DatabaseId;
+import com.google.cloud.spanner.Mutation;
+import com.google.common.collect.Iterables;
+import java.io.Serializable;
+import java.util.Arrays;
+
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFnTester;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentMatcher;
+
+/**
+ * Unit tests for {@link SpannerIO}.
+ */
+@RunWith(JUnit4.class)
+public class SpannerIOWriteTest implements Serializable {
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+ @Rule public transient ExpectedException thrown = ExpectedException.none();
+
+ private FakeServiceFactory serviceFactory;
+
+ @Before
+ @SuppressWarnings("unchecked")
+ public void setUp() throws Exception {
+ serviceFactory = new FakeServiceFactory();
+ }
+
+ @Test
+ public void emptyTransform() throws Exception {
+ SpannerIO.Write write = SpannerIO.write();
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires instance id to be set with");
+ write.validate(null);
+ }
+
+ @Test
+ public void emptyInstanceId() throws Exception {
+ SpannerIO.Write write = SpannerIO.write().withDatabaseId("123");
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires instance id to be set with");
+ write.validate(null);
+ }
+
+ @Test
+ public void emptyDatabaseId() throws Exception {
+ SpannerIO.Write write = SpannerIO.write().withInstanceId("123");
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires database id to be set with");
+ write.validate(null);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void singleMutationPipeline() throws Exception {
+ Mutation mutation = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2).build();
+ PCollection<Mutation> mutations = pipeline.apply(Create.of(mutation));
+
+ mutations.apply(
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withServiceFactory(serviceFactory));
+ pipeline.run();
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void singleMutationGroupPipeline() throws Exception {
+ Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
+ Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
+ Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build();
+ PCollection<MutationGroup> mutations = pipeline
+ .apply(Create.<MutationGroup>of(g(one, two, three)));
+ mutations.apply(
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withServiceFactory(serviceFactory)
+ .grouped());
+ pipeline.run();
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(3)));
+ }
+
+ @Test
+ public void batching() throws Exception {
+ MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
+ MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(1000000000)
+ .withServiceFactory(serviceFactory);
+ SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
+ DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
+ fnTester.processBundle(Arrays.asList(one, two));
+
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(2)));
+ }
+
+ @Test
+ public void batchingGroups() throws Exception {
+ MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
+ MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
+ MutationGroup three = g(Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build());
+
+ // Have a room to accumulate one more item.
+ long batchSize = MutationSizeEstimator.sizeOf(one) + 1;
+
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(batchSize)
+ .withServiceFactory(serviceFactory);
+ SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
+ DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
+ fnTester.processBundle(Arrays.asList(one, two, three));
+
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(2)));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+ }
+
+ @Test
+ public void noBatching() throws Exception {
+ MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
+ MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(0) // turn off batching.
+ .withServiceFactory(serviceFactory);
+ SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
+ DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
+ fnTester.processBundle(Arrays.asList(one, two));
+
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(2))
+ .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+ }
+
+ @Test
+ public void groups() throws Exception {
+ Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
+ Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
+ Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build();
+
+ // Smallest batch size
+ long batchSize = 1;
+
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(batchSize)
+ .withServiceFactory(serviceFactory);
+ SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
+ DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
+ fnTester.processBundle(Arrays.asList(g(one, two, three)));
+
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(3)));
+ }
+
+ @Test
+ public void displayData() throws Exception {
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(123);
+
+ DisplayData data = DisplayData.from(write);
+ assertThat(data.items(), hasSize(4));
+ assertThat(data, hasDisplayItem("projectId", "test-project"));
+ assertThat(data, hasDisplayItem("instanceId", "test-instance"));
+ assertThat(data, hasDisplayItem("databaseId", "test-database"));
+ assertThat(data, hasDisplayItem("batchSizeBytes", 123));
+ }
+
+ private static class IterableOfSize extends ArgumentMatcher<Iterable<Mutation>> {
+ private final int size;
+
+ private IterableOfSize(int size) {
+ this.size = size;
+ }
+
+ @Override
+ public boolean matches(Object argument) {
+ return argument instanceof Iterable && Iterables.size((Iterable<?>) argument) == size;
+ }
+ }
+
+ private static MutationGroup g(Mutation m, Mutation... other) {
+ return MutationGroup.create(m, other);
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/a21a6d79/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java
new file mode 100644
index 0000000..f5d7cbd
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import com.google.cloud.spanner.Database;
+import com.google.cloud.spanner.DatabaseAdminClient;
+import com.google.cloud.spanner.DatabaseClient;
+import com.google.cloud.spanner.DatabaseId;
+import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.Operation;
+import com.google.cloud.spanner.Spanner;
+import com.google.cloud.spanner.SpannerOptions;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.TimestampBound;
+import com.google.spanner.admin.database.v1.CreateDatabaseMetadata;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** End-to-end test of Cloud Spanner Source. */
+@RunWith(JUnit4.class)
+public class SpannerReadIT {
+
+ private static final int MAX_DB_NAME_LENGTH = 30;
+
+ @Rule public final transient TestPipeline p = TestPipeline.create();
+
+ /** Pipeline options for this test. */
+ public interface SpannerTestPipelineOptions extends TestPipelineOptions {
+ @Description("Project ID for Spanner")
+ @Default.String("apache-beam-testing")
+ String getProjectId();
+ void setProjectId(String value);
+
+ @Description("Instance ID to write to in Spanner")
+ @Default.String("beam-test")
+ String getInstanceId();
+ void setInstanceId(String value);
+
+ @Description("Database ID prefix to write to in Spanner")
+ @Default.String("beam-testdb")
+ String getDatabaseIdPrefix();
+ void setDatabaseIdPrefix(String value);
+
+ @Description("Table name")
+ @Default.String("users")
+ String getTable();
+ void setTable(String value);
+ }
+
+ private Spanner spanner;
+ private DatabaseAdminClient databaseAdminClient;
+ private SpannerTestPipelineOptions options;
+ private String databaseName;
+
+ @Before
+ public void setUp() throws Exception {
+ PipelineOptionsFactory.register(SpannerTestPipelineOptions.class);
+ options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class);
+
+ spanner = SpannerOptions.newBuilder().setProjectId(options.getProjectId()).build().getService();
+
+ databaseName = generateDatabaseName();
+
+ databaseAdminClient = spanner.getDatabaseAdminClient();
+
+ // Delete database if exists.
+ databaseAdminClient.dropDatabase(options.getInstanceId(), databaseName);
+
+ Operation<Database, CreateDatabaseMetadata> op =
+ databaseAdminClient.createDatabase(
+ options.getInstanceId(),
+ databaseName,
+ Collections.singleton(
+ "CREATE TABLE "
+ + options.getTable()
+ + " ("
+ + " Key INT64,"
+ + " Value STRING(MAX),"
+ + ") PRIMARY KEY (Key)"));
+ op.waitFor();
+ }
+
+ @Test
+ public void testRead() throws Exception {
+ DatabaseClient databaseClient =
+ spanner.getDatabaseClient(
+ DatabaseId.of(
+ options.getProjectId(), options.getInstanceId(), databaseName));
+
+ List<Mutation> mutations = new ArrayList<>();
+ for (int i = 0; i < 5L; i++) {
+ mutations.add(
+ Mutation.newInsertOrUpdateBuilder(options.getTable())
+ .set("key")
+ .to((long) i)
+ .set("value")
+ .to(RandomStringUtils.random(100, true, true))
+ .build());
+ }
+
+ databaseClient.writeAtLeastOnce(mutations);
+
+ SpannerConfig spannerConfig = SpannerConfig.create()
+ .withProjectId(options.getProjectId())
+ .withInstanceId(options.getInstanceId())
+ .withDatabaseId(databaseName);
+
+ PCollectionView<Transaction> tx =
+ p.apply(
+ SpannerIO.createTransaction()
+ .withSpannerConfig(spannerConfig)
+ .withTimestampBound(TimestampBound.strong()));
+
+ PCollection<Struct> output =
+ p.apply(
+ SpannerIO.read()
+ .withSpannerConfig(spannerConfig)
+ .withQuery("SELECT * FROM " + options.getTable())
+ .withTransaction(tx));
+ PAssert.thatSingleton(output.apply("Count rows", Count.<Struct>globally())).isEqualTo(5L);
+ p.run();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ databaseAdminClient.dropDatabase(options.getInstanceId(), databaseName);
+ spanner.close();
+ }
+
+ private String generateDatabaseName() {
+ String random =
+ RandomStringUtils.randomAlphanumeric(
+ MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length())
+ .toLowerCase();
+ return options.getDatabaseIdPrefix() + "-" + random;
+ }
+}