RyanSkraba commented on a change in pull request #11794:
URL: https://github.com/apache/beam/pull/11794#discussion_r443687101
##########
File path:
sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
##########
@@ -447,6 +513,346 @@ public void populateDisplayData(DisplayData.Builder
builder) {
}
}
+ /** Implementation of {@link #write()}. */
+ @AutoValue
+ public abstract static class Write<T> extends PTransform<PCollection<T>,
PDone> {
+ @Nullable
+ abstract SerializableFunction<Void, DataSource> getDataSourceProviderFn();
+
+ @Nullable
+ abstract String getTable();
+
+ @Nullable
+ abstract String getStorageIntegrationName();
+
+ @Nullable
+ abstract String getStagingBucketName();
+
+ @Nullable
+ abstract String getQuery();
+
+ @Nullable
+ abstract String getFileNameTemplate();
+
+ @Nullable
+ abstract WriteDisposition getWriteDisposition();
+
+ @Nullable
+ abstract UserDataMapper getUserDataMapper();
+
+ @Nullable
+ abstract SnowflakeService getSnowflakeService();
+
+ abstract Builder<T> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+ abstract Builder<T> setDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn);
+
+ abstract Builder<T> setTable(String table);
+
+ abstract Builder<T> setStorageIntegrationName(String
storageIntegrationName);
+
+ abstract Builder<T> setStagingBucketName(String stagingBucketName);
+
+ abstract Builder<T> setQuery(String query);
+
+ abstract Builder<T> setFileNameTemplate(String fileNameTemplate);
+
+ abstract Builder<T> setUserDataMapper(UserDataMapper userDataMapper);
+
+ abstract Builder<T> setWriteDisposition(WriteDisposition
writeDisposition);
+
+ abstract Builder<T> setSnowflakeService(SnowflakeService
snowflakeService);
+
+ abstract Write<T> build();
+ }
+
+ /**
+ * Setting information about Snowflake server.
+ *
+ * @param config - An instance of {@link DataSourceConfiguration}.
+ */
+ public Write<T> withDataSourceConfiguration(final DataSourceConfiguration
config) {
+ return withDataSourceProviderFn(new
DataSourceProviderFromDataSourceConfiguration(config));
+ }
+
+ /**
+ * Setting function that will provide {@link DataSourceConfiguration} in
runtime.
+ *
+ * @param dataSourceProviderFn a {@link SerializableFunction}.
+ */
+ public Write<T> withDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn) {
+ return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build();
+ }
+
+ /**
+ * A table name to be written in Snowflake.
+ *
+ * @param table - String with the name of the table.
+ */
+ public Write<T> withTable(String table) {
+ return toBuilder().setTable(table).build();
+ }
+
+ /**
+ * Name of the cloud bucket (GCS by now) to use as tmp location of CSVs
during COPY statement.
+ *
+ * @param stagingBucketName - String with the name of the bucket.
+ */
+ public Write<T> withStagingBucketName(String stagingBucketName) {
+ return toBuilder().setStagingBucketName(stagingBucketName).build();
+ }
+
+ /**
+ * Name of the Storage Integration in Snowflake to be used. See
+ *
https://docs.snowflake.com/en/sql-reference/sql/create-storage-integration.html
for
+ * reference.
+ *
+ * @param integrationName - String with the name of the Storage
Integration.
+ */
+ public Write<T> withStorageIntegrationName(String integrationName) {
+ return toBuilder().setStorageIntegrationName(integrationName).build();
+ }
+
+ /**
+ * A query to be executed in Snowflake.
+ *
+ * @param query - String with query.
+ */
+ public Write<T> withQueryTransformation(String query) {
+ return toBuilder().setQuery(query).build();
+ }
+
+ /**
+ * A template name for files saved to GCP.
+ *
+ * @param fileNameTemplate - String with template name for files.
+ */
+ public Write<T> withFileNameTemplate(String fileNameTemplate) {
+ return toBuilder().setFileNameTemplate(fileNameTemplate).build();
+ }
+
+ /**
+ * User-defined function mapping user data into CSV lines.
+ *
+ * @param userDataMapper - an instance of {@link UserDataMapper}.
+ */
+ public Write<T> withUserDataMapper(UserDataMapper userDataMapper) {
+ return toBuilder().setUserDataMapper(userDataMapper).build();
+ }
+
+ /**
+ * A disposition to be used during writing to table phase.
+ *
+ * @param writeDisposition - an instance of {@link WriteDisposition}.
+ */
+ public Write<T> withWriteDisposition(WriteDisposition writeDisposition) {
+ return toBuilder().setWriteDisposition(writeDisposition).build();
+ }
+
+ /**
+ * A snowflake service which is supposed to be used. Note: Currently we
have {@link
+ * SnowflakeServiceImpl} with corresponding {@link
FakeSnowflakeServiceImpl} used for testing.
+ *
+ * @param snowflakeService - an instance of {@link SnowflakeService}.
+ */
+ public Write<T> withSnowflakeService(SnowflakeService snowflakeService) {
+ return toBuilder().setSnowflakeService(snowflakeService).build();
+ }
+
+ @Override
+ public PDone expand(PCollection<T> input) {
+ checkArguments();
+
+ String stagingBucketDir = String.format("%s/%s/",
getStagingBucketName(), WRITE_TMP_PATH);
+
+ PCollection<String> out = write(input, stagingBucketDir);
+ out.setCoder(StringUtf8Coder.of());
+
+ return PDone.in(out.getPipeline());
+ }
+
+ private void checkArguments() {
+ checkArgument(getStagingBucketName() != null, "withStagingBucketName is
required");
+
+ checkArgument(getUserDataMapper() != null, "withUserDataMapper() is
required");
+
+ checkArgument(
+ (getDataSourceProviderFn() != null),
+ "withDataSourceConfiguration() or withDataSourceProviderFn() is
required");
+
+ checkArgument(getTable() != null, "withTable() is required");
+ }
+
+ private PCollection<String> write(PCollection<T> input, String
stagingBucketDir) {
+ SnowflakeService snowflakeService =
+ getSnowflakeService() != null ? getSnowflakeService() : new
SnowflakeServiceImpl();
+
+ PCollection<String> files = writeFiles(input, stagingBucketDir);
+
+ // Combining PCollection of files as a side input into one list of files
+ ListCoder<String> coder = ListCoder.of(StringUtf8Coder.of());
+ files =
+ (PCollection)
+ files
+ .getPipeline()
+ .apply(
+ Reify.viewInGlobalWindow(
+ (PCollectionView) files.apply(View.asList()),
coder));
+
+ return (PCollection)
+ files.apply("Copy files to table", copyToTable(snowflakeService,
stagingBucketDir));
+ }
+
+ private PCollection<String> writeFiles(PCollection<T> input, String
stagingBucketDir) {
+
+ PCollection<String> mappedUserData =
+ input
+ .apply(
+ MapElements.via(
+ new SimpleFunction<T, Object[]>() {
+ @Override
+ public Object[] apply(T element) {
+ return getUserDataMapper().mapRow(element);
+ }
+ }))
+ .apply("Map Objects array to CSV lines", ParDo.of(new
MapObjectsArrayToCsvFn()))
+ .setCoder(StringUtf8Coder.of());
+
+ WriteFilesResult filesResult =
+ mappedUserData.apply(
+ "Write files to specified location",
+ FileIO.<String>write()
+ .via(TextIO.sink())
+ .to(stagingBucketDir)
+ .withPrefix(getFileNameTemplate())
+ .withSuffix(".csv")
+ .withCompression(Compression.GZIP));
+
+ return (PCollection)
+ filesResult
+ .getPerDestinationOutputFilenames()
+ .apply("Parse KV filenames to Strings", Values.<String>create());
+ }
+
+ private ParDo.SingleOutput<Object, Object> copyToTable(
+ SnowflakeService snowflakeService, String stagingBucketDir) {
+ return ParDo.of(
+ new CopyToTableFn<>(
+ getDataSourceProviderFn(),
+ getTable(),
+ getQuery(),
+ stagingBucketDir,
+ getStorageIntegrationName(),
+ getWriteDisposition(),
+ snowflakeService));
+ }
+ }
+
+ public static class Concatenate extends Combine.CombineFn<String,
List<String>, List<String>> {
+ @Override
+ public List<String> createAccumulator() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public List<String> addInput(List<String> mutableAccumulator, String
input) {
+ mutableAccumulator.add(String.format("'%s'", input));
+ return mutableAccumulator;
+ }
+
+ @Override
+ public List<String> mergeAccumulators(Iterable<List<String>> accumulators)
{
+ List<String> result = createAccumulator();
+ for (List<String> accumulator : accumulators) {
+ result.addAll(accumulator);
+ }
+ return result;
+ }
+
+ @Override
+ public List<String> extractOutput(List<String> accumulator) {
+ return accumulator;
+ }
+ }
+
+ /**
+ * Custom DoFn that maps {@link Object[]} into CSV line to be saved to
Snowflake.
+ *
+ * <p>Adds Snowflake-specific quotations around strings.
+ */
+ private static class MapObjectsArrayToCsvFn extends DoFn<Object[], String> {
+
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ List<Object> csvItems = new ArrayList<>();
+ for (Object o : context.element()) {
+ if (o instanceof String) {
+ String field = (String) o;
+ field = field.replace("'", "''");
+ field = quoteField(field);
+
+ csvItems.add(field);
+ } else {
+ csvItems.add(o);
+ }
+ }
+ context.output(Joiner.on(",").useForNull("").join(csvItems));
Review comment:
This is where there's an implicit `.toString()` on the objects in the
array -- this is still pretty dangerous for many non-primitive classes!
I can think of a couple of solutions: (1) Use something other than CSV as
the file for copying into Snowflake, (2) Add a warning to the data mapper doc
that the toString's have to be coherent!
For example, if my data mapper function returns `new Object[]
{Arrays.asList(1,',',"\n")}` to insert, it's almost certainly going to break
the function.
I can see this happening, for example, if a user thinks that returning a
JsonObject will insert the JSON as a string into that column.
##########
File path:
sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
##########
@@ -447,6 +494,346 @@ public void populateDisplayData(DisplayData.Builder
builder) {
}
}
+ /** Implementation of {@link #write()}. */
+ @AutoValue
+ public abstract static class Write<T> extends PTransform<PCollection<T>,
PDone> {
+ @Nullable
+ abstract SerializableFunction<Void, DataSource> getDataSourceProviderFn();
+
+ @Nullable
+ abstract String getTable();
+
+ @Nullable
+ abstract String getQuery();
+
+ @Nullable
+ abstract Location getLocation();
+
+ @Nullable
+ abstract String getFileNameTemplate();
+
+ @Nullable
+ abstract WriteDisposition getWriteDisposition();
+
+ @Nullable
+ abstract UserDataMapper getUserDataMapper();
+
+ @Nullable
+ abstract SnowflakeService getSnowflakeService();
+
+ abstract Builder<T> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+ abstract Builder<T> setDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn);
+
+ abstract Builder<T> setTable(String table);
+
+ abstract Builder<T> setQuery(String query);
+
+ abstract Builder<T> setLocation(Location location);
+
+ abstract Builder<T> setFileNameTemplate(String fileNameTemplate);
+
+ abstract Builder<T> setUserDataMapper(UserDataMapper userDataMapper);
+
+ abstract Builder<T> setWriteDisposition(WriteDisposition
writeDisposition);
+
+ abstract Builder<T> setSnowflakeService(SnowflakeService
snowflakeService);
+
+ abstract Write<T> build();
+ }
+
+ /**
+ * Setting information about Snowflake server.
+ *
+ * @param config - An instance of {@link DataSourceConfiguration}.
+ */
+ public Write<T> withDataSourceConfiguration(final DataSourceConfiguration
config) {
+ return withDataSourceProviderFn(new
DataSourceProviderFromDataSourceConfiguration(config));
+ }
+
+ /**
+ * Setting function that will provide {@link DataSourceConfiguration} in
runtime.
+ *
+ * @param dataSourceProviderFn a {@link SerializableFunction}.
+ */
+ public Write<T> withDataSourceProviderFn(
+ SerializableFunction<Void, DataSource> dataSourceProviderFn) {
+ return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build();
+ }
+
+ /**
+ * A table name to be written in Snowflake.
+ *
+ * @param table - String with the name of the table.
+ */
+ public Write<T> to(String table) {
+ return toBuilder().setTable(table).build();
+ }
+
+ /**
+ * A query to be executed in Snowflake.
+ *
+ * @param query - String with query.
+ */
+ public Write<T> withQueryTransformation(String query) {
+ return toBuilder().setQuery(query).build();
+ }
+
+ /**
+ * A location object which contains connection config between Snowflake
and GCP.
+ *
+ * @param location - an instance of {@link Location}.
+ */
+ public Write<T> via(Location location) {
+ return toBuilder().setLocation(location).build();
+ }
+
+ /**
+ * A template name for files saved to GCP.
+ *
+ * @param fileNameTemplate - String with template name for files.
+ */
+ public Write<T> withFileNameTemplate(String fileNameTemplate) {
+ return toBuilder().setFileNameTemplate(fileNameTemplate).build();
+ }
+
+ /**
+ * User-defined function mapping user data into CSV lines.
+ *
+ * @param userDataMapper - an instance of {@link UserDataMapper}.
+ */
+ public Write<T> withUserDataMapper(UserDataMapper userDataMapper) {
+ return toBuilder().setUserDataMapper(userDataMapper).build();
+ }
+
+ /**
+ * A disposition to be used during writing to table phase.
+ *
+ * @param writeDisposition - an instance of {@link WriteDisposition}.
+ */
+ public Write<T> withWriteDisposition(WriteDisposition writeDisposition) {
+ return toBuilder().setWriteDisposition(writeDisposition).build();
+ }
+
+ /**
+ * A snowflake service which is supposed to be used. Note: Currently we
have {@link
+ * SnowflakeServiceImpl} with corresponding {@link
FakeSnowflakeServiceImpl} used for testing.
+ *
+ * @param snowflakeService - an instance of {@link SnowflakeService}.
+ */
+ public Write<T> withSnowflakeService(SnowflakeService snowflakeService) {
+ return toBuilder().setSnowflakeService(snowflakeService).build();
+ }
+
+ @Override
+ public PDone expand(PCollection<T> input) {
+ Location loc = getLocation();
+ checkArguments(loc);
+
+ String stagingBucketDir = String.format("%s/%s/",
loc.getStagingBucketName(), WRITE_TMP_PATH);
+
+ PCollection out = write(input, stagingBucketDir);
+ out.setCoder(StringUtf8Coder.of());
+
+ return PDone.in(out.getPipeline());
+ }
+
+ private void checkArguments(Location loc) {
+ checkArgument(loc != null, "via() is required");
+ checkArgument(
+ loc.getStorageIntegrationName() != null,
+ "location with storageIntegrationName is required");
+ checkArgument(
+ loc.getStagingBucketName() != null, "location with stagingBucketName
is required");
+
+ checkArgument(getUserDataMapper() != null, "withUserDataMapper() is
required");
+
+ checkArgument(
+ (getDataSourceProviderFn() != null),
+ "withDataSourceConfiguration() or withDataSourceProviderFn() is
required");
+
+ checkArgument(getTable() != null, "withTable() is required");
+ }
+
+ private PCollection write(PCollection input, String stagingBucketDir) {
+ SnowflakeService snowflakeService =
+ getSnowflakeService() != null ? getSnowflakeService() : new
SnowflakeServiceImpl();
+
+ PCollection files = writeFiles(input, stagingBucketDir);
+
+ files =
+ (PCollection)
+ files.apply("Create list of files to copy", Combine.globally(new
Concatenate()));
+
+ return (PCollection)
+ files.apply("Copy files to table", copyToTable(snowflakeService,
stagingBucketDir));
+ }
+
+ private PCollection writeFiles(PCollection<T> input, String
stagingBucketDir) {
+ class Parse extends DoFn<KV<T, String>, String> {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ c.output(c.element().getValue());
+ }
+ }
+
+ PCollection mappedUserData =
+ input
+ .apply(
+ "Map user data to Objects array",
+ ParDo.of(new
MapUserDataObjectsArrayFn<T>(getUserDataMapper())))
+ .apply("Map Objects array to CSV lines", ParDo.of(new
MapObjectsArrayToCsvFn()))
+ .setCoder(StringUtf8Coder.of());
+
+ WriteFilesResult filesResult =
+ (WriteFilesResult)
+ mappedUserData.apply(
+ "Write files to specified location",
+ FileIO.write()
+ .via((FileIO.Sink) new CSVSink())
+ .to(stagingBucketDir)
+ .withPrefix(getFileNameTemplate())
+ .withSuffix(".csv")
+ .withPrefix(UUID.randomUUID().toString().subSequence(0,
8).toString())
+ .withCompression(Compression.GZIP));
+
+ return (PCollection)
+ filesResult
+ .getPerDestinationOutputFilenames()
+ .apply("Parse KV filenames to Strings", ParDo.of(new Parse()));
+ }
+
+ private ParDo.SingleOutput<Object, Object> copyToTable(
+ SnowflakeService snowflakeService, String stagingBucketDir) {
+ return ParDo.of(
+ new CopyToTableFn<>(
+ getDataSourceProviderFn(),
+ getTable(),
+ getQuery(),
+ stagingBucketDir,
+ getLocation(),
+ getWriteDisposition(),
+ snowflakeService));
+ }
+ }
+
+ public static class Concatenate extends Combine.CombineFn<String,
List<String>, List<String>> {
+ @Override
+ public List<String> createAccumulator() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public List<String> addInput(List<String> mutableAccumulator, String
input) {
+ mutableAccumulator.add(String.format("'%s'", input));
+ return mutableAccumulator;
+ }
+
+ @Override
+ public List<String> mergeAccumulators(Iterable<List<String>> accumulators)
{
+ List<String> result = createAccumulator();
+ for (List<String> accumulator : accumulators) {
+ result.addAll(accumulator);
+ }
+ return result;
+ }
+
+ @Override
+ public List<String> extractOutput(List<String> accumulator) {
+ return accumulator;
+ }
+ }
+
+ private static class MapUserDataObjectsArrayFn<T> extends DoFn<T, Object[]> {
+ private final UserDataMapper<T> csvMapper;
+
+ public MapUserDataObjectsArrayFn(UserDataMapper<T> csvMapper) {
+ this.csvMapper = csvMapper;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context) throws Exception {
+ context.output(csvMapper.mapRow(context.element()));
+ }
+ }
+
+ /**
+ * Custom DoFn that maps {@link Object[]} into CSV line to be saved to
Snowflake.
+ *
+ * <p>Adds Snowflake-specific quotations around strings.
+ */
+ private static class MapObjectsArrayToCsvFn extends DoFn<Object[], String> {
+
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ List<Object> csvItems = new ArrayList<>();
+ for (Object o : context.element()) {
+ if (o instanceof String) {
+ String field = (String) o;
+ field = field.replace("'", "''");
+ field = quoteField(field);
+
+ csvItems.add(field);
+ } else {
+ csvItems.add(o);
Review comment:
The line above `field = field.replace("'", "''");` ? It looks like the
`''` style. If you confirm it works with Snowflake, I'll trust you!
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]