You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ch...@apache.org on 2020/05/21 22:39:32 UTC
[beam] branch master updated: [BEAM-9722] added SnowflakeIO with
Read operation (#11360)
This is an automated email from the ASF dual-hosted git repository.
chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 73fa135 [BEAM-9722] added SnowflakeIO with Read operation (#11360)
73fa135 is described below
commit 73fa1356e4bc2d04e51c06cbf7c3aa860264f9ee
Author: Dariusz Aniszewski <da...@polidea.com>
AuthorDate: Fri May 22 00:39:16 2020 +0200
[BEAM-9722] added SnowflakeIO with Read operation (#11360)
* [BEAM-9722] added SnowflakeIO with Read operation
* [BEAM-9722] Added SnowflakeCloudProvider to enable use various clouds with Snowflake
* [BEAM-9722] added docstrings for public methods
* [BEAM-9722] Added changed cleanup staged GCS files to Beam FileSystems
* [BEAM-9722] Added javadocs for public methods in DataSourceConfiguration
* add testing p8 file to RAT exclude
refactor SnowflakeCredentials
add information about possibly left files on cloud storage
small docs changes
* documentation changes
* [BEAM-9722] Added TestRule and changed Unit tests to use pipeline.run
* [BEAM-9722] Renamed Snowflake Read unit test and applied spotless
* [BEAM-9722] remove SnowflakeCloudProvider interface
* [BEAM-9722] doc changes
* [BEAM-9722] add `withoutValidation` to disable verifying connection to Snowflake during pipeline construction
* [BEAM-9722] added MoveOption and removed leftover file
* [BEAM-9722] fixed tests. Add tests for `withQuery`
* [BEAM-9722] make `CopyIntoStageFn` retryable
* [BEAM-9722] added `Reshuffle` step after `CopyIntoStageFn`
Co-authored-by: Kasia Kucharczyk <ka...@polidea.com>
Co-authored-by: pawel.urbanowicz <pa...@polidea.com>
---
CHANGES.md | 1 +
build.gradle | 3 +
sdks/java/io/snowflake/build.gradle | 42 ++
.../beam/sdk/io/snowflake/CloudProvider.java | 32 +
.../apache/beam/sdk/io/snowflake/SnowflakeIO.java | 759 +++++++++++++++++++++
.../sdk/io/snowflake/SnowflakePipelineOptions.java | 133 ++++
.../beam/sdk/io/snowflake/SnowflakeService.java | 36 +
.../sdk/io/snowflake/SnowflakeServiceImpl.java | 90 +++
.../credentials/KeyPairSnowflakeCredentials.java | 81 +++
.../OAuthTokenSnowflakeCredentials.java | 31 +
.../credentials/SnowflakeCredentials.java | 24 +
.../credentials/SnowflakeCredentialsFactory.java | 55 ++
.../UsernamePasswordSnowflakeCredentials.java | 37 +
.../sdk/io/snowflake/credentials/package-info.java | 20 +
.../apache/beam/sdk/io/snowflake/package-info.java | 20 +
.../test/FakeSnowflakeBasicDataSource.java | 298 ++++++++
.../io/snowflake/test/FakeSnowflakeDatabase.java | 81 +++
.../snowflake/test/FakeSnowflakeServiceImpl.java | 64 ++
.../beam/sdk/io/snowflake/test/TestUtils.java | 40 ++
.../beam/sdk/io/snowflake/test/package-info.java | 20 +
.../test/unit/BatchTestPipelineOptions.java | 28 +
.../test/unit/DataSourceConfigurationTest.java | 159 +++++
.../KeyPairSnowflakeCredentialsTest.java | 38 ++
.../OAuthTokenSnowflakeCredentialsTest.java | 46 ++
.../SnowflakeCredentialsFactoryTest.java | 77 +++
.../UsernamePasswordSnowflakeCredentialsTest.java | 50 ++
.../test/unit/read/SnowflakeIOReadTest.java | 278 ++++++++
.../snowflake/src/test/resources/test_rsa_key.p8 | 29 +
settings.gradle | 1 +
29 files changed, 2573 insertions(+)
diff --git a/CHANGES.md b/CHANGES.md
index 107e7dc..5908214 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -57,6 +57,7 @@
## I/Os
* Support for X source added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
+* Support for reading from Snowflake added (Java) ([BEAM-9722](https://issues.apache.org/jira/browse/BEAM-9722)).
## New Features / Improvements
diff --git a/build.gradle b/build.gradle
index c26a020..65363b7 100644
--- a/build.gradle
+++ b/build.gradle
@@ -112,6 +112,9 @@ rat {
"learning/katas/**/task-remote-info.yaml",
"learning/katas/*/IO/**/*.txt",
+ // test p8 file for SnowflakeIO
+ "sdks/java/io/snowflake/src/test/resources/test_rsa_key.p8",
+
// Mockito extensions
"sdks/java/io/amazon-web-services2/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker",
"sdks/java/extensions/ml/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker"
diff --git a/sdks/java/io/snowflake/build.gradle b/sdks/java/io/snowflake/build.gradle
new file mode 100644
index 0000000..c51c034
--- /dev/null
+++ b/sdks/java/io/snowflake/build.gradle
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+plugins { id 'org.apache.beam.module' }
+applyJavaNature(automaticModuleName: 'org.apache.beam.sdk.io.snowflake')
+provideIntegrationTestingDependencies()
+enableJavaPerformanceTesting()
+description = "Apache Beam :: SDKs :: Java :: IO :: Snowflake"
+ext.summary = "IO to read and write on Snowflake."
+dependencies {
+ compile library.java.vendored_guava_26_0_jre
+ compile project(path: ":sdks:java:core", configuration: "shadow")
+ compile project(path: ":sdks:java:extensions:google-cloud-platform-core")
+ compile library.java.slf4j_api
+ compile group: 'net.snowflake', name: 'snowflake-jdbc', version: '3.11.0'
+ compile group: 'com.opencsv', name: 'opencsv', version: '5.0'
+ testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
+ testCompile project(path: ":sdks:java:io:common", configuration: "testRuntime")
+ testCompile project(path: ":sdks:java:testing:test-utils", configuration: "testRuntime")
+ testCompile library.java.avro
+ testCompile library.java.junit
+ testCompile library.java.hamcrest_core
+ testCompile library.java.hamcrest_library
+ testCompile library.java.slf4j_api
+ testRuntimeOnly library.java.hadoop_client
+ testRuntimeOnly library.java.slf4j_jdk14
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
new file mode 100644
index 0000000..404859c
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake;
+
+public enum CloudProvider {
+ GCS("gs://");
+
+ private final String prefix;
+
+ private CloudProvider(String prefix) {
+ this.prefix = prefix;
+ }
+
+ public String getPrefix() {
+ return prefix;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
new file mode 100644
index 0000000..a67ba32
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
@@ -0,0 +1,759 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake;
+
+import static org.apache.beam.sdk.io.TextIO.readFiles;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.value.AutoValue;
+import com.opencsv.CSVParser;
+import com.opencsv.CSVParserBuilder;
+import java.io.IOException;
+import java.io.Serializable;
+import java.security.PrivateKey;
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.text.SimpleDateFormat;
+import java.util.Date;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.sql.DataSource;
+import net.snowflake.client.jdbc.SnowflakeBasicDataSource;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.io.FileIO;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.fs.MoveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.io.snowflake.credentials.KeyPairSnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.credentials.OAuthTokenSnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.credentials.SnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.credentials.UsernamePasswordSnowflakeCredentials;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Reshuffle;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Wait;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.display.HasDisplayData;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * IO to read and write data on Snowflake.
+ *
+ * <p>SnowflakeIO uses <a href="https://docs.snowflake.net/manuals/user-guide/jdbc.html">Snowflake
+ * JDBC</a> driver under the hood, but data isn't read/written using JDBC directly. Instead,
+ * SnowflakeIO uses dedicated <b>COPY</b> operations to read/write data from/to a cloud bucket. By
+ * now only Google Cloud Storage is supported.
+ *
+ * <p>To configure SnowflakeIO to read/write from your Snowflake instance, you have to provide a
+ * {@link DataSourceConfiguration} using {@link
+ * DataSourceConfiguration#create(SnowflakeCredentials)}, where {@link SnowflakeCredentials might be
+ * created using {@link org.apache.beam.sdk.io.snowflake.credentials.SnowflakeCredentialsFactory}}.
+ * Additionally one of {@link DataSourceConfiguration#withServerName(String)} or {@link
+ * DataSourceConfiguration#withUrl(String)} must be used to tell SnowflakeIO which instance to use.
+ * <br>
+ * There are also other options available to configure connection to Snowflake:
+ *
+ * <ul>
+ * <li>{@link DataSourceConfiguration#withWarehouse(String)} to specify which Warehouse to use
+ * <li>{@link DataSourceConfiguration#withDatabase(String)} to specify which Database to connect
+ * to
+ * <li>{@link DataSourceConfiguration#withSchema(String)} to specify which schema to use
+ * <li>{@link DataSourceConfiguration#withRole(String)} to specify which role to use
+ * <li>{@link DataSourceConfiguration#withLoginTimeout(Integer)} to specify the timeout for the
+ * login
+ * <li>{@link DataSourceConfiguration#withPortNumber(Integer)} to specify custom port of Snowflake
+ * instance
+ * </ul>
+ *
+ * <p>For example:
+ *
+ * <pre>{@code
+ * SnowflakeIO.DataSourceConfiguration dataSourceConfiguration =
+ * SnowflakeIO.DataSourceConfiguration.create(SnowflakeCredentialsFactory.of(options))
+ * .withServerName(options.getServerName())
+ * .withWarehouse(options.getWarehouse())
+ * .withDatabase(options.getDatabase())
+ * .withSchema(options.getSchema());
+ * }</pre>
+ *
+ * <h3>Reading from Snowflake</h3>
+ *
+ * <p>SnowflakeIO.Read returns a bounded collection of {@code T} as a {@code PCollection<T>}. T is
+ * the type returned by the provided {@link CsvMapper}.
+ *
+ * <p>For example
+ *
+ * <pre>{@code
+ * PCollection<GenericRecord> items = pipeline.apply(
+ * SnowflakeIO.<GenericRecord>read()
+ * .withDataSourceConfiguration(dataSourceConfiguration)
+ * .fromQuery(QUERY)
+ * .withStagingBucketName(stagingBucketName)
+ * .withIntegrationName(integrationName)
+ * .withCsvMapper(...)
+ * .withCoder(...));
+ * }</pre>
+ *
+ * <p><b>Important</b> When reading data from Snowflake, temporary CSV files are created on the
+ * specified stagingBucketName in directory named `sf_copy_csv_[RANDOM CHARS]_[TIMESTAMP]`. This
+ * directory and all the files are cleaned up automatically by default, but in case of failed
+ * pipeline they may remain and will have to be cleaned up manually.
+ */
+public class SnowflakeIO {
+ private static final Logger LOG = LoggerFactory.getLogger(SnowflakeIO.class);
+
+ private static final String CSV_QUOTE_CHAR = "'";
+ /**
+ * Read data from Snowflake.
+ *
+ * @param snowflakeService user-defined {@link SnowflakeService}
+ * @param <T> Type of the data to be read.
+ */
+ public static <T> Read<T> read(SnowflakeService snowflakeService) {
+ return new AutoValue_SnowflakeIO_Read.Builder<T>()
+ .setSnowflakeService(snowflakeService)
+ .build();
+ }
+
+ /**
+ * Read data from Snowflake.
+ *
+ * @param <T> Type of the data to be read.
+ */
+ public static <T> Read<T> read() {
+ return read(new SnowflakeServiceImpl());
+ }
+
+ /**
+ * Interface for user-defined function mapping parts of CSV line into T. Used for
+ * SnowflakeIO.Read.
+ *
+ * @param <T> Type of data to be read.
+ */
+ @FunctionalInterface
+ public interface CsvMapper<T> extends Serializable {
+ T mapRow(String[] parts) throws Exception;
+ }
+
+ /** Implementation of {@link #read()}. */
+ @AutoValue
+ public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>> {
+ @Nullable
+ abstract SerializableFunction<Void, DataSource> getDataSourceProviderFn();
+
+ @Nullable
+ abstract String getQuery();
+
+ @Nullable
+ abstract String getTable();
+
+ @Nullable
+ abstract String getIntegrationName();
+
+ @Nullable
+ abstract String getStagingBucketName();
+
+ @Nullable
+ abstract CsvMapper<T> getCsvMapper();
+
+ @Nullable
+ abstract Coder<T> getCoder();
+
+ @Nullable
+ abstract SnowflakeService getSnowflakeService();
+
+ abstract Builder<T> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+ abstract Builder<T> setDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn);
+
+ abstract Builder<T> setQuery(String query);
+
+ abstract Builder<T> setTable(String table);
+
+ abstract Builder<T> setIntegrationName(String integrationName);
+
+ abstract Builder<T> setStagingBucketName(String stagingBucketName);
+
+ abstract Builder<T> setCsvMapper(CsvMapper<T> csvMapper);
+
+ abstract Builder<T> setCoder(Coder<T> coder);
+
+ abstract Builder<T> setSnowflakeService(SnowflakeService snowflakeService);
+
+ abstract Read<T> build();
+ }
+
+ /**
+ * Setting information about Snowflake server.
+ *
+ * @param config - An instance of {@link DataSourceConfiguration}.
+ */
+ public Read<T> withDataSourceConfiguration(final DataSourceConfiguration config) {
+ if (config.getValidate()) {
+ try {
+ Connection connection = config.buildDatasource().getConnection();
+ connection.close();
+ } catch (SQLException e) {
+ throw new IllegalArgumentException(
+ "Invalid DataSourceConfiguration. Underlying cause: " + e);
+ }
+ }
+ return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config));
+ }
+
+ /**
+ * Setting function that will provide {@link DataSourceConfiguration} in runtime.
+ *
+ * @param dataSourceProviderFn a {@link SerializableFunction}.
+ */
+ public Read<T> withDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn) {
+ return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build();
+ }
+
+ /**
+ * A query to be executed in Snowflake.
+ *
+ * @param query - String with query.
+ */
+ public Read<T> fromQuery(String query) {
+ return toBuilder().setQuery(query).build();
+ }
+
+ /**
+ * A table name to be read in Snowflake.
+ *
+ * @param table - String with the name of the table.
+ */
+ public Read<T> fromTable(String table) {
+ return toBuilder().setTable(table).build();
+ }
+
+ /**
+ * Name of the cloud bucket (GCS by now) to use as tmp location of CSVs during COPY statement.
+ *
+ * @param stagingBucketName - String with the name of the bucket.
+ */
+ public Read<T> withStagingBucketName(String stagingBucketName) {
+ return toBuilder().setStagingBucketName(stagingBucketName).build();
+ }
+
+ /**
+ * Name of the Storage Integration in Snowflake to be used. See
+ * https://docs.snowflake.com/en/sql-reference/sql/create-storage-integration.html for
+ * reference.
+ *
+ * @param integrationName - String with the name of the Storage Integration.
+ */
+ public Read<T> withIntegrationName(String integrationName) {
+ return toBuilder().setIntegrationName(integrationName).build();
+ }
+
+ /**
+ * User-defined function mapping CSV lines into user data.
+ *
+ * @param csvMapper - an instance of {@link CsvMapper}.
+ */
+ public Read<T> withCsvMapper(CsvMapper<T> csvMapper) {
+ return toBuilder().setCsvMapper(csvMapper).build();
+ }
+
+ /**
+ * A Coder to be used by the output PCollection generated by the source.
+ *
+ * @param coder - an instance of {@link Coder}.
+ */
+ public Read<T> withCoder(Coder<T> coder) {
+ return toBuilder().setCoder(coder).build();
+ }
+
+ @Override
+ public PCollection<T> expand(PBegin input) {
+ // Either table or query is required. If query is present, it's being used, table is used
+ // otherwise
+ checkArgument(
+ getQuery() != null || getTable() != null, "fromTable() or fromQuery() is required");
+ checkArgument(
+ !(getQuery() != null && getTable() != null),
+ "fromTable() and fromQuery() are not allowed together");
+ checkArgument(getCsvMapper() != null, "withCsvMapper() is required");
+ checkArgument(getCoder() != null, "withCoder() is required");
+ checkArgument(getIntegrationName() != null, "withIntegrationName() is required");
+ checkArgument(getStagingBucketName() != null, "withStagingBucketName() is required");
+ checkArgument(
+ (getDataSourceProviderFn() != null),
+ "withDataSourceConfiguration() or withDataSourceProviderFn() is required");
+
+ String tmpDirName = makeTmpDirName();
+ String stagingBucketDir = String.format("%s/%s", getStagingBucketName(), tmpDirName);
+
+ PCollection<Void> emptyCollection = input.apply(Create.of((Void) null));
+
+ PCollection<T> output =
+ emptyCollection
+ .apply(
+ ParDo.of(
+ new CopyIntoStageFn(
+ getDataSourceProviderFn(),
+ getQuery(),
+ getTable(),
+ getIntegrationName(),
+ stagingBucketDir,
+ getSnowflakeService())))
+ .apply(Reshuffle.viaRandomKey())
+ .apply(FileIO.matchAll())
+ .apply(FileIO.readMatches())
+ .apply(readFiles())
+ .apply(ParDo.of(new MapCsvToStringArrayFn()))
+ .apply(ParDo.of(new MapStringArrayToUserDataFn<>(getCsvMapper())));
+
+ output.setCoder(getCoder());
+
+ emptyCollection
+ .apply(Wait.on(output))
+ .apply(ParDo.of(new CleanTmpFilesFromGcsFn(stagingBucketDir)));
+
+ return output;
+ }
+
+ private String makeTmpDirName() {
+ return String.format(
+ "sf_copy_csv_%s_%s",
+ new SimpleDateFormat("yyyyMMdd_HHmmss").format(new Date()),
+ UUID.randomUUID().toString().subSequence(0, 8) // first 8 chars of UUID should be enough
+ );
+ }
+
+ private static class CopyIntoStageFn extends DoFn<Object, String> {
+ private final SerializableFunction<Void, DataSource> dataSourceProviderFn;
+ private final String query;
+ private final String table;
+ private final String integrationName;
+ private final String stagingBucketDir;
+ private final SnowflakeService snowflakeService;
+
+ private CopyIntoStageFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn,
+ String query,
+ String table,
+ String integrationName,
+ String stagingBucketDir,
+ SnowflakeService snowflakeService) {
+ this.dataSourceProviderFn = dataSourceProviderFn;
+ this.query = query;
+ this.table = table;
+ this.integrationName = integrationName;
+ this.stagingBucketDir =
+ String.format(
+ "%s/run_%s/", stagingBucketDir, UUID.randomUUID().toString().subSequence(0, 8));
+ this.snowflakeService = snowflakeService;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context) throws Exception {
+ String output =
+ snowflakeService.copyIntoStage(
+ dataSourceProviderFn, query, table, integrationName, stagingBucketDir);
+
+ context.output(output);
+ }
+ }
+
+ public static class MapCsvToStringArrayFn extends DoFn<String, String[]> {
+ @ProcessElement
+ public void processElement(ProcessContext c) throws IOException {
+ String csvLine = c.element();
+ CSVParser parser = new CSVParserBuilder().withQuoteChar(CSV_QUOTE_CHAR.charAt(0)).build();
+ String[] parts = parser.parseLine(csvLine);
+ c.output(parts);
+ }
+ }
+
+ private static class MapStringArrayToUserDataFn<T> extends DoFn<String[], T> {
+ private final CsvMapper<T> csvMapper;
+
+ public MapStringArrayToUserDataFn(CsvMapper<T> csvMapper) {
+ this.csvMapper = csvMapper;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context) throws Exception {
+ context.output(csvMapper.mapRow(context.element()));
+ }
+ }
+
+ public static class CleanTmpFilesFromGcsFn extends DoFn<Object, Object> {
+ private final String stagingBucketDir;
+
+ public CleanTmpFilesFromGcsFn(String stagingBucketDir) {
+ this.stagingBucketDir = stagingBucketDir;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext c) throws IOException {
+ String combinedPath = stagingBucketDir + "/**";
+ List<ResourceId> paths =
+ FileSystems.match(combinedPath).metadata().stream()
+ .map(metadata -> metadata.resourceId())
+ .collect(Collectors.toList());
+
+ FileSystems.delete(paths, MoveOptions.StandardMoveOptions.IGNORE_MISSING_FILES);
+ }
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+ if (getQuery() != null) {
+ builder.add(DisplayData.item("query", getQuery()));
+ }
+ if (getTable() != null) {
+ builder.add(DisplayData.item("table", getTable()));
+ }
+ builder.add(DisplayData.item("integrationName", getIntegrationName()));
+ builder.add(DisplayData.item("stagingBucketName", getStagingBucketName()));
+ builder.add(DisplayData.item("csvMapper", getCsvMapper().getClass().getName()));
+ builder.add(DisplayData.item("coder", getCoder().getClass().getName()));
+ if (getDataSourceProviderFn() instanceof HasDisplayData) {
+ ((HasDisplayData) getDataSourceProviderFn()).populateDisplayData(builder);
+ }
+ }
+ }
+
+ /**
+ * A POJO describing a {@link DataSource}, providing all properties allowing to create a {@link
+ * DataSource}.
+ */
+ @AutoValue
+ public abstract static class DataSourceConfiguration implements Serializable {
+ @Nullable
+ public abstract String getUrl();
+
+ @Nullable
+ public abstract String getUsername();
+
+ @Nullable
+ public abstract String getPassword();
+
+ @Nullable
+ public abstract PrivateKey getPrivateKey();
+
+ @Nullable
+ public abstract String getOauthToken();
+
+ @Nullable
+ public abstract String getDatabase();
+
+ @Nullable
+ public abstract String getWarehouse();
+
+ @Nullable
+ public abstract String getSchema();
+
+ @Nullable
+ public abstract String getServerName();
+
+ @Nullable
+ public abstract Integer getPortNumber();
+
+ @Nullable
+ public abstract String getRole();
+
+ @Nullable
+ public abstract Integer getLoginTimeout();
+
+ @Nullable
+ public abstract Boolean getSsl();
+
+ @Nullable
+ public abstract Boolean getValidate();
+
+ @Nullable
+ public abstract DataSource getDataSource();
+
+ abstract Builder builder();
+
+ @AutoValue.Builder
+ abstract static class Builder {
+ abstract Builder setUrl(String url);
+
+ abstract Builder setUsername(String username);
+
+ abstract Builder setPassword(String password);
+
+ abstract Builder setPrivateKey(PrivateKey privateKey);
+
+ abstract Builder setOauthToken(String oauthToken);
+
+ abstract Builder setDatabase(String database);
+
+ abstract Builder setWarehouse(String warehouse);
+
+ abstract Builder setSchema(String schema);
+
+ abstract Builder setServerName(String serverName);
+
+ abstract Builder setPortNumber(Integer portNumber);
+
+ abstract Builder setRole(String role);
+
+ abstract Builder setLoginTimeout(Integer loginTimeout);
+
+ abstract Builder setSsl(Boolean ssl);
+
+ abstract Builder setValidate(Boolean validate);
+
+ abstract Builder setDataSource(DataSource dataSource);
+
+ abstract DataSourceConfiguration build();
+ }
+
+ /**
+ * Creates {@link DataSourceConfiguration} from existing instance of {@link DataSource}.
+ *
+ * @param dataSource - an instance of {@link DataSource}.
+ */
+ public static DataSourceConfiguration create(DataSource dataSource) {
+ checkArgument(dataSource instanceof Serializable, "dataSource must be Serializable");
+ return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
+ .setValidate(true)
+ .setDataSource(dataSource)
+ .build();
+ }
+
+ /**
+ * Creates {@link DataSourceConfiguration} from instance of {@link SnowflakeCredentials}.
+ *
+ * @param credentials - an instance of {@link SnowflakeCredentials}.
+ */
+ public static DataSourceConfiguration create(SnowflakeCredentials credentials) {
+ if (credentials instanceof UsernamePasswordSnowflakeCredentials) {
+ return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
+ .setValidate(true)
+ .setUsername(((UsernamePasswordSnowflakeCredentials) credentials).getUsername())
+ .setPassword(((UsernamePasswordSnowflakeCredentials) credentials).getPassword())
+ .build();
+ } else if (credentials instanceof OAuthTokenSnowflakeCredentials) {
+ return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
+ .setValidate(true)
+ .setOauthToken(((OAuthTokenSnowflakeCredentials) credentials).getToken())
+ .build();
+ } else if (credentials instanceof KeyPairSnowflakeCredentials) {
+ return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
+ .setValidate(true)
+ .setUsername(((KeyPairSnowflakeCredentials) credentials).getUsername())
+ .setPrivateKey(((KeyPairSnowflakeCredentials) credentials).getPrivateKey())
+ .build();
+ }
+ throw new IllegalArgumentException(
+ "Can't create DataSourceConfiguration from given credentials");
+ }
+
+ /**
+ * Sets URL of Snowflake server in following format:
+ * jdbc:snowflake://<account_name>.snowflakecomputing.com
+ *
+ * <p>Either withUrl or withServerName is required.
+ *
+ * @param url - String with URL of the Snowflake server.
+ */
+ public DataSourceConfiguration withUrl(String url) {
+ checkArgument(
+ url.startsWith("jdbc:snowflake://"),
+ "url must have format: jdbc:snowflake://<account_name>.snowflakecomputing.com");
+ checkArgument(
+ url.endsWith("snowflakecomputing.com"),
+ "url must have format: jdbc:snowflake://<account_name>.snowflakecomputing.com");
+ return builder().setUrl(url).build();
+ }
+
+ /**
+ * Sets database to use.
+ *
+ * @param database - String with database name.
+ */
+ public DataSourceConfiguration withDatabase(String database) {
+ return builder().setDatabase(database).build();
+ }
+
+ /**
+ * Sets Snowflake Warehouse to use.
+ *
+ * @param warehouse - String with warehouse name.
+ */
+ public DataSourceConfiguration withWarehouse(String warehouse) {
+ return builder().setWarehouse(warehouse).build();
+ }
+
+ /**
+ * Sets schema to use when connecting to Snowflake.
+ *
+ * @param schema - String with schema name.
+ */
+ public DataSourceConfiguration withSchema(String schema) {
+ return builder().setSchema(schema).build();
+ }
+
+ /**
+ * Sets the name of the Snowflake server. Following format is required:
+ * <account_name>.snowflakecomputing.com
+ *
+ * <p>Either withServerName or withUrl is required.
+ *
+ * @param serverName - String with server name.
+ */
+ public DataSourceConfiguration withServerName(String serverName) {
+ checkArgument(
+ serverName.endsWith("snowflakecomputing.com"),
+ "serverName must be in format <account_name>.snowflakecomputing.com");
+ return builder().setServerName(serverName).build();
+ }
+
+ /**
+ * Sets port number to use to connect to Snowflake.
+ *
+ * @param portNumber - Integer with port number.
+ */
+ public DataSourceConfiguration withPortNumber(Integer portNumber) {
+ return builder().setPortNumber(portNumber).build();
+ }
+
+ /**
+ * Sets user's role to be used when running queries on Snowflake.
+ *
+ * @param role - String with role name.
+ */
+ public DataSourceConfiguration withRole(String role) {
+ return builder().setRole(role).build();
+ }
+
+ /**
+ * Sets loginTimeout that will be used in {@link SnowflakeBasicDataSource:setLoginTimeout}.
+ *
+ * @param loginTimeout - Integer with timeout value.
+ */
+ public DataSourceConfiguration withLoginTimeout(Integer loginTimeout) {
+ return builder().setLoginTimeout(loginTimeout).build();
+ }
+
+ /**
+ * Disables validation of connection parameters prior to pipeline submission.
+ *
+ * @return
+ */
+ public DataSourceConfiguration withoutValidation() {
+ return builder().setValidate(false).build();
+ }
+
+ void populateDisplayData(DisplayData.Builder builder) {
+ if (getDataSource() != null) {
+ builder.addIfNotNull(DisplayData.item("dataSource", getDataSource().getClass().getName()));
+ } else {
+ builder.addIfNotNull(DisplayData.item("jdbcUrl", getUrl()));
+ builder.addIfNotNull(DisplayData.item("username", getUsername()));
+ }
+ }
+
+ /** Builds {@link SnowflakeBasicDataSource} based on the current configuration. */
+ public DataSource buildDatasource() {
+ if (getDataSource() == null) {
+ SnowflakeBasicDataSource basicDataSource = new SnowflakeBasicDataSource();
+
+ if (getUrl() != null) {
+ basicDataSource.setUrl(getUrl());
+ }
+ if (getUsername() != null) {
+ basicDataSource.setUser(getUsername());
+ }
+ if (getPassword() != null) {
+ basicDataSource.setPassword(getPassword());
+ }
+ if (getPrivateKey() != null) {
+ basicDataSource.setPrivateKey(getPrivateKey());
+ }
+ if (getDatabase() != null) {
+ basicDataSource.setDatabaseName(getDatabase());
+ }
+ if (getWarehouse() != null) {
+ basicDataSource.setWarehouse(getWarehouse());
+ }
+ if (getSchema() != null) {
+ basicDataSource.setSchema(getSchema());
+ }
+ if (getServerName() != null) {
+ basicDataSource.setServerName(getServerName());
+ }
+ if (getPortNumber() != null) {
+ basicDataSource.setPortNumber(getPortNumber());
+ }
+ if (getRole() != null) {
+ basicDataSource.setRole(getRole());
+ }
+ if (getLoginTimeout() != null) {
+ try {
+ basicDataSource.setLoginTimeout(getLoginTimeout());
+ } catch (SQLException e) {
+ throw new RuntimeException("Failed to setLoginTimeout");
+ }
+ }
+ if (getOauthToken() != null) {
+ basicDataSource.setOauthToken(getOauthToken());
+ }
+ return basicDataSource;
+ }
+ return getDataSource();
+ }
+ }
+
+ public static class DataSourceProviderFromDataSourceConfiguration
+ implements SerializableFunction<Void, DataSource>, HasDisplayData {
+ private static final ConcurrentHashMap<DataSourceConfiguration, DataSource> instances =
+ new ConcurrentHashMap<>();
+ private final DataSourceConfiguration config;
+
+ private DataSourceProviderFromDataSourceConfiguration(DataSourceConfiguration config) {
+ this.config = config;
+ }
+
+ public static SerializableFunction<Void, DataSource> of(DataSourceConfiguration config) {
+ return new DataSourceProviderFromDataSourceConfiguration(config);
+ }
+
+ @Override
+ public DataSource apply(Void input) {
+ return instances.computeIfAbsent(config, (config) -> config.buildDatasource());
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ config.populateDisplayData(builder);
+ }
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakePipelineOptions.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakePipelineOptions.java
new file mode 100644
index 0000000..783230e
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakePipelineOptions.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake;
+
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.Validation;
+
+public interface SnowflakePipelineOptions extends PipelineOptions {
+ String BASIC_CONNECTION_INFO_VALIDATION_GROUP = "BASIC_CONNECTION_INFO_GROUP";
+ String AUTH_VALIDATION_GROUP = "AUTH_VALIDATION_GROUP";
+
+ @Description(
+ "Snowflake's JDBC-like url including account name and region without any parameters.")
+ @Validation.Required(groups = BASIC_CONNECTION_INFO_VALIDATION_GROUP)
+ String getUrl();
+
+ void setUrl(String url);
+
+ @Description("Server Name - full server name with account, zone and domain.")
+ @Validation.Required(groups = BASIC_CONNECTION_INFO_VALIDATION_GROUP)
+ String getServerName();
+
+ void setServerName(String serverName);
+
+ @Description("Username. Required for username/password and Private Key authentication.")
+ @Validation.Required(groups = AUTH_VALIDATION_GROUP)
+ String getUsername();
+
+ void setUsername(String username);
+
+ @Description("OAuth token. Required for OAuth authentication only.")
+ @Validation.Required(groups = AUTH_VALIDATION_GROUP)
+ String getOauthToken();
+
+ void setOauthToken(String oauthToken);
+
+ @Description("Password. Required for username/password authentication only.")
+ @Default.String("")
+ String getPassword();
+
+ void setPassword(String password);
+
+ @Description("Path to Private Key file. Required for Private Key authentication only.")
+ @Default.String("")
+ String getPrivateKeyPath();
+
+ void setPrivateKeyPath(String privateKeyPath);
+
+ @Description("Private Key's passphrase. Required for Private Key authentication only.")
+ @Default.String("")
+ String getPrivateKeyPassphrase();
+
+ void setPrivateKeyPassphrase(String keyPassphrase);
+
+ @Description("Warehouse to use. Optional.")
+ @Default.String("")
+ String getWarehouse();
+
+ void setWarehouse(String warehouse);
+
+ @Description("Database name to connect to. Optional.")
+ @Default.String("")
+ String getDatabase();
+
+ void setDatabase(String database);
+
+ @Description("Schema to use. Optional.")
+ @Default.String("")
+ String getSchema();
+
+ void setSchema(String schema);
+
+ @Description("Role to use. Optional.")
+ @Default.String("")
+ String getRole();
+
+ void setRole(String role);
+
+ @Description("Authenticator to use. Optional.")
+ @Default.String("")
+ String getAuthenticator();
+
+ void setAuthenticator(String authenticator);
+
+ @Description("Port number. Optional.")
+ @Default.String("")
+ String getPortNumber();
+
+ void setPortNumber(String portNumber);
+
+ @Description("Login timeout. Optional.")
+ @Default.String("")
+ String getLoginTimeout();
+
+ void setLoginTimeout(String loginTimeout);
+
+ @Description("External location name to connect to.")
+ String getExternalLocation();
+
+ void setExternalLocation(String externalLocation);
+
+ @Description("Temporary GCS bucket name")
+ String getStagingBucketName();
+
+ void setStagingBucketName(String stagingBucketName);
+
+ @Description("Storage integration - required in case the external stage is not specified.")
+ String getStorageIntegration();
+
+ void setStorageIntegration(String integration);
+
+ @Description("Stage name. Optional.")
+ String getStage();
+
+ void setStage(String stage);
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeService.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeService.java
new file mode 100644
index 0000000..6375e79
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeService.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake;
+
+import java.io.Serializable;
+import java.sql.SQLException;
+import javax.sql.DataSource;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+
+/** Interface which defines common methods for interacting with SnowFlake. */
+public interface SnowflakeService extends Serializable {
+ String CSV_QUOTE_CHAR_FOR_COPY = "''";
+
+ String copyIntoStage(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn,
+ String query,
+ String table,
+ String integrationName,
+ String stagingBucketDir)
+ throws SQLException;
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeServiceImpl.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeServiceImpl.java
new file mode 100644
index 0000000..5aaad06
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeServiceImpl.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.function.Consumer;
+import javax.sql.DataSource;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+
+/**
+ * Implemenation of {@link org.apache.beam.sdk.io.snowflake.SnowflakeService} used in production.
+ */
+public class SnowflakeServiceImpl implements SnowflakeService {
+ private static final String SNOWFLAKE_GCS_PREFIX = "gcs://";
+
+ @Override
+ public String copyIntoStage(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn,
+ String query,
+ String table,
+ String integrationName,
+ String stagingBucketDir)
+ throws SQLException {
+
+ String from;
+ if (query != null) {
+ // Query must be surrounded with brackets
+ from = String.format("(%s)", query);
+ } else {
+ from = table;
+ }
+
+ String copyQuery =
+ String.format(
+ "COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');",
+ getProperBucketDir(stagingBucketDir), from, integrationName, CSV_QUOTE_CHAR_FOR_COPY);
+
+ runStatement(copyQuery, getConnection(dataSourceProviderFn), null);
+
+ return stagingBucketDir.concat("*");
+ }
+
+ private static void runStatement(String query, Connection connection, Consumer resultSetMethod)
+ throws SQLException {
+ PreparedStatement statement = connection.prepareStatement(query);
+ try {
+ if (resultSetMethod != null) {
+ ResultSet resultSet = statement.executeQuery();
+ resultSetMethod.accept(resultSet);
+ } else {
+ statement.execute();
+ }
+ } finally {
+ statement.close();
+ connection.close();
+ }
+ }
+
+ private Connection getConnection(SerializableFunction<Void, DataSource> dataSourceProviderFn)
+ throws SQLException {
+ DataSource dataSource = dataSourceProviderFn.apply(null);
+ return dataSource.getConnection();
+ }
+
+ // Snowflake is expecting "gcs://" prefix for GCS and Beam "gs://"
+ private String getProperBucketDir(String bucketDir) {
+ if (bucketDir.contains(CloudProvider.GCS.getPrefix())) {
+ return bucketDir.replace(CloudProvider.GCS.getPrefix(), SNOWFLAKE_GCS_PREFIX);
+ }
+ return bucketDir;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/KeyPairSnowflakeCredentials.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/KeyPairSnowflakeCredentials.java
new file mode 100644
index 0000000..286ec62
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/KeyPairSnowflakeCredentials.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.credentials;
+
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.security.InvalidKeyException;
+import java.security.KeyFactory;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivateKey;
+import java.security.spec.InvalidKeySpecException;
+import java.security.spec.PKCS8EncodedKeySpec;
+import java.util.Base64;
+import javax.crypto.EncryptedPrivateKeyInfo;
+import javax.crypto.SecretKeyFactory;
+import javax.crypto.spec.PBEKeySpec;
+
+/** POJO for handling Key-Pair authentication against Snowflake. */
+public class KeyPairSnowflakeCredentials implements SnowflakeCredentials {
+ private String username;
+ private PrivateKey privateKey;
+
+ public KeyPairSnowflakeCredentials(
+ String username, String privateKeyPath, String privateKeyPassword) {
+ this.username = username;
+ this.privateKey = getPrivateKey(privateKeyPath, privateKeyPassword);
+ }
+
+ public KeyPairSnowflakeCredentials(String username, PrivateKey privateKey) {
+ this.username = username;
+ this.privateKey = privateKey;
+ }
+
+ private PrivateKey getPrivateKey(String privateKeyPath, String privateKeyPassphrase) {
+ try {
+ byte[] keyBytes = Files.readAllBytes(Paths.get(privateKeyPath));
+
+ String encrypted = new String(keyBytes, Charset.defaultCharset());
+ encrypted = encrypted.replace("-----BEGIN ENCRYPTED PRIVATE KEY-----", "");
+ encrypted = encrypted.replace("-----END ENCRYPTED PRIVATE KEY-----", "");
+ EncryptedPrivateKeyInfo pkInfo =
+ new EncryptedPrivateKeyInfo(Base64.getMimeDecoder().decode(encrypted));
+ PBEKeySpec keySpec = new PBEKeySpec(privateKeyPassphrase.toCharArray());
+ SecretKeyFactory pbeKeyFactory = SecretKeyFactory.getInstance(pkInfo.getAlgName());
+ PKCS8EncodedKeySpec encodedKeySpec = pkInfo.getKeySpec(pbeKeyFactory.generateSecret(keySpec));
+
+ KeyFactory keyFactory = KeyFactory.getInstance("RSA");
+ return keyFactory.generatePrivate(encodedKeySpec);
+ } catch (IOException
+ | NoSuchAlgorithmException
+ | InvalidKeySpecException
+ | InvalidKeyException ex) {
+ throw new RuntimeException("Can't create PrivateKey from options");
+ }
+ }
+
+ public String getUsername() {
+ return username;
+ }
+
+ public PrivateKey getPrivateKey() {
+ return privateKey;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/OAuthTokenSnowflakeCredentials.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/OAuthTokenSnowflakeCredentials.java
new file mode 100644
index 0000000..be102a8
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/OAuthTokenSnowflakeCredentials.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.credentials;
+
+/** POJO for handling OAuth authentication against Snowflake, using pre-obtained OAuth token. */
+public class OAuthTokenSnowflakeCredentials implements SnowflakeCredentials {
+ private String token;
+
+ public OAuthTokenSnowflakeCredentials(String token) {
+ this.token = token;
+ }
+
+ public String getToken() {
+ return token;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/SnowflakeCredentials.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/SnowflakeCredentials.java
new file mode 100644
index 0000000..e3abf91
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/SnowflakeCredentials.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.credentials;
+
+/**
+ * Interface for holding credentials. Allows creating {@link
+ * org.apache.beam.sdk.io.snowflake.SnowflakeIO.DataSourceConfiguration}.
+ */
+public interface SnowflakeCredentials {}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/SnowflakeCredentialsFactory.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/SnowflakeCredentialsFactory.java
new file mode 100644
index 0000000..3876c2f
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/SnowflakeCredentialsFactory.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.credentials;
+
+import org.apache.beam.sdk.io.snowflake.SnowflakePipelineOptions;
+
+/**
+ * Factory class for creating implementations of {@link SnowflakeCredentials} from {@link
+ * SnowflakePipelineOptions}.
+ */
+public class SnowflakeCredentialsFactory {
+ public static SnowflakeCredentials of(SnowflakePipelineOptions options) {
+ if (oauthOptionsAvailable(options)) {
+ return new OAuthTokenSnowflakeCredentials(options.getOauthToken());
+ } else if (usernamePasswordOptionsAvailable(options)) {
+ return new UsernamePasswordSnowflakeCredentials(options.getUsername(), options.getPassword());
+ } else if (keyPairOptionsAvailable(options)) {
+ return new KeyPairSnowflakeCredentials(
+ options.getUsername(), options.getPrivateKeyPath(), options.getPrivateKeyPassphrase());
+ }
+ throw new RuntimeException("Can't get credentials from Options");
+ }
+
+ private static boolean oauthOptionsAvailable(SnowflakePipelineOptions options) {
+ return options.getOauthToken() != null && !options.getOauthToken().isEmpty();
+ }
+
+ private static boolean usernamePasswordOptionsAvailable(SnowflakePipelineOptions options) {
+ return options.getUsername() != null
+ && !options.getUsername().isEmpty()
+ && !options.getPassword().isEmpty();
+ }
+
+ private static boolean keyPairOptionsAvailable(SnowflakePipelineOptions options) {
+ return options.getUsername() != null
+ && !options.getUsername().isEmpty()
+ && !options.getPrivateKeyPath().isEmpty()
+ && !options.getPrivateKeyPassphrase().isEmpty();
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/UsernamePasswordSnowflakeCredentials.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/UsernamePasswordSnowflakeCredentials.java
new file mode 100644
index 0000000..1d8bdce
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/UsernamePasswordSnowflakeCredentials.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.credentials;
+
+/** POJO for handling Username & Password authentication against Snowflake. */
+public class UsernamePasswordSnowflakeCredentials implements SnowflakeCredentials {
+ private String username;
+ private String password;
+
+ public UsernamePasswordSnowflakeCredentials(String username, String password) {
+ this.username = username;
+ this.password = password;
+ }
+
+ public String getUsername() {
+ return username;
+ }
+
+ public String getPassword() {
+ return password;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/package-info.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/package-info.java
new file mode 100644
index 0000000..f76d241
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/credentials/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Credentials for SnowflakeIO. */
+package org.apache.beam.sdk.io.snowflake.credentials;
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/package-info.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/package-info.java
new file mode 100644
index 0000000..9dbcf05
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Snowflake IO transforms. */
+package org.apache.beam.sdk.io.snowflake;
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeBasicDataSource.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeBasicDataSource.java
new file mode 100644
index 0000000..5fc694f
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeBasicDataSource.java
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test;
+
+import java.io.Serializable;
+import java.sql.Array;
+import java.sql.Blob;
+import java.sql.CallableStatement;
+import java.sql.Clob;
+import java.sql.Connection;
+import java.sql.DatabaseMetaData;
+import java.sql.NClob;
+import java.sql.PreparedStatement;
+import java.sql.SQLClientInfoException;
+import java.sql.SQLException;
+import java.sql.SQLWarning;
+import java.sql.SQLXML;
+import java.sql.Savepoint;
+import java.sql.Statement;
+import java.sql.Struct;
+import java.util.Map;
+import java.util.Properties;
+import java.util.concurrent.Executor;
+import net.snowflake.client.jdbc.SnowflakeBasicDataSource;
+
+/**
+ * Fake implementation of {@link net.snowflake.client.jdbc.SnowflakeBasicDataSource} used in tests.
+ */
+public class FakeSnowflakeBasicDataSource extends SnowflakeBasicDataSource implements Serializable {
+ @Override
+ public FakeConnection getConnection() throws SQLException {
+ return new FakeConnection();
+ }
+
+ private class FakeConnection implements Connection {
+
+ @Override
+ public Statement createStatement() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public PreparedStatement prepareStatement(String sql) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public CallableStatement prepareCall(String sql) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public String nativeSQL(String sql) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void setAutoCommit(boolean autoCommit) throws SQLException {}
+
+ @Override
+ public boolean getAutoCommit() throws SQLException {
+ return false;
+ }
+
+ @Override
+ public void commit() throws SQLException {}
+
+ @Override
+ public void rollback() throws SQLException {}
+
+ @Override
+ public void close() throws SQLException {}
+
+ @Override
+ public boolean isClosed() throws SQLException {
+ return false;
+ }
+
+ @Override
+ public DatabaseMetaData getMetaData() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void setReadOnly(boolean readOnly) throws SQLException {}
+
+ @Override
+ public boolean isReadOnly() throws SQLException {
+ return false;
+ }
+
+ @Override
+ public void setCatalog(String catalog) throws SQLException {}
+
+ @Override
+ public String getCatalog() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void setTransactionIsolation(int level) throws SQLException {}
+
+ @Override
+ public int getTransactionIsolation() throws SQLException {
+ return 0;
+ }
+
+ @Override
+ public SQLWarning getWarnings() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void clearWarnings() throws SQLException {}
+
+ @Override
+ public Statement createStatement(int resultSetType, int resultSetConcurrency)
+ throws SQLException {
+ return null;
+ }
+
+ @Override
+ public PreparedStatement prepareStatement(
+ String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency)
+ throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Map<String, Class<?>> getTypeMap() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void setTypeMap(Map<String, Class<?>> map) throws SQLException {}
+
+ @Override
+ public void setHoldability(int holdability) throws SQLException {}
+
+ @Override
+ public int getHoldability() throws SQLException {
+ return 0;
+ }
+
+ @Override
+ public Savepoint setSavepoint() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Savepoint setSavepoint(String name) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void rollback(Savepoint savepoint) throws SQLException {}
+
+ @Override
+ public void releaseSavepoint(Savepoint savepoint) throws SQLException {}
+
+ @Override
+ public Statement createStatement(
+ int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public PreparedStatement prepareStatement(
+ String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability)
+ throws SQLException {
+ return null;
+ }
+
+ @Override
+ public CallableStatement prepareCall(
+ String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability)
+ throws SQLException {
+ return null;
+ }
+
+ @Override
+ public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys)
+ throws SQLException {
+ return null;
+ }
+
+ @Override
+ public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public PreparedStatement prepareStatement(String sql, String[] columnNames)
+ throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Clob createClob() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Blob createBlob() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public NClob createNClob() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public SQLXML createSQLXML() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public boolean isValid(int timeout) throws SQLException {
+ return false;
+ }
+
+ @Override
+ public void setClientInfo(String name, String value) throws SQLClientInfoException {}
+
+ @Override
+ public void setClientInfo(Properties properties) throws SQLClientInfoException {}
+
+ @Override
+ public String getClientInfo(String name) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Properties getClientInfo() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void setSchema(String schema) throws SQLException {}
+
+ @Override
+ public String getSchema() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void abort(Executor executor) throws SQLException {}
+
+ @Override
+ public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {}
+
+ @Override
+ public int getNetworkTimeout() throws SQLException {
+ return 0;
+ }
+
+ @Override
+ public <T> T unwrap(Class<T> iface) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public boolean isWrapperFor(Class<?> iface) throws SQLException {
+ return false;
+ }
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeDatabase.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeDatabase.java
new file mode 100644
index 0000000..5bf8b21
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeDatabase.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import net.snowflake.client.jdbc.SnowflakeSQLException;
+
+/** Fake implementation of SnowFlake warehouse used in test code. */
+public class FakeSnowflakeDatabase implements Serializable {
+ private static Map<String, List<String>> tables = new HashMap<>();
+
+ private FakeSnowflakeDatabase() {
+ tables = new HashMap<>();
+ }
+
+ public static void createTable(String table) {
+ FakeSnowflakeDatabase.tables.put(table, Collections.emptyList());
+ }
+
+ public static List<String> getElements(String table) throws SnowflakeSQLException {
+ if (!isTableExist(table)) {
+ throw new SnowflakeSQLException(
+ null, "SQL compilation error: Table does not exist", table, 0);
+ }
+
+ return FakeSnowflakeDatabase.tables.get(table);
+ }
+
+ public static List<String> runQuery(String query) throws SnowflakeSQLException {
+ if (query.startsWith("SELECT * FROM ")) {
+ String tableName = query.replace("SELECT * FROM ", "");
+ return getElements(tableName);
+ }
+ throw new SnowflakeSQLException(null, "SQL compilation error: Invalid query", query, 0);
+ }
+
+ public static List<Long> getElementsAsLong(String table) throws SnowflakeSQLException {
+ List<String> elements = getElements(table);
+ return elements.stream().map(Long::parseLong).collect(Collectors.toList());
+ }
+
+ public static boolean isTableExist(String table) {
+ return FakeSnowflakeDatabase.tables.containsKey(table);
+ }
+
+ public static boolean isTableEmpty(String table) {
+ return FakeSnowflakeDatabase.tables.get(table).isEmpty();
+ }
+
+ public static void createTableWithElements(String table, List<String> rows) {
+ FakeSnowflakeDatabase.tables.put(table, rows);
+ }
+
+ public static void clean() {
+ FakeSnowflakeDatabase.tables = new HashMap<>();
+ }
+
+ public static void truncateTable(String table) {
+ FakeSnowflakeDatabase.createTable(table);
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeServiceImpl.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeServiceImpl.java
new file mode 100644
index 0000000..4a62dcd
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeServiceImpl.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.sql.SQLException;
+import java.util.List;
+import javax.sql.DataSource;
+import org.apache.beam.sdk.io.snowflake.SnowflakeService;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+
+/**
+ * Fake implementation of {@link org.apache.beam.sdk.io.snowflake.SnowflakeService} used in tests.
+ */
+public class FakeSnowflakeServiceImpl implements SnowflakeService {
+
+ @Override
+ public String copyIntoStage(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn,
+ String query,
+ String table,
+ String integrationName,
+ String stagingBucketName)
+ throws SQLException {
+
+ if (table != null) {
+ writeToFile(stagingBucketName, FakeSnowflakeDatabase.getElements(table));
+ }
+ if (query != null) {
+ writeToFile(stagingBucketName, FakeSnowflakeDatabase.runQuery(query));
+ }
+
+ return String.format("./%s/*", stagingBucketName);
+ }
+
+ private void writeToFile(String stagingBucketNameTmp, List<String> rows) {
+ Path filePath = Paths.get(String.format("./%s/table.csv.gz", stagingBucketNameTmp));
+ try {
+ Files.createDirectories(filePath.getParent());
+ Files.createFile(filePath);
+ Files.write(filePath, rows);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create files", e);
+ }
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/TestUtils.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/TestUtils.java
new file mode 100644
index 0000000..aab8d7d
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/TestUtils.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test;
+
+import java.io.File;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TestUtils {
+
+ private static final Logger LOG = LoggerFactory.getLogger(TestUtils.class);
+
+ private static final String PRIVATE_KEY_FILE_NAME = "test_rsa_key.p8";
+ private static final String PRIVATE_KEY_PASSPHRASE = "snowflake";
+
+ public static String getPrivateKeyPath(Class klass) {
+ ClassLoader classLoader = klass.getClassLoader();
+ File file = new File(classLoader.getResource(PRIVATE_KEY_FILE_NAME).getFile());
+ return file.getAbsolutePath();
+ }
+
+ public static String getPrivateKeyPassphrase() {
+ return PRIVATE_KEY_PASSPHRASE;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/package-info.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/package-info.java
new file mode 100644
index 0000000..2e2cf2f
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Snowflake IO tests. */
+package org.apache.beam.sdk.io.snowflake.test;
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/BatchTestPipelineOptions.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/BatchTestPipelineOptions.java
new file mode 100644
index 0000000..3504c45
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/BatchTestPipelineOptions.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit;
+
+import org.apache.beam.sdk.io.snowflake.SnowflakePipelineOptions;
+import org.apache.beam.sdk.options.Description;
+
+public interface BatchTestPipelineOptions extends SnowflakePipelineOptions {
+ @Description("Table name to connect to.")
+ String getTable();
+
+ void setTable(String table);
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/DataSourceConfigurationTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/DataSourceConfigurationTest.java
new file mode 100644
index 0000000..98f2213
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/DataSourceConfigurationTest.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import javax.sql.DataSource;
+import net.snowflake.client.jdbc.SnowflakeBasicDataSource;
+import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
+import org.apache.beam.sdk.io.snowflake.credentials.OAuthTokenSnowflakeCredentials;
+import org.junit.Before;
+import org.junit.Test;
+
+/** Unit tests for {@link org.apache.beam.sdk.io.snowflake.SnowflakeIO.DataSourceConfiguration}. */
+public class DataSourceConfigurationTest {
+
+ private SnowflakeIO.DataSourceConfiguration configuration;
+
+ @Before
+ public void setUp() {
+ configuration =
+ SnowflakeIO.DataSourceConfiguration.create(
+ new OAuthTokenSnowflakeCredentials("some-token"));
+ }
+
+ @Test
+ public void testSettingUrlWithBadPrefix() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> configuration.withUrl("account.snowflakecomputing.com"));
+ }
+
+ @Test
+ public void testSettingUrlWithBadSuffix() {
+ assertThrows(
+ IllegalArgumentException.class, () -> configuration.withUrl("jdbc:snowflake://account"));
+ }
+
+ @Test
+ public void testSettingStringUrl() {
+ String url = "jdbc:snowflake://account.snowflakecomputing.com";
+ configuration = configuration.withUrl(url);
+ assertEquals(url, configuration.getUrl());
+ }
+
+ @Test
+ public void testSettingServerNameWithBadSuffix() {
+ assertThrows(
+ IllegalArgumentException.class, () -> configuration.withServerName("not.properly.ended"));
+ }
+
+ @Test
+ public void testSettingStringServerName() {
+ String serverName = "account.snowflakecomputing.com";
+ configuration = configuration.withServerName(serverName);
+ assertEquals(serverName, configuration.getServerName());
+ }
+
+ @Test
+ public void testSettingStringDatabase() {
+ String database = "dbname";
+ configuration = configuration.withDatabase(database);
+ assertEquals(database, configuration.getDatabase());
+ }
+
+ @Test
+ public void testSettingStringWarehouse() {
+ String warehouse = "warehouse";
+ configuration = configuration.withWarehouse(warehouse);
+ assertEquals(warehouse, configuration.getWarehouse());
+ }
+
+ @Test
+ public void testSettingStringSchema() {
+ String schema = "schema";
+ configuration = configuration.withSchema(schema);
+ assertEquals(schema, configuration.getSchema());
+ }
+
+ @Test
+ public void testSettingStringRole() {
+ String role = "role";
+ configuration = configuration.withRole(role);
+ assertEquals(role, configuration.getRole());
+ }
+
+ @Test
+ public void testSettingStringPortNumber() {
+ Integer portNumber = 1234;
+ configuration = configuration.withPortNumber(portNumber);
+ assertEquals(portNumber, configuration.getPortNumber());
+ }
+
+ @Test
+ public void testSettingStringLoginTimeout() {
+ Integer loginTimeout = 999;
+ configuration = configuration.withLoginTimeout(loginTimeout);
+ assertEquals(loginTimeout, configuration.getLoginTimeout());
+ }
+
+ @Test
+ public void testSettingValidate() {
+ configuration = configuration.withoutValidation();
+ assertEquals(false, configuration.getValidate());
+ }
+
+ @Test
+ public void testDataSourceCreatedFromUrl() {
+ String url = "jdbc:snowflake://account.snowflakecomputing.com";
+ configuration = configuration.withUrl(url);
+
+ DataSource dataSource = configuration.buildDatasource();
+
+ assertEquals(SnowflakeBasicDataSource.class, dataSource.getClass());
+ assertEquals(url, ((SnowflakeBasicDataSource) dataSource).getUrl());
+ }
+
+ @Test
+ public void testDataSourceCreatedFromServerName() {
+ String serverName = "account.snowflakecomputing.com";
+ configuration = configuration.withServerName(serverName);
+
+ DataSource dataSource = configuration.buildDatasource();
+
+ String expectedUrl = "jdbc:snowflake://account.snowflakecomputing.com";
+ assertEquals(SnowflakeBasicDataSource.class, dataSource.getClass());
+ assertEquals(expectedUrl, ((SnowflakeBasicDataSource) dataSource).getUrl());
+ }
+
+ @Test
+ public void testDataSourceCreatedFromServerNameAndPort() {
+ String serverName = "account.snowflakecomputing.com";
+ int portNumber = 1234;
+
+ configuration = configuration.withServerName(serverName);
+ configuration = configuration.withPortNumber(portNumber);
+
+ DataSource dataSource = configuration.buildDatasource();
+ assertEquals(SnowflakeBasicDataSource.class, dataSource.getClass());
+ String expectedUrl = "jdbc:snowflake://account.snowflakecomputing.com:1234";
+ assertEquals(expectedUrl, ((SnowflakeBasicDataSource) dataSource).getUrl());
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/KeyPairSnowflakeCredentialsTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/KeyPairSnowflakeCredentialsTest.java
new file mode 100644
index 0000000..b231a2d
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/KeyPairSnowflakeCredentialsTest.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit.credentials;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+import org.apache.beam.sdk.io.snowflake.credentials.KeyPairSnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.test.TestUtils;
+import org.junit.Test;
+
+public class KeyPairSnowflakeCredentialsTest {
+ @Test
+ public void testFilePathConstructor() {
+ KeyPairSnowflakeCredentials credentials =
+ new KeyPairSnowflakeCredentials(
+ "username",
+ TestUtils.getPrivateKeyPath(getClass()),
+ TestUtils.getPrivateKeyPassphrase());
+ assertEquals("username", credentials.getUsername());
+ assertNotNull(credentials.getPrivateKey());
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/OAuthTokenSnowflakeCredentialsTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/OAuthTokenSnowflakeCredentialsTest.java
new file mode 100644
index 0000000..a1dee76
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/OAuthTokenSnowflakeCredentialsTest.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit.credentials;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
+import org.apache.beam.sdk.io.snowflake.credentials.OAuthTokenSnowflakeCredentials;
+import org.junit.Test;
+
+public class OAuthTokenSnowflakeCredentialsTest {
+
+ @Test
+ public void testConstructor() {
+ OAuthTokenSnowflakeCredentials credentials = new OAuthTokenSnowflakeCredentials("token");
+
+ assertEquals("token", credentials.getToken());
+ }
+
+ @Test
+ public void testBuildingDataSource() {
+ OAuthTokenSnowflakeCredentials credentials = new OAuthTokenSnowflakeCredentials("token");
+
+ SnowflakeIO.DataSourceConfiguration configuration =
+ SnowflakeIO.DataSourceConfiguration.create(credentials);
+
+ assertEquals(credentials.getToken(), configuration.getOauthToken());
+ assertTrue(configuration.getValidate());
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/SnowflakeCredentialsFactoryTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/SnowflakeCredentialsFactoryTest.java
new file mode 100644
index 0000000..f9f612d
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/SnowflakeCredentialsFactoryTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit.credentials;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import org.apache.beam.sdk.io.snowflake.SnowflakePipelineOptions;
+import org.apache.beam.sdk.io.snowflake.credentials.KeyPairSnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.credentials.OAuthTokenSnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.credentials.SnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.credentials.SnowflakeCredentialsFactory;
+import org.apache.beam.sdk.io.snowflake.credentials.UsernamePasswordSnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.test.TestUtils;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.junit.Test;
+
+public class SnowflakeCredentialsFactoryTest {
+
+ @Test
+ public void usernamePasswordTest() {
+ SnowflakePipelineOptions options = PipelineOptionsFactory.as(SnowflakePipelineOptions.class);
+ options.setUsername("username");
+ options.setPassword("password");
+
+ SnowflakeCredentials credentials = SnowflakeCredentialsFactory.of(options);
+
+ assertEquals(UsernamePasswordSnowflakeCredentials.class, credentials.getClass());
+ }
+
+ @Test
+ public void oauthTokenTest() {
+ SnowflakePipelineOptions options = PipelineOptionsFactory.as(SnowflakePipelineOptions.class);
+ options.setOauthToken("token");
+
+ SnowflakeCredentials credentials = SnowflakeCredentialsFactory.of(options);
+
+ assertEquals(OAuthTokenSnowflakeCredentials.class, credentials.getClass());
+ }
+
+ @Test
+ public void keyPairTest() {
+ SnowflakePipelineOptions options = PipelineOptionsFactory.as(SnowflakePipelineOptions.class);
+ System.out.println(TestUtils.getPrivateKeyPath(getClass()));
+ options.setUsername("username");
+ options.setPrivateKeyPath(TestUtils.getPrivateKeyPath(getClass()));
+ options.setPrivateKeyPassphrase(TestUtils.getPrivateKeyPassphrase());
+
+ SnowflakeCredentials credentials = SnowflakeCredentialsFactory.of(options);
+
+ assertEquals(KeyPairSnowflakeCredentials.class, credentials.getClass());
+ }
+
+ @Test
+ public void emptyOptionsTest() {
+ SnowflakePipelineOptions options = PipelineOptionsFactory.as(SnowflakePipelineOptions.class);
+
+ Exception ex =
+ assertThrows(RuntimeException.class, () -> SnowflakeCredentialsFactory.of(options));
+ assertEquals("Can't get credentials from Options", ex.getMessage());
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/UsernamePasswordSnowflakeCredentialsTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/UsernamePasswordSnowflakeCredentialsTest.java
new file mode 100644
index 0000000..0c7503a
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/credentials/UsernamePasswordSnowflakeCredentialsTest.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit.credentials;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
+import org.apache.beam.sdk.io.snowflake.credentials.UsernamePasswordSnowflakeCredentials;
+import org.junit.Test;
+
+public class UsernamePasswordSnowflakeCredentialsTest {
+
+ @Test
+ public void testConstructor() {
+ UsernamePasswordSnowflakeCredentials credentials =
+ new UsernamePasswordSnowflakeCredentials("username", "password");
+
+ assertEquals("username", credentials.getUsername());
+ assertEquals("password", credentials.getPassword());
+ }
+
+ @Test
+ public void testBuildingDataSource() {
+ UsernamePasswordSnowflakeCredentials credentials =
+ new UsernamePasswordSnowflakeCredentials("username", "password");
+
+ SnowflakeIO.DataSourceConfiguration configuration =
+ SnowflakeIO.DataSourceConfiguration.create(credentials);
+
+ assertEquals("username", configuration.getUsername());
+ assertEquals("password", configuration.getPassword());
+ assertTrue(configuration.getValidate());
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/read/SnowflakeIOReadTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/read/SnowflakeIOReadTest.java
new file mode 100644
index 0000000..e4eda0d
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/read/SnowflakeIOReadTest.java
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit.read;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.avro.generic.GenericRecordBuilder;
+import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
+import org.apache.beam.sdk.coders.AvroCoder;
+import org.apache.beam.sdk.io.AvroGeneratedUser;
+import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
+import org.apache.beam.sdk.io.snowflake.SnowflakeService;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeBasicDataSource;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeDatabase;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeServiceImpl;
+import org.apache.beam.sdk.io.snowflake.test.unit.BatchTestPipelineOptions;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class SnowflakeIOReadTest implements Serializable {
+ public static final String FAKE_TABLE = "FAKE_TABLE";
+ public static final String FAKE_QUERY = "SELECT * FROM FAKE_TABLE";
+
+ private static final BatchTestPipelineOptions options =
+ TestPipeline.testingPipelineOptions().as(BatchTestPipelineOptions.class);;
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+
+ @Rule public transient ExpectedException thrown = ExpectedException.none();
+
+ private static SnowflakeIO.DataSourceConfiguration dataSourceConfiguration;
+ private static SnowflakeService snowflakeService;
+ private static String stagingBucketName;
+ private static String integrationName;
+ private static List<GenericRecord> avroTestData;
+
+ @BeforeClass
+ public static void setup() {
+
+ List<String> testData = Arrays.asList("Paul,51,red", "Jackson,41,green");
+
+ avroTestData =
+ ImmutableList.of(
+ new AvroGeneratedUser("Paul", 51, "red"),
+ new AvroGeneratedUser("Jackson", 41, "green"));
+
+ FakeSnowflakeDatabase.createTableWithElements(FAKE_TABLE, testData);
+
+ options.setServerName("NULL.snowflakecomputing.com");
+ options.setStorageIntegration("STORAGE_INTEGRATION");
+ options.setStagingBucketName("BUCKET");
+
+ stagingBucketName = options.getStagingBucketName();
+ integrationName = options.getStorageIntegration();
+
+ dataSourceConfiguration =
+ SnowflakeIO.DataSourceConfiguration.create(new FakeSnowflakeBasicDataSource())
+ .withServerName(options.getServerName());
+
+ snowflakeService = new FakeSnowflakeServiceImpl();
+ }
+
+ @Test
+ public void testConfigIsMissingStagingBucketName() {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("withStagingBucketName() is required");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromTable(FAKE_TABLE)
+ .withIntegrationName(integrationName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testConfigIsMissingIntegrationName() {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("withIntegrationName() is required");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromTable(FAKE_TABLE)
+ .withStagingBucketName(stagingBucketName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testConfigIsMissingCsvMapper() {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("withCsvMapper() is required");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromTable(FAKE_TABLE)
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testConfigIsMissingCoder() {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("withCoder() is required");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromTable(FAKE_TABLE)
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper()));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testConfigIsMissingFromTableOrFromQuery() {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("fromTable() or fromQuery() is required");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testConfigIsMissingDataSourceConfiguration() {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("withDataSourceConfiguration() or withDataSourceProviderFn() is required");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .fromTable(FAKE_TABLE)
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testConfigContainsFromQueryAndFromTable() {
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("fromTable() and fromQuery() are not allowed together");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromQuery("")
+ .fromTable(FAKE_TABLE)
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testTableDoesntExist() {
+ thrown.expect(PipelineExecutionException.class);
+ thrown.expectMessage("SQL compilation error: Table does not exist");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromTable("NON_EXIST")
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testInvalidQuery() {
+ thrown.expect(PipelineExecutionException.class);
+ thrown.expectMessage("SQL compilation error: Invalid query");
+
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromQuery("BAD_QUERY")
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testReadFromTable() {
+ PCollection<GenericRecord> items =
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromTable(FAKE_TABLE)
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ PAssert.that(items).containsInAnyOrder(avroTestData);
+ pipeline.run();
+ }
+
+ @Test
+ public void testReadFromQuery() {
+ PCollection<GenericRecord> items =
+ pipeline.apply(
+ SnowflakeIO.<GenericRecord>read(snowflakeService)
+ .withDataSourceConfiguration(dataSourceConfiguration)
+ .fromQuery(FAKE_QUERY)
+ .withStagingBucketName(stagingBucketName)
+ .withIntegrationName(integrationName)
+ .withCsvMapper(getCsvMapper())
+ .withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
+
+ PAssert.that(items).containsInAnyOrder(avroTestData);
+ pipeline.run();
+ }
+
+ static SnowflakeIO.CsvMapper<GenericRecord> getCsvMapper() {
+ return (SnowflakeIO.CsvMapper<GenericRecord>)
+ parts ->
+ new GenericRecordBuilder(AvroGeneratedUser.getClassSchema())
+ .set("name", String.valueOf(parts[0]))
+ .set("favorite_number", Integer.valueOf(parts[1]))
+ .set("favorite_color", String.valueOf(parts[2]))
+ .build();
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/resources/test_rsa_key.p8 b/sdks/java/io/snowflake/src/test/resources/test_rsa_key.p8
new file mode 100644
index 0000000..ee86a8d
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/resources/test_rsa_key.p8
@@ -0,0 +1,29 @@
+-----BEGIN ENCRYPTED PRIVATE KEY-----
+MIIE6TAbBgkqhkiG9w0BBQMwDgQI2sbwjlr6RrcCAggABIIEyPbY/oiy8EH1QY9t
+EVlG2GyZK07bYsVP57PfFWVTJi6gN7G8zxF8vQvPN2fm8w1GJTzyz+ZAdxAXEv6/
+Oe0KmvYbe+YeO0+u+aZah8rnhFgAo1OgMDA8hCKLL98qrGau03TRGZKD6Xxce3nZ
+DizFbCxcRejnJflWUFWgyro45Qnb9jtOop+rnmqHeV/CMP/RLattYsZ0SrcCnr1O
+fIfb7jqNjY94V7xh9O8G1g5YHL3fv8ir3iXLpy7wPjMVHSMGcK092tnAK9/okmiA
+EdzNH/DfwUQ2qZ4gWJJp75b5RwVTZ9uNDQZkQJQLAS8zowtjXhKJ+Zy26ID/EI7u
+H5R5WhUl2ROI2ssjX8I0wFyOMA5guBjw2xL/Y+/eI3dQ+2g6hxN2TpPoKEAzeM1a
+OyrTyyVZ7nJyJ7RC6odRnHE9PLFzpmGv/3YhbSvhAwAnLN2mky+LqSw8hTGVMaIv
+QbnkeqxGGd6miIZgPOrbBj1mRErLdgKkJn9UNYPXKoB+0gEyLGYSkY1TdJ6NjRq8
+5oU+MgU1dnyp73JMiafz4AbFLHrxXG3AG6vww9WBgiV9pFmQSMcuVX8p8rvxJ8H4
+nFQlwiZrvl98xcfGICkoiSKP2fEDir1sdbjpcEY1Rxk3baUqBxZCrRKo+Dz4k4yy
+VjZX8SkXrYgcQNxmDiv59D34QYtKyq8ZxeHB7tVrj8/G1N8WgU8EApR0+yQL+0TP
+8aicPc++9ta62Wv59iqmsQeR+Bdq1/kOTZCMA9QMK7/mJpcn18/5EmbrAUbw/2pf
+TdyLqEXUf3N0zrDUAKUuWSOJMHOEhdcvQwVfyj3zy8O2+aM2PYos3c5pKQ05X/RJ
+6bPl6taxcEsHQByNC0+7JJ1yPxYlqW28uDao4XNkrwhSkhzv0DylAGZNIrwAFkKq
+dXRdCBMijvBB7jKvfPKK0aOlVJ7fdRo6PPoAJaDhmfsd5lbIGFcpwu4Rf2AiJQq3
+7ZzfGSUH/uvTsXW/e+QOQkr8cI4apRjGAuTImoIebZgkU8U9NOjCIwtBZO+KgiRZ
+4cjC4mXgxMAdmmIMXUqy0QGiqgK2IGhpDneC9y2Al1WT/7Sz5au6tvMEIf5yrYuW
+62/LyErTfNiwFGa4gYQ0nQ21ifwrA5bFFCyf4K0NyFRZaKJd/gpnszZtuw9NKxLn
+5Lz10bI/aFPcOyKxgX0DqYHhiJFw0v4uTUIQK7RkDCQ+xAiJJ7c53/piCaqc6+IY
+BRHvKTRK2jIxSlL8In+MxL+hDPXBm1c/NLIAqMogh7u0qYRg3U0V9leH5vZqatHI
+/SDPyMAbcrjuIcg04fjaH2KW/REdPL87heoqH0tH5x0PnQqAUGxmuUm/7FEcoB+Q
+oQOD/KIkZ5Abenmw/VJW11tt9A5dV6d3y+OBaTN2U1ZzT8PaWchjUimsIY3CVxTM
+h5IxI1VTqMy9o/5mkA5ishzdUxh0hReO9NzUx4zgKFuWHAnUqYEkGC1okdSm+DqM
+s3jtYwZcbXhV8USCZJWEyfV7T5/1iXR2/U432e7HN6Wv1uC/GQWafelkKosr2ulG
+Y9Heehs56te3osz62G8Y27gCdZGi+GnysgegiLg7E2Qaep3UGk+Q3h8E+YAyQ0eK
+H8gI6sKRjdIAGuhs7w==
+-----END ENCRYPTED PRIVATE KEY-----
diff --git a/settings.gradle b/settings.gradle
index 0c64230..8a80ffd 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -119,6 +119,7 @@ include ":sdks:java:io:parquet"
include ":sdks:java:io:rabbitmq"
include ":sdks:java:io:redis"
include ":sdks:java:io:solr"
+include ":sdks:java:io:snowflake"
include ":sdks:java:io:thrift"
include ":sdks:java:io:tika"
include ":sdks:java:io:xml"