This is an automated email from the ASF dual-hosted git repository.

kojiromike pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git


The following commit(s) were added to refs/heads/master by this push:
     new efb1231  AVRO-2906: Traversal validation (#936)
efb1231 is described below

commit efb12314b5acfea075a533368441c0f5a3b844d4
Author: Cris Ewing <[email protected]>
AuthorDate: Wed Aug 19 14:14:01 2020 -0700

    AVRO-2906: Traversal validation (#936)
    
    * AVRO-2906: Convert validation to a traversal-based approach
    
    Use schema-type specific iterators and validators to allow a
    breadth-first traversal of a full schema, validating each node
    as you go.
    
    The benefit of this approach is that it allows us to pin-point
    the specific part of the schema that has failed validation.
    Where previously the error message for a large schema would print
    the entire datum as well as the full schema and say "this is not
    that", this new approach will print the specific sub-schema that has
    failed in order to allow more informative errors.
    
    A second improvement is that by traversing the schema instead of
    processing it recursively, the algorithm is more efficient in use
    of system resources.  In particular for schemas that have lots of
    nested parts, this will make a difference.
    
    Make the required changes to pass tests in all supported python versions.
    
    This commit removes type hints present in the first commit in order to
    allow using the code in older Python versions.
    
    In addition:
      * the use of `str` has been replaced by the compatible `unicode`.
      * the ValidationNode namedtuple has been re-expressed in syntax available
        in all supported Python versions.
      * the use of a custom InvalidEvent exception has been replace by using
        AvroTypeException
      * all specific single-type validators have been replaced by partials of
        _validate_type with a tuple of one or more type objects.
    
    Fix typos and raise StopIteration as suggested in code review
    
    Move the responsibility for validation to the Schema class.
    
    Each schema subclass will be responsible for its own validation. This
    simplifies the structure of io.py, removes the dict lookup of validators,
    and reduces somewhat the repetition that was in io.py.
    
    Move validators to a class attribute and update method code.
    
    This makes things look a little bit cleaner than having the validators 
right in the midst of the method.
    
    Add arg spec docs to docstring for base Schema class.
    
    Clean up mistakes.
    
    * Fix a docstring to be a more accurate statement of reality.
    * Remove an unused import.
    * Remove extra blank lines.
---
 lang/py/avro/io.py     | 217 +++++++++++++++++++++++++++++++------------------
 lang/py/avro/schema.py | 101 ++++++++++++++++++++++-
 2 files changed, 236 insertions(+), 82 deletions(-)

diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py
index d420c20..a476a7d 100644
--- a/lang/py/avro/io.py
+++ b/lang/py/avro/io.py
@@ -41,6 +41,49 @@ uses the following mapping:
   * Schema floats are implemented as float.
   * Schema doubles are implemented as float.
   * Schema booleans are implemented as bool.
+
+Validation:
+
+The validation of schema is performed using breadth-first graph
+traversal. This allows validation exceptions to pinpoint the exact node
+within a complex schema that is problematic, simplifying debugging
+considerably. Because it is a traversal, it will also be less
+resource-intensive, particularly when validating schema with deep
+structures.
+
+Components
+==========
+
+Nodes
+-----
+Avro schemas contain many different schema types. Data about the schema
+types is used to validate the data in the corresponding part of a Python
+body (the object to be serialized). A node combines a given schema type
+with the corresponding Python data, as well as an optional "name" to
+identify the specific node. Names are generally the name of a schema
+(for named schema) or the name of a field (for child nodes of schema
+with named children like maps and records), or None, for schema who's
+children are not named (like Arrays).
+
+Iterators
+---------
+Iterators are generator functions that take a node and return a
+generator which will yield a node for each child datum in the data for
+the current node. If a node is of a type which has no children, then the
+default iterator will immediately exit.
+
+Validators
+----------
+Validators are used to determine if the datum for a given node is valid
+according to the given schema type. Validator functions take a node as
+an argument and return a node if the node datum passes validation. If it
+does not, the validator must return None.
+
+In most cases, the node returned is identical to the node provided (is
+in fact the same object). However, in the case of Union schema, the
+returned "valid" node will hold the schema that is represented by the
+datum contained. This allows iteration over the child nodes
+in that datum, if there are any.
 """
 
 from __future__ import absolute_import, division, print_function
@@ -48,7 +91,7 @@ from __future__ import absolute_import, division, 
print_function
 import datetime
 import json
 import struct
-import sys
+from collections import deque, namedtuple
 from decimal import Decimal, getcontext
 from struct import Struct
 
@@ -75,13 +118,6 @@ except NameError:
 # Constants
 #
 
-_DEBUG_VALIDATE_INDENT = 0
-_DEBUG_VALIDATE = False
-
-INT_MIN_VALUE = -(1 << 31)
-INT_MAX_VALUE = (1 << 31) - 1
-LONG_MIN_VALUE = -(1 << 63)
-LONG_MAX_VALUE = (1 << 63) - 1
 
 # TODO(hammer): shouldn't ! be < for little-endian (according to spec?)
 STRUCT_FLOAT = Struct('<f')           # big-endian float
@@ -96,79 +132,103 @@ STRUCT_SIGNED_LONG = Struct('>q')     # big-endian signed 
long
 #
 
 
-def _is_timezone_aware_datetime(dt):
-    return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None
-
-
-_valid = {
-    'null': lambda s, d: d is None,
-    'boolean': lambda s, d: isinstance(d, bool),
-    'string': lambda s, d: isinstance(d, unicode),
-    'bytes': lambda s, d: ((isinstance(d, bytes)) or
-                           (isinstance(d, Decimal) and
-                            getattr(s, 'logical_type', None) == 
constants.DECIMAL)),
-    'int': lambda s, d: ((isinstance(d, (int, long))) and (INT_MIN_VALUE <= d 
<= INT_MAX_VALUE) or
-                         (isinstance(d, datetime.date) and
-                          getattr(s, 'logical_type', None) == constants.DATE) 
or
-                         (isinstance(d, datetime.time) and
-                          getattr(s, 'logical_type', None) == 
constants.TIME_MILLIS)),
-    'long': lambda s, d: ((isinstance(d, (int, long))) and (LONG_MIN_VALUE <= 
d <= LONG_MAX_VALUE) or
-                          (isinstance(d, datetime.time) and
-                           getattr(s, 'logical_type', None) == 
constants.TIME_MICROS) or
-                          (isinstance(d, datetime.date) and
-                           _is_timezone_aware_datetime(d) and
-                           getattr(s, 'logical_type', None) in 
(constants.TIMESTAMP_MILLIS,
-                                                                
constants.TIMESTAMP_MICROS))),
-    'float': lambda s, d: isinstance(d, (int, long, float)),
-    'fixed': lambda s, d: ((isinstance(d, bytes) and len(d) == s.size) or
-                           (isinstance(d, Decimal) and
-                            getattr(s, 'logical_type', None) == 
constants.DECIMAL)),
-    'enum': lambda s, d: d in s.symbols,
-
-    'array': lambda s, d: isinstance(d, list) and all(validate(s.items, item) 
for item in d),
-    'map': lambda s, d: (isinstance(d, dict) and all(isinstance(key, unicode) 
for key in d) and
-                         all(validate(s.values, value) for value in 
d.values())),
-    'union': lambda s, d: any(validate(branch, d) for branch in s.schemas),
-    'record': lambda s, d: (isinstance(d, dict) and
-                            all(validate(f.type, d.get(f.name)) for f in 
s.fields) and
-                            {f.name for f in s.fields}.issuperset(d.keys())),
-}
-_valid['double'] = _valid['float']
-_valid['error_union'] = _valid['union']
-_valid['error'] = _valid['request'] = _valid['record']
+ValidationNode = namedtuple("ValidationNode", ['schema', 'datum', 'name'])
+
 
+def validate(expected_schema, datum, raise_on_error=False):
+    """Return True if the provided datum is valid for the expected schema
 
-def validate(expected_schema, datum):
-    """Determines if a python datum is an instance of a schema.
+    If raise_on_error is passed and True, then raise a validation error
+    with specific information about the error encountered in validation.
 
-    Args:
-      expected_schema: Schema to validate against.
-      datum: Datum to validate.
-    Returns:
-      True if the datum is an instance of the schema.
+    :param expected_schema: An avro schema type object representing the schema 
against
+                            which the datum will be validated.
+    :param datum: The datum to be validated, A python dictionary or some 
supported type
+    :param raise_on_error: True if a AvroTypeException should be raised 
immediately when a
+                           validation problem is encountered.
+    :raises: AvroTypeException if datum is invalid and raise_on_error is True
+    :returns: True if datum is valid for expected_schema, False if not.
     """
-    global _DEBUG_VALIDATE_INDENT
-    global _DEBUG_VALIDATE
-    expected_type = expected_schema.type
-    name = getattr(expected_schema, 'name', '')
-    if name:
-        name = ' ' + name
-    if expected_type in ('array', 'map', 'union', 'record'):
-        if _DEBUG_VALIDATE:
-            print('{!s}{!s}{!s}: {!s} {{'.format(' ' * _DEBUG_VALIDATE_INDENT, 
expected_schema.type, name, type(datum).__name__), file=sys.stderr)
-            _DEBUG_VALIDATE_INDENT += 2
-            if datum is not None and not datum:
-                print('{!s}<Empty>'.format(' ' * _DEBUG_VALIDATE_INDENT), 
file=sys.stderr)
-        result = _valid[expected_type](expected_schema, datum)
-        if _DEBUG_VALIDATE:
-            _DEBUG_VALIDATE_INDENT -= 2
-            print('{!s}}} -> {!s}'.format(' ' * _DEBUG_VALIDATE_INDENT, 
result), file=sys.stderr)
-    else:
-        result = _valid[expected_type](expected_schema, datum)
-        if _DEBUG_VALIDATE:
-            print('{!s}{!s}{!s}: {!s} -> {!s}'.format(' ' * 
_DEBUG_VALIDATE_INDENT,
-                  expected_schema.type, name, type(datum).__name__, result), 
file=sys.stderr)
-    return result
+    # use a FIFO queue to process schema nodes breadth first.
+    nodes = deque()
+    nodes.append(ValidationNode(expected_schema, datum, 
getattr(expected_schema, "name", None)))
+
+    while nodes:
+        current_node = nodes.popleft()
+
+        # _validate_node returns the node for iteration if it is valid. Or it 
returns None
+        # if current_node.schema.type in {'array', 'map', 'record'}:
+        validated_schema = current_node.schema.validate(current_node.datum)
+        if validated_schema:
+            valid_node = ValidationNode(validated_schema, current_node.datum, 
current_node.name)
+        else:
+            valid_node = None
+        # else:
+        #     valid_node = _validate_node(current_node)
+
+        if valid_node is not None:
+            # if there are children of this node to append, do so.
+            for child_node in _iterate_node(valid_node):
+                nodes.append(child_node)
+        else:
+            # the current node was not valid.
+            if raise_on_error:
+                raise avro.errors.AvroTypeException(current_node.schema, 
current_node.datum)
+            else:
+                # preserve the prior validation behavior of returning false 
when there are problems.
+                return False
+
+    return True
+
+
+def _iterate_node(node):
+    for item in _ITERATORS.get(node.schema.type, _default_iterator)(node):
+        yield ValidationNode(*item)
+
+
+#############
+# Iteration #
+#############
+
+def _default_iterator(_):
+    """Immediately raise StopIteration.
+
+    This exists to prevent problems with iteration over unsupported container 
types.
+
+    More efficient approaches are not possible due to support for Python 2.7
+    """
+    for item in ():
+        yield item
+
+
+def _record_iterator(node):
+    """Yield each child node of the provided record node."""
+    schema, datum, name = node
+    for field in schema.fields:
+        yield ValidationNode(field.type, datum.get(field.name), field.name)  # 
type: ignore
+
+
+def _array_iterator(node):
+    """Yield each child node of the provided array node."""
+    schema, datum, name = node
+    for item in datum:  # type: ignore
+        yield ValidationNode(schema.items, item, name)
+
+
+def _map_iterator(node):
+    """Yield each child node of the provided map node."""
+    schema, datum, _ = node
+    child_schema = schema.values
+    for child_name, child_datum in datum.items():  # type: ignore
+        yield ValidationNode(child_schema, child_datum, child_name)
+
+
+_ITERATORS = {
+    'record': _record_iterator,
+    'array': _array_iterator,
+    'map': _map_iterator,
+}
+_ITERATORS['error'] = _ITERATORS['request'] = _ITERATORS['record']
 
 
 #
@@ -954,8 +1014,7 @@ class DatumWriter(object):
                               set_writers_schema)
 
     def write(self, datum, encoder):
-        if not validate(self.writers_schema, datum):
-            raise avro.errors.AvroTypeException(self.writers_schema, datum)
+        validate(self.writers_schema, datum, raise_on_error=True)
         self.write_data(self.writers_schema, datum, encoder)
 
     def write_data(self, writers_schema, datum, encoder):
diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py
index de66894..5aa9977 100644
--- a/lang/py/avro/schema.py
+++ b/lang/py/avro/schema.py
@@ -42,11 +42,13 @@ A schema may be one of:
 
 from __future__ import absolute_import, division, print_function
 
+import datetime
 import json
 import math
 import re
 import sys
 import warnings
+from decimal import Decimal
 
 import avro.errors
 from avro import constants
@@ -122,6 +124,11 @@ VALID_FIELD_SORT_ORDERS = (
     'ignore',
 )
 
+INT_MIN_VALUE = -(1 << 31)
+INT_MAX_VALUE = (1 << 31) - 1
+LONG_MIN_VALUE = -(1 << 63)
+LONG_MAX_VALUE = (1 << 63) - 1
+
 
 def validate_basename(basename):
     """Raise InvalidName if the given basename is not a valid name."""
@@ -131,11 +138,15 @@ def validate_basename(basename):
                 "does not match the pattern {!s}".format(
                     basename, _BASE_NAME_PATTERN.pattern))
 
