Author: cutting
Date: Wed Jul 1 16:53:37 2009
New Revision: 790264
URL: http://svn.apache.org/viewvc?rev=790264&view=rev
Log:
AVRO-28. Add Python support for default values. Contributed by sharad.
Modified:
hadoop/avro/trunk/CHANGES.txt
hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java
hadoop/avro/trunk/src/py/avro/genericio.py
hadoop/avro/trunk/src/py/avro/io.py
hadoop/avro/trunk/src/py/avro/protocol.py
hadoop/avro/trunk/src/py/avro/reflectio.py
hadoop/avro/trunk/src/py/avro/reflectipc.py
hadoop/avro/trunk/src/py/avro/schema.py
hadoop/avro/trunk/src/test/py/testio.py
hadoop/avro/trunk/src/test/py/testioreflect.py
Modified: hadoop/avro/trunk/CHANGES.txt
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/CHANGES.txt?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/CHANGES.txt (original)
+++ hadoop/avro/trunk/CHANGES.txt Wed Jul 1 16:53:37 2009
@@ -46,6 +46,8 @@
AVRO-67. Add per-call RPC metadata to spec. (George Porter via cutting)
+ AVRO-28. Add Python support for default values. (sharad via cutting)
+
IMPROVEMENTS
AVRO-11. Re-implement specific and reflect datum readers and
Modified:
hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java
(original)
+++ hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java
Wed Jul 1 16:53:37 2009
@@ -89,6 +89,7 @@
if (branch.getType() == actual.getType())
switch (branch.getType()) {
case RECORD:
+ case ENUM:
case FIXED:
String name = branch.getName();
if (name == null || name.equals(actual.getName()))
@@ -101,21 +102,21 @@
for (Schema branch : expected.getTypes())
switch (actual.getType()) {
case INT:
- switch (expected.getType()) {
+ switch (branch.getType()) {
case LONG: case FLOAT: case DOUBLE:
- return expected;
+ return branch;
}
break;
case LONG:
- switch (expected.getType()) {
+ switch (branch.getType()) {
case FLOAT: case DOUBLE:
- return expected;
+ return branch;
}
break;
case FLOAT:
- switch (expected.getType()) {
+ switch (branch.getType()) {
case DOUBLE:
- return expected;
+ return branch;
}
break;
}
Modified: hadoop/avro/trunk/src/py/avro/genericio.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/genericio.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/genericio.py (original)
+++ hadoop/avro/trunk/src/py/avro/genericio.py Wed Jul 1 16:53:37 2009
@@ -52,8 +52,8 @@
def _validaterecord(schm, object):
if not isinstance(object, dict):
return False
- for field,fieldschema in schm.getfields():
- if not validate(fieldschema, object.get(field)):
+ for field in schm.getfields().values():
+ if not validate(field.getschema(), object.get(field.getname())):
return False
return True
@@ -98,72 +98,240 @@
class DatumReader(io.DatumReaderBase):
"""DatumReader for generic python objects."""
- def __init__(self, schm=None):
- self.setschema(schm)
+ def __init__(self, actual=None, expected=None):
+ self.setschema(actual)
+ self.__expected = expected
self.__readfn = {
- schema.BOOLEAN : lambda schm, decoder: decoder.readboolean(),
- schema.STRING : lambda schm, decoder: decoder.readutf8(),
- schema.INT : lambda schm, decoder: decoder.readint(),
- schema.LONG : lambda schm, decoder: decoder.readlong(),
- schema.FLOAT : lambda schm, decoder: decoder.readfloat(),
- schema.DOUBLE : lambda schm, decoder: decoder.readdouble(),
- schema.BYTES : lambda schm, decoder: decoder.readbytes(),
- schema.FIXED : lambda schm, decoder:
- (decoder.read(schm.getsize())),
+ schema.BOOLEAN : lambda actual, expected, decoder: decoder.readboolean(),
+ schema.STRING : lambda actual, expected, decoder: decoder.readutf8(),
+ schema.INT : lambda actual, expected, decoder: decoder.readint(),
+ schema.LONG : lambda actual, expected, decoder: decoder.readlong(),
+ schema.FLOAT : lambda actual, expected, decoder: decoder.readfloat(),
+ schema.DOUBLE : lambda actual, expected, decoder: decoder.readdouble(),
+ schema.BYTES : lambda actual, expected, decoder: decoder.readbytes(),
+ schema.FIXED : self.readfixed,
schema.ARRAY : self.readarray,
schema.MAP : self.readmap,
schema.RECORD : self.readrecord,
- schema.ENUM : self.readenum,
- schema.UNION : self.readunion
+ schema.ENUM : self.readenum
+ }
+ self.__skipfn = {
+ schema.BOOLEAN : lambda schm, decoder: decoder.skipboolean(),
+ schema.STRING : lambda schm, decoder: decoder.skiputf8(),
+ schema.INT : lambda schm, decoder: decoder.skipint(),
+ schema.LONG : lambda schm, decoder: decoder.skiplong(),
+ schema.FLOAT : lambda schm, decoder: decoder.skipfloat(),
+ schema.DOUBLE : lambda schm, decoder: decoder.skipdouble(),
+ schema.BYTES : lambda schm, decoder: decoder.skipbytes(),
+ schema.FIXED : self.skipfixed,
+ schema.ARRAY : self.skiparray,
+ schema.MAP : self.skipmap,
+ schema.RECORD : self.skiprecord,
+ schema.ENUM : self.skipenum,
+ schema.UNION : self.skipunion
}
def setschema(self, schm):
- self.__schm = schm
+ self.__actual = schm
def read(self, decoder):
- return self.readdata(self.__schm, decoder)
-
- def readdata(self, schm, decoder):
- if schm.gettype() == schema.NULL:
+ if self.__expected is None:
+ self.__expected = self.__actual
+ return self.readdata(self.__actual, self.__expected, decoder)
+
+ def readdata(self, actual, expected, decoder):
+ if actual.gettype() == schema.UNION:
+ actual = actual.getelementtypes()[int(decoder.readlong())]
+ if expected.gettype() == schema.UNION:
+ expected = self._resolve(actual, expected)
+ if actual.gettype() == schema.NULL:
return None
- fn = self.__readfn.get(schm.gettype())
+ fn = self.__readfn.get(actual.gettype())
+ if fn is not None:
+ return fn(actual, expected, decoder)
+ else:
+ raise schema.AvroException("Unknown type: "+schema.stringval(actual));
+
+ def skipdata(self, schm, decoder):
+ fn = self.__skipfn.get(schm.gettype())
if fn is not None:
return fn(schm, decoder)
else:
- raise AvroException("Unknown type: "+schema.stringval(schm));
+ raise schema.AvroException("Unknown type: "+schema.stringval(schm));
- def readmap(self, schm, decoder):
+ def readfixed(self, actual, expected, decoder):
+ self.__checkname(actual, expected)
+ if actual.getsize() != expected.getsize():
+ self.__raisematchException(actual, expected)
+ return decoder.read(actual.getsize())
+
+ def skipfixed(self, schm):
+ return decoder.skip(actual.getsize())
+
+ def readmap(self, actual, expected, decoder):
+ if (actual.getvaluetype().gettype() !=
+ expected.getvaluetype().gettype()):
+ self.__raisematchException(actual, expected)
result = dict()
size = decoder.readlong()
if size != 0:
for i in range(0, size):
key = decoder.readutf8()
- result[key] = self.readdata(schm.getvaluetype(), decoder)
+ result[key] = self.readdata(actual.getvaluetype(),
+ expected.getvaluetype(), decoder)
decoder.readlong()
return result
- def readarray(self, schm, decoder):
+ def skipmap(self, schm, decoder):
+ size = decoder.readlong()
+ if size != 0:
+ for i in range(0, size):
+ decoder.skiputf8()
+ self.skipdata(schm.getvaluetype(), decoder)
+ decoder.skiplong()
+
+ def readarray(self, actual, expected, decoder):
+ if (actual.getelementtype().gettype() !=
+ expected.getelementtype().gettype()):
+ self.__raisematchException(actual, expected)
result = list()
size = decoder.readlong()
if size != 0:
for i in range(0, size):
- result.append(self.readdata(schm.getelementtype(), decoder))
+ result.append(self.readdata(actual.getelementtype(),
+ expected.getelementtype(), decoder))
decoder.readlong()
return result
- def readrecord(self, schm, decoder):
- result = dict()
- for field,fieldschema in schm.getfields():
- result[field] = self.readdata(fieldschema, decoder)
- return result
+ def skiparray(self, schm, decoder):
+ size = decoder.readlong()
+ if size != 0:
+ for i in range(0, size):
+ self.skipdata(schm.getelementtype(), decoder)
+ decoder.skiplong()
+
+ def createrecord(self, schm):
+ return dict()
+
+ def addfield(self, record, name, value):
+ record[name] = value
- def readenum(self, schm, decoder):
+ def readrecord(self, actual, expected, decoder):
+ self.__checkname(actual, expected)
+ expectedfields = expected.getfields()
+ record = self.createrecord(actual)
+ size = 0
+ for fieldname, field in actual.getfields().items():
+ if expected == actual:
+ expectedfield = field
+ else:
+ expectedfield = expectedfields.get(fieldname)
+ if expectedfield is None:
+ self.skipdata(field.getschema(), decoder)
+ continue
+ self.addfield(record, fieldname, self.readdata(field.getschema(),
+ expectedfield.getschema(), decoder))
+ size += 1
+ if len(expectedfields) > size: # not all fields set
+ actualfields = actual.getfields()
+ for fieldname, field in expectedfields.items():
+ if not actualfields.has_key(fieldname):
+ defval = field.getdefaultvalue()
+ if defval is not None:
+ self.addfield(record, fieldname,
+ self._defaultfieldvalue(field.getschema(), defval))
+ return record
+
+ def skiprecord(self, schm, decoder):
+ for field in schm.getfields().values():
+ self.skipdata(field.getschema(), decoder)
+
+ def readenum(self, actual, expected, decoder):
+ self.__checkname(actual, expected)
index = decoder.readint()
- return schm.getenumsymbols()[index]
+ return actual.getenumsymbols()[index]
+
+ def skipenum(self, schm, decoder):
+ decoder.skipint()
- def readunion(self, schm, decoder):
+ def skipunion(self, schm, decoder):
index = int(decoder.readlong())
- return self.readdata(schm.getelementtypes()[index], decoder)
+ return self.skipdata(schm.getelementtypes()[index], decoder)
+
+ def _resolve(self, actual, expected):
+ # scan for exact match
+ for branch in expected.getelementtypes():
+ if branch.gettype() == actual.gettype():
+ return branch
+ #scan for match via numeric promotion
+ for branch in expected.getelementtypes():
+ actualtype = actual.gettype()
+ expectedtype = branch.gettype()
+ if actualtype == schema.INT:
+ if (expectedtype == schema.LONG or expectedtype == schema.FLOAT
+ or expectedtype == schema.DOUBLE):
+ return branch
+ elif actualtype == schema.LONG:
+ if (expectedtype == schema.FLOAT or expectedtype == schema.DOUBLE):
+ return branch
+ elif actualtype == schema.FLOAT:
+ if (expectedtype == schema.DOUBLE):
+ return branch
+ self.__raisematchException(actual, expected)
+
+ def __checkname(self, actual, expected):
+ if actual.getname() != expected.getname():
+ self.__raisematchException(actual, expected)
+
+ def __raisematchException(self, actual, expected):
+ raise schema.AvroException("Expected "+schema.stringval(expected)+
+ ", found "+schema.stringval(actual))
+
+ def _defaultfieldvalue(self, schm, defaultnode):
+ if schm.gettype() == schema.RECORD:
+ record = self.createrecord(schm)
+ for field in schm.getfields().values():
+ v = defaultnode.get(field.getname())
+ if v is None:
+ v = field.getdefaultvalue()
+ if v is not None:
+ record[field.getname()] = self._defaultfieldvalue(
+ field.getschema(), v)
+ return record
+ elif schm.gettype() == schema.ENUM:
+ return defaultnode
+ elif schm.gettype() == schema.ARRAY:
+ array = list()
+ for node in defaultnode:
+ array.append(self._defaultfieldvalue(schm.getelementtype(), node))
+ return array
+ elif schm.gettype() == schema.MAP:
+ map = dict()
+ for k,v in defaultnode.items():
+ map[k] = self._defaultfieldvalue(schm.getvaluetype(), v)
+ return map
+ elif schm.gettype() == schema.UNION:
+ return self._defaultfieldvalue(schm.getelementtypes()[0], defaultnode)
+ elif schm.gettype() == schema.FIXED:
+ return defaultnode
+ elif schm.gettype() == schema.STRING:
+ return defaultnode
+ elif schm.gettype() == schema.BYTES:
+ return defaultnode
+ elif schm.gettype() == schema.INT:
+ return int(defaultnode)
+ elif schm.gettype() == schema.LONG:
+ return long(defaultnode)
+ elif schm.gettype() == schema.FLOAT:
+ return float(defaultnode)
+ elif schm.gettype() == schema.DOUBLE:
+ return float(defaultnode)
+ elif schm.gettype() == schema.BOOLEAN:
+ return bool(defaultnode)
+ elif schm.gettype() == schema.NULL:
+ return None
+ else:
+ raise schema.AvroException("Unknown type: "+schema.stringval(actual))
class DatumWriter(io.DatumWriterBase):
"""DatumWriter for generic python objects."""
@@ -233,8 +401,8 @@
def writerecord(self, schm, datum, encoder):
if not isinstance(datum, dict):
raise io.AvroTypeException(schm, datum)
- for field,fieldschema in schm.getfields():
- self.writedata(fieldschema, datum.get(field), encoder)
+ for field in schm.getfields().values():
+ self.writedata(field.getschema(), datum.get(field.getname()), encoder)
def writeunion(self, schm, datum, encoder):
index = self.resolveunion(schm, datum)
Modified: hadoop/avro/trunk/src/py/avro/io.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/io.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/io.py (original)
+++ hadoop/avro/trunk/src/py/avro/io.py Wed Jul 1 16:53:37 2009
@@ -108,6 +108,30 @@
def read(self, len):
return struct.unpack(len.__str__()+'s', self.__reader.read(len))[0]
+ def skipboolean(self):
+ self.skip(1)
+
+ def skipint(self):
+ self.skip(4)
+
+ def skiplong(self):
+ self.skip(8)
+
+ def skipfloat(self):
+ self.skip(4)
+
+ def skipdouble(self):
+ self.skip(8)
+
+ def skipbytes(self):
+ self.skip(self.readlong())
+
+ def skiputf8(self):
+ self.skipbytes()
+
+ def skip(self, len):
+ self.__reader.seek(self.__reader.tell()+len)
+
class Encoder(object):
"""Write leaf values."""
Modified: hadoop/avro/trunk/src/py/avro/protocol.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/protocol.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/protocol.py (original)
+++ hadoop/avro/trunk/src/py/avro/protocol.py Wed Jul 1 16:53:37 2009
@@ -76,11 +76,11 @@
str = cStringIO.StringIO()
str.write("{\"request\": [")
count = 0
- for k,v in self.__request.getfields():
+ for field in self.__request.getfields().values():
str.write("{\"name\": \"")
- str.write(k)
+ str.write(field.getname())
str.write("\", \"type\": ")
- str.write(v.str(self.__proto.gettypes()))
+ str.write(field.getschema().str(self.__proto.gettypes()))
str.write("}")
count+=1
if count < len(self.__request.getfields()):
@@ -158,8 +158,9 @@
fieldtype = field.get("type")
if fieldtype is None:
raise SchemaParseException("No param type: "+field.__str__())
- fields[fieldname] = schema._parse(fieldtype, self.__types)
- request = schema._RecordSchema(list(fields.iteritems()))
+ fields[fieldname] = schema.Field(fieldname,
+ schema._parse(fieldtype, self.__types))
+ request = schema._RecordSchema(fields)
response = schema._parse(res, self.__types)
erorrs = list()
Modified: hadoop/avro/trunk/src/py/avro/reflectio.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/reflectio.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/reflectio.py (original)
+++ hadoop/avro/trunk/src/py/avro/reflectio.py Wed Jul 1 16:53:37 2009
@@ -43,9 +43,9 @@
def _validaterecord(schm, pkgname, object):
if not isinstance(object, gettype(schm, pkgname)):
return False
- for field,fieldschema in schm.getfields():
- data = object.__getattribute__(field)
- if not validate(fieldschema, pkgname, data):
+ for field in schm.getfields().values():
+ data = object.__getattribute__(field.getname())
+ if not validate(field.getschema(), pkgname, data):
return False
return True
@@ -94,8 +94,8 @@
clazz = globals().get(clazzname)
if clazz is None:
clazz = type(str(clazzname),(base,),{})
- for field,fieldschema in recordschm.getfields():
- setattr(clazz, field, None)
+ for field in recordschm.getfields().values():
+ setattr(clazz, field.getname(), None)
globals()[clazzname] = clazz
return clazz
@@ -106,12 +106,12 @@
genericio.DatumReader.__init__(self, schm)
self.__pkgname = pkgname
- def readrecord(self, schm, decoder):
+ def addfield(self, record, name, value):
+ setattr(record, name, value)
+
+ def createrecord(self, schm):
type = gettype(schm, self.__pkgname)
- result = type()
- for field,fieldschema in schm.getfields():
- setattr(result, field, self.readdata(fieldschema, decoder))
- return result
+ return type()
class ReflectDatumWriter(genericio.DatumWriter):
"""DatumWriter for arbitrary python classes."""
@@ -121,8 +121,9 @@
self.__pkgname = pkgname
def writerecord(self, schm, datum, encoder):
- for field,fieldschema in schm.getfields():
- self.writedata(fieldschema, getattr(datum, field), encoder)
+ for field in schm.getfields().values():
+ self.writedata(field.getschema(), getattr(datum, field.getname()),
+ encoder)
def resolveunion(self, schm, datum):
index = 0
Modified: hadoop/avro/trunk/src/py/avro/reflectipc.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/reflectipc.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/reflectipc.py (original)
+++ hadoop/avro/trunk/src/py/avro/reflectipc.py Wed Jul 1 16:53:37 2009
@@ -68,9 +68,8 @@
return reflectio.ReflectDatumReader(self.__pkgname, schm)
def writerequest(self, schm, req, encoder):
- index = 0
for arg in req:
- argschm = schm.getfields()[index][1]
+ argschm = schm.getfields().values()[0].getschema()
genericipc.Requestor.writerequest(self, argschm, arg, encoder)
def readerror(self, schm, decoder):
@@ -91,8 +90,9 @@
def readrequest(self, schm, decoder):
req = list()
- for field, fieldschm in schm.getfields():
- req.append(genericipc.Responder.readrequest(self, fieldschm, decoder))
+ for field in schm.getfields().values():
+ req.append(genericipc.Responder.readrequest(self, field.getschema(),
+ decoder))
return req
def writeerror(self, schm, error, encoder):
Modified: hadoop/avro/trunk/src/py/avro/schema.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/schema.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/schema.py (original)
+++ hadoop/avro/trunk/src/py/avro/schema.py Wed Jul 1 16:53:37 2009
@@ -142,6 +142,25 @@
hash += self.__space.__hash__()
return hash
+class Field(object):
+ def __init__(self, name, schema, defaultvalue=None):
+ self.__name = name
+ self.__schema = schema
+ self.__defaultvalue = defaultvalue
+
+ def getname(self):
+ return self.__name
+
+ def getschema(self):
+ return self.__schema
+
+ def getdefaultvalue(self):
+ return self.__defaultvalue
+
+ def __eq__(self, other, seen={}):
+ return (self.__name == other.__name and
+ self.__schema.__eq__(other.__schema, seen) and
+ self.__defaultvalue == other.__defaultvalue)
class _RecordSchema(NamedSchema):
def __init__(self, fields, name=None, space=None, iserror=False):
@@ -170,11 +189,14 @@
str.write(self.namestring())
str.write("\"fields\": [")
count=0
- for k,v in self.__fields:
+ for field in self.__fields.values():
str.write("{\"name\": \"")
- str.write(k)
+ str.write(field.getname())
str.write("\", \"type\": ")
- str.write(v.str(names))
+ str.write(field.getschema().str(names))
+ if field.getdefaultvalue() is not None:
+ str.write(", \"default\": ")
+ str.write(repr(field.getdefaultvalue()))
str.write("}")
count+=1
if count < len(self.__fields):
@@ -190,8 +212,8 @@
if len(other.__fields) != size:
return False
seen[id(self)] = other
- for i in range(0, size):
- if not self.__fields[i][1].__eq__(other.__fields[i][1], seen):
+ for field in self.__fields.values():
+ if not field.__eq__(other.__fields.get(field.getname()), seen):
return False
return True
else:
@@ -202,8 +224,8 @@
return 0
seen.add(id(self))
hash = NamedSchema.__hash__(self, seen)
- for field, fieldschm in self.__fields:
- hash = hash + fieldschm.__hash__(seen)
+ for field in self.__fields.values():
+ hash = hash + field.getschema().__hash__(seen)
return hash
class _ArraySchema(Schema):
@@ -444,7 +466,7 @@
if name is None:
raise SchemaParseException("No name in schema: "+obj.__str__())
if type == "record" or type == "error":
- fields = list()
+ fields = odict.OrderedDict()
schema = _RecordSchema(fields, name, space, type == "error")
names[name] = schema
fieldsnode = obj.get("fields")
@@ -457,7 +479,9 @@
fieldtype = field.get("type")
if fieldtype is None:
raise SchemaParseException("No field type: "+field.__str__())
- fields.append((fieldname, _parse(fieldtype, names)))
+ defaultval = field.get("default")
+ fields[fieldname] = Field(fieldname, _parse(fieldtype, names),
+ defaultval)
return schema
elif type == "enum":
symbolsnode = obj.get("symbols")
Modified: hadoop/avro/trunk/src/test/py/testio.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/testio.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/test/py/testio.py (original)
+++ hadoop/avro/trunk/src/test/py/testio.py Wed Jul 1 16:53:37 2009
@@ -75,8 +75,8 @@
return map
elif schm.gettype() == schema.RECORD:
m = dict()
- for field, fieldschm in schm.getfields():
- m[field] = self.nextdata(fieldschm, d+1)
+ for field in schm.getfields().values():
+ m[field.getname()] = self.nextdata(field.getschema(), d+1)
return m
elif schm.gettype() == schema.UNION:
datum = self.nextdata(random.choice(schm.getelementtypes()), d)
@@ -107,43 +107,45 @@
self.__assertdata = assertdata
def testNull(self):
- self.check("\"null\"")
+ self.checkdefault("\"null\"", "null", None)
def testBoolean(self):
- self.check("\"boolean\"")
+ self.checkdefault("\"boolean\"", "true", True)
def testString(self):
- self.check("\"string\"")
+ self.checkdefault("\"string\"", "\"foo\"", "foo")
def testBytes(self):
- self.check("\"bytes\"")
+ self.checkdefault("\"bytes\"", "\"foo\"", "foo")
def testInt(self):
- self.check("\"int\"")
+ self.checkdefault("\"int\"", "5", 5)
def testLong(self):
- self.check("\"long\"")
+ self.checkdefault("\"long\"", "9", 9)
def testFloat(self):
- self.check("\"float\"")
+ self.checkdefault("\"float\"", "1.2", float(1.2))
def testDouble(self):
- self.check("\"double\"")
+ self.checkdefault("\"double\"", "1.2", float(1.2))
def testArray(self):
- self.check("{\"type\":\"array\", \"items\": \"long\"}")
+ self.checkdefault("{\"type\":\"array\", \"items\": \"long\"}",
+ "[1]", [1])
def testMap(self):
- self.check("{\"type\":\"map\", \"values\": \"string\"}")
+ self.checkdefault("{\"type\":\"map\", \"values\": \"long\"}",
+ "{\"a\":1}", {unicode("a"):1})
def testRecord(self):
- self.check("{\"type\":\"record\", \"name\":\"Test\"," +
+ self.checkdefault("{\"type\":\"record\", \"name\":\"Test\"," +
"\"fields\":[{\"name\":\"f\", \"type\":" +
- "\"string\"}, {\"name\":\"fb\", \"type\":\"bytes\"}]}")
+ "\"long\"}]}", "{\"f\":11}", {"f" : 11})
def testEnum(self):
- self.check("{\"type\": \"enum\", \"name\":\"Test\","+
- "\"symbols\": [\"A\", \"B\"]}")
+ self.checkdefault("{\"type\": \"enum\", \"name\":\"Test\","+
+ "\"symbols\": [\"A\", \"B\"]}", "\"B\"", "B")
def testRecursive(self):
self.check("{\"type\": \"record\", \"name\": \"Node\", \"fields\": ["
@@ -163,9 +165,11 @@
+"{\"type\": \"record\", \"name\": \"Cons\", \"fields\": ["
+"{\"name\":\"car\", \"type\":\"string\"},"
+"{\"name\":\"cdr\", \"type\":\"string\"}]}]")
+ self.checkdefault("[\"double\", \"long\"]", "1.1", 1.1)
def testFixed(self):
- self.check("{\"type\": \"fixed\", \"name\":\"Test\", \"size\": 1}")
+ self.checkdefault("{\"type\": \"fixed\", \"name\":\"Test\", \"size\": 1}",
+ "\"a\"", "a")
def check(self, string):
schm = schema.parse(string)
@@ -180,6 +184,20 @@
self.checkser(schm, randomdata)
self.checkdatafile(schm)
+ def checkdefault(self, schemajson, defaultjson, defaultvalue):
+ self.check(schemajson)
+ actual = schema.parse("{\"type\":\"record\", \"name\":\"Foo\","
+ + "\"fields\":[]}")
+ expected = schema.parse("{\"type\":\"record\", \"name\":\"Foo\","
+ +"\"fields\":[{\"name\":\"f\", "
+ +"\"type\":"+schemajson+", "
+ +"\"default\":"+defaultjson+"}]}")
+ reader = genericio.DatumReader(actual, expected)
+ record = reader.read(io.Decoder(cStringIO.StringIO()))
+ self.assertEquals(defaultvalue, record.get("f"))
+ #FIXME fix to string for default values
+ #self.assertEquals(expected, schema.parse(schema.stringval(expected)))
+
def checkser(self, schm, randomdata):
datum = randomdata.next()
self.assertTrue(self.__validator(schm, datum))
Modified: hadoop/avro/trunk/src/test/py/testioreflect.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/testioreflect.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/test/py/testioreflect.py (original)
+++ hadoop/avro/trunk/src/test/py/testioreflect.py Wed Jul 1 16:53:37 2009
@@ -29,8 +29,8 @@
if schm.gettype() == schema.RECORD:
clazz = reflectio.gettype(schm, _PKGNAME)
result = clazz()
- for field,fieldschema in schm.getfields():
- result.__setattr__(field, self.nextdata(fieldschema,d))
+ for field in schm.getfields().values():
+ result.__setattr__(field.getname(), self.nextdata(field.getschema(),d))
return result
else:
return testio.RandomData.nextdata(self, schm, d)