Modified: mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java?rev=1053787&r1=1053786&r2=1053787&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java Thu Dec 30 02:30:03 2010 @@ -20,19 +20,23 @@ package org.apache.mahout.ep; import org.apache.mahout.common.MahoutTestCase; import org.junit.Test; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + public final class EvolutionaryProcessTest extends MahoutTestCase { @Test public void testConverges() throws Exception { - State<Foo> s0 = new State<Foo>(new double[5], 1); + State<Foo, Double> s0 = new State<Foo, Double>(new double[5], 1); s0.setPayload(new Foo()); - EvolutionaryProcess<Foo> ep = new EvolutionaryProcess<Foo>(10, 100, s0); + EvolutionaryProcess<Foo, Double> ep = new EvolutionaryProcess<Foo, Double>(10, 100, s0); - State<Foo> best = null; - for (int i = 0; i < 20 ; i++) { - best = ep.parallelDo(new EvolutionaryProcess.Function<Foo>() { + State<Foo, Double> best = null; + for (int i = 0; i < 20; i++) { + best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<Double>>() { @Override - public double apply(Foo payload, double[] params) { + public double apply(Payload<Double> payload, double[] params) { int i = 1; double sum = 0; for (double x : params) { @@ -52,7 +56,7 @@ public final class EvolutionaryProcessTe assertEquals(0.0, best.getValue(), 0.02); } - private static class Foo implements Payload<Foo> { + private static class Foo implements Payload<Double> { @Override public Foo copy() { return this; @@ -62,5 +66,15 @@ public final class EvolutionaryProcessTe public void update(double[] params) { // ignore } + + @Override + public void write(DataOutput dataOutput) throws IOException { + // no-op + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + // no-op + } } }
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1053787&r1=1053786&r2=1053787&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java Thu Dec 30 02:30:03 2010 @@ -159,7 +159,7 @@ public final class TrainNewsGroups { int k = 0; double step = 0; int[] bumps = {1, 2, 5}; - for (File file : files.subList(0, 10000)) { + for (File file : files.subList(0, 3000)) { String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); @@ -170,7 +170,7 @@ public final class TrainNewsGroups { int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); - State<AdaptiveLogisticRegression.Wrapper> best = learningAlgorithm.getBest(); + State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); double maxBeta; double nonZeros; double positive; @@ -214,7 +214,7 @@ public final class TrainNewsGroups { } if (k % (bump * scale) == 0) { if (learningAlgorithm.getBest() != null) { - ModelSerializer.writeJson("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner()); + ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); } step += 0.25; @@ -227,7 +227,7 @@ public final class TrainNewsGroups { dissect(leakType, newsGroups, learningAlgorithm, files); System.out.println("exiting main"); - ModelSerializer.writeJson("/tmp/news-group.model", learningAlgorithm); + ModelSerializer.writeBinary("/tmp/news-group.model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); List<Integer> counts = Lists.newArrayList(); System.out.printf("Word counts\n"); Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/SparseMatrix.java URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/SparseMatrix.java?rev=1053787&r1=1053786&r2=1053787&view=diff ============================================================================== --- mahout/trunk/math/src/main/java/org/apache/mahout/math/SparseMatrix.java (original) +++ mahout/trunk/math/src/main/java/org/apache/mahout/math/SparseMatrix.java Thu Dec 30 02:30:03 2010 @@ -17,10 +17,10 @@ package org.apache.mahout.math; -import java.util.Map; - import org.apache.mahout.math.map.OpenIntObjectHashMap; +import java.util.Map; + /** Doubly sparse matrix. Implemented as a Map of RandomAccessSparseVector rows */ public class SparseMatrix extends AbstractMatrix { private OpenIntObjectHashMap<Vector> rows; @@ -55,6 +55,13 @@ public class SparseMatrix extends Abstra this.cardinality = cardinality.clone(); this.rows = new OpenIntObjectHashMap<Vector>(); } + + /** + * Construct a matrix with specified number of rows and columns. + */ + public SparseMatrix(int rows, int columns) { + this(new int[]{rows, columns}); + } @Override public Matrix clone() {