+
+def _is_timezone_aware_datetime(dt):
+    return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None
+
+
 #
 # Base Classes
 #
 
-
 class Schema(object):
     """Base class for all Schema classes."""
     _props = None
@@ -202,6 +213,18 @@ class Schema(object):
         """
         raise NotImplemented("Must be implemented by subclasses.")
 
+    def validate(self, datum):
+        """Returns the appropriate schema object if datum is valid for that 
schema, else None.
+
+        Validation concerns only shape and type of data in the top level of 
the current schema.
+        In most cases, the returned schema object will be self. However, for 
UnionSchema objects,
+        the returned Schema will be the first branch schema for which 
validation passes.
+
+        @arg datum: The data to be checked for validity according to this 
schema
+        @return Optional[Schema]
+        """
+        raise Exception("Must be implemented by subclasses.")
+
 
 class Name(object):
     """Class to describe Avro name."""
@@ -475,6 +498,17 @@ class Field(object):
 class PrimitiveSchema(Schema):
     """Valid primitive types are in PRIMITIVE_TYPES."""
 
+    _validators = {
+        'null': lambda x: x is None,
+        'boolean': lambda x: isinstance(x, bool),
+        'string': lambda x: isinstance(x, unicode),
+        'bytes': lambda x: isinstance(x, bytes),
+        'int': lambda x: isinstance(x, int) and INT_MIN_VALUE <= x <= 
INT_MAX_VALUE,
+        'long': lambda x: isinstance(x, int) and LONG_MIN_VALUE <= x <= 
LONG_MAX_VALUE,
+        'float': lambda x: isinstance(x, (int, float)),
+        'double': lambda x: isinstance(x, (int, float)),
+    }
+
     def __init__(self, type, other_props=None):
         # Ensure valid ctor args
         if type not in PRIMITIVE_TYPES:
@@ -503,6 +537,15 @@ class PrimitiveSchema(Schema):
         else:
             return self.props
 
+    def validate(self, datum):
+        """Return self if datum is a valid representation of this type of 
primitive schema, else None
+
+        @arg datum: The data to be checked for validity according to this 
schema
+        @return Schema object or None
+        """
+        validator = self._validators.get(self.type, lambda x: False)
+        return self if validator(datum) else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -525,6 +568,10 @@ class BytesDecimalSchema(PrimitiveSchema, 
DecimalLogicalSchema):
     def to_json(self, names=None):
         return self.props
 
+    def validate(self, datum):
+        """Return self if datum is a Decimal object, else None."""
+        return self if isinstance(datum, Decimal) else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -565,6 +612,10 @@ class FixedSchema(NamedSchema):
             names.names[self.fullname] = self
             return names.prune_namespace(self.props)
 
+    def validate(self, datum):
+        """Return self if datum is a valid representation of this schema, else 
None."""
+        return self if isinstance(datum, bytes) and len(datum) == self.size 
else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -588,6 +639,10 @@ class FixedDecimalSchema(FixedSchema, 
DecimalLogicalSchema):
     def to_json(self, names=None):
         return self.props
 
+    def validate(self, datum):
+        """Return self if datum is a Decimal object, else None."""
+        return self if isinstance(datum, Decimal) else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -637,6 +692,10 @@ class EnumSchema(NamedSchema):
             names.names[self.fullname] = self
             return names.prune_namespace(self.props)
 
+    def validate(self, datum):
+        """Return self if datum is a valid member of this Enum, else None."""
+        return self if datum in self.symbols else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -656,7 +715,7 @@ class ArraySchema(Schema):
         else:
             try:
                 items_schema = make_avsc_object(items, names)
-            except SchemaParseException as e:
+            except avro.errors.SchemaParseException as e:
                 fail_msg = 'Items schema (%s) not a valid Avro schema: %s 
(known names: %s)' % (items, e, names.names.keys())
                 raise avro.errors.SchemaParseException(fail_msg)
 
@@ -681,6 +740,10 @@ class ArraySchema(Schema):
         to_dump['items'] = item_schema.to_json(names)
         return to_dump
 
+    def validate(self, datum):
+        """Return self if datum is a valid representation of this schema, else 
None."""
+        return self if isinstance(datum, list) else None
+
     def __eq__(self, that):
         to_cmp = json.loads(str(self))
         return to_cmp == json.loads(str(that))
@@ -697,7 +760,7 @@ class MapSchema(Schema):
         else:
             try:
                 values_schema = make_avsc_object(values, names)
-            except SchemaParseException:
+            except avro.errors.SchemaParseException:
                 raise
             except Exception:
                 raise avro.errors.SchemaParseException('Values schema is not a 
valid Avro schema.')
@@ -722,6 +785,10 @@ class MapSchema(Schema):
         to_dump['values'] = self.get_prop('values').to_json(names)
         return to_dump
 
+    def validate(self, datum):
+        """Return self if datum is a valid representation of this schema, else 
None."""
+        return self if isinstance(datum, dict) and all(isinstance(key, 
unicode) for key in datum) else None
+
     def __eq__(self, that):
         to_cmp = json.loads(str(self))
         return to_cmp == json.loads(str(that))
@@ -780,6 +847,12 @@ class UnionSchema(Schema):
             to_dump.append(schema.to_json(names))
         return to_dump
 
+    def validate(self, datum):
+        """Return the first branch schema of which datum is a valid example, 
else None."""
+        for branch in self.schemas:
+            if branch.validate(datum) is not None:
+                return branch
+
     def __eq__(self, that):
         to_cmp = json.loads(str(self))
         return to_cmp == json.loads(str(that))
@@ -901,6 +974,10 @@ class RecordSchema(NamedSchema):
         to_dump['fields'] = [f.to_json(names) for f in self.fields]
         return to_dump
 
+    def validate(self, datum):
+        """Return self if datum is a valid representation of this schema, else 
None"""
+        return self if isinstance(datum, dict) and {f.name for f in 
self.fields}.issuperset(datum.keys()) else None
+
     def __eq__(self, that):
         to_cmp = json.loads(str(self))
         return to_cmp == json.loads(str(that))
@@ -918,6 +995,10 @@ class DateSchema(LogicalSchema, PrimitiveSchema):
     def to_json(self, names=None):
         return self.props
 
+    def validate(self, datum):
+        """Return self if datum is a valid date object, else None."""
+        return self if isinstance(datum, datetime.date) else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -934,6 +1015,10 @@ class TimeMillisSchema(LogicalSchema, PrimitiveSchema):
     def to_json(self, names=None):
         return self.props
 
+    def validate(self, datum):
+        """Return self if datum is a valid representation of this schema, else 
None."""
+        return self if isinstance(datum, datetime.time) else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -950,6 +1035,10 @@ class TimeMicrosSchema(LogicalSchema, PrimitiveSchema):
     def to_json(self, names=None):
         return self.props
 
+    def validate(self, datum):
+        """Return self if datum is a valid representation of this schema, else 
None."""
+        return self if isinstance(datum, datetime.time) else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -966,6 +1055,9 @@ class TimestampMillisSchema(LogicalSchema, 
PrimitiveSchema):
     def to_json(self, names=None):
         return self.props
 
+    def validate(self, datum):
+        return self if isinstance(datum, datetime.datetime) and 
_is_timezone_aware_datetime(datum) else None
+
     def __eq__(self, that):
         return self.props == that.props
 
@@ -982,6 +1074,9 @@ class TimestampMicrosSchema(LogicalSchema, 
PrimitiveSchema):
     def to_json(self, names=None):
         return self.props
 
+    def validate(self, datum):
+        return self if isinstance(datum, datetime.datetime) and 
_is_timezone_aware_datetime(datum) else None
+
     def __eq__(self, that):
         return self.props == that.props
 

Reply via email to