damccorm commented on code in PR #34411:
URL: https://github.com/apache/beam/pull/34411#discussion_r2040178411


##########
sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordSchemaTransformProviderTest.java:
##########
@@ -0,0 +1,605 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static org.apache.beam.sdk.io.Compression.AUTO;
+import static org.apache.beam.sdk.io.Compression.DEFLATE;
+import static org.apache.beam.sdk.io.Compression.GZIP;
+import static org.apache.beam.sdk.io.Compression.UNCOMPRESSED;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.startsWith;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.in;
+import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.ServiceLoader;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+import 
org.apache.beam.sdk.io.TFRecordReadSchemaTransformProvider.TFRecordReadSchemaTransform;
+import 
org.apache.beam.sdk.io.TFRecordWriteSchemaTransformProvider.TFRecordWriteSchemaTransform;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for TFRecordIO Read and Write transforms. */
+@RunWith(JUnit4.class)
+public class TFRecordSchemaTransformProviderTest {
+
+  /*
+  From 
https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py
+  Created by running following code in python:
+  >>> import tensorflow as tf
+  >>> import base64
+  >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
+  >>> writer.write('foo')
+  >>> writer.close()
+  >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
+  ...   data = base64.b64encode(f.read())
+  ...   print data
+  */
+  private static final String FOO_RECORD_BASE64 = 
"AwAAAAAAAACwmUkOZm9vYYq+/g==";
+
+  // Same as above but containing two records ['foo', 'bar']
+  private static final String FOO_BAR_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
+  private static final String BAR_FOO_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
+
+  private static final String[] FOO_RECORDS = {"foo"};
+  private static final String[] FOO_BAR_RECORDS = {"foo", "bar"};
+
+  private static final Iterable<String> EMPTY = Collections.emptyList();
+  private static final Iterable<String> SMALL = makeLines(1, 4);
+  private static final Iterable<String> LARGE = makeLines(1000, 4);
+  private static final Iterable<String> LARGE_RECORDS = makeLines(100, 100000);
+
+  @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+  @Rule public TestPipeline readPipeline = TestPipeline.create();
+
+  @Rule public TestPipeline writePipeline = TestPipeline.create();
+
+  @Rule public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testReadInvalidConfigurations() {
+    String filePattern = "foo.*";
+    String compression = "AUTO";
+
+    // Invalid filepattern
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(true)
+              .setCompression(compression)
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+
+    // Filepattern unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(true)
+              .setCompression(compression)
+              // .setFilePattern(StaticValueProvider.of("vegetable")) File 
pattern is mandatory
+              .build()
+              .validate();
+        });
+
+    // Validate unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              // .setValidate(true) // Validate is mandatory
+              .setCompression(compression)
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+
+    // Compression unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(false)
+              // .setCompression(Compression.AUTO) // Compression is mandatory
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testWriteInvalidConfigurations() throws Exception {
+    String fileName = "foo";
+    String nonExistentPath = "abc";
+    String filenameSuffix = "bar";
+    String shardTemplate = "xyz";
+    String compression = "AUTO";
+    Integer numShards = 10;
+
+    // Invalid outputPrefix
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
nonExistentPath)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              .setCompression(compression)
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // NumShards unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              // .setNumShards(numShards) // NumShards is mandatory
+              .setCompression(compression)
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // Compression unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              // .setCompression(compression) // Compression is mandatory
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // NoSpilling unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              .setCompression(compression)
+              // .setNoSpilling(true) // NoSpilling is mandatory
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testReadBuildTransform() {
+    TFRecordReadSchemaTransformProvider provider = new 
TFRecordReadSchemaTransformProvider();
+    provider.from(
+        TFRecordReadSchemaTransformConfiguration.builder()
+            .setValidate(false)
+            .setCompression("AUTO")
+            .setFilePattern("foo.*")
+            .build());
+  }
+
+  @Test
+  public void testWriteBuildTransform() {
+    TFRecordWriteSchemaTransformProvider provider = new 
TFRecordWriteSchemaTransformProvider();
+    provider.from(
+        TFRecordWriteSchemaTransformConfiguration.builder()
+            .setOutputPrefix(tempFolder.getRoot().toPath().toString())
+            .setFilenameSuffix("bar")
+            .setShardTemplate("xyz")
+            .setNumShards(10)
+            .setCompression("UNCOMPRESSED")
+            .setNoSpilling(true)
+            .build());
+  }
+
+  @Test
+  public void testReadFindTransformAndMakeItWork() {
+    ServiceLoader<SchemaTransformProvider> serviceLoader =
+        ServiceLoader.load(SchemaTransformProvider.class);
+    List<SchemaTransformProvider> providers =
+        StreamSupport.stream(serviceLoader.spliterator(), false)
+            .filter(provider -> provider.getClass() == 
TFRecordReadSchemaTransformProvider.class)
+            .collect(Collectors.toList());
+    SchemaTransformProvider tfrecordProvider = providers.get(0);
+    assertEquals(tfrecordProvider.outputCollectionNames(), 
Lists.newArrayList("output", "errors"));
+    assertEquals(tfrecordProvider.inputCollectionNames(), 
Lists.newArrayList());
+
+    assertEquals(
+        Sets.newHashSet("file_pattern", "compression", "validate", 
"error_handling"),
+        tfrecordProvider.configurationSchema().getFields().stream()
+            .map(field -> field.getName())
+            .collect(Collectors.toSet()));
+  }
+
+  @Test
+  public void testWriteFindTransformAndMakeItWork() {
+    ServiceLoader<SchemaTransformProvider> serviceLoader =
+        ServiceLoader.load(SchemaTransformProvider.class);
+    List<SchemaTransformProvider> providers =
+        StreamSupport.stream(serviceLoader.spliterator(), false)
+            .filter(provider -> provider.getClass() == 
TFRecordWriteSchemaTransformProvider.class)
+            .collect(Collectors.toList());
+    SchemaTransformProvider tfrecordProvider = providers.get(0);
+    assertEquals(tfrecordProvider.outputCollectionNames(), 
Lists.newArrayList("output", "errors"));
+
+    assertEquals(
+        Sets.newHashSet(
+            "output_prefix",
+            "filename_suffix",
+            "shard_template",
+            "num_shards",
+            "compression",
+            "no_spilling",
+            "error_handling"),
+        tfrecordProvider.configurationSchema().getFields().stream()
+            .map(field -> field.getName())
+            .collect(Collectors.toSet()));
+  }
+
+  /** Tests that TFRecordReadSchemaTransformProvider is presented. */
+  @Test
+  public void testReadNamed() {
+    readPipeline.enableAbandonedNodeEnforcement(false);
+    PCollectionRowTuple begin = PCollectionRowTuple.empty(readPipeline);
+    SchemaTransform transform =
+        new TFRecordReadSchemaTransformProvider()
+            .from(
+                TFRecordReadSchemaTransformConfiguration.builder()
+                    .setValidate(false)
+                    .setCompression("AUTO")
+                    .setFilePattern("foo.*")
+                    .build());
+
+    PCollectionRowTuple reads = begin.apply(transform);
+    String name = reads.get("output").getName();
+    assertThat(name, 
startsWith("TFRecordReadSchemaTransformProvider.TFRecordReadSchemaTransform"));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadOne() throws Exception {
+    runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadTwo() throws Exception {
+    runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteOne() throws Exception {
+    runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteTwo() throws Exception {
+    runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, 
BAR_FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidRecord() throws Exception {
+    expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 
bytes.");
+    runTestRead("bar".getBytes(StandardCharsets.UTF_8), new String[0]);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidLengthMask() throws Exception {
+    expectedException.expectCause(hasMessage(containsString("Mismatch of 
length mask")));
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[9] += (byte) 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidDataMask() throws Exception {
+    expectedException.expectCause(hasMessage(containsString("Mismatch of data 
mask")));
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[16] += (byte) 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  private void runTestRead(String base64, String[] expected) throws 
IOException {
+    runTestRead(BaseEncoding.base64().decode(base64), expected);
+  }
+
+  /** Tests {@link TFRecordReadSchemaTransformProvider}. */
+  private void runTestRead(byte[] data, String[] expected) throws IOException {
+    // Create temp filename
+    File tmpFile =
+        Files.createTempFile(tempFolder.getRoot().toPath(), "file", 
".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+    try (FileOutputStream fos = new FileOutputStream(tmpFile)) {
+      fos.write(data);
+    }
+
+    // Create transform provider with configuration data
+    TFRecordReadSchemaTransformProvider provider = new 
TFRecordReadSchemaTransformProvider();
+    String compression = "AUTO";
+    TFRecordReadSchemaTransformConfiguration configuration =
+        TFRecordReadSchemaTransformConfiguration.builder()
+            .setValidate(true)
+            .setCompression(compression)
+            .setFilePattern(filename)
+            .build();
+    TFRecordReadSchemaTransform transform =
+        (TFRecordReadSchemaTransform) provider.from(configuration);
+
+    // Create PCollectionRowTuples input data and apply transform to read
+    PCollectionRowTuple input = PCollectionRowTuple.empty(readPipeline);
+    PCollectionRowTuple reads = input.apply(transform);
+
+    // Create expected row data
+    Schema schema = Schema.of(Schema.Field.of("record", 
Schema.FieldType.BYTES));
+    List<Row> row =
+        Arrays.stream(expected)
+            .map(str -> str.getBytes(StandardCharsets.UTF_8))
+            .map(bytes -> Row.withSchema(schema).addValue(bytes).build())
+            .collect(Collectors.toList());
+    PAssert.that(reads.get("output")).containsInAnyOrder(row);
+
+    readPipeline.run().waitUntilFinish();
+  }
+
+  /** Tests {@link TFRecordWriteSchemaTransformProvider}. */
+  private void runTestWrite(String[] elems, String... base64) throws 
IOException {
+    // Create temp filename
+    File tmpFile =
+        Files.createTempFile(tempFolder.getRoot().toPath(), "file", 
".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+
+    // Create beam row schema
+    Schema schema = Schema.of(Schema.Field.of("record", 
Schema.FieldType.BYTES));
+
+    // Create transform provider with configuration data
+    TFRecordWriteSchemaTransformProvider provider = new 
TFRecordWriteSchemaTransformProvider();
+    String compression = "UNCOMPRESSED";
+    TFRecordWriteSchemaTransformConfiguration configuration =
+        TFRecordWriteSchemaTransformConfiguration.builder()
+            .setOutputPrefix(filename)
+            .setCompression(compression)
+            .setNumShards(0)
+            .setNoSpilling(true)
+            .build();
+    TFRecordWriteSchemaTransform transform =
+        (TFRecordWriteSchemaTransform) provider.from(configuration);
+
+    // Create Beam row byte data
+    List<Row> rows =
+        Arrays.stream(elems)
+            .map(str -> str.getBytes(StandardCharsets.UTF_8))
+            .map(bytes -> Row.withSchema(schema).addValue(bytes).build())
+            .collect(Collectors.toList());
+
+    // Create PColleciton input beam row data on pipeline and apply transform
+    PCollection<Row> input = 
writePipeline.apply(Create.of(rows).withRowSchema(schema));
+    PCollectionRowTuple rowTuple = PCollectionRowTuple.of("input", input);
+    rowTuple.apply(transform);
+
+    // Run pipeline
+    writePipeline.run().waitUntilFinish();
+
+    assertTrue("File should exist", tmpFile.exists());
+    assertTrue("File should have content", tmpFile.length() > 0);
+
+    FileInputStream fis = new FileInputStream(tmpFile);
+    String written = 
BaseEncoding.base64().encode(ByteStreams.toByteArray(fis));
+    assertThat(written, is(in(base64)));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTrip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithEmptyData() throws IOException {
+    runTestRoundTrip(EMPTY, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithOneShards() throws IOException {
+    runTestRoundTrip(LARGE, 1, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithSuffix() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".suffix", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripGzip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripZlib() throws IOException {
+    runTestRoundTrip(SMALL, 10, ".tfrecords", DEFLATE, DEFLATE);

Review Comment:
   Interestingly, the other compression types succeed...



##########
sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordSchemaTransformProviderTest.java:
##########
@@ -0,0 +1,605 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static org.apache.beam.sdk.io.Compression.AUTO;
+import static org.apache.beam.sdk.io.Compression.DEFLATE;
+import static org.apache.beam.sdk.io.Compression.GZIP;
+import static org.apache.beam.sdk.io.Compression.UNCOMPRESSED;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.startsWith;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.in;
+import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.ServiceLoader;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+import 
org.apache.beam.sdk.io.TFRecordReadSchemaTransformProvider.TFRecordReadSchemaTransform;
+import 
org.apache.beam.sdk.io.TFRecordWriteSchemaTransformProvider.TFRecordWriteSchemaTransform;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for TFRecordIO Read and Write transforms. */
+@RunWith(JUnit4.class)
+public class TFRecordSchemaTransformProviderTest {
+
+  /*
+  From 
https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py
+  Created by running following code in python:
+  >>> import tensorflow as tf
+  >>> import base64
+  >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
+  >>> writer.write('foo')
+  >>> writer.close()
+  >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
+  ...   data = base64.b64encode(f.read())
+  ...   print data
+  */
+  private static final String FOO_RECORD_BASE64 = 
"AwAAAAAAAACwmUkOZm9vYYq+/g==";
+
+  // Same as above but containing two records ['foo', 'bar']
+  private static final String FOO_BAR_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
+  private static final String BAR_FOO_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
+
+  private static final String[] FOO_RECORDS = {"foo"};
+  private static final String[] FOO_BAR_RECORDS = {"foo", "bar"};
+
+  private static final Iterable<String> EMPTY = Collections.emptyList();
+  private static final Iterable<String> SMALL = makeLines(1, 4);
+  private static final Iterable<String> LARGE = makeLines(1000, 4);
+  private static final Iterable<String> LARGE_RECORDS = makeLines(100, 100000);
+
+  @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+  @Rule public TestPipeline readPipeline = TestPipeline.create();
+
+  @Rule public TestPipeline writePipeline = TestPipeline.create();
+
+  @Rule public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testReadInvalidConfigurations() {
+    String filePattern = "foo.*";
+    String compression = "AUTO";
+
+    // Invalid filepattern
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(true)
+              .setCompression(compression)
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+
+    // Filepattern unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(true)
+              .setCompression(compression)
+              // .setFilePattern(StaticValueProvider.of("vegetable")) File 
pattern is mandatory
+              .build()
+              .validate();
+        });
+
+    // Validate unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              // .setValidate(true) // Validate is mandatory
+              .setCompression(compression)
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+
+    // Compression unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(false)
+              // .setCompression(Compression.AUTO) // Compression is mandatory
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testWriteInvalidConfigurations() throws Exception {
+    String fileName = "foo";
+    String nonExistentPath = "abc";
+    String filenameSuffix = "bar";
+    String shardTemplate = "xyz";
+    String compression = "AUTO";
+    Integer numShards = 10;
+
+    // Invalid outputPrefix
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
nonExistentPath)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              .setCompression(compression)
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // NumShards unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              // .setNumShards(numShards) // NumShards is mandatory
+              .setCompression(compression)
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // Compression unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              // .setCompression(compression) // Compression is mandatory
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // NoSpilling unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              .setCompression(compression)
+              // .setNoSpilling(true) // NoSpilling is mandatory
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testReadBuildTransform() {
+    TFRecordReadSchemaTransformProvider provider = new 
TFRecordReadSchemaTransformProvider();
+    provider.from(
+        TFRecordReadSchemaTransformConfiguration.builder()
+            .setValidate(false)
+            .setCompression("AUTO")
+            .setFilePattern("foo.*")
+            .build());
+  }
+
+  @Test
+  public void testWriteBuildTransform() {
+    TFRecordWriteSchemaTransformProvider provider = new 
TFRecordWriteSchemaTransformProvider();
+    provider.from(
+        TFRecordWriteSchemaTransformConfiguration.builder()
+            .setOutputPrefix(tempFolder.getRoot().toPath().toString())
+            .setFilenameSuffix("bar")
+            .setShardTemplate("xyz")
+            .setNumShards(10)
+            .setCompression("UNCOMPRESSED")
+            .setNoSpilling(true)
+            .build());
+  }
+
+  @Test
+  public void testReadFindTransformAndMakeItWork() {
+    ServiceLoader<SchemaTransformProvider> serviceLoader =
+        ServiceLoader.load(SchemaTransformProvider.class);
+    List<SchemaTransformProvider> providers =
+        StreamSupport.stream(serviceLoader.spliterator(), false)
+            .filter(provider -> provider.getClass() == 
TFRecordReadSchemaTransformProvider.class)
+            .collect(Collectors.toList());
+    SchemaTransformProvider tfrecordProvider = providers.get(0);
+    assertEquals(tfrecordProvider.outputCollectionNames(), 
Lists.newArrayList("output", "errors"));
+    assertEquals(tfrecordProvider.inputCollectionNames(), 
Lists.newArrayList());
+
+    assertEquals(
+        Sets.newHashSet("file_pattern", "compression", "validate", 
"error_handling"),
+        tfrecordProvider.configurationSchema().getFields().stream()
+            .map(field -> field.getName())
+            .collect(Collectors.toSet()));
+  }
+
+  @Test
+  public void testWriteFindTransformAndMakeItWork() {
+    ServiceLoader<SchemaTransformProvider> serviceLoader =
+        ServiceLoader.load(SchemaTransformProvider.class);
+    List<SchemaTransformProvider> providers =
+        StreamSupport.stream(serviceLoader.spliterator(), false)
+            .filter(provider -> provider.getClass() == 
TFRecordWriteSchemaTransformProvider.class)
+            .collect(Collectors.toList());
+    SchemaTransformProvider tfrecordProvider = providers.get(0);
+    assertEquals(tfrecordProvider.outputCollectionNames(), 
Lists.newArrayList("output", "errors"));
+
+    assertEquals(
+        Sets.newHashSet(
+            "output_prefix",
+            "filename_suffix",
+            "shard_template",
+            "num_shards",
+            "compression",
+            "no_spilling",
+            "error_handling"),
+        tfrecordProvider.configurationSchema().getFields().stream()
+            .map(field -> field.getName())
+            .collect(Collectors.toSet()));
+  }
+
+  /** Tests that TFRecordReadSchemaTransformProvider is presented. */
+  @Test
+  public void testReadNamed() {
+    readPipeline.enableAbandonedNodeEnforcement(false);
+    PCollectionRowTuple begin = PCollectionRowTuple.empty(readPipeline);
+    SchemaTransform transform =
+        new TFRecordReadSchemaTransformProvider()
+            .from(
+                TFRecordReadSchemaTransformConfiguration.builder()
+                    .setValidate(false)
+                    .setCompression("AUTO")
+                    .setFilePattern("foo.*")
+                    .build());
+
+    PCollectionRowTuple reads = begin.apply(transform);
+    String name = reads.get("output").getName();
+    assertThat(name, 
startsWith("TFRecordReadSchemaTransformProvider.TFRecordReadSchemaTransform"));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadOne() throws Exception {
+    runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadTwo() throws Exception {
+    runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteOne() throws Exception {
+    runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteTwo() throws Exception {
+    runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, 
BAR_FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidRecord() throws Exception {
+    expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 
bytes.");
+    runTestRead("bar".getBytes(StandardCharsets.UTF_8), new String[0]);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidLengthMask() throws Exception {
+    expectedException.expectCause(hasMessage(containsString("Mismatch of 
length mask")));
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[9] += (byte) 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidDataMask() throws Exception {
+    expectedException.expectCause(hasMessage(containsString("Mismatch of data 
mask")));
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[16] += (byte) 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  private void runTestRead(String base64, String[] expected) throws 
IOException {
+    runTestRead(BaseEncoding.base64().decode(base64), expected);
+  }
+
+  /** Tests {@link TFRecordReadSchemaTransformProvider}. */
+  private void runTestRead(byte[] data, String[] expected) throws IOException {
+    // Create temp filename
+    File tmpFile =
+        Files.createTempFile(tempFolder.getRoot().toPath(), "file", 
".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+    try (FileOutputStream fos = new FileOutputStream(tmpFile)) {
+      fos.write(data);
+    }
+
+    // Create transform provider with configuration data
+    TFRecordReadSchemaTransformProvider provider = new 
TFRecordReadSchemaTransformProvider();
+    String compression = "AUTO";
+    TFRecordReadSchemaTransformConfiguration configuration =
+        TFRecordReadSchemaTransformConfiguration.builder()
+            .setValidate(true)
+            .setCompression(compression)
+            .setFilePattern(filename)
+            .build();
+    TFRecordReadSchemaTransform transform =
+        (TFRecordReadSchemaTransform) provider.from(configuration);
+
+    // Create PCollectionRowTuples input data and apply transform to read
+    PCollectionRowTuple input = PCollectionRowTuple.empty(readPipeline);
+    PCollectionRowTuple reads = input.apply(transform);
+
+    // Create expected row data
+    Schema schema = Schema.of(Schema.Field.of("record", 
Schema.FieldType.BYTES));
+    List<Row> row =
+        Arrays.stream(expected)
+            .map(str -> str.getBytes(StandardCharsets.UTF_8))
+            .map(bytes -> Row.withSchema(schema).addValue(bytes).build())
+            .collect(Collectors.toList());
+    PAssert.that(reads.get("output")).containsInAnyOrder(row);
+
+    readPipeline.run().waitUntilFinish();
+  }
+
+  /** Tests {@link TFRecordWriteSchemaTransformProvider}. */
+  private void runTestWrite(String[] elems, String... base64) throws 
IOException {
+    // Create temp filename
+    File tmpFile =
+        Files.createTempFile(tempFolder.getRoot().toPath(), "file", 
".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+
+    // Create beam row schema
+    Schema schema = Schema.of(Schema.Field.of("record", 
Schema.FieldType.BYTES));
+
+    // Create transform provider with configuration data
+    TFRecordWriteSchemaTransformProvider provider = new 
TFRecordWriteSchemaTransformProvider();
+    String compression = "UNCOMPRESSED";
+    TFRecordWriteSchemaTransformConfiguration configuration =
+        TFRecordWriteSchemaTransformConfiguration.builder()
+            .setOutputPrefix(filename)
+            .setCompression(compression)
+            .setNumShards(0)
+            .setNoSpilling(true)
+            .build();
+    TFRecordWriteSchemaTransform transform =
+        (TFRecordWriteSchemaTransform) provider.from(configuration);
+
+    // Create Beam row byte data
+    List<Row> rows =
+        Arrays.stream(elems)
+            .map(str -> str.getBytes(StandardCharsets.UTF_8))
+            .map(bytes -> Row.withSchema(schema).addValue(bytes).build())
+            .collect(Collectors.toList());
+
+    // Create PColleciton input beam row data on pipeline and apply transform
+    PCollection<Row> input = 
writePipeline.apply(Create.of(rows).withRowSchema(schema));
+    PCollectionRowTuple rowTuple = PCollectionRowTuple.of("input", input);
+    rowTuple.apply(transform);
+
+    // Run pipeline
+    writePipeline.run().waitUntilFinish();
+
+    assertTrue("File should exist", tmpFile.exists());
+    assertTrue("File should have content", tmpFile.length() > 0);
+
+    FileInputStream fis = new FileInputStream(tmpFile);
+    String written = 
BaseEncoding.base64().encode(ByteStreams.toByteArray(fis));
+    assertThat(written, is(in(base64)));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTrip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithEmptyData() throws IOException {
+    runTestRoundTrip(EMPTY, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithOneShards() throws IOException {
+    runTestRoundTrip(LARGE, 1, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithSuffix() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".suffix", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripGzip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripZlib() throws IOException {
+    runTestRoundTrip(SMALL, 10, ".tfrecords", DEFLATE, DEFLATE);

Review Comment:
   This is confusing, I'm not sure why only this mode would fail. But it is 
true across runs...



##########
sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordSchemaTransformProviderTest.java:
##########
@@ -0,0 +1,605 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static org.apache.beam.sdk.io.Compression.AUTO;
+import static org.apache.beam.sdk.io.Compression.DEFLATE;
+import static org.apache.beam.sdk.io.Compression.GZIP;
+import static org.apache.beam.sdk.io.Compression.UNCOMPRESSED;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.startsWith;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.in;
+import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.ServiceLoader;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+import 
org.apache.beam.sdk.io.TFRecordReadSchemaTransformProvider.TFRecordReadSchemaTransform;
+import 
org.apache.beam.sdk.io.TFRecordWriteSchemaTransformProvider.TFRecordWriteSchemaTransform;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for TFRecordIO Read and Write transforms. */
+@RunWith(JUnit4.class)
+public class TFRecordSchemaTransformProviderTest {
+
+  /*
+  From 
https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py
+  Created by running following code in python:
+  >>> import tensorflow as tf
+  >>> import base64
+  >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
+  >>> writer.write('foo')
+  >>> writer.close()
+  >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
+  ...   data = base64.b64encode(f.read())
+  ...   print data
+  */
+  private static final String FOO_RECORD_BASE64 = 
"AwAAAAAAAACwmUkOZm9vYYq+/g==";
+
+  // Same as above but containing two records ['foo', 'bar']
+  private static final String FOO_BAR_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
+  private static final String BAR_FOO_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
+
+  private static final String[] FOO_RECORDS = {"foo"};
+  private static final String[] FOO_BAR_RECORDS = {"foo", "bar"};
+
+  private static final Iterable<String> EMPTY = Collections.emptyList();
+  private static final Iterable<String> SMALL = makeLines(1, 4);
+  private static final Iterable<String> LARGE = makeLines(1000, 4);
+  private static final Iterable<String> LARGE_RECORDS = makeLines(100, 100000);
+
+  @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+  @Rule public TestPipeline readPipeline = TestPipeline.create();
+
+  @Rule public TestPipeline writePipeline = TestPipeline.create();
+
+  @Rule public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testReadInvalidConfigurations() {
+    String filePattern = "foo.*";
+    String compression = "AUTO";
+
+    // Invalid filepattern
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(true)
+              .setCompression(compression)
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+
+    // Filepattern unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(true)
+              .setCompression(compression)
+              // .setFilePattern(StaticValueProvider.of("vegetable")) File 
pattern is mandatory
+              .build()
+              .validate();
+        });
+
+    // Validate unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              // .setValidate(true) // Validate is mandatory
+              .setCompression(compression)
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+
+    // Compression unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordReadSchemaTransformConfiguration.builder()
+              .setValidate(false)
+              // .setCompression(Compression.AUTO) // Compression is mandatory
+              .setFilePattern(filePattern)
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testWriteInvalidConfigurations() throws Exception {
+    String fileName = "foo";
+    String nonExistentPath = "abc";
+    String filenameSuffix = "bar";
+    String shardTemplate = "xyz";
+    String compression = "AUTO";
+    Integer numShards = 10;
+
+    // Invalid outputPrefix
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
nonExistentPath)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              .setCompression(compression)
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // NumShards unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              // .setNumShards(numShards) // NumShards is mandatory
+              .setCompression(compression)
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // Compression unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              // .setCompression(compression) // Compression is mandatory
+              .setNoSpilling(true)
+              .build()
+              .validate();
+        });
+
+    // NoSpilling unset
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          TFRecordWriteSchemaTransformConfiguration.builder()
+              .setOutputPrefix(tempFolder.getRoot().toPath().toString() + 
fileName)
+              .setFilenameSuffix(filenameSuffix)
+              .setShardTemplate(shardTemplate)
+              .setNumShards(numShards)
+              .setCompression(compression)
+              // .setNoSpilling(true) // NoSpilling is mandatory
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testReadBuildTransform() {
+    TFRecordReadSchemaTransformProvider provider = new 
TFRecordReadSchemaTransformProvider();
+    provider.from(
+        TFRecordReadSchemaTransformConfiguration.builder()
+            .setValidate(false)
+            .setCompression("AUTO")
+            .setFilePattern("foo.*")
+            .build());
+  }
+
+  @Test
+  public void testWriteBuildTransform() {
+    TFRecordWriteSchemaTransformProvider provider = new 
TFRecordWriteSchemaTransformProvider();
+    provider.from(
+        TFRecordWriteSchemaTransformConfiguration.builder()
+            .setOutputPrefix(tempFolder.getRoot().toPath().toString())
+            .setFilenameSuffix("bar")
+            .setShardTemplate("xyz")
+            .setNumShards(10)
+            .setCompression("UNCOMPRESSED")
+            .setNoSpilling(true)
+            .build());
+  }
+
+  @Test
+  public void testReadFindTransformAndMakeItWork() {
+    ServiceLoader<SchemaTransformProvider> serviceLoader =
+        ServiceLoader.load(SchemaTransformProvider.class);
+    List<SchemaTransformProvider> providers =
+        StreamSupport.stream(serviceLoader.spliterator(), false)
+            .filter(provider -> provider.getClass() == 
TFRecordReadSchemaTransformProvider.class)
+            .collect(Collectors.toList());
+    SchemaTransformProvider tfrecordProvider = providers.get(0);
+    assertEquals(tfrecordProvider.outputCollectionNames(), 
Lists.newArrayList("output", "errors"));
+    assertEquals(tfrecordProvider.inputCollectionNames(), 
Lists.newArrayList());
+
+    assertEquals(
+        Sets.newHashSet("file_pattern", "compression", "validate", 
"error_handling"),
+        tfrecordProvider.configurationSchema().getFields().stream()
+            .map(field -> field.getName())
+            .collect(Collectors.toSet()));
+  }
+
+  @Test
+  public void testWriteFindTransformAndMakeItWork() {
+    ServiceLoader<SchemaTransformProvider> serviceLoader =
+        ServiceLoader.load(SchemaTransformProvider.class);
+    List<SchemaTransformProvider> providers =
+        StreamSupport.stream(serviceLoader.spliterator(), false)
+            .filter(provider -> provider.getClass() == 
TFRecordWriteSchemaTransformProvider.class)
+            .collect(Collectors.toList());
+    SchemaTransformProvider tfrecordProvider = providers.get(0);
+    assertEquals(tfrecordProvider.outputCollectionNames(), 
Lists.newArrayList("output", "errors"));
+
+    assertEquals(
+        Sets.newHashSet(
+            "output_prefix",
+            "filename_suffix",
+            "shard_template",
+            "num_shards",
+            "compression",
+            "no_spilling",
+            "error_handling"),
+        tfrecordProvider.configurationSchema().getFields().stream()
+            .map(field -> field.getName())
+            .collect(Collectors.toSet()));
+  }
+
+  /** Tests that TFRecordReadSchemaTransformProvider is presented. */
+  @Test
+  public void testReadNamed() {
+    readPipeline.enableAbandonedNodeEnforcement(false);
+    PCollectionRowTuple begin = PCollectionRowTuple.empty(readPipeline);
+    SchemaTransform transform =
+        new TFRecordReadSchemaTransformProvider()
+            .from(
+                TFRecordReadSchemaTransformConfiguration.builder()
+                    .setValidate(false)
+                    .setCompression("AUTO")
+                    .setFilePattern("foo.*")
+                    .build());
+
+    PCollectionRowTuple reads = begin.apply(transform);
+    String name = reads.get("output").getName();
+    assertThat(name, 
startsWith("TFRecordReadSchemaTransformProvider.TFRecordReadSchemaTransform"));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadOne() throws Exception {
+    runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadTwo() throws Exception {
+    runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteOne() throws Exception {
+    runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteTwo() throws Exception {
+    runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, 
BAR_FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidRecord() throws Exception {
+    expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 
bytes.");
+    runTestRead("bar".getBytes(StandardCharsets.UTF_8), new String[0]);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidLengthMask() throws Exception {
+    expectedException.expectCause(hasMessage(containsString("Mismatch of 
length mask")));
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[9] += (byte) 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidDataMask() throws Exception {
+    expectedException.expectCause(hasMessage(containsString("Mismatch of data 
mask")));
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[16] += (byte) 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  private void runTestRead(String base64, String[] expected) throws 
IOException {
+    runTestRead(BaseEncoding.base64().decode(base64), expected);
+  }
+
+  /** Tests {@link TFRecordReadSchemaTransformProvider}. */
+  private void runTestRead(byte[] data, String[] expected) throws IOException {
+    // Create temp filename
+    File tmpFile =
+        Files.createTempFile(tempFolder.getRoot().toPath(), "file", 
".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+    try (FileOutputStream fos = new FileOutputStream(tmpFile)) {
+      fos.write(data);
+    }
+
+    // Create transform provider with configuration data
+    TFRecordReadSchemaTransformProvider provider = new 
TFRecordReadSchemaTransformProvider();
+    String compression = "AUTO";
+    TFRecordReadSchemaTransformConfiguration configuration =
+        TFRecordReadSchemaTransformConfiguration.builder()
+            .setValidate(true)
+            .setCompression(compression)
+            .setFilePattern(filename)
+            .build();
+    TFRecordReadSchemaTransform transform =
+        (TFRecordReadSchemaTransform) provider.from(configuration);
+
+    // Create PCollectionRowTuples input data and apply transform to read
+    PCollectionRowTuple input = PCollectionRowTuple.empty(readPipeline);
+    PCollectionRowTuple reads = input.apply(transform);
+
+    // Create expected row data
+    Schema schema = Schema.of(Schema.Field.of("record", 
Schema.FieldType.BYTES));
+    List<Row> row =
+        Arrays.stream(expected)
+            .map(str -> str.getBytes(StandardCharsets.UTF_8))
+            .map(bytes -> Row.withSchema(schema).addValue(bytes).build())
+            .collect(Collectors.toList());
+    PAssert.that(reads.get("output")).containsInAnyOrder(row);
+
+    readPipeline.run().waitUntilFinish();
+  }
+
+  /** Tests {@link TFRecordWriteSchemaTransformProvider}. */
+  private void runTestWrite(String[] elems, String... base64) throws 
IOException {
+    // Create temp filename
+    File tmpFile =
+        Files.createTempFile(tempFolder.getRoot().toPath(), "file", 
".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+
+    // Create beam row schema
+    Schema schema = Schema.of(Schema.Field.of("record", 
Schema.FieldType.BYTES));
+
+    // Create transform provider with configuration data
+    TFRecordWriteSchemaTransformProvider provider = new 
TFRecordWriteSchemaTransformProvider();
+    String compression = "UNCOMPRESSED";
+    TFRecordWriteSchemaTransformConfiguration configuration =
+        TFRecordWriteSchemaTransformConfiguration.builder()
+            .setOutputPrefix(filename)
+            .setCompression(compression)
+            .setNumShards(0)
+            .setNoSpilling(true)
+            .build();
+    TFRecordWriteSchemaTransform transform =
+        (TFRecordWriteSchemaTransform) provider.from(configuration);
+
+    // Create Beam row byte data
+    List<Row> rows =
+        Arrays.stream(elems)
+            .map(str -> str.getBytes(StandardCharsets.UTF_8))
+            .map(bytes -> Row.withSchema(schema).addValue(bytes).build())
+            .collect(Collectors.toList());
+
+    // Create PColleciton input beam row data on pipeline and apply transform
+    PCollection<Row> input = 
writePipeline.apply(Create.of(rows).withRowSchema(schema));
+    PCollectionRowTuple rowTuple = PCollectionRowTuple.of("input", input);
+    rowTuple.apply(transform);
+
+    // Run pipeline
+    writePipeline.run().waitUntilFinish();
+
+    assertTrue("File should exist", tmpFile.exists());
+    assertTrue("File should have content", tmpFile.length() > 0);
+
+    FileInputStream fis = new FileInputStream(tmpFile);
+    String written = 
BaseEncoding.base64().encode(ByteStreams.toByteArray(fis));
+    assertThat(written, is(in(base64)));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTrip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithEmptyData() throws IOException {
+    runTestRoundTrip(EMPTY, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithOneShards() throws IOException {
+    runTestRoundTrip(LARGE, 1, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithSuffix() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".suffix", UNCOMPRESSED, UNCOMPRESSED);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripGzip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripZlib() throws IOException {
+    runTestRoundTrip(SMALL, 10, ".tfrecords", DEFLATE, DEFLATE);

Review Comment:
   It looksl ike this is currently causing 
https://github.com/apache/beam/actions/runs/14409115108/job/40414345626?pr=34411
 to fail. From the test report:
   
   ```
   java.lang.RuntimeException: java.io.EOFException: Unexpected end of ZLIB 
input stream
        at 
org.apache.beam.runners.direct.DirectRunner$DirectPipelineResult.waitUntilFinish(DirectRunner.java:385)
        at 
org.apache.beam.runners.direct.DirectRunner$DirectPipelineResult.waitUntilFinish(DirectRunner.java:345)
        at 
org.apache.beam.runners.direct.DirectRunner.run(DirectRunner.java:218)
        at org.apache.beam.runners.direct.DirectRunner.run(DirectRunner.java:67)
        at org.apache.beam.sdk.Pipeline.run(Pipeline.java:325)
        at org.apache.beam.sdk.testing.TestPipeline.run(TestPipeline.java:404)
        at org.apache.beam.sdk.testing.TestPipeline.run(TestPipeline.java:343)
        at 
org.apache.beam.sdk.io.TFRecordSchemaTransformProviderTest.runTestRoundTrip(TFRecordSchemaTransformProviderTest.java:587)
        at 
org.apache.beam.sdk.io.TFRecordSchemaTransformProviderTest.runTestRoundTripZlib(TFRecordSchemaTransformProviderTest.java:492)
        at 
java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at 
java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at 
java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at 
org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:59)
        at 
org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12)
        at 
org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:56)
        at 
org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17)
        at 
org.junit.rules.ExpectedException$ExpectedExceptionStatement.evaluate(ExpectedException.java:258)
        at 
org.apache.beam.sdk.testing.TestPipeline$1.evaluate(TestPipeline.java:331)
        at org.junit.rules.ExternalResource$1.evaluate(ExternalResource.java:54)
        at 
org.apache.beam.sdk.testing.TestPipeline$1.evaluate(TestPipeline.java:331)
        at org.junit.runners.ParentRunner$3.evaluate(ParentRunner.java:306)
        at 
org.junit.runners.BlockJUnit4ClassRunner$1.evaluate(BlockJUnit4ClassRunner.java:100)
        at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:366)
        at 
org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:103)
        at 
org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:63)
        at org.junit.runners.ParentRunner$4.run(ParentRunner.java:331)
        at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:79)
        at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:329)
        at org.junit.runners.ParentRunner.access$100(ParentRunner.java:66)
        at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:293)
        at org.junit.runners.ParentRunner$3.evaluate(ParentRunner.java:306)
        at org.junit.runners.ParentRunner.run(ParentRunner.java:413)
        at 
org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.runTestClass(JUnitTestClassExecutor.java:112)
        at 
org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:58)
        at 
org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:40)
        at 
org.gradle.api.internal.tasks.testing.junit.AbstractJUnitTestClassProcessor.processTestClass(AbstractJUnitTestClassProcessor.java:60)
        at 
org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.processTestClass(SuiteTestClassProcessor.java:52)
        at jdk.internal.reflect.GeneratedMethodAccessor2.invoke(Unknown Source)
        at 
java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at 
org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
        at 
org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
        at 
org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
        at 
org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
        at com.sun.proxy.$Proxy5.processTestClass(Unknown Source)
        at 
org.gradle.api.internal.tasks.testing.worker.TestWorker$2.run(TestWorker.java:176)
        at 
org.gradle.api.internal.tasks.testing.worker.TestWorker.executeAndMaintainThreadName(TestWorker.java:129)
        at 
org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:100)
        at 
org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:60)
        at 
org.gradle.process.internal.worker.child.ActionExecutionWorker.execute(ActionExecutionWorker.java:56)
        at 
org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:113)
        at 
org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:65)
        at 
worker.org.gradle.process.internal.worker.GradleWorkerMain.run(GradleWorkerMain.java:69)
        at 
worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)
   Caused by: java.io.EOFException: Unexpected end of ZLIB input stream
        at 
java.base/java.util.zip.InflaterInputStream.fill(InflaterInputStream.java:245)
        at 
java.base/java.util.zip.InflaterInputStream.read(InflaterInputStream.java:159)
        at 
org.apache.beam.repackaged.core.org.apache.commons.compress.compressors.deflate.DeflateCompressorInputStream.read(DeflateCompressorInputStream.java:120)
        at 
java.base/java.nio.channels.Channels$ReadableByteChannelImpl.read(Channels.java:388)
        at 
org.apache.beam.sdk.io.TFRecordIO$TFRecordCodec.read(TFRecordIO.java:736)
        at 
org.apache.beam.sdk.io.TFRecordIO$TFRecordCodec.read(TFRecordIO.java:669)
        at 
org.apache.beam.sdk.io.TFRecordIO$TFRecordSource$TFRecordReader.readNextRecord(TFRecordIO.java:567)
        at 
org.apache.beam.sdk.io.CompressedSource$CompressedReader.readNextRecord(CompressedSource.java:453)
        at 
org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.advanceImpl(FileBasedSource.java:542)
        at 
org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.startImpl(FileBasedSource.java:537)
        at 
org.apache.beam.sdk.io.OffsetBasedSource$OffsetBasedReader.start(OffsetBasedSource.java:252)
        at 
org.apache.beam.runners.direct.BoundedReadEvaluatorFactory$BoundedReadEvaluator.processElement(BoundedReadEvaluatorFactory.java:150)
        at 
org.apache.beam.runners.direct.DirectTransformExecutor.processElements(DirectTransformExecutor.java:165)
        at 
org.apache.beam.runners.direct.DirectTransformExecutor.run(DirectTransformExecutor.java:129)
        at 
java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
        at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
        at 
java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
        at 
java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
        at java.base/java.lang.Thread.run(Thread.java:829)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to