This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ef576f09f8 [BEAM-10265] Display error message if trying to infer 
recursive schema from POJO/Proto objects (#17477)
2ef576f09f8 is described below

commit 2ef576f09f80be5f3d7776c88531b72d473f89a6
Author: Andrei Gurau <[email protected]>
AuthorDate: Mon May 2 17:31:01 2022 -0400

    [BEAM-10265] Display error message if trying to infer recursive schema from 
POJO/Proto objects (#17477)
    
    * [BEAM-10265]: Prevent circular or self-referencing schemas from being 
created
    
    * Added protobuf schema validation for circular references
    
    * added comment for hashset
    
    * removed unused imports
    
    * moved logic statements around for detecting circular references
    
    * code styling issues fixed
    
    * fixed logic for where HashSet gets instantiated
    
    * removed print statements
    
    * made sure python extension folder compiled successfully
    
    * created private helper methods and fixed unit test nits for 
sdks/java/core module
    
    * applied PR comments to protobuf module
    
    * converted to using HashMap, to allow duplicate schemas to not be 
reinferred
    
    * fixed comments
    
    * fixed spacing
    
    * fixed another comment
    
    * fixed javadocs
    
    * used generic map type, removed javadoc, and made method synchronized
    
    * removed nullness warning
    
    * added unit test for getSchema using an Empty Schema
    
    * added Nullable field instead of using EMPTY_SCHEMA
---
 .../sdk/schemas/utils/StaticSchemaInference.java   |  51 ++++++--
 .../beam/sdk/schemas/JavaFieldSchemaTest.java      |  44 +++++++
 .../beam/sdk/schemas/utils/POJOUtilsTest.java      |   9 ++
 .../apache/beam/sdk/schemas/utils/TestPOJOs.java   | 129 +++++++++++++++++++++
 .../extensions/protobuf/ProtoSchemaTranslator.java |  42 +++++--
 .../protobuf/ProtoSchemaTranslatorTest.java        |  51 ++++++++
 .../sdk/extensions/protobuf/TestProtoSchemas.java  |   4 +
 .../src/test/proto/proto3_schema_messages.proto    |  15 +++
 8 files changed, 328 insertions(+), 17 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
index c681aadc353..103405037be 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
@@ -24,6 +24,7 @@ import java.math.BigDecimal;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
@@ -88,21 +89,47 @@ public class StaticSchemaInference {
    */
   public static Schema schemaFromClass(
       Class<?> clazz, FieldValueTypeSupplier fieldValueTypeSupplier) {
+    return schemaFromClass(clazz, fieldValueTypeSupplier, new HashMap<Class, 
Schema>());
+  }
+
+  private static Schema schemaFromClass(
+      Class<?> clazz,
+      FieldValueTypeSupplier fieldValueTypeSupplier,
+      Map<Class, Schema> alreadyVisitedSchemas) {
+    if (alreadyVisitedSchemas.containsKey(clazz)) {
+      Schema existingSchema = alreadyVisitedSchemas.get(clazz);
+      if (existingSchema == null) {
+        throw new IllegalArgumentException(
+            "Cannot infer schema with a circular reference. Class: " + 
clazz.getTypeName());
+      }
+      return existingSchema;
+    }
+    alreadyVisitedSchemas.put(clazz, null);
     Schema.Builder builder = Schema.builder();
     for (FieldValueTypeInformation type : fieldValueTypeSupplier.get(clazz)) {
-      Schema.FieldType fieldType = fieldFromType(type.getType(), 
fieldValueTypeSupplier);
+      Schema.FieldType fieldType =
+          fieldFromType(type.getType(), fieldValueTypeSupplier, 
alreadyVisitedSchemas);
       if (type.isNullable()) {
         builder.addNullableField(type.getName(), fieldType);
       } else {
         builder.addField(type.getName(), fieldType);
       }
     }
-    return builder.build();
+    Schema generatedSchema = builder.build();
+    alreadyVisitedSchemas.replace(clazz, generatedSchema);
+    return generatedSchema;
   }
 
   /** Map a Java field type to a Beam Schema FieldType. */
   public static Schema.FieldType fieldFromType(
       TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) {
+    return fieldFromType(type, fieldValueTypeSupplier, new HashMap<Class, 
Schema>());
+  }
+
+  private static Schema.FieldType fieldFromType(
+      TypeDescriptor type,
+      FieldValueTypeSupplier fieldValueTypeSupplier,
+      Map<Class, Schema> alreadyVisitedSchemas) {
     FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType());
     if (primitiveType != null) {
       return primitiveType;
@@ -122,7 +149,8 @@ public class StaticSchemaInference {
         return FieldType.BYTES;
       } else {
         // Otherwise this is an array type.
-        return FieldType.array(fieldFromType(component, 
fieldValueTypeSupplier));
+        return FieldType.array(
+            fieldFromType(component, fieldValueTypeSupplier, 
alreadyVisitedSchemas));
       }
     } else if (type.isSubtypeOf(TypeDescriptor.of(Map.class))) {
       TypeDescriptor<Collection<?>> map = type.getSupertype(Map.class);
@@ -130,8 +158,12 @@ public class StaticSchemaInference {
         ParameterizedType ptype = (ParameterizedType) map.getType();
         java.lang.reflect.Type[] params = ptype.getActualTypeArguments();
         checkArgument(params.length == 2);
-        FieldType keyType = fieldFromType(TypeDescriptor.of(params[0]), 
fieldValueTypeSupplier);
-        FieldType valueType = fieldFromType(TypeDescriptor.of(params[1]), 
fieldValueTypeSupplier);
+        FieldType keyType =
+            fieldFromType(
+                TypeDescriptor.of(params[0]), fieldValueTypeSupplier, 
alreadyVisitedSchemas);
+        FieldType valueType =
+            fieldFromType(
+                TypeDescriptor.of(params[1]), fieldValueTypeSupplier, 
alreadyVisitedSchemas);
         checkArgument(
             keyType.getTypeName().isPrimitiveType(),
             "Only primitive types can be map keys. type: " + 
keyType.getTypeName());
@@ -154,16 +186,19 @@ public class StaticSchemaInference {
         // TODO: should this be AbstractCollection?
         if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) {
           return FieldType.array(
-              fieldFromType(TypeDescriptor.of(params[0]), 
fieldValueTypeSupplier));
+              fieldFromType(
+                  TypeDescriptor.of(params[0]), fieldValueTypeSupplier, 
alreadyVisitedSchemas));
         } else {
           return FieldType.iterable(
-              fieldFromType(TypeDescriptor.of(params[0]), 
fieldValueTypeSupplier));
+              fieldFromType(
+                  TypeDescriptor.of(params[0]), fieldValueTypeSupplier, 
alreadyVisitedSchemas));
         }
       } else {
         throw new RuntimeException("Cannot infer schema from unparameterized 
collection.");
       }
     } else {
-      return FieldType.row(schemaFromClass(type.getRawType(), 
fieldValueTypeSupplier));
+      return FieldType.row(
+          schemaFromClass(type.getRawType(), fieldValueTypeSupplier, 
alreadyVisitedSchemas));
     }
   }
 }
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java
index 90a4c2e4a9f..67f99b0683c 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java
@@ -53,6 +53,7 @@ import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType;
 import org.apache.beam.sdk.schemas.utils.SchemaTestUtils;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.AnnotatedSimplePojo;
