kojiromike commented on a change in pull request #979:
URL: https://github.com/apache/avro/pull/979#discussion_r522233036
##########
File path: lang/py/avro/compatibility.py
##########
@@ -0,0 +1,318 @@
+from copy import copy
+from enum import Enum
+from typing import Dict, List, Optional, Set, cast
+
+from avro.schema import ArraySchema, EnumSchema, Field, FixedSchema,
MapSchema, NamedSchema, RecordSchema, Schema, UnionSchema
+
+
+class SchemaType(str, Enum):
+ ARRAY = "array"
+ BOOLEAN = "boolean"
+ BYTES = "bytes"
+ DOUBLE = "double"
+ ENUM = "enum"
+ FIXED = "fixed"
+ FLOAT = "float"
+ INT = "int"
+ LONG = "long"
+ MAP = "map"
+ NULL = "null"
+ RECORD = "record"
+ STRING = "string"
+ UNION = "union"
+
+
+class SchemaCompatibilityType(Enum):
+ compatible = "compatible"
+ incompatible = "incompatible"
+ recursion_in_progress = "recursion_in_progress"
+
+
+class SchemaIncompatibilityType(Enum):
+ name_mismatch = "name_mismatch"
+ fixed_size_mismatch = "fixed_size_mismatch"
+ missing_enum_symbols = "missing_enum_symbols"
+ reader_field_missing_default_value = "reader_field_missing_default_value"
+ type_mismatch = "type_mismatch"
+ missing_union_branch = "missing_union_branch"
+
+
+class AvroRuntimeException(Exception):
+ pass
+
+
+class SchemaCompatibilityResult:
+ def __init__(
+ self,
+ compatibility: SchemaCompatibilityType =
SchemaCompatibilityType.recursion_in_progress,
+ incompatibilities: List[SchemaIncompatibilityType] = None,
+ messages: Optional[Set[str]] = None,
+ locations: Optional[Set[str]] = None,
+ ):
+ self.locations = locations if locations else {"/"}
+ self.messages = messages if messages else set()
+ self.compatibility = compatibility
+ self.incompatibilities = incompatibilities or []
+
+ def merged_with(self, that):
+ that = cast(SchemaCompatibilityResult, that)
+ merged = copy(self.incompatibilities)
+ merged.extend(copy(that.incompatibilities))
+ if self.compatibility is SchemaCompatibilityType.compatible:
+ compat = that.compatibility
+ messages = that.messages
+ locations = that.locations
+ else:
+ compat = self.compatibility
+ messages = self.messages.union(that.messages)
+ locations = self.locations.union(that.locations)
+ return SchemaCompatibilityResult(
+ compatibility=compat, incompatibilities=merged, messages=messages,
locations=locations
+ )
+
+ @staticmethod
+ def compatible():
+ return SchemaCompatibilityResult(SchemaCompatibilityType.compatible)
+
+ @staticmethod
+ def incompatible(incompat_type: SchemaIncompatibilityType, message: str,
location: List[str]):
+ locations = "/".join(location)
+ if len(location) > 1:
+ locations = locations[1:]
+ ret = SchemaCompatibilityResult(
+ compatibility=SchemaCompatibilityType.incompatible,
+ incompatibilities=[incompat_type],
+ locations={locations},
+ messages={message},
+ )
+ return ret
+
+ def __str__(self):
+ return f"{self.compatibility}: {self.messages}"
+
+
+class ReaderWriter:
+ def __init__(self, reader: Schema, writer: Schema) -> None:
+ self.reader, self.writer = reader, writer
+
+ def __hash__(self) -> SchemaType.INT:
+ return id(self.reader) ^ id(self.writer)
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, ReaderWriter):
+ return False
+ return self.reader is other.reader and self.writer is other.writer
+
+
+class ReaderWriterCompatibilityChecker:
+ ROOT_REFERENCE_TOKEN = "/"
+
+ def __init__(self):
+ self.memoize_map: Dict[ReaderWriter, SchemaCompatibilityResult] = {}
+
+ def get_compatibility(
+ self,
+ reader: Schema,
+ writer: Schema,
+ reference_token: str = ROOT_REFERENCE_TOKEN,
+ location: Optional[List[str]] = None
+ ) -> SchemaCompatibilityResult:
+ if location is None:
+ location = []
+ location.append(reference_token)
+ pair = ReaderWriter(reader, writer)
+ if pair in self.memoize_map:
+ result = self.memoize_map[pair]
+ if result.compatibility is
SchemaCompatibilityType.recursion_in_progress:
+ result = SchemaCompatibilityResult.compatible()
+ else:
+ self.memoize_map[pair] = SchemaCompatibilityResult()
+ result = self.calculate_compatibility(reader, writer, location)
+ self.memoize_map[pair] = result
+ location.pop()
+ return result
+
+ # pylSchemaType.INT: disable=too-many-return-statements
+ def calculate_compatibility(
+ self,
+ reader: Schema,
+ writer: Schema,
+ location: List[str],
+ ) -> SchemaCompatibilityResult:
+ assert reader is not None
+ assert writer is not None
+ result = SchemaCompatibilityResult.compatible()
+ if reader.type == writer.type:
+ if reader.type in {
+ SchemaType.NULL, SchemaType.BOOLEAN, SchemaType.INT,
+ SchemaType.LONG, SchemaType.FLOAT, SchemaType.DOUBLE,
+ SchemaType.BYTES, SchemaType.STRING
+ }:
Review comment:
Should probably just have a module constant for the set of primitive
types.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]