You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by kk...@apache.org on 2020/06/24 08:34:45 UTC
[beam] branch master updated: Merge pull request #11794:
[BEAM-9894] Add batch SnowflakeIO.Write to Java SDK
This is an automated email from the ASF dual-hosted git repository.
kkucharczyk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 55ff8b3 Merge pull request #11794: [BEAM-9894] Add batch SnowflakeIO.Write to Java SDK
55ff8b3 is described below
commit 55ff8b3939e2751056a40bd43fd9cf9114491664
Author: purbanow <37...@users.noreply.github.com>
AuthorDate: Wed Jun 24 10:34:16 2020 +0200
Merge pull request #11794: [BEAM-9894] Add batch SnowflakeIO.Write to Java SDK
* [BEAM-9894] Add batch SnowflakeIO.Write to Java SDK
* fix: add missing license headers
* refactor: method names
* fix: make Location.storageIntegrationName as a nullable variable
* fix: remove MapUserDataObjectsArrayFn class
* fix(SnowFlakeIO): removed Parse class + add @Experimental annotation
* refactor(SnowFlakeIO): removed Location class
* fix(SnowFlakeIO): added missing types
* refactor(SnowFlakeIO): removed CSVSink class
* [BEAM-9894] Snowflake write - change Combine step into Reify
Co-authored-by: Kasia Kucharczyk <ka...@polidea.com>
---
CHANGES.md | 1 +
.../apache/beam/sdk/io/snowflake/SnowflakeIO.java | 493 +++++++++++++++++++--
.../sdk/io/snowflake/SnowflakePipelineOptions.java | 18 +-
.../sdk/io/snowflake/SnowflakeServiceImpl.java | 90 ----
.../io/snowflake/{ => enums}/CloudProvider.java | 2 +-
.../WriteDisposition.java} | 19 +-
.../package-info.java} | 16 +-
.../ServiceConfig.java} | 16 +-
.../snowflake/{ => services}/SnowflakeService.java | 19 +-
.../snowflake/services/SnowflakeServiceConfig.java | 93 ++++
.../snowflake/services/SnowflakeServiceImpl.java | 208 +++++++++
.../package-info.java} | 16 +-
.../io/snowflake/test/FakeSnowflakeDatabase.java | 6 +-
.../snowflake/test/FakeSnowflakeServiceImpl.java | 73 ++-
.../beam/sdk/io/snowflake/test/TestUtils.java | 65 +++
...pelineOptions.java => TestPipelineOptions.java} | 2 +-
.../test/unit/read/SnowflakeIOReadTest.java | 62 ++-
.../unit/write/QueryDispositionLocationTest.java | 160 +++++++
.../test/unit/write/SnowflakeIOWriteTest.java | 172 +++++++
19 files changed, 1262 insertions(+), 269 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 4d616b2..a00d8d2 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -91,6 +91,7 @@
* Basic Kafka read/write support for DataflowRunner (Python) ([BEAM-8019](https://issues.apache.org/jira/browse/BEAM-8019)).
* Sources and sinks for Google Healthcare APIs (Java)([BEAM-9468](https://issues.apache.org/jira/browse/BEAM-9468)).
+* Support for writing to Snowflake added (Java) ([BEAM-9894](https://issues.apache.org/jira/browse/BEAM-9894)).
## New Features / Improvements
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
index a67ba32..4895e4e 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
@@ -29,6 +29,7 @@ import java.security.PrivateKey;
import java.sql.Connection;
import java.sql.SQLException;
import java.text.SimpleDateFormat;
+import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.UUID;
@@ -37,26 +38,45 @@ import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.sql.DataSource;
import net.snowflake.client.jdbc.SnowflakeBasicDataSource;
+import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.io.WriteFilesResult;
import org.apache.beam.sdk.io.fs.MoveOptions;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.snowflake.credentials.KeyPairSnowflakeCredentials;
import org.apache.beam.sdk.io.snowflake.credentials.OAuthTokenSnowflakeCredentials;
import org.apache.beam.sdk.io.snowflake.credentials.SnowflakeCredentials;
import org.apache.beam.sdk.io.snowflake.credentials.UsernamePasswordSnowflakeCredentials;
+import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeServiceConfig;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeServiceImpl;
+import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Reify;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.Wait;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -112,8 +132,8 @@ import org.slf4j.LoggerFactory;
* SnowflakeIO.<GenericRecord>read()
* .withDataSourceConfiguration(dataSourceConfiguration)
* .fromQuery(QUERY)
- * .withStagingBucketName(stagingBucketName)
- * .withIntegrationName(integrationName)
+ * .withStagingBucketName(...)
+ * .withStorageIntegrationName(...)
* .withCsvMapper(...)
* .withCoder(...));
* }</pre>
@@ -122,11 +142,35 @@ import org.slf4j.LoggerFactory;
* specified stagingBucketName in directory named `sf_copy_csv_[RANDOM CHARS]_[TIMESTAMP]`. This
* directory and all the files are cleaned up automatically by default, but in case of failed
* pipeline they may remain and will have to be cleaned up manually.
+ *
+ * <h3>Writing to Snowflake</h3>
+ *
+ * <p>SnowflakeIO.Write supports writing records into a database. It writes a {@link PCollection<T>}
+ * to the database by converting each T into a {@link Object[]} via a user-provided {@link
+ * UserDataMapper}.
+ *
+ * <p>For example
+ *
+ * <pre>{@code
+ * items.apply(
+ * SnowflakeIO.<KV<Integer, String>>write()
+ * .withDataSourceConfiguration(dataSourceConfiguration)
+ * .withTable(table)
+ * .withStagingBucketName(...)
+ * .withStorageIntegrationName(...)
+ * .withUserDataMapper(maper);
+ * }</pre>
+ *
+ * <p><b>Important</b> When writing data to Snowflake, firstly data will be saved as CSV files on
+ * specified stagingBucketName in directory named 'data' and then into Snowflake.
*/
+@Experimental
public class SnowflakeIO {
private static final Logger LOG = LoggerFactory.getLogger(SnowflakeIO.class);
private static final String CSV_QUOTE_CHAR = "'";
+ private static final String WRITE_TMP_PATH = "data";
+
/**
* Read data from Snowflake.
*
@@ -159,6 +203,29 @@ public class SnowflakeIO {
T mapRow(String[] parts) throws Exception;
}
+ /**
+ * Interface for user-defined function mapping T into array of Objects. Used for
+ * SnowflakeIO.Write.
+ *
+ * @param <T> Type of data to be written.
+ */
+ @FunctionalInterface
+ public interface UserDataMapper<T> extends Serializable {
+ Object[] mapRow(T element);
+ }
+
+ /**
+ * Write data to Snowflake via COPY statement.
+ *
+ * @param <T> Type of data to be written.
+ */
+ public static <T> Write<T> write() {
+ return new AutoValue_SnowflakeIO_Write.Builder<T>()
+ .setFileNameTemplate("output*")
+ .setWriteDisposition(WriteDisposition.APPEND)
+ .build();
+ }
+
/** Implementation of {@link #read()}. */
@AutoValue
public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>> {
@@ -172,7 +239,7 @@ public class SnowflakeIO {
abstract String getTable();
@Nullable
- abstract String getIntegrationName();
+ abstract String getStorageIntegrationName();
@Nullable
abstract String getStagingBucketName();
@@ -197,7 +264,7 @@ public class SnowflakeIO {
abstract Builder<T> setTable(String table);
- abstract Builder<T> setIntegrationName(String integrationName);
+ abstract Builder<T> setStorageIntegrationName(String storageIntegrationName);
abstract Builder<T> setStagingBucketName(String stagingBucketName);
@@ -216,15 +283,6 @@ public class SnowflakeIO {
* @param config - An instance of {@link DataSourceConfiguration}.
*/
public Read<T> withDataSourceConfiguration(final DataSourceConfiguration config) {
- if (config.getValidate()) {
- try {
- Connection connection = config.buildDatasource().getConnection();
- connection.close();
- } catch (SQLException e) {
- throw new IllegalArgumentException(
- "Invalid DataSourceConfiguration. Underlying cause: " + e);
- }
- }
return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config));
}
@@ -272,8 +330,8 @@ public class SnowflakeIO {
*
* @param integrationName - String with the name of the Storage Integration.
*/
- public Read<T> withIntegrationName(String integrationName) {
- return toBuilder().setIntegrationName(integrationName).build();
+ public Read<T> withStorageIntegrationName(String integrationName) {
+ return toBuilder().setStorageIntegrationName(integrationName).build();
}
/**
@@ -296,23 +354,10 @@ public class SnowflakeIO {
@Override
public PCollection<T> expand(PBegin input) {
- // Either table or query is required. If query is present, it's being used, table is used
- // otherwise
- checkArgument(
- getQuery() != null || getTable() != null, "fromTable() or fromQuery() is required");
- checkArgument(
- !(getQuery() != null && getTable() != null),
- "fromTable() and fromQuery() are not allowed together");
- checkArgument(getCsvMapper() != null, "withCsvMapper() is required");
- checkArgument(getCoder() != null, "withCoder() is required");
- checkArgument(getIntegrationName() != null, "withIntegrationName() is required");
- checkArgument(getStagingBucketName() != null, "withStagingBucketName() is required");
- checkArgument(
- (getDataSourceProviderFn() != null),
- "withDataSourceConfiguration() or withDataSourceProviderFn() is required");
+ checkArguments();
String tmpDirName = makeTmpDirName();
- String stagingBucketDir = String.format("%s/%s", getStagingBucketName(), tmpDirName);
+ String stagingBucketDir = String.format("%s/%s/", getStagingBucketName(), tmpDirName);
PCollection<Void> emptyCollection = input.apply(Create.of((Void) null));
@@ -324,7 +369,7 @@ public class SnowflakeIO {
getDataSourceProviderFn(),
getQuery(),
getTable(),
- getIntegrationName(),
+ getStorageIntegrationName(),
stagingBucketDir,
getSnowflakeService())))
.apply(Reshuffle.viaRandomKey())
@@ -339,10 +384,28 @@ public class SnowflakeIO {
emptyCollection
.apply(Wait.on(output))
.apply(ParDo.of(new CleanTmpFilesFromGcsFn(stagingBucketDir)));
-
return output;
}
+ private void checkArguments() {
+ // Either table or query is required. If query is present, it's being used, table is used
+ // otherwise
+
+ checkArgument(getStorageIntegrationName() != null, "withStorageIntegrationName is required");
+ checkArgument(getStagingBucketName() != null, "withStagingBucketName is required");
+
+ checkArgument(
+ getQuery() != null || getTable() != null, "fromTable() or fromQuery() is required");
+ checkArgument(
+ !(getQuery() != null && getTable() != null),
+ "fromTable() and fromQuery() are not allowed together");
+ checkArgument(getCsvMapper() != null, "withCsvMapper() is required");
+ checkArgument(getCoder() != null, "withCoder() is required");
+ checkArgument(
+ (getDataSourceProviderFn() != null),
+ "withDataSourceConfiguration() or withDataSourceProviderFn() is required");
+ }
+
private String makeTmpDirName() {
return String.format(
"sf_copy_csv_%s_%s",
@@ -355,7 +418,7 @@ public class SnowflakeIO {
private final SerializableFunction<Void, DataSource> dataSourceProviderFn;
private final String query;
private final String table;
- private final String integrationName;
+ private final String storageIntegrationName;
private final String stagingBucketDir;
private final SnowflakeService snowflakeService;
@@ -363,13 +426,13 @@ public class SnowflakeIO {
SerializableFunction<Void, DataSource> dataSourceProviderFn,
String query,
String table,
- String integrationName,
+ String storageIntegrationName,
String stagingBucketDir,
SnowflakeService snowflakeService) {
this.dataSourceProviderFn = dataSourceProviderFn;
this.query = query;
this.table = table;
- this.integrationName = integrationName;
+ this.storageIntegrationName = storageIntegrationName;
this.stagingBucketDir =
String.format(
"%s/run_%s/", stagingBucketDir, UUID.randomUUID().toString().subSequence(0, 8));
@@ -378,9 +441,12 @@ public class SnowflakeIO {
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
- String output =
- snowflakeService.copyIntoStage(
- dataSourceProviderFn, query, table, integrationName, stagingBucketDir);
+
+ SnowflakeServiceConfig config =
+ new SnowflakeServiceConfig(
+ dataSourceProviderFn, table, query, storageIntegrationName, stagingBucketDir);
+
+ String output = snowflakeService.read(config);
context.output(output);
}
@@ -437,7 +503,7 @@ public class SnowflakeIO {
if (getTable() != null) {
builder.add(DisplayData.item("table", getTable()));
}
- builder.add(DisplayData.item("integrationName", getIntegrationName()));
+ builder.add(DisplayData.item("storageIntegrationName", getStagingBucketName()));
builder.add(DisplayData.item("stagingBucketName", getStagingBucketName()));
builder.add(DisplayData.item("csvMapper", getCsvMapper().getClass().getName()));
builder.add(DisplayData.item("coder", getCoder().getClass().getName()));
@@ -447,6 +513,346 @@ public class SnowflakeIO {
}
}
+ /** Implementation of {@link #write()}. */
+ @AutoValue
+ public abstract static class Write<T> extends PTransform<PCollection<T>, PDone> {
+ @Nullable
+ abstract SerializableFunction<Void, DataSource> getDataSourceProviderFn();
+
+ @Nullable
+ abstract String getTable();
+
+ @Nullable
+ abstract String getStorageIntegrationName();
+
+ @Nullable
+ abstract String getStagingBucketName();
+
+ @Nullable
+ abstract String getQuery();
+
+ @Nullable
+ abstract String getFileNameTemplate();
+
+ @Nullable
+ abstract WriteDisposition getWriteDisposition();
+
+ @Nullable
+ abstract UserDataMapper getUserDataMapper();
+
+ @Nullable
+ abstract SnowflakeService getSnowflakeService();
+
+ abstract Builder<T> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+ abstract Builder<T> setDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn);
+
+ abstract Builder<T> setTable(String table);
+
+ abstract Builder<T> setStorageIntegrationName(String storageIntegrationName);
+
+ abstract Builder<T> setStagingBucketName(String stagingBucketName);
+
+ abstract Builder<T> setQuery(String query);
+
+ abstract Builder<T> setFileNameTemplate(String fileNameTemplate);
+
+ abstract Builder<T> setUserDataMapper(UserDataMapper userDataMapper);
+
+ abstract Builder<T> setWriteDisposition(WriteDisposition writeDisposition);
+
+ abstract Builder<T> setSnowflakeService(SnowflakeService snowflakeService);
+
+ abstract Write<T> build();
+ }
+
+ /**
+ * Setting information about Snowflake server.
+ *
+ * @param config - An instance of {@link DataSourceConfiguration}.
+ */
+ public Write<T> withDataSourceConfiguration(final DataSourceConfiguration config) {
+ return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config));
+ }
+
+ /**
+ * Setting function that will provide {@link DataSourceConfiguration} in runtime.
+ *
+ * @param dataSourceProviderFn a {@link SerializableFunction}.
+ */
+ public Write<T> withDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn) {
+ return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build();
+ }
+
+ /**
+ * A table name to be written in Snowflake.
+ *
+ * @param table - String with the name of the table.
+ */
+ public Write<T> withTable(String table) {
+ return toBuilder().setTable(table).build();
+ }
+
+ /**
+ * Name of the cloud bucket (GCS by now) to use as tmp location of CSVs during COPY statement.
+ *
+ * @param stagingBucketName - String with the name of the bucket.
+ */
+ public Write<T> withStagingBucketName(String stagingBucketName) {
+ return toBuilder().setStagingBucketName(stagingBucketName).build();
+ }
+
+ /**
+ * Name of the Storage Integration in Snowflake to be used. See
+ * https://docs.snowflake.com/en/sql-reference/sql/create-storage-integration.html for
+ * reference.
+ *
+ * @param integrationName - String with the name of the Storage Integration.
+ */
+ public Write<T> withStorageIntegrationName(String integrationName) {
+ return toBuilder().setStorageIntegrationName(integrationName).build();
+ }
+
+ /**
+ * A query to be executed in Snowflake.
+ *
+ * @param query - String with query.
+ */
+ public Write<T> withQueryTransformation(String query) {
+ return toBuilder().setQuery(query).build();
+ }
+
+ /**
+ * A template name for files saved to GCP.
+ *
+ * @param fileNameTemplate - String with template name for files.
+ */
+ public Write<T> withFileNameTemplate(String fileNameTemplate) {
+ return toBuilder().setFileNameTemplate(fileNameTemplate).build();
+ }
+
+ /**
+ * User-defined function mapping user data into CSV lines.
+ *
+ * @param userDataMapper - an instance of {@link UserDataMapper}.
+ */
+ public Write<T> withUserDataMapper(UserDataMapper userDataMapper) {
+ return toBuilder().setUserDataMapper(userDataMapper).build();
+ }
+
+ /**
+ * A disposition to be used during writing to table phase.
+ *
+ * @param writeDisposition - an instance of {@link WriteDisposition}.
+ */
+ public Write<T> withWriteDisposition(WriteDisposition writeDisposition) {
+ return toBuilder().setWriteDisposition(writeDisposition).build();
+ }
+
+ /**
+ * A snowflake service which is supposed to be used. Note: Currently we have {@link
+ * SnowflakeServiceImpl} with corresponding {@link FakeSnowflakeServiceImpl} used for testing.
+ *
+ * @param snowflakeService - an instance of {@link SnowflakeService}.
+ */
+ public Write<T> withSnowflakeService(SnowflakeService snowflakeService) {
+ return toBuilder().setSnowflakeService(snowflakeService).build();
+ }
+
+ @Override
+ public PDone expand(PCollection<T> input) {
+ checkArguments();
+
+ String stagingBucketDir = String.format("%s/%s/", getStagingBucketName(), WRITE_TMP_PATH);
+
+ PCollection<String> out = write(input, stagingBucketDir);
+ out.setCoder(StringUtf8Coder.of());
+
+ return PDone.in(out.getPipeline());
+ }
+
+ private void checkArguments() {
+ checkArgument(getStagingBucketName() != null, "withStagingBucketName is required");
+
+ checkArgument(getUserDataMapper() != null, "withUserDataMapper() is required");
+
+ checkArgument(
+ (getDataSourceProviderFn() != null),
+ "withDataSourceConfiguration() or withDataSourceProviderFn() is required");
+
+ checkArgument(getTable() != null, "withTable() is required");
+ }
+
+ private PCollection<String> write(PCollection<T> input, String stagingBucketDir) {
+ SnowflakeService snowflakeService =
+ getSnowflakeService() != null ? getSnowflakeService() : new SnowflakeServiceImpl();
+
+ PCollection<String> files = writeFiles(input, stagingBucketDir);
+
+ // Combining PCollection of files as a side input into one list of files
+ ListCoder<String> coder = ListCoder.of(StringUtf8Coder.of());
+ files =
+ (PCollection)
+ files
+ .getPipeline()
+ .apply(
+ Reify.viewInGlobalWindow(
+ (PCollectionView) files.apply(View.asList()), coder));
+
+ return (PCollection)
+ files.apply("Copy files to table", copyToTable(snowflakeService, stagingBucketDir));
+ }
+
+ private PCollection<String> writeFiles(PCollection<T> input, String stagingBucketDir) {
+
+ PCollection<String> mappedUserData =
+ input
+ .apply(
+ MapElements.via(
+ new SimpleFunction<T, Object[]>() {
+ @Override
+ public Object[] apply(T element) {
+ return getUserDataMapper().mapRow(element);
+ }
+ }))
+ .apply("Map Objects array to CSV lines", ParDo.of(new MapObjectsArrayToCsvFn()))
+ .setCoder(StringUtf8Coder.of());
+
+ WriteFilesResult filesResult =
+ mappedUserData.apply(
+ "Write files to specified location",
+ FileIO.<String>write()
+ .via(TextIO.sink())
+ .to(stagingBucketDir)
+ .withPrefix(getFileNameTemplate())
+ .withSuffix(".csv")
+ .withCompression(Compression.GZIP));
+
+ return (PCollection)
+ filesResult
+ .getPerDestinationOutputFilenames()
+ .apply("Parse KV filenames to Strings", Values.<String>create());
+ }
+
+ private ParDo.SingleOutput<Object, Object> copyToTable(
+ SnowflakeService snowflakeService, String stagingBucketDir) {
+ return ParDo.of(
+ new CopyToTableFn<>(
+ getDataSourceProviderFn(),
+ getTable(),
+ getQuery(),
+ stagingBucketDir,
+ getStorageIntegrationName(),
+ getWriteDisposition(),
+ snowflakeService));
+ }
+ }
+
+ public static class Concatenate extends Combine.CombineFn<String, List<String>, List<String>> {
+ @Override
+ public List<String> createAccumulator() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public List<String> addInput(List<String> mutableAccumulator, String input) {
+ mutableAccumulator.add(String.format("'%s'", input));
+ return mutableAccumulator;
+ }
+
+ @Override
+ public List<String> mergeAccumulators(Iterable<List<String>> accumulators) {
+ List<String> result = createAccumulator();
+ for (List<String> accumulator : accumulators) {
+ result.addAll(accumulator);
+ }
+ return result;
+ }
+
+ @Override
+ public List<String> extractOutput(List<String> accumulator) {
+ return accumulator;
+ }
+ }
+
+ /**
+ * Custom DoFn that maps {@link Object[]} into CSV line to be saved to Snowflake.
+ *
+ * <p>Adds Snowflake-specific quotations around strings.
+ */
+ private static class MapObjectsArrayToCsvFn extends DoFn<Object[], String> {
+
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ List<Object> csvItems = new ArrayList<>();
+ for (Object o : context.element()) {
+ if (o instanceof String) {
+ String field = (String) o;
+ field = field.replace("'", "''");
+ field = quoteField(field);
+
+ csvItems.add(field);
+ } else {
+ csvItems.add(o);
+ }
+ }
+ context.output(Joiner.on(",").useForNull("").join(csvItems));
+ }
+
+ private String quoteField(String field) {
+ return quoteField(field, CSV_QUOTE_CHAR);
+ }
+
+ private String quoteField(String field, String quotation) {
+ return String.format("%s%s%s", quotation, field, quotation);
+ }
+ }
+
+ private static class CopyToTableFn<ParameterT, OutputT> extends DoFn<ParameterT, OutputT> {
+ private final SerializableFunction<Void, DataSource> dataSourceProviderFn;
+ private final String table;
+ private final String query;
+ private final String stagingBucketDir;
+ private final String storageIntegrationName;
+ private final WriteDisposition writeDisposition;
+ private final SnowflakeService snowflakeService;
+
+ CopyToTableFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn,
+ String table,
+ String query,
+ String stagingBucketDir,
+ String storageIntegrationName,
+ WriteDisposition writeDisposition,
+ SnowflakeService snowflakeService) {
+ this.dataSourceProviderFn = dataSourceProviderFn;
+ this.table = table;
+ this.query = query;
+ this.stagingBucketDir = stagingBucketDir;
+ this.storageIntegrationName = storageIntegrationName;
+ this.writeDisposition = writeDisposition;
+ this.snowflakeService = snowflakeService;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context) throws Exception {
+ SnowflakeServiceConfig config =
+ new SnowflakeServiceConfig(
+ dataSourceProviderFn,
+ (List<String>) context.element(),
+ table,
+ query,
+ writeDisposition,
+ storageIntegrationName,
+ stagingBucketDir);
+ snowflakeService.write(config);
+ }
+ }
+
/**
* A POJO describing a {@link DataSource}, providing all properties allowing to create a {@link
* DataSource}.
@@ -739,6 +1145,15 @@ public class SnowflakeIO {
private final DataSourceConfiguration config;
private DataSourceProviderFromDataSourceConfiguration(DataSourceConfiguration config) {
+ if (config.getValidate()) {
+ try {
+ Connection connection = config.buildDatasource().getConnection();
+ connection.close();
+ } catch (SQLException e) {
+ throw new IllegalArgumentException(
+ "Invalid DataSourceConfiguration. Underlying cause: " + e);
+ }
+ }
this.config = config;
}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakePipelineOptions.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakePipelineOptions.java
index 783230e..bf91e0c 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakePipelineOptions.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakePipelineOptions.java
@@ -111,23 +111,13 @@ public interface SnowflakePipelineOptions extends PipelineOptions {
void setLoginTimeout(String loginTimeout);
- @Description("External location name to connect to.")
- String getExternalLocation();
-
- void setExternalLocation(String externalLocation);
-
- @Description("Temporary GCS bucket name")
+ @Description("Temporary GCS bucket name.")
String getStagingBucketName();
void setStagingBucketName(String stagingBucketName);
- @Description("Storage integration - required in case the external stage is not specified.")
- String getStorageIntegration();
-
- void setStorageIntegration(String integration);
-
- @Description("Stage name. Optional.")
- String getStage();
+ @Description("Storage integration name")
+ String getStorageIntegrationName();
- void setStage(String stage);
+ void setStorageIntegrationName(String storageIntegrationName);
}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeServiceImpl.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeServiceImpl.java
deleted file mode 100644
index 5aaad06..0000000
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeServiceImpl.java
+++ /dev/null
@@ -1,90 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.snowflake;
-
-import java.sql.Connection;
-import java.sql.PreparedStatement;
-import java.sql.ResultSet;
-import java.sql.SQLException;
-import java.util.function.Consumer;
-import javax.sql.DataSource;
-import org.apache.beam.sdk.transforms.SerializableFunction;
-
-/**
- * Implemenation of {@link org.apache.beam.sdk.io.snowflake.SnowflakeService} used in production.
- */
-public class SnowflakeServiceImpl implements SnowflakeService {
- private static final String SNOWFLAKE_GCS_PREFIX = "gcs://";
-
- @Override
- public String copyIntoStage(
- SerializableFunction<Void, DataSource> dataSourceProviderFn,
- String query,
- String table,
- String integrationName,
- String stagingBucketDir)
- throws SQLException {
-
- String from;
- if (query != null) {
- // Query must be surrounded with brackets
- from = String.format("(%s)", query);
- } else {
- from = table;
- }
-
- String copyQuery =
- String.format(
- "COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');",
- getProperBucketDir(stagingBucketDir), from, integrationName, CSV_QUOTE_CHAR_FOR_COPY);
-
- runStatement(copyQuery, getConnection(dataSourceProviderFn), null);
-
- return stagingBucketDir.concat("*");
- }
-
- private static void runStatement(String query, Connection connection, Consumer resultSetMethod)
- throws SQLException {
- PreparedStatement statement = connection.prepareStatement(query);
- try {
- if (resultSetMethod != null) {
- ResultSet resultSet = statement.executeQuery();
- resultSetMethod.accept(resultSet);
- } else {
- statement.execute();
- }
- } finally {
- statement.close();
- connection.close();
- }
- }
-
- private Connection getConnection(SerializableFunction<Void, DataSource> dataSourceProviderFn)
- throws SQLException {
- DataSource dataSource = dataSourceProviderFn.apply(null);
- return dataSource.getConnection();
- }
-
- // Snowflake is expecting "gcs://" prefix for GCS and Beam "gs://"
- private String getProperBucketDir(String bucketDir) {
- if (bucketDir.contains(CloudProvider.GCS.getPrefix())) {
- return bucketDir.replace(CloudProvider.GCS.getPrefix(), SNOWFLAKE_GCS_PREFIX);
- }
- return bucketDir;
- }
-}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/CloudProvider.java
similarity index 95%
copy from sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
copy to sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/CloudProvider.java
index 404859c..450ecc7 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/CloudProvider.java
@@ -15,7 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.sdk.io.snowflake;
+package org.apache.beam.sdk.io.snowflake.enums;
public enum CloudProvider {
GCS("gs://");
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/WriteDisposition.java
similarity index 76%
copy from sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
copy to sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/WriteDisposition.java
index 404859c..0c9fb4e 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/WriteDisposition.java
@@ -15,18 +15,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.sdk.io.snowflake;
+package org.apache.beam.sdk.io.snowflake.enums;
-public enum CloudProvider {
- GCS("gs://");
-
- private final String prefix;
-
- private CloudProvider(String prefix) {
- this.prefix = prefix;
- }
-
- public String getPrefix() {
- return prefix;
- }
+/** Enum containing all supported dispositions during writing to table phase. */
+public enum WriteDisposition {
+ TRUNCATE,
+ APPEND,
+ EMPTY
}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/package-info.java
similarity index 76%
copy from sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
copy to sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/package-info.java
index 404859c..60feb40 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/enums/package-info.java
@@ -15,18 +15,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.sdk.io.snowflake;
-public enum CloudProvider {
- GCS("gs://");
-
- private final String prefix;
-
- private CloudProvider(String prefix) {
- this.prefix = prefix;
- }
-
- public String getPrefix() {
- return prefix;
- }
-}
+/** Snowflake IO data types. */
+package org.apache.beam.sdk.io.snowflake.enums;
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/ServiceConfig.java
similarity index 76%
copy from sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
copy to sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/ServiceConfig.java
index 404859c..09e1368 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/ServiceConfig.java
@@ -15,18 +15,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.sdk.io.snowflake;
+package org.apache.beam.sdk.io.snowflake.services;
-public enum CloudProvider {
- GCS("gs://");
-
- private final String prefix;
-
- private CloudProvider(String prefix) {
- this.prefix = prefix;
- }
-
- public String getPrefix() {
- return prefix;
- }
-}
+public abstract class ServiceConfig {}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeService.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeService.java
similarity index 67%
rename from sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeService.java
rename to sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeService.java
index 6375e79..16cd3c6 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeService.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeService.java
@@ -15,22 +15,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.sdk.io.snowflake;
+package org.apache.beam.sdk.io.snowflake.services;
import java.io.Serializable;
-import java.sql.SQLException;
-import javax.sql.DataSource;
-import org.apache.beam.sdk.transforms.SerializableFunction;
-/** Interface which defines common methods for interacting with SnowFlake. */
-public interface SnowflakeService extends Serializable {
+/** Interface which defines common methods for interacting with Snowflake. */
+public interface SnowflakeService<T extends ServiceConfig> extends Serializable {
String CSV_QUOTE_CHAR_FOR_COPY = "''";
- String copyIntoStage(
- SerializableFunction<Void, DataSource> dataSourceProviderFn,
- String query,
- String table,
- String integrationName,
- String stagingBucketDir)
- throws SQLException;
+ String read(T config) throws Exception;
+
+ void write(T config) throws Exception;
}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeServiceConfig.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeServiceConfig.java
new file mode 100644
index 0000000..46ad877
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeServiceConfig.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.services;
+
+import java.util.List;
+import javax.sql.DataSource;
+import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+
+public class SnowflakeServiceConfig extends ServiceConfig {
+ private SerializableFunction<Void, DataSource> dataSourceProviderFn;
+
+ private String table;
+ private String query;
+ private String storageIntegrationName;
+ private List<String> filesList;
+
+ private WriteDisposition writeDisposition;
+ private String stagingBucketDir;
+
+ public SnowflakeServiceConfig(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn,
+ String table,
+ String query,
+ String storageIntegration,
+ String stagingBucketDir) {
+ this.dataSourceProviderFn = dataSourceProviderFn;
+ this.table = table;
+ this.query = query;
+ this.storageIntegrationName = storageIntegration;
+ this.stagingBucketDir = stagingBucketDir;
+ }
+
+ public SnowflakeServiceConfig(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn,
+ List<String> filesList,
+ String table,
+ String query,
+ WriteDisposition writeDisposition,
+ String storageIntegrationName,
+ String stagingBucketDir) {
+ this.dataSourceProviderFn = dataSourceProviderFn;
+ this.filesList = filesList;
+ this.table = table;
+ this.query = query;
+ this.writeDisposition = writeDisposition;
+ this.storageIntegrationName = storageIntegrationName;
+ this.stagingBucketDir = stagingBucketDir;
+ }
+
+ public SerializableFunction<Void, DataSource> getDataSourceProviderFn() {
+ return dataSourceProviderFn;
+ }
+
+ public String getTable() {
+ return table;
+ }
+
+ public String getQuery() {
+ return query;
+ }
+
+ public String getstorageIntegrationName() {
+ return storageIntegrationName;
+ }
+
+ public String getStagingBucketDir() {
+ return stagingBucketDir;
+ }
+
+ public List<String> getFilesList() {
+ return filesList;
+ }
+
+ public WriteDisposition getWriteDisposition() {
+ return writeDisposition;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeServiceImpl.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeServiceImpl.java
new file mode 100644
index 0000000..6330c18
--- /dev/null
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeServiceImpl.java
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.services;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.List;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import javax.sql.DataSource;
+import org.apache.beam.sdk.io.snowflake.enums.CloudProvider;
+import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Implemenation of {@link SnowflakeService} used in production. */
+public class SnowflakeServiceImpl implements SnowflakeService<SnowflakeServiceConfig> {
+ private static final Logger LOG = LoggerFactory.getLogger(SnowflakeServiceImpl.class);
+ private static final String SNOWFLAKE_GCS_PREFIX = "gcs://";
+
+ @Override
+ public void write(SnowflakeServiceConfig config) throws Exception {
+ copyToTable(config);
+ }
+
+ @Override
+ public String read(SnowflakeServiceConfig config) throws Exception {
+ return copyIntoStage(config);
+ }
+
+ public String copyIntoStage(SnowflakeServiceConfig config) throws SQLException {
+ SerializableFunction<Void, DataSource> dataSourceProviderFn = config.getDataSourceProviderFn();
+ String table = config.getTable();
+ String query = config.getQuery();
+ String storageIntegrationName = config.getstorageIntegrationName();
+ String stagingBucketDir = config.getStagingBucketDir();
+
+ String source;
+ if (query != null) {
+ // Query must be surrounded with brackets
+ source = String.format("(%s)", query);
+ } else {
+ source = table;
+ }
+
+ String copyQuery =
+ String.format(
+ "COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');",
+ getProperBucketDir(stagingBucketDir),
+ source,
+ storageIntegrationName,
+ CSV_QUOTE_CHAR_FOR_COPY);
+
+ runStatement(copyQuery, getConnection(dataSourceProviderFn), null);
+
+ return stagingBucketDir.concat("*");
+ }
+
+ public void copyToTable(SnowflakeServiceConfig config) throws SQLException {
+
+ SerializableFunction<Void, DataSource> dataSourceProviderFn = config.getDataSourceProviderFn();
+ List<String> filesList = config.getFilesList();
+ String table = config.getTable();
+ String query = config.getQuery();
+ WriteDisposition writeDisposition = config.getWriteDisposition();
+ String storageIntegrationName = config.getstorageIntegrationName();
+ String stagingBucketDir = config.getStagingBucketDir();
+
+ String source;
+ if (query != null) {
+ // Query must be surrounded with brackets
+ source = String.format("(%s)", query);
+ } else {
+ source = String.format("'%s'", stagingBucketDir);
+ }
+
+ filesList = filesList.stream().map(e -> String.format("'%s'", e)).collect(Collectors.toList());
+ String files = String.join(", ", filesList);
+ files = files.replaceAll(stagingBucketDir, "");
+ DataSource dataSource = dataSourceProviderFn.apply(null);
+
+ prepareTableAccordingWriteDisposition(dataSource, table, writeDisposition);
+
+ if (!storageIntegrationName.isEmpty()) {
+ query =
+ String.format(
+ "COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP) STORAGE_INTEGRATION=%s;",
+ table,
+ getProperBucketDir(source),
+ files,
+ CSV_QUOTE_CHAR_FOR_COPY,
+ storageIntegrationName);
+ } else {
+ query =
+ String.format(
+ "COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP);",
+ table, source, files, CSV_QUOTE_CHAR_FOR_COPY);
+ }
+
+ runStatement(query, dataSource.getConnection(), null);
+ }
+
+ private void truncateTable(DataSource dataSource, String table) throws SQLException {
+ String query = String.format("TRUNCATE %s;", table);
+ runConnectionWithStatement(dataSource, query, null);
+ }
+
+ private static void checkIfTableIsEmpty(DataSource dataSource, String table) throws SQLException {
+ String selectQuery = String.format("SELECT count(*) FROM %s LIMIT 1;", table);
+ runConnectionWithStatement(
+ dataSource,
+ selectQuery,
+ resultSet -> {
+ assert resultSet != null;
+ checkIfTableIsEmpty((ResultSet) resultSet);
+ });
+ }
+
+ private static void checkIfTableIsEmpty(ResultSet resultSet) {
+ int columnId = 1;
+ try {
+ if (!resultSet.next() || !checkIfTableIsEmpty(resultSet, columnId)) {
+ throw new RuntimeException("Table is not empty. Aborting COPY with disposition EMPTY");
+ }
+ } catch (SQLException e) {
+ throw new RuntimeException("Unable run pipeline with EMPTY disposition.", e);
+ }
+ }
+
+ private static boolean checkIfTableIsEmpty(ResultSet resultSet, int columnId)
+ throws SQLException {
+ int rowCount = resultSet.getInt(columnId);
+ if (rowCount >= 1) {
+ return false;
+ }
+ return true;
+ }
+
+ private void prepareTableAccordingWriteDisposition(
+ DataSource dataSource, String table, WriteDisposition writeDisposition) throws SQLException {
+ switch (writeDisposition) {
+ case TRUNCATE:
+ truncateTable(dataSource, table);
+ break;
+ case EMPTY:
+ checkIfTableIsEmpty(dataSource, table);
+ break;
+ case APPEND:
+ default:
+ break;
+ }
+ }
+
+ private static void runConnectionWithStatement(
+ DataSource dataSource, String query, Consumer resultSetMethod) throws SQLException {
+ Connection connection = dataSource.getConnection();
+ runStatement(query, connection, resultSetMethod);
+ connection.close();
+ }
+
+ private static void runStatement(String query, Connection connection, Consumer resultSetMethod)
+ throws SQLException {
+ PreparedStatement statement = connection.prepareStatement(query);
+ try {
+ if (resultSetMethod != null) {
+ ResultSet resultSet = statement.executeQuery();
+ resultSetMethod.accept(resultSet);
+ } else {
+ statement.execute();
+ }
+ } finally {
+ statement.close();
+ connection.close();
+ }
+ }
+
+ private Connection getConnection(SerializableFunction<Void, DataSource> dataSourceProviderFn)
+ throws SQLException {
+ DataSource dataSource = dataSourceProviderFn.apply(null);
+ return dataSource.getConnection();
+ }
+
+ // Snowflake is expecting "gcs://" prefix for GCS and Beam "gs://"
+ private String getProperBucketDir(String bucketDir) {
+ if (bucketDir.contains(CloudProvider.GCS.getPrefix())) {
+ return bucketDir.replace(CloudProvider.GCS.getPrefix(), SNOWFLAKE_GCS_PREFIX);
+ }
+ return bucketDir;
+ }
+}
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/package-info.java
similarity index 76%
rename from sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
rename to sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/package-info.java
index 404859c..5ec1c30 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/CloudProvider.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/package-info.java
@@ -15,18 +15,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.sdk.io.snowflake;
-public enum CloudProvider {
- GCS("gs://");
-
- private final String prefix;
-
- private CloudProvider(String prefix) {
- this.prefix = prefix;
- }
-
- public String getPrefix() {
- return prefix;
- }
-}
+/** Snowflake IO services and POJOs. */
+package org.apache.beam.sdk.io.snowflake.services;
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeDatabase.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeDatabase.java
index 5bf8b21..320f4fd 100644
--- a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeDatabase.java
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeDatabase.java
@@ -25,7 +25,7 @@ import java.util.Map;
import java.util.stream.Collectors;
import net.snowflake.client.jdbc.SnowflakeSQLException;
-/** Fake implementation of SnowFlake warehouse used in test code. */
+/** Fake implementation of Snowflake warehouse used in test code. */
public class FakeSnowflakeDatabase implements Serializable {
private static Map<String, List<String>> tables = new HashMap<>();
@@ -71,10 +71,6 @@ public class FakeSnowflakeDatabase implements Serializable {
FakeSnowflakeDatabase.tables.put(table, rows);
}
- public static void clean() {
- FakeSnowflakeDatabase.tables = new HashMap<>();
- }
-
public static void truncateTable(String table) {
FakeSnowflakeDatabase.createTable(table);
}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeServiceImpl.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeServiceImpl.java
index 4a62dcd..5164c31 100644
--- a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeServiceImpl.java
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeServiceImpl.java
@@ -22,33 +22,72 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.sql.SQLException;
+import java.util.ArrayList;
import java.util.List;
-import javax.sql.DataSource;
-import org.apache.beam.sdk.io.snowflake.SnowflakeService;
-import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeServiceConfig;
-/**
- * Fake implementation of {@link org.apache.beam.sdk.io.snowflake.SnowflakeService} used in tests.
- */
-public class FakeSnowflakeServiceImpl implements SnowflakeService {
+/** Fake implementation of {@link SnowflakeService} used in tests. */
+public class FakeSnowflakeServiceImpl implements SnowflakeService<SnowflakeServiceConfig> {
@Override
- public String copyIntoStage(
- SerializableFunction<Void, DataSource> dataSourceProviderFn,
- String query,
- String table,
- String integrationName,
- String stagingBucketName)
- throws SQLException {
+ public void write(SnowflakeServiceConfig config) throws Exception {
+ copyToTable(config);
+ }
+
+ @Override
+ public String read(SnowflakeServiceConfig config) throws Exception {
+ return copyIntoStage(config);
+ }
+
+ public String copyIntoStage(SnowflakeServiceConfig config) throws SQLException {
+ String table = config.getTable();
+ String query = config.getQuery();
+
+ String stagingBucketDir = config.getStagingBucketDir();
if (table != null) {
- writeToFile(stagingBucketName, FakeSnowflakeDatabase.getElements(table));
+ writeToFile(stagingBucketDir, FakeSnowflakeDatabase.getElements(table));
}
if (query != null) {
- writeToFile(stagingBucketName, FakeSnowflakeDatabase.runQuery(query));
+ writeToFile(stagingBucketDir, FakeSnowflakeDatabase.runQuery(query));
}
- return String.format("./%s/*", stagingBucketName);
+ return String.format("./%s/*", stagingBucketDir);
+ }
+
+ public void copyToTable(SnowflakeServiceConfig config) throws SQLException {
+ List<String> filesList = config.getFilesList();
+ String table = config.getTable();
+ WriteDisposition writeDisposition = config.getWriteDisposition();
+
+ List<String> rows = new ArrayList<>();
+ for (String file : filesList) {
+ rows.addAll(TestUtils.readGZIPFile(file.replace("'", "")));
+ }
+
+ prepareTableAccordingWriteDisposition(table, writeDisposition);
+
+ FakeSnowflakeDatabase.createTableWithElements(table, rows);
+ }
+
+ private void prepareTableAccordingWriteDisposition(
+ String table, WriteDisposition writeDisposition) throws SQLException {
+ switch (writeDisposition) {
+ case TRUNCATE:
+ FakeSnowflakeDatabase.truncateTable(table);
+ break;
+ case EMPTY:
+ if (!FakeSnowflakeDatabase.isTableEmpty(table)) {
+ throw new RuntimeException("Table is not empty. Aborting COPY with disposition EMPTY");
+ }
+ break;
+ case APPEND:
+
+ default:
+ break;
+ }
}
private void writeToFile(String stagingBucketNameTmp, List<String> rows) {
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/TestUtils.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/TestUtils.java
index aab8d7d..ec458ea 100644
--- a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/TestUtils.java
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/TestUtils.java
@@ -17,7 +17,23 @@
*/
package org.apache.beam.sdk.io.snowflake.test;
+import java.io.BufferedReader;
import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Stream;
+import java.util.zip.GZIPInputStream;
+import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.values.KV;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,4 +53,53 @@ public class TestUtils {
public static String getPrivateKeyPassphrase() {
return PRIVATE_KEY_PASSPHRASE;
}
+
+ public static void removeTempDir(String dir) {
+ Path path = Paths.get(dir);
+ try (Stream<Path> stream = Files.walk(path)) {
+ stream.sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
+ } catch (IOException e) {
+ LOG.info("Not able to remove files");
+ }
+ }
+
+ public static boolean areListsEqual(List<?> expected, List<?> actual) {
+ return expected.size() == actual.size()
+ && expected.containsAll(actual)
+ && actual.containsAll(expected);
+ }
+
+ public static SnowflakeIO.UserDataMapper<KV<String, Long>> getLongCsvMapperKV() {
+ return (SnowflakeIO.UserDataMapper<KV<String, Long>>)
+ recordLine -> new Long[] {recordLine.getValue()};
+ }
+
+ public static SnowflakeIO.UserDataMapper<Long> getLongCsvMapper() {
+ return (SnowflakeIO.UserDataMapper<Long>) recordLine -> new Long[] {recordLine};
+ }
+
+ public static class ParseToKv extends DoFn<Long, KV<String, Long>> {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ KV stringIntKV = KV.of(c.element().toString(), c.element().longValue());
+ c.output(stringIntKV);
+ }
+ }
+
+ public static List<String> readGZIPFile(String file) {
+ List<String> lines = new ArrayList<>();
+ try {
+ GZIPInputStream gzip = new GZIPInputStream(new FileInputStream(file));
+ BufferedReader br = new BufferedReader(new InputStreamReader(gzip, Charset.defaultCharset()));
+
+ String line;
+ while ((line = br.readLine()) != null) {
+ lines.add(line);
+ }
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to read file", e);
+ }
+
+ return lines;
+ }
}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/BatchTestPipelineOptions.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/TestPipelineOptions.java
similarity index 93%
rename from sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/BatchTestPipelineOptions.java
rename to sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/TestPipelineOptions.java
index 3504c45..c10c7cb 100644
--- a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/BatchTestPipelineOptions.java
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/TestPipelineOptions.java
@@ -20,7 +20,7 @@ package org.apache.beam.sdk.io.snowflake.test.unit;
import org.apache.beam.sdk.io.snowflake.SnowflakePipelineOptions;
import org.apache.beam.sdk.options.Description;
-public interface BatchTestPipelineOptions extends SnowflakePipelineOptions {
+public interface TestPipelineOptions extends SnowflakePipelineOptions {
@Description("Table name to connect to.")
String getTable();
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/read/SnowflakeIOReadTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/read/SnowflakeIOReadTest.java
index e4eda0d..6016a66 100644
--- a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/read/SnowflakeIOReadTest.java
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/read/SnowflakeIOReadTest.java
@@ -26,11 +26,11 @@ import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.io.AvroGeneratedUser;
import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
-import org.apache.beam.sdk.io.snowflake.SnowflakeService;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeBasicDataSource;
import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeDatabase;
import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeServiceImpl;
-import org.apache.beam.sdk.io.snowflake.test.unit.BatchTestPipelineOptions;
+import org.apache.beam.sdk.io.snowflake.test.unit.TestPipelineOptions;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.PCollection;
@@ -47,16 +47,14 @@ public class SnowflakeIOReadTest implements Serializable {
public static final String FAKE_TABLE = "FAKE_TABLE";
public static final String FAKE_QUERY = "SELECT * FROM FAKE_TABLE";
- private static final BatchTestPipelineOptions options =
- TestPipeline.testingPipelineOptions().as(BatchTestPipelineOptions.class);;
+ private static final TestPipelineOptions options =
+ TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class);;
@Rule public final transient TestPipeline pipeline = TestPipeline.create();
@Rule public transient ExpectedException thrown = ExpectedException.none();
private static SnowflakeIO.DataSourceConfiguration dataSourceConfiguration;
private static SnowflakeService snowflakeService;
- private static String stagingBucketName;
- private static String integrationName;
private static List<GenericRecord> avroTestData;
@BeforeClass
@@ -72,12 +70,9 @@ public class SnowflakeIOReadTest implements Serializable {
FakeSnowflakeDatabase.createTableWithElements(FAKE_TABLE, testData);
options.setServerName("NULL.snowflakecomputing.com");
- options.setStorageIntegration("STORAGE_INTEGRATION");
+ options.setStorageIntegrationName("STORAGE_INTEGRATION");
options.setStagingBucketName("BUCKET");
- stagingBucketName = options.getStagingBucketName();
- integrationName = options.getStorageIntegration();
-
dataSourceConfiguration =
SnowflakeIO.DataSourceConfiguration.create(new FakeSnowflakeBasicDataSource())
.withServerName(options.getServerName());
@@ -88,14 +83,13 @@ public class SnowflakeIOReadTest implements Serializable {
@Test
public void testConfigIsMissingStagingBucketName() {
thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("withStagingBucketName() is required");
+ thrown.expectMessage("withStagingBucketName is required");
pipeline.apply(
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromTable(FAKE_TABLE)
- .withIntegrationName(integrationName)
- .withIntegrationName(integrationName)
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -103,15 +97,15 @@ public class SnowflakeIOReadTest implements Serializable {
}
@Test
- public void testConfigIsMissingIntegrationName() {
+ public void testConfigIsMissingStorageIntegration() {
thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("withIntegrationName() is required");
+ thrown.expectMessage("withStorageIntegrationName is required");
pipeline.apply(
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromTable(FAKE_TABLE)
- .withStagingBucketName(stagingBucketName)
+ .withStagingBucketName(options.getStagingBucketName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -127,8 +121,8 @@ public class SnowflakeIOReadTest implements Serializable {
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromTable(FAKE_TABLE)
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
pipeline.run();
@@ -143,8 +137,8 @@ public class SnowflakeIOReadTest implements Serializable {
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromTable(FAKE_TABLE)
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper()));
pipeline.run();
@@ -158,8 +152,8 @@ public class SnowflakeIOReadTest implements Serializable {
pipeline.apply(
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -174,8 +168,8 @@ public class SnowflakeIOReadTest implements Serializable {
pipeline.apply(
SnowflakeIO.<GenericRecord>read(snowflakeService)
.fromTable(FAKE_TABLE)
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -192,8 +186,8 @@ public class SnowflakeIOReadTest implements Serializable {
.withDataSourceConfiguration(dataSourceConfiguration)
.fromQuery("")
.fromTable(FAKE_TABLE)
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -209,8 +203,8 @@ public class SnowflakeIOReadTest implements Serializable {
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromTable("NON_EXIST")
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -226,8 +220,8 @@ public class SnowflakeIOReadTest implements Serializable {
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromQuery("BAD_QUERY")
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -241,8 +235,8 @@ public class SnowflakeIOReadTest implements Serializable {
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromTable(FAKE_TABLE)
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
@@ -257,8 +251,8 @@ public class SnowflakeIOReadTest implements Serializable {
SnowflakeIO.<GenericRecord>read(snowflakeService)
.withDataSourceConfiguration(dataSourceConfiguration)
.fromQuery(FAKE_QUERY)
- .withStagingBucketName(stagingBucketName)
- .withIntegrationName(integrationName)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
.withCsvMapper(getCsvMapper())
.withCoder(AvroCoder.of(AvroGeneratedUser.getClassSchema())));
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/write/QueryDispositionLocationTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/write/QueryDispositionLocationTest.java
new file mode 100644
index 0000000..593f026
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/write/QueryDispositionLocationTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit.write;
+
+import static org.junit.Assert.assertTrue;
+
+import java.sql.SQLException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.LongStream;
+import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
+import org.apache.beam.sdk.io.snowflake.SnowflakePipelineOptions;
+import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeBasicDataSource;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeDatabase;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeServiceImpl;
+import org.apache.beam.sdk.io.snowflake.test.TestUtils;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class QueryDispositionLocationTest {
+ private static final String FAKE_TABLE = "FAKE_TABLE";
+ private static final String BUCKET_NAME = "BUCKET";
+
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+ @Rule public ExpectedException exceptionRule = ExpectedException.none();
+
+ private static SnowflakePipelineOptions options;
+ private static SnowflakeIO.DataSourceConfiguration dc;
+
+ private static SnowflakeService snowflakeService;
+ private static List<Long> testData;
+
+ @BeforeClass
+ public static void setupAll() {
+ PipelineOptionsFactory.register(SnowflakePipelineOptions.class);
+ options = TestPipeline.testingPipelineOptions().as(SnowflakePipelineOptions.class);
+
+ snowflakeService = new FakeSnowflakeServiceImpl();
+ testData = LongStream.range(0, 100).boxed().collect(Collectors.toList());
+ }
+
+ @Before
+ public void setup() {
+ options.setStagingBucketName(BUCKET_NAME);
+ options.setStorageIntegrationName("STORAGE_INTEGRATION");
+ options.setServerName("NULL.snowflakecomputing.com");
+
+ dc =
+ SnowflakeIO.DataSourceConfiguration.create(new FakeSnowflakeBasicDataSource())
+ .withServerName(options.getServerName());
+ }
+
+ @After
+ public void tearDown() {
+ TestUtils.removeTempDir(BUCKET_NAME);
+ }
+
+ @Test
+ public void writeWithWriteTruncateDispositionSuccess() throws SQLException {
+ FakeSnowflakeDatabase.createTable(FAKE_TABLE);
+
+ pipeline
+ .apply(Create.of(testData))
+ .apply(
+ "Truncate before write",
+ SnowflakeIO.<Long>write()
+ .withDataSourceConfiguration(dc)
+ .withTable(FAKE_TABLE)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
+ .withUserDataMapper(TestUtils.getLongCsvMapper())
+ .withFileNameTemplate("output*")
+ .withWriteDisposition(WriteDisposition.TRUNCATE)
+ .withSnowflakeService(snowflakeService));
+
+ pipeline.run(options).waitUntilFinish();
+
+ List<Long> actualData = FakeSnowflakeDatabase.getElementsAsLong(FAKE_TABLE);
+
+ assertTrue(TestUtils.areListsEqual(testData, actualData));
+ }
+
+ @Test
+ public void writeWithWriteEmptyDispositionWithNotEmptyTableFails() {
+ FakeSnowflakeDatabase.createTableWithElements(FAKE_TABLE, Arrays.asList("NOT_EMPTY"));
+
+ exceptionRule.expect(RuntimeException.class);
+ exceptionRule.expectMessage(
+ "java.lang.RuntimeException: Table is not empty. Aborting COPY with disposition EMPTY");
+
+ pipeline
+ .apply(Create.of(testData))
+ .apply(
+ "Write SnowflakeIO",
+ SnowflakeIO.<Long>write()
+ .withDataSourceConfiguration(dc)
+ .withTable(FAKE_TABLE)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
+ .withUserDataMapper(TestUtils.getLongCsvMapper())
+ .withFileNameTemplate("output*")
+ .withWriteDisposition(WriteDisposition.EMPTY)
+ .withSnowflakeService(snowflakeService));
+
+ pipeline.run(options).waitUntilFinish();
+ }
+
+ @Test
+ public void writeWithWriteEmptyDispositionWithEmptyTableSuccess() throws SQLException {
+ FakeSnowflakeDatabase.createTable(FAKE_TABLE);
+
+ pipeline
+ .apply(Create.of(testData))
+ .apply(
+ "Write SnowflakeIO",
+ SnowflakeIO.<Long>write()
+ .withDataSourceConfiguration(dc)
+ .withTable(FAKE_TABLE)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
+ .withFileNameTemplate("output*")
+ .withUserDataMapper(TestUtils.getLongCsvMapper())
+ .withWriteDisposition(WriteDisposition.EMPTY)
+ .withSnowflakeService(snowflakeService));
+
+ pipeline.run(options).waitUntilFinish();
+
+ List<Long> actualData = FakeSnowflakeDatabase.getElementsAsLong(FAKE_TABLE);
+
+ assertTrue(TestUtils.areListsEqual(testData, actualData));
+ }
+}
diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/write/SnowflakeIOWriteTest.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/write/SnowflakeIOWriteTest.java
new file mode 100644
index 0000000..9924e6e
--- /dev/null
+++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/unit/write/SnowflakeIOWriteTest.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.snowflake.test.unit.write;
+
+import static org.junit.Assert.assertTrue;
+
+import java.sql.SQLException;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.LongStream;
+import net.snowflake.client.jdbc.SnowflakeSQLException;
+import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
+import org.apache.beam.sdk.io.snowflake.SnowflakePipelineOptions;
+import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeBasicDataSource;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeDatabase;
+import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeServiceImpl;
+import org.apache.beam.sdk.io.snowflake.test.TestUtils;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.KV;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class SnowflakeIOWriteTest {
+ private static final String FAKE_TABLE = "FAKE_TABLE";
+ private static final String BUCKET_NAME = "BUCKET";
+
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+
+ @Rule public ExpectedException exceptionRule = ExpectedException.none();
+
+ private static SnowflakePipelineOptions options;
+ private static SnowflakeIO.DataSourceConfiguration dc;
+
+ private static SnowflakeService snowflakeService;
+ private static List<Long> testData;
+
+ @BeforeClass
+ public static void setupAll() {
+ snowflakeService = new FakeSnowflakeServiceImpl();
+ testData = LongStream.range(0, 100).boxed().collect(Collectors.toList());
+ }
+
+ @Before
+ public void setup() {
+ FakeSnowflakeDatabase.createTable(FAKE_TABLE);
+
+ PipelineOptionsFactory.register(SnowflakePipelineOptions.class);
+ options = TestPipeline.testingPipelineOptions().as(SnowflakePipelineOptions.class);
+ options.setStagingBucketName(BUCKET_NAME);
+ options.setStorageIntegrationName("STORAGE_INTEGRATION");
+ options.setServerName("NULL.snowflakecomputing.com");
+
+ dc =
+ SnowflakeIO.DataSourceConfiguration.create(new FakeSnowflakeBasicDataSource())
+ .withServerName(options.getServerName());
+ }
+
+ @After
+ public void tearDown() {
+ TestUtils.removeTempDir(BUCKET_NAME);
+ }
+
+ @Test
+ public void writeToExternalWithIntegrationTest() throws SnowflakeSQLException {
+ pipeline
+ .apply(Create.of(testData))
+ .apply(
+ "Write SnowflakeIO",
+ SnowflakeIO.<Long>write()
+ .withDataSourceConfiguration(dc)
+ .withUserDataMapper(TestUtils.getLongCsvMapper())
+ .withTable(FAKE_TABLE)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
+ .withSnowflakeService(snowflakeService));
+
+ pipeline.run(options).waitUntilFinish();
+
+ List<Long> actualData = FakeSnowflakeDatabase.getElementsAsLong(FAKE_TABLE);
+
+ assertTrue(TestUtils.areListsEqual(testData, actualData));
+ }
+
+ @Test
+ public void writeToExternalWithMapperTest() throws SnowflakeSQLException {
+ pipeline
+ .apply(Create.of(testData))
+ .apply(
+ "External text write IO",
+ SnowflakeIO.<Long>write()
+ .withTable(FAKE_TABLE)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
+ .withDataSourceConfiguration(dc)
+ .withUserDataMapper(TestUtils.getLongCsvMapper())
+ .withSnowflakeService(snowflakeService));
+
+ pipeline.run(options).waitUntilFinish();
+
+ List<Long> actualData = FakeSnowflakeDatabase.getElementsAsLong(FAKE_TABLE);
+
+ assertTrue(TestUtils.areListsEqual(testData, actualData));
+ }
+
+ @Test
+ public void writeToExternalWithKVInput() {
+ pipeline
+ .apply(Create.of(testData))
+ .apply(ParDo.of(new TestUtils.ParseToKv()))
+ .apply(
+ "Write SnowflakeIO",
+ SnowflakeIO.<KV<String, Long>>write()
+ .withDataSourceConfiguration(dc)
+ .withUserDataMapper(TestUtils.getLongCsvMapperKV())
+ .withTable(FAKE_TABLE)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
+ .withSnowflakeService(snowflakeService));
+
+ pipeline.run(options).waitUntilFinish();
+ }
+
+ @Test
+ public void writeToExternalWithTransformationTest() throws SQLException {
+ String query = "select t.$1 from %s t";
+ pipeline
+ .apply(Create.of(testData))
+ .apply(ParDo.of(new TestUtils.ParseToKv()))
+ .apply(
+ "Write SnowflakeIO",
+ SnowflakeIO.<KV<String, Long>>write()
+ .withTable(FAKE_TABLE)
+ .withStagingBucketName(options.getStagingBucketName())
+ .withStorageIntegrationName(options.getStorageIntegrationName())
+ .withUserDataMapper(TestUtils.getLongCsvMapperKV())
+ .withDataSourceConfiguration(dc)
+ .withQueryTransformation(query)
+ .withSnowflakeService(snowflakeService));
+
+ pipeline.run(options).waitUntilFinish();
+
+ List<Long> actualData = FakeSnowflakeDatabase.getElementsAsLong(FAKE_TABLE);
+
+ assertTrue(TestUtils.areListsEqual(testData, actualData));
+ }
+}