+import org.apache.beam.sdk.schemas.utils.TestPOJOs.FirstCircularNestedPOJO;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedArrayPOJO;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedArraysPOJO;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedMapPOJO;
@@ -67,6 +68,7 @@ import 
org.apache.beam.sdk.schemas.utils.TestPOJOs.PojoWithEnum.Color;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.PojoWithIterable;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.PojoWithNestedArray;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.PrimitiveArrayPOJO;
+import org.apache.beam.sdk.schemas.utils.TestPOJOs.SelfNestedPOJO;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.SimplePOJO;
 import org.apache.beam.sdk.schemas.utils.TestPOJOs.StaticCreationSimplePojo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
@@ -728,4 +730,46 @@ public class JavaFieldSchemaTest {
         thrown.getMessage(),
         containsString("zero-argument constructor"));
   }
+
+  @Test
+  public void testSelfNestedPOJOThrows() throws NoSuchSchemaException {
+    SchemaRegistry registry = SchemaRegistry.createDefault();
+
+    IllegalArgumentException thrown =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> {
+              registry.getSchema(SelfNestedPOJO.class);
+            });
+
+    assertThat(
+        "Message should suggest not using a circular schema reference.",
+        thrown.getMessage(),
+        containsString("circular reference"));
+    assertThat(
+        "Message should suggest which class has circular schema reference.",
+        thrown.getMessage(),
+        containsString("TestPOJOs$SelfNestedPOJO"));
+  }
+
+  @Test
+  public void testCircularNestedPOJOThrows() throws NoSuchSchemaException {
+    SchemaRegistry registry = SchemaRegistry.createDefault();
+
+    IllegalArgumentException thrown =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> {
+              registry.getSchema(FirstCircularNestedPOJO.class);
+            });
+
+    assertThat(
+        "Message should suggest not using a circular schema reference.",
+        thrown.getMessage(),
+        containsString("circular reference"));
+    assertThat(
+        "Message should suggest which class has circular schema reference.",
+        thrown.getMessage(),
+        containsString("TestPOJOs$FirstCircularNestedPOJO"));
+  }
 }
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java
index cff25bff1f2..67f372644f0 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java
@@ -21,6 +21,7 @@ import static 
org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_ARRAY_POJO_SCHE
 import static 
