Author: tdunning
Date: Thu Sep 23 00:12:13 2010
New Revision: 1000281
URL: http://svn.apache.org/viewvc?rev=1000281&view=rev
Log:
Fixed OnlineAuc serialization woes by adding class name
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1000281&r1=1000280&r2=1000281&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
Thu Sep 23 00:12:13 2010
@@ -37,6 +37,7 @@ import org.apache.mahout.math.DenseVecto
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
import java.io.FileWriter;
import java.io.IOException;
@@ -57,6 +58,7 @@ public final class ModelSerializer {
gb.registerTypeAdapter(AdaptiveLogisticRegression.class, new
AdaptiveLogisticRegressionTypeAdapter());
gb.registerTypeAdapter(Mapping.class, new MappingTypeAdapter());
gb.registerTypeAdapter(PriorFunction.class, new PriorTypeAdapter());
+ gb.registerTypeAdapter(OnlineAuc.class, new AucTypeAdapter());
gb.registerTypeAdapter(CrossFoldLearner.class, new
CrossFoldLearnerTypeAdapter());
gb.registerTypeAdapter(Vector.class, new VectorTypeAdapter());
gb.registerTypeAdapter(Matrix.class, new MatrixTypeAdapter());
@@ -146,6 +148,31 @@ public final class ModelSerializer {
}
}
+ private static class AucTypeAdapter implements JsonDeserializer<OnlineAuc>,
JsonSerializer<OnlineAuc> {
+ @Override
+ public OnlineAuc deserialize(JsonElement jsonElement,
+ Type type,
+ JsonDeserializationContext
jsonDeserializationContext) {
+ JsonObject x = jsonElement.getAsJsonObject();
+ try {
+ return jsonDeserializationContext.deserialize(x.get("value"),
Class.forName(x.get("class").getAsString()));
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException("Can't understand serialized data,
found bad type: "
+ + x.get("class").getAsString());
+ }
+ }
+
+ @Override
+ public JsonElement serialize(OnlineAuc auc,
+ Type type,
+ JsonSerializationContext
jsonSerializationContext) {
+ JsonObject r = new JsonObject();
+ r.add("class", new JsonPrimitive(auc.getClass().getName()));
+ r.add("value", jsonSerializationContext.serialize(auc));
+ return r;
+ }
+ }
+
private static class CrossFoldLearnerTypeAdapter implements
JsonDeserializer<CrossFoldLearner> {
@Override
public CrossFoldLearner deserialize(JsonElement jsonElement,
@@ -154,7 +181,7 @@ public final class ModelSerializer {
CrossFoldLearner r = new CrossFoldLearner();
JsonObject x = jsonElement.getAsJsonObject();
r.setRecord(x.get("record").getAsInt());
-
r.setAucEvaluator(jsonDeserializationContext.<GlobalOnlineAuc>deserialize(x.get("auc"),
GlobalOnlineAuc.class));
+
r.setAucEvaluator(jsonDeserializationContext.<OnlineAuc>deserialize(x.get("auc"),
OnlineAuc.class));
r.setLogLikelihood(x.get("logLikelihood").getAsDouble());
JsonArray models = x.get("models").getAsJsonArray();
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java?rev=1000281&r1=1000280&r2=1000281&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
Thu Sep 23 00:12:13 2010
@@ -69,7 +69,7 @@ public final class ModelSerializerTest e
auc1.addSample(0, gen.nextGaussian());
auc1.addSample(1, gen.nextGaussian() + 1);
}
- assertEquals(0.76, auc1.auc(), 0.04);
+ assertEquals(0.76, auc1.auc(), 0.01);
Gson gson = ModelSerializer.gson();
String s = gson.toJson(auc1);
@@ -87,6 +87,23 @@ public final class ModelSerializerTest e
}
assertEquals(auc1.auc(), auc2.auc(), 0.01);
+
+ Foo x = new Foo();
+ x.foo = auc1;
+ x.pig = 3.13;
+ x.dog = 42;
+
+ s = gson.toJson(x);
+
+ Foo y = gson.fromJson(s, Foo.class);
+
+ assertEquals(auc1.auc(), y.foo.auc(), 0.01);
+ }
+
+ public static class Foo {
+ OnlineAuc foo;
+ double pig;
+ int dog;
}
@Test