Fixing avro deserializer (using ReflectDatumReader) to be able to read arrays 
in InstanceData objects


Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/3fbfc071
Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/3fbfc071
Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/3fbfc071

Branch: refs/heads/master
Commit: 3fbfc071ef6eaa37bd93e20fe68ec0d2d61bd0e8
Parents: e0d5b31
Author: Jakub Jankowski <[email protected]>
Authored: Tue May 16 14:11:36 2017 +0200
Committer: nkourtellis <[email protected]>
Committed: Fri Jul 21 21:12:18 2017 +0300

----------------------------------------------------------------------
 .../samoa/streams/kafka/KafkaAvroMapper.java    |  11 +-
 .../streams/kafka/avro/SamoaDatumReader.java    | 115 +++++++++++++++++++
 samoa-api/src/main/resources/kafka.avsc         |  53 ++++++++-
 .../kafka/AvroSerializerDeserializerTest.java   |  70 +++++++++++
 4 files changed, 238 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/3fbfc071/samoa-api/src/main/java/org/apache/samoa/streams/kafka/KafkaAvroMapper.java
----------------------------------------------------------------------
diff --git 
a/samoa-api/src/main/java/org/apache/samoa/streams/kafka/KafkaAvroMapper.java 
b/samoa-api/src/main/java/org/apache/samoa/streams/kafka/KafkaAvroMapper.java
index afbc002..a045bed 100644
--- 
a/samoa-api/src/main/java/org/apache/samoa/streams/kafka/KafkaAvroMapper.java
+++ 
b/samoa-api/src/main/java/org/apache/samoa/streams/kafka/KafkaAvroMapper.java
@@ -21,21 +21,18 @@ import java.io.File;
 import java.io.IOException;
 
 import org.apache.avro.Schema;
-import org.apache.avro.generic.GenericDatumWriter;
 import org.apache.avro.io.BinaryEncoder;
 import org.apache.avro.io.DatumReader;
 import org.apache.avro.io.DatumWriter;
 import org.apache.avro.io.Decoder;
 import org.apache.avro.io.DecoderFactory;
-import org.apache.avro.io.Encoder;
 import org.apache.avro.io.EncoderFactory;
 import org.apache.avro.reflect.ReflectData;
-import org.apache.avro.reflect.ReflectDatumReader;
 import org.apache.avro.reflect.ReflectDatumWriter;
-import org.apache.avro.specific.SpecificDatumReader;
 import org.apache.avro.specific.SpecificDatumWriter;
 import org.apache.avro.specific.SpecificRecord;
 import org.apache.samoa.learners.InstanceContentEvent;