org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_COLLECTION_POJO_SCHEMA;
 import static 
org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_MAP_POJO_SCHEMA;
 import static org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_POJO_SCHEMA;
+import static 
org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_POJO_WITH_SIMPLE_POJO_SCHEMA;
 import static 
org.apache.beam.sdk.schemas.utils.TestPOJOs.POJO_WITH_BOXED_FIELDS_SCHEMA;
 import static 
org.apache.beam.sdk.schemas.utils.TestPOJOs.POJO_WITH_BYTE_ARRAY_SCHEMA;
 import static 
org.apache.beam.sdk.schemas.utils.TestPOJOs.PRIMITIVE_ARRAY_POJO_SCHEMA;
@@ -85,6 +86,14 @@ public class POJOUtilsTest {
     SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_SCHEMA, schema);
   }
 
+  @Test
+  public void testNestedPOJOWithSimplePOJO() {
+    Schema schema =
+        POJOUtils.schemaFromPojoClass(
+            TestPOJOs.NestedPOJOWithSimplePOJO.class, 
JavaFieldTypeSupplier.INSTANCE);
+    
SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_WITH_SIMPLE_POJO_SCHEMA, 
schema);
+  }
+
   @Test
   public void testPrimitiveArray() {
     Schema schema =
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java
index 5e32519afc1..0e1b6b07dc0 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java
@@ -1106,4 +1106,133 @@ public class TestPOJOs {
       this.user = user;
     }
   }
