Author: tdunning
Date: Thu Sep 23 00:12:13 2010
New Revision: 1000281

URL: http://svn.apache.org/viewvc?rev=1000281&view=rev
Log:
Fixed OnlineAuc serialization woes by adding class name

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1000281&r1=1000280&r2=1000281&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
 Thu Sep 23 00:12:13 2010
@@ -37,6 +37,7 @@ import org.apache.mahout.math.DenseVecto
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
 
 import java.io.FileWriter;
 import java.io.IOException;
@@ -57,6 +58,7 @@ public final class ModelSerializer {
     gb.registerTypeAdapter(AdaptiveLogisticRegression.class, new 
AdaptiveLogisticRegressionTypeAdapter());
     gb.registerTypeAdapter(Mapping.class, new MappingTypeAdapter());
     gb.registerTypeAdapter(PriorFunction.class, new PriorTypeAdapter());
+    gb.registerTypeAdapter(OnlineAuc.class, new AucTypeAdapter());
     gb.registerTypeAdapter(CrossFoldLearner.class, new 
CrossFoldLearnerTypeAdapter());
     gb.registerTypeAdapter(Vector.class, new VectorTypeAdapter());
     gb.registerTypeAdapter(Matrix.class, new MatrixTypeAdapter());
@@ -146,6 +148,31 @@ public final class ModelSerializer {
     }
   }
 
+  private static class AucTypeAdapter implements JsonDeserializer<OnlineAuc>, 
JsonSerializer<OnlineAuc> {
+    @Override
+    public OnlineAuc deserialize(JsonElement jsonElement,
+                                     Type type,
+                                     JsonDeserializationContext 
jsonDeserializationContext) {
+      JsonObject x = jsonElement.getAsJsonObject();
+      try {
+        return jsonDeserializationContext.deserialize(x.get("value"), 
Class.forName(x.get("class").getAsString()));
+      } catch (ClassNotFoundException e) {
+        throw new IllegalStateException("Can't understand serialized data, 
found bad type: "
+            + x.get("class").getAsString());
+      }
+    }
+
+    @Override
+    public JsonElement serialize(OnlineAuc auc,
+                                 Type type,
+                                 JsonSerializationContext 
jsonSerializationContext) {
+      JsonObject r = new JsonObject();
+      r.add("class", new JsonPrimitive(auc.getClass().getName()));
+      r.add("value", jsonSerializationContext.serialize(auc));
+      return r;
+    }
+  }
+
   private static class CrossFoldLearnerTypeAdapter implements 
JsonDeserializer<CrossFoldLearner> {
     @Override
     public CrossFoldLearner deserialize(JsonElement jsonElement,
@@ -154,7 +181,7 @@ public final class ModelSerializer {
       CrossFoldLearner r = new CrossFoldLearner();
       JsonObject x = jsonElement.getAsJsonObject();
       r.setRecord(x.get("record").getAsInt());
-      
r.setAucEvaluator(jsonDeserializationContext.<GlobalOnlineAuc>deserialize(x.get("auc"),
 GlobalOnlineAuc.class));
+      
r.setAucEvaluator(jsonDeserializationContext.<OnlineAuc>deserialize(x.get("auc"),
 OnlineAuc.class));
       r.setLogLikelihood(x.get("logLikelihood").getAsDouble());
 
       JsonArray models = x.get("models").getAsJsonArray();

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java?rev=1000281&r1=1000280&r2=1000281&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
 Thu Sep 23 00:12:13 2010
@@ -69,7 +69,7 @@ public final class ModelSerializerTest e
       auc1.addSample(0, gen.nextGaussian());
       auc1.addSample(1, gen.nextGaussian() + 1);
     }
-    assertEquals(0.76, auc1.auc(), 0.04);
+    assertEquals(0.76, auc1.auc(), 0.01);
 
     Gson gson = ModelSerializer.gson();
     String s = gson.toJson(auc1);
@@ -87,6 +87,23 @@ public final class ModelSerializerTest e
     }
 
     assertEquals(auc1.auc(), auc2.auc(), 0.01);
+
+    Foo x = new Foo();
+    x.foo = auc1;
+    x.pig = 3.13;
+    x.dog = 42;
+
+    s = gson.toJson(x);
+
+    Foo y = gson.fromJson(s, Foo.class);
+
+    assertEquals(auc1.auc(), y.foo.auc(), 0.01);
+  }
+
+  public static class Foo {
+    OnlineAuc foo;
+    double pig;
+    int dog;
   }
 
   @Test


Reply via email to