+import org.apache.samoa.streams.kafka.avro.SamoaDatumReader;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -91,7 +88,7 @@ public class KafkaAvroMapper implements 
KafkaDeserializer<InstanceContentEvent>,
        public static <V> byte[] avroSerialize(final Class<V> cls, final V v) {
                ByteArrayOutputStream bout = new ByteArrayOutputStream();
                try {
-                       Schema schema = new Schema.Parser().parse(new 
File("C:/java/avro/kafka.avsc"));
+                       Schema schema = new 
Schema.Parser().parse(KafkaAvroMapper.class.getResourceAsStream("/kafka.avsc"));
                        DatumWriter<V> writer;
 
                        if (v instanceof SpecificRecord) {
@@ -123,9 +120,9 @@ public class KafkaAvroMapper implements 
KafkaDeserializer<InstanceContentEvent>,
        public static <V> V avroDeserialize(byte[] avroBytes, Class<V> clazz) {
                V ret = null;
                try {
-                       Schema schema = new Schema.Parser().parse(new 
File("C:/java/avro/kafka.avsc"));
+                       Schema schema = new 
Schema.Parser().parse(KafkaAvroMapper.class.getResourceAsStream("/kafka.avsc"));
                        ByteArrayInputStream in = new 
ByteArrayInputStream(avroBytes);
-                       DatumReader<V> reader = new 
ReflectDatumReader<>(schema);
+                       DatumReader<V> reader = new SamoaDatumReader<>(schema);
                        
                        Decoder decoder = 
DecoderFactory.get().directBinaryDecoder(in, null);
                        

http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/3fbfc071/samoa-api/src/main/java/org/apache/samoa/streams/kafka/avro/SamoaDatumReader.java
----------------------------------------------------------------------
diff --git 
a/samoa-api/src/main/java/org/apache/samoa/streams/kafka/avro/SamoaDatumReader.java
 
b/samoa-api/src/main/java/org/apache/samoa/streams/kafka/avro/SamoaDatumReader.java
new file mode 100644
index 0000000..b7a18aa
--- /dev/null
+++ 
b/samoa-api/src/main/java/org/apache/samoa/streams/kafka/avro/SamoaDatumReader.java
@@ -0,0 +1,115 @@
+package org.apache.samoa.streams.kafka.avro;
+
+import java.io.IOException;
+
+import org.apache.avro.AvroRuntimeException;
+import org.apache.avro.Schema;
+import org.apache.avro.Schema.Field;
+import org.apache.avro.generic.GenericData.Array;
+import org.apache.avro.generic.IndexedRecord;
+import org.apache.avro.io.ResolvingDecoder;
+import org.apache.avro.reflect.ReflectData;
+import org.apache.avro.reflect.ReflectDatumReader;
+import org.apache.avro.specific.SpecificRecordBase;
+import org.apache.samoa.instances.DenseInstanceData;
+import org.apache.samoa.instances.SingleClassInstanceData;
+import org.apache.samoa.instances.SparseInstanceData;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * DatumReader used to read objects built with InstanceData classes
+ * @author Jakub Jankowski
+ *
+ * @param <T>
+ */
+public class SamoaDatumReader<T> extends ReflectDatumReader<T> {
+
+       private static Logger logger = 
LoggerFactory.getLogger(SamoaDatumReader.class);
+
+       public SamoaDatumReader() {
+               super();
+       }
+
+       /** Construct for reading instances of a class. */
+       public SamoaDatumReader(Class<T> c) {
+               super(c);
+       }
+
+       /** Construct where the writer's and reader's schemas are the same. */
+       public SamoaDatumReader(Schema root) {
+               super(root);
+       }
+
+       /** Construct given writer's and reader's schema. */
+       public SamoaDatumReader(Schema writer, Schema reader) {
+               super(writer, reader);
+       }
+
+       /** Construct given writer's and reader's schema and the data model. */
+       public SamoaDatumReader(Schema writer, Schema reader, ReflectData data) 
{
+               super(writer, reader, data);
+       }
+
+       /** Construct given a {@link ReflectData}. */
+       public SamoaDatumReader(ReflectData data) {
+               super(data);
+       }
+
+       @Override
+       /**
+        * Called to read a record instance. Overridden to read InstanceData.
+        */
+       protected Object readRecord(Object old, Schema expected, 
ResolvingDecoder in) throws IOException {
+               Object r = getData().newRecord(old, expected);
+               Object state = null;
+
+               for (Field f : in.readFieldOrder()) {
+                       int pos = f.pos();
+                       String name = f.name();
+                       Object oldDatum = null;
+                       if (r instanceof DenseInstanceData) {
+                               r = readDenseInstanceData(r, f, oldDatum, in, 
state);
+                       } else if (r instanceof SparseInstanceData) {
+                               r = readSparseInstanceData(r, f, oldDatum, in, 
state);
+                       } else
+                               readField(r, f, oldDatum, in, state);
+               }
+               
+               return r;
+       }
+
+       private Object readDenseInstanceData(Object record, Field f, Object 
oldDatum, ResolvingDecoder in, Object state)
+                       throws IOException {
+               if (f.name().equals("attributeValues")) {
+                       Array atributes = (Array) read(oldDatum, f.schema(), 
in);
+                       double[] atributesArr = new double[atributes.size()];
+                       for (int i = 0; i < atributes.size(); i++) {
+                               atributesArr[i] = (double) atributes.get(i);
+                       }
+                       return new DenseInstanceData(atributesArr);
+               }
+               return null;
+       }
+       
+       private Object readSparseInstanceData(Object record, Field f, Object 
oldDatum, ResolvingDecoder in, Object state)
+                       throws IOException {
+               if(f.name().equals("attributeValues")) {
+                       Array atributes = (Array) read(oldDatum, f.schema(), 
in);
+                       double[] atributesArr = new double[atributes.size()];
+                       for (int i = 0; i < atributes.size(); i++) 
+                               atributesArr[i] = (double) atributes.get(i);
+                       
((SparseInstanceData)record).setAttributeValues(atributesArr);
+               }
+               if(f.name().equals("indexValues")) {
+                       Array indexValues = (Array) read(oldDatum, f.schema(), 
in);
+                       int[] indexValuesArr = new int[indexValues.size()];
+                       for (int i = 0; i < indexValues.size(); i++) {
+                               indexValuesArr[i] = (int) indexValues.get(i);
+                       }
+                       
((SparseInstanceData)record).setIndexValues(indexValuesArr);
+               }
+               return record;
+       }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/3fbfc071/samoa-api/src/main/resources/kafka.avsc
----------------------------------------------------------------------
diff --git a/samoa-api/src/main/resources/kafka.avsc 
b/samoa-api/src/main/resources/kafka.avsc
index c21e153..f5f12cf 100644
--- a/samoa-api/src/main/resources/kafka.avsc
+++ b/samoa-api/src/main/resources/kafka.avsc
@@ -1,11 +1,31 @@
 [
 {
+  "namespace": "org.apache.samoa.streams.kafka.temp",
+  "type": "record",
+  "name": "BurrTest",
+  "fields": [
+               {"name":"name", "type": "string"},
+               {"name":"atrs", "type": {"type": "array", "items": "string"}},
+               {"name":"nums", "type": {"type": "array", "items": "int"}},
+               {"name":"list", "type": {"type": "array", "items": "string"}}
+       ]
+},
+{
+  "namespace": "org.apache.samoa.instances",
+  "type": "record",
+  "name": "Instance",
+  "fields": [
+       ]
+},
+{
+  "namespace": "org.apache.samoa.instances",
   "type": "record",
   "name": "InstanceData",
   "fields": [
        ]
 },
 {
+  "namespace": "org.apache.samoa.instances",
   "type": "record",
   "name": "SingleClassInstanceData",
   "fields": [
@@ -13,6 +33,7 @@
        ]
 },
 {
+  "namespace": "org.apache.samoa.instances",
   "type": "record",
   "name": "DenseInstanceData",
   "fields": [
@@ -20,6 +41,7 @@
        ]
 },
 {
+  "namespace": "org.apache.samoa.instances",
   "type": "record",
   "name": "SparseInstanceData",
   "fields": [
@@ -29,32 +51,55 @@
        ]
 },
 {
+  "namespace": "org.apache.samoa.instances",
+  "type": "record",
+  "name": "SingleLabelInstance",
+  "fields": [
+               {"name": "weight", "type": "double"},
+               {"name": "instanceData", "type": ["null", 
"org.apache.samoa.instances.InstanceData", 
"org.apache.samoa.instances.DenseInstanceData", 
"org.apache.samoa.instances.SparseInstanceData", 
"org.apache.samoa.instances.SingleClassInstanceData"]},
+               {"name": "classData", "type": ["null", 
"org.apache.samoa.instances.InstanceData", 
"org.apache.samoa.instances.DenseInstanceData", 
"org.apache.samoa.instances.SparseInstanceData", 
"org.apache.samoa.instances.SingleClassInstanceData"]}
+       ]
+},
+{
+  "namespace": "org.apache.samoa.instances",
+  "type": "record",
+  "name": "DenseInstance",
+  "fields": [
+               {"name": "weight", "type": "double"},
+               {"name": "instanceData", "type": ["null", 
"org.apache.samoa.instances.InstanceData", 
"org.apache.samoa.instances.DenseInstanceData", 
"org.apache.samoa.instances.SparseInstanceData", 
"org.apache.samoa.instances.SingleClassInstanceData"]},
+               {"name": "classData", "type": ["null", 
"org.apache.samoa.instances.InstanceData", 
"org.apache.samoa.instances.DenseInstanceData", 
"org.apache.samoa.instances.SparseInstanceData", 
"org.apache.samoa.instances.SingleClassInstanceData"]}
+       ]
+},
+{
+  "namespace": "org.apache.samoa.core",
   "type": "record",
   "name": "SerializableInstance",
   "fields": [
                {"name": "weight", "type": "double"},
-               {"name": "instanceData", "type": ["null", "InstanceData", 
"DenseInstanceData", "SparseInstanceData", "SingleClassInstanceData"]},
-               {"name": "classData", "type": "InstanceData"}
+               {"name": "instanceData", "type": ["null", 
"org.apache.samoa.instances.InstanceData", 
"org.apache.samoa.instances.DenseInstanceData", 
"org.apache.samoa.instances.SparseInstanceData", 
"org.apache.samoa.instances.SingleClassInstanceData"]},
+               {"name": "classData", "type": ["null", 
"org.apache.samoa.instances.InstanceData", 
"org.apache.samoa.instances.DenseInstanceData", 
"org.apache.samoa.instances.SparseInstanceData", 
"org.apache.samoa.instances.SingleClassInstanceData"]}
        ]
 },
 {
+  "namespace": "org.apache.samoa.learners",
   "type": "record",
   "name": "InstanceContent",
   "fields": [
                {"name": "instanceIndex", "type": "long"},
                {"name": "classifierIndex", "type": "int"},
                {"name": "evaluationIndex", "type": "int"},
-               {"name":"instance", "type":"SerializableInstance"},
+               {"name":"instance", 
"type":"org.apache.samoa.core.SerializableInstance"},
                {"name": "isTraining", "type": "boolean"},
                {"name": "isTesting", "type": "boolean"},
                {"name": "isLast", "type": "boolean"}
        ]
 },
 {
+ "namespace": "org.apache.samoa.learners",
  "type": "record",
  "name": "InstanceContentEvent",
  "fields": [
-     {"name": "instanceContent", "type": "InstanceContent"}
+     {"name": "instanceContent", "type": 
"org.apache.samoa.learners.InstanceContent"}
  ]
 }
 ]

http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/3fbfc071/samoa-api/src/test/java/org/apache/samoa/streams/kafka/AvroSerializerDeserializerTest.java
----------------------------------------------------------------------
diff --git 
a/samoa-api/src/test/java/org/apache/samoa/streams/kafka/AvroSerializerDeserializerTest.java
 
b/samoa-api/src/test/java/org/apache/samoa/streams/kafka/AvroSerializerDeserializerTest.java
new file mode 100644
index 0000000..1a1a718
--- /dev/null
+++ 
b/samoa-api/src/test/java/org/apache/samoa/streams/kafka/AvroSerializerDeserializerTest.java
@@ -0,0 +1,70 @@
+package org.apache.samoa.streams.kafka;
+
+import static org.junit.Assert.assertTrue;
+
+import java.util.Random;
+import java.util.logging.Logger;
+
+import org.apache.samoa.instances.InstancesHeader;
+import org.apache.samoa.learners.InstanceContentEvent;
+import org.apache.samoa.streams.kafka.KafkaAvroMapper;
+import org.junit.Test;
+
+public class AvroSerializerDeserializerTest {
+
+       private Logger logger = 
Logger.getLogger(AvroSerializerDeserializerTest.class.getName());
+       public AvroSerializerDeserializerTest() {}
+       
+       @Test
+       public void testAvroSerialize() {
+               Random r = new Random();
+        InstancesHeader header = TestUtilsForKafka.generateHeader(10);
+        InstanceContentEvent eventToSerialize = TestUtilsForKafka.getData(r, 
10, header);
+               byte[] data = 
KafkaAvroMapper.avroSerialize(InstanceContentEvent.class, eventToSerialize);
+               
+               InstanceContentEvent eventDeserialized = 
KafkaAvroMapper.avroDeserialize(data, InstanceContentEvent.class);
+               
+               assertTrue("Serialized and deserialized event", 
isEqual(eventToSerialize, eventDeserialized));
+               
+       }
+       
+       public boolean isEqual(InstanceContentEvent a, InstanceContentEvent b) {
+               if(a.getClassId() != b.getClassId()) {
+                       logger.info("a.getClassId() != b.getClassId(): " + 
(a.getClassId() != b.getClassId()));
+                       return false;
+               }
+               if(a.isLastEvent() != b.isLastEvent()) {
+                       logger.info("a.isLastEvent() != b.isLastEvent(): " + 
(a.isLastEvent() != b.isLastEvent()));
+                       return false;
+               }
+               if(a.isTesting() != b.isTesting()) {
+                       logger.info("a.isTesting() != b.isTesting(): " + 
(a.isTesting() != b.isTesting()));
+                       return false;
+               }
+               if(a.isTraining() != b.isTraining()) {
+                       logger.info("a.isTraining() != b.isTraining(): " + 
(a.isTraining() != b.isTraining()));
+                       return false;
+               }
+               if(a.getClassifierIndex() != b.getClassifierIndex()) {
+                       logger.info("a.getClassifierIndex() != 
b.getClassifierIndex(): " + (a.getClassifierIndex() != b.getClassifierIndex()));
+                       return false;
+               }
+               if(a.getEvaluationIndex() != b.getEvaluationIndex()) {
+                       logger.info("a.getEvaluationIndex() != 
b.getEvaluationIndex(): " + (a.getEvaluationIndex() != b.getEvaluationIndex()));
+                       return false;
+               }
+               if(a.getInstanceIndex() != b.getInstanceIndex()) {
+                       logger.info("a.getInstanceIndex() != 
b.getInstanceIndex(): " + (a.getInstanceIndex() != b.getInstanceIndex()));
+                       return false;
+               }
+               
if(!a.getInstance().toString().equals(b.getInstance().toString())) {
+                       logger.info("a.getInstance().toString()!= 
b.getInstance().toString(): " + (a.getInstance().toString()!= 
b.getInstance().toString()));
+                       logger.info("a.toString(): " + 
a.getInstance().toString());
+                       logger.info("b.toString(): " + 
b.getInstance().toString());
+                       return false;
+               }
+               
+               return true;
+       }
+
+}

Reply via email to