+
+  /** A POJO containing itself as a nested class. * */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class SelfNestedPOJO {
+    public SelfNestedPOJO nested;
+
+    public SelfNestedPOJO(SelfNestedPOJO nested) {
+      this.nested = nested;
+    }
+
+    public SelfNestedPOJO() {}
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      SelfNestedPOJO that = (SelfNestedPOJO) o;
+      return Objects.equals(nested, that.nested);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(nested);
+    }
+  }
+
+  /**
+   * A POJO containing a circular reference back to itself through the 
accompanying POJO below. *
+   */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class FirstCircularNestedPOJO {
+    public SecondCircularNestedPOJO nested;
+
+    public FirstCircularNestedPOJO(SecondCircularNestedPOJO nested) {
+      this.nested = nested;
+    }
+
+    public FirstCircularNestedPOJO() {}
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      FirstCircularNestedPOJO that = (FirstCircularNestedPOJO) o;
+      return Objects.equals(nested, that.nested);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(nested);
+    }
+  }
+
+  /**
+   * A POJO containing a circular reference back to itself through the 
accompanying POJO below. *
+   */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class SecondCircularNestedPOJO {
+    public FirstCircularNestedPOJO nested;
+
+    public SecondCircularNestedPOJO(FirstCircularNestedPOJO nested) {
+      this.nested = nested;
+    }
+
+    public SecondCircularNestedPOJO() {}
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      SecondCircularNestedPOJO that = (SecondCircularNestedPOJO) o;
+      return Objects.equals(nested, that.nested);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(nested);
+    }
+  }
+
+  /** A POJO containing a nested class, along with a SimplePOJO. * */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class NestedPOJOWithSimplePOJO {
+    public NestedPOJO nested;
+    public SimplePOJO simplePojo;
+
+    public NestedPOJOWithSimplePOJO(NestedPOJO nested, SimplePOJO simplePojo) {
+      this.nested = nested;
+      this.simplePojo = simplePojo;
+    }
+
+    public NestedPOJOWithSimplePOJO() {}
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      NestedPOJOWithSimplePOJO that = (NestedPOJOWithSimplePOJO) o;
+      return Objects.equals(nested, that.nested);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(nested);
+    }
+  }
+
+  /** The schema for {@link NestedPOJOWithSimplePOJO}. * */
+  public static final Schema NESTED_POJO_WITH_SIMPLE_POJO_SCHEMA =
+      Schema.builder()
+          .addRowField("nested", NESTED_POJO_SCHEMA)
+          .addRowField("simplePojo", SIMPLE_POJO_SCHEMA)
+          .build();
 }
diff --git 
a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java
 
b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java
index ef46b59ced9..84f890d065a 100644
--- 
a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java
+++ 
b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java
@@ -24,6 +24,7 @@ import com.google.protobuf.Descriptors.EnumValueDescriptor;
 import com.google.protobuf.Descriptors.FieldDescriptor;
 import com.google.protobuf.Descriptors.OneofDescriptor;
 import com.google.protobuf.Message;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -45,6 +46,7 @@ import org.apache.beam.sdk.schemas.logicaltypes.OneOfType;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
  * This class provides utilities for inferring a Beam schema from a protocol 
buffer.
@@ -139,6 +141,13 @@ class ProtoSchemaTranslator {
   /** Option prefix for options on fields. */
   public static final String SCHEMA_OPTION_FIELD_PREFIX = 
"beam:option:proto:field:";
 
+  /**
+   * A HashMap containing the sentinel values (null values) of schemas in the 
process of being
+   * inferenced, to prevent circular references.
+   */
+  private static Map<Descriptors.Descriptor, @Nullable Schema> 
alreadyVisitedSchemas =
+      new HashMap<Descriptors.Descriptor, @Nullable Schema>();
+
   /** Attach a proto field number to a type. */
   static Field withFieldNumber(Field field, int number) {
     return field.withOptions(
@@ -150,12 +159,22 @@ class ProtoSchemaTranslator {
     return field.getOptions().getValue(SCHEMA_OPTION_META_NUMBER);
   }
 
-  /** Return a Beam scheam representing a proto class. */
+  /** Return a Beam schema representing a proto class. */
   static Schema getSchema(Class<? extends Message> clazz) {
     return getSchema(ProtobufUtil.getDescriptorForClass(clazz));
   }
 
-  static Schema getSchema(Descriptors.Descriptor descriptor) {
+  static synchronized Schema getSchema(Descriptors.Descriptor descriptor) {
+    if (alreadyVisitedSchemas.containsKey(descriptor)) {
+      @Nullable Schema existingSchema = alreadyVisitedSchemas.get(descriptor);
+      if (existingSchema == null) {
+        throw new IllegalArgumentException(
+            "Cannot infer schema with a circular reference. Proto Field: "
+                + descriptor.getFullName());
+      }
+      return existingSchema;
+    }
+    alreadyVisitedSchemas.put(descriptor, null);
     /* OneOfComponentFields refers to the field number in the protobuf where 
the component subfields
      * are. This is needed to prevent double inclusion of the component 
fields.*/
     Set<Integer> oneOfComponentFields = Sets.newHashSet();
@@ -199,13 +218,18 @@ class ProtoSchemaTranslator {
         }
       }
     }
-    return Schema.builder()
-        .addFields(fields)
-        .setOptions(
-            getSchemaOptions(descriptor)
-                .setOption(
-                    SCHEMA_OPTION_META_TYPE_NAME, FieldType.STRING, 
descriptor.getFullName()))
-        .build();
+
+    Schema generatedSchema =
+        Schema.builder()
+            .addFields(fields)
+            .setOptions(
+                getSchemaOptions(descriptor)
+                    .setOption(
+                        SCHEMA_OPTION_META_TYPE_NAME, FieldType.STRING, 
descriptor.getFullName()))
+            .build();
+    alreadyVisitedSchemas.put(descriptor, generatedSchema);
+
+    return generatedSchema;
   }
 
   private static FieldType beamFieldTypeFromProtoField(
diff --git 
a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java
 
b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java
index f478a9403bc..44ddfc5ab3f 100644
--- 
a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java
+++ 
b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java
@@ -17,13 +17,17 @@
  */
 package org.apache.beam.sdk.extensions.protobuf;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.containsString;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
 
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TypeDescriptor;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -111,6 +115,53 @@ public class ProtoSchemaTranslatorTest {
         
ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.WktMessage.class));
   }
 
+  @Test
+  public void testEmptySchema() {
+    assertEquals(
+        TestProtoSchemas.EMPTY_SCHEMA,
+        ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.Empty.class));
+  }
+
+  @Test
+  public void testSelfNestedProtoThrows() {
+    IllegalArgumentException thrown =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> {
+              new ProtoMessageSchema()
+                  
.schemaFor(TypeDescriptor.of(Proto3SchemaMessages.SelfNested.class));
+            });
+
+    assertThat(
+        "Message should suggest not using a circular schema reference.",
+        thrown.getMessage(),
+        containsString("circular reference"));
+    assertThat(
+        "Message should suggest which class has circular schema reference.",
+        thrown.getMessage(),
+        containsString("proto3_schema_messages.SelfNested"));
+  }
+
+  @Test
+  public void testCircularNestedProtoThrows() {
+    IllegalArgumentException thrown =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> {
+              new ProtoMessageSchema()
+                  
.schemaFor(TypeDescriptor.of(Proto3SchemaMessages.FirstCircularNested.class));
+            });
+
+    assertThat(
+        "Message should suggest not using a circular schema reference.",
+        thrown.getMessage(),
+        containsString("circular reference"));
+    assertThat(
+        "Message should suggest which class has circular schema reference.",
+        thrown.getMessage(),
+        containsString("proto3_schema_messages.FirstCircularNested"));
+  }
+
   @Test
   public void testOptionalNestedSchema() {
     assertEquals(
diff --git 
a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java
 
b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java
index 40055d05ec6..8ed1ba8cfa7 100644
--- 
a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java
+++ 
b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java
@@ -668,4 +668,8 @@ class TestProtoSchemas {
   // A sample instance of the proto.
   static final RequiredNested REQUIRED_NESTED =
       RequiredNested.newBuilder().setNested(REQUIRED_PRIMITIVE_PROTO).build();
+
+  // The schema for the Empty proto.
+  static final Schema EMPTY_SCHEMA =
+      
Schema.builder().setOptions(withTypeName("proto3_schema_messages.Empty")).build();
 }
diff --git 
a/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto 
b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto
index 02748649506..946cd99320e 100644
--- a/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto
+++ b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto
@@ -194,4 +194,19 @@ message OptionMessage {
         (proto3_schema_options.field_option_repeated_message) = {
             single_int64: 88
         }];
+}
+
+message SelfNested {
+  SelfNested nested = 1;
+}
+
+message FirstCircularNested {
+  SecondCircularNested nested = 1;
+}
+
+message SecondCircularNested {
+  FirstCircularNested nested = 1;
+}
+
+message Empty {
 }
\ No newline at end of file

Reply via email to