Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestDenseVector.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestDenseVector.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestDenseVector.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestDenseVector.java Tue Jun 16 13:28:14 2009 @@ -17,6 +17,9 @@ package org.apache.mahout.matrix; +import java.util.HashMap; +import java.util.Map; + import junit.framework.TestCase; public class TestDenseVector extends TestCase { @@ -31,7 +34,10 @@ public void testAsFormatString() { String formatString = test.asWritableComparable().toString(); - assertEquals("format", "[, 1.1, 2.2, 3.3, ] ", formatString); + assertEquals( + "format", + "{\"class\":\"org.apache.mahout.matrix.DenseVector\",\"vector\":\"{\\\"values\\\":[1.1,2.2,3.3]}\"}", + formatString); } public void testCardinality() { @@ -126,8 +132,8 @@ } } - public void testDecodeFormat() throws Exception { - Vector val = DenseVector.decodeFormat(test.asWritableComparable()); + public void testDecodeVector() throws Exception { + Vector val = AbstractVector.decodeVector(test.asWritableComparable()); for (int i = 0; i < test.cardinality(); i++) assertEquals("get [" + i + ']', test.get(i), val.get(i)); } @@ -350,4 +356,17 @@ } } } + + public void testLabelIndexing() { + Map<String, Integer> bindings = new HashMap<String, Integer>(); + bindings.put("Fee", 0); + bindings.put("Fie", 1); + bindings.put("Foe", 2); + test.setLabelBindings(bindings); + assertEquals("Fee", test.get(0), test.get("Fee")); + assertEquals("Fie", test.get(1), test.get("Fie")); + assertEquals("Foe", test.get(2), test.get("Foe")); + test.set("Fie", 15.3); + assertEquals("Fie", test.get(1), test.get("Fie")); + } }
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java Tue Jun 16 13:28:14 2009 @@ -43,9 +43,12 @@ } public void testAsFormatString() { - assertEquals("format", - "[, [, 2.2, 3.3, ], [, 4.4, 5.5, ], [, 6.6, 7.7, ], ] ", test - .asWritableComparable().toString()); + String string = test + .asWritableComparable().toString(); + Matrix m = AbstractMatrix.decodeMatrix(string); + int[] c = m.cardinality(); + assertEquals("row cardinality", values.length - 2, c[ROW]); + assertEquals("col cardinality", values[0].length - 1, c[COL]); } public void testCardinality() { Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseColumnMatrix.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseColumnMatrix.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseColumnMatrix.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseColumnMatrix.java Tue Jun 16 13:28:14 2009 @@ -33,9 +33,4 @@ return matrix; } - public void testAsFormatString() { - assertEquals("format", "[[, 1.1, 2.2, ], 3.3, 4.4, ], 5.5, 6.6, ], ] ", - test.asWritableComparable().toString()); - } - } Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseMatrix.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseMatrix.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseMatrix.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseMatrix.java Tue Jun 16 13:28:14 2009 @@ -25,7 +25,7 @@ @Override public Matrix matrixFactory(double[][] values) { - int[] cardinality = {values.length, values[0].length}; + int[] cardinality = { values.length, values[0].length }; Matrix matrix = new SparseMatrix(cardinality); for (int row = 0; row < cardinality[ROW]; row++) for (int col = 0; col < cardinality[COL]; col++) @@ -33,11 +33,4 @@ return matrix; } - public void testAsFormatString() { - assertEquals( - "format", - "[s3, [s2, 0:1.1, 1:2.2, ] [s2, 0:3.3, 1:4.4, ] [s2, 0:5.5, 1:6.6, ] ] ", - test.asWritableComparable().toString()); - } - } Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseRowMatrix.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseRowMatrix.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseRowMatrix.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseRowMatrix.java Tue Jun 16 13:28:14 2009 @@ -33,11 +33,6 @@ return matrix; } - public void testAsFormatString() { - assertEquals("format", "[[, 1.1, 2.2, ], 3.3, 4.4, ], 5.5, 6.6, ], ] ", - test.asWritableComparable().toString()); - } - } Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java Tue Jun 16 13:28:14 2009 @@ -1,4 +1,3 @@ - /** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -18,6 +17,9 @@ package org.apache.mahout.matrix; +import java.util.HashMap; +import java.util.Map; + import junit.framework.TestCase; public class TestSparseVector extends TestCase { @@ -39,7 +41,10 @@ public void testAsFormatString() { String formatString = test.asWritableComparable().toString(); - assertEquals("format", "[s5, 1:1.1, 2:2.2, 3:3.3, ] ", formatString); + assertEquals( + "format", + "{\"class\":\"org.apache.mahout.matrix.SparseVector\",\"vector\":\"{\\\"values\\\":{\\\"1\\\":1.1,\\\"2\\\":2.2,\\\"3\\\":3.3},\\\"cardinality\\\":5}\"}", + formatString); } public void testCardinality() { @@ -139,8 +144,8 @@ } } - public void testDecodeFormat() throws Exception { - Vector val = SparseVector.decodeFormat(test.asWritableComparable()); + public void testDecodeVectort() throws Exception { + Vector val = AbstractVector.decodeVector(test.asWritableComparable()); for (int i = 0; i < test.cardinality(); i++) assertEquals("get [" + i + ']', test.get(i), val.get(i)); } @@ -383,4 +388,17 @@ assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row) * test.getQuick(col), result.getQuick(row, col)); } + + public void testLabelIndexing() { + Map<String, Integer> bindings = new HashMap<String, Integer>(); + bindings.put("Fee", 0); + bindings.put("Fie", 1); + bindings.put("Foe", 2); + test.setLabelBindings(bindings); + assertEquals("Fee", test.get(0), test.get("Fee")); + assertEquals("Fie", test.get(1), test.get("Fie")); + assertEquals("Foe", test.get(2), test.get("Foe")); + test.set("Fie", 15.3); + assertEquals("Fie", test.get(1), test.get("Fie")); + } } Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorView.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorView.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorView.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorView.java Tue Jun 16 13:28:14 2009 @@ -25,9 +25,10 @@ private static final int offset = 1; - final double[] values = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; + final double[] values = { 0.0, 1.1, 2.2, 3.3, 4.4, 5.5 }; - final Vector test = new VectorView(new DenseVector(values), offset, cardinality); + final Vector test = new VectorView(new DenseVector(values), offset, + cardinality); public TestVectorView(String name) { super(name); @@ -35,7 +36,8 @@ public void testAsFormatString() { String formatString = test.asWritableComparable().toString(); - assertEquals("format", "[, 2.2, 3.3, 4.4, ] ", formatString); + Vector v = AbstractVector.decodeVector(formatString); + assertEquals("cardinality", test.cardinality(), v.cardinality()); } public void testCardinality() { @@ -200,7 +202,7 @@ assertEquals("cardinality", 3, val.cardinality()); for (int i = 0; i < test.cardinality(); i++) assertEquals("get [" + i + ']', values[offset + i] * values[offset + i], - val.get(i)); + val.get(i)); } public void testTimesVectorCardinality() { @@ -326,6 +328,6 @@ for (int row = 0; row < result.cardinality()[0]; row++) for (int col = 0; col < result.cardinality()[1]; col++) assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row) - * test.getQuick(col), result.getQuick(row, col)); + * test.getQuick(col), result.getQuick(row, col)); } } Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java Tue Jun 16 13:28:14 2009 @@ -17,11 +17,18 @@ package org.apache.mahout.matrix; -import junit.framework.TestCase; - +import java.lang.reflect.Type; import java.util.Date; +import java.util.HashMap; +import java.util.Map; import java.util.Random; +import junit.framework.TestCase; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.reflect.TypeToken; + public class VectorTest extends TestCase { public VectorTest(String s) { @@ -46,19 +53,22 @@ assertTrue("equivalent didn't work", AbstractVector.equivalent(left, right)); assertTrue("equals didn't work", left.equals(right)); right.setQuick(2, 4); - assertTrue("equivalent didn't work", AbstractVector.equivalent(left, right) == false); - assertTrue("equals didn't work", left.equals(right) == false); + assertTrue("equivalent didn't work", + AbstractVector.equivalent(left, right) == false); + assertTrue("equals didn't work", left.equals(right) == false); right = new DenseVector(4); right.setQuick(0, 1); right.setQuick(1, 2); right.setQuick(2, 3); right.setQuick(3, 3); - assertTrue("equivalent didn't work", AbstractVector.equivalent(left, right) == false); + assertTrue("equivalent didn't work", + AbstractVector.equivalent(left, right) == false); assertTrue("equals didn't work", left.equals(right) == false); left = new SparseVector(2); left.setQuick(0, 1); left.setQuick(1, 2); - assertTrue("equivalent didn't work", AbstractVector.equivalent(left, right) == false); + assertTrue("equivalent didn't work", + AbstractVector.equivalent(left, right) == false); assertTrue("equals didn't work", left.equals(right) == false); DenseVector dense = new DenseVector(3); @@ -69,7 +79,8 @@ dense.setQuick(0, 1); dense.setQuick(1, 2); dense.setQuick(2, 3); - assertTrue("equivalent didn't work", AbstractVector.equivalent(dense, right) == true); + assertTrue("equivalent didn't work", AbstractVector + .equivalent(dense, right) == true); assertTrue("equals didn't work", dense.equals(right) == true); SparseVector sparse = new SparseVector(3); @@ -80,23 +91,24 @@ left.setQuick(0, 1); left.setQuick(1, 2); left.setQuick(2, 3); - assertTrue("equivalent didn't work", AbstractVector.equivalent(sparse, left) == true); + assertTrue("equivalent didn't work", AbstractVector + .equivalent(sparse, left) == true); assertTrue("equals didn't work", left.equals(sparse) == true); VectorView v1 = new VectorView(left, 0, 2); VectorView v2 = new VectorView(right, 0, 2); - assertTrue("equivalent didn't work", AbstractVector.equivalent(v1, v2) == true); + assertTrue("equivalent didn't work", + AbstractVector.equivalent(v1, v2) == true); assertTrue("equals didn't work", v1.equals(v2) == true); sparse = new SparseVector(2); sparse.setQuick(0, 1); sparse.setQuick(1, 2); - assertTrue("equivalent didn't work", AbstractVector.equivalent(v1, sparse) == true); + assertTrue("equivalent didn't work", + AbstractVector.equivalent(v1, sparse) == true); assertTrue("equals didn't work", v1.equals(sparse) == true); - } - private static void doTestVectors(Vector left, Vector right) { left.setQuick(0, 1); left.setQuick(1, 2); @@ -110,7 +122,8 @@ System.out.println("Vec: " + formattedString); Vector vec = AbstractVector.decodeVector(formattedString); assertTrue("vec is null and it shouldn't be", vec != null); - assertTrue("Vector could not be decoded from the formatString", AbstractVector.equivalent(vec, left)); + assertTrue("Vector could not be decoded from the formatString", + AbstractVector.equivalent(vec, left)); } public void testNormalize() throws Exception { @@ -132,40 +145,43 @@ assertTrue("norm is not equal to expected", norm.equals(expected)); norm = vec1.normalize(1); - expected.setQuick(0, 1.0/6); - expected.setQuick(1, 2.0/6); - expected.setQuick(2, 3.0/6); + expected.setQuick(0, 1.0 / 6); + expected.setQuick(1, 2.0 / 6); + expected.setQuick(2, 3.0 / 6); assertTrue("norm is not equal to expected", norm.equals(expected)); norm = vec1.normalize(3); - //TODO this is not used + // TODO this is not used expected = vec1.times(vec1).times(vec1); - //double sum = expected.zSum(); - //cube = Math.pow(sum, 1.0/3); - double cube = Math.pow(36, 1.0/3); + // double sum = expected.zSum(); + // cube = Math.pow(sum, 1.0/3); + double cube = Math.pow(36, 1.0 / 3); expected = vec1.divide(cube); - - assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected)); + + assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + + expected.asFormatString(), norm.equals(expected)); norm = vec1.normalize(Double.POSITIVE_INFINITY); - //The max is 3, so we divide by that. - expected.setQuick(0, 1.0/3); - expected.setQuick(1, 2.0/3); - expected.setQuick(2, 3.0/3); - assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected)); + // The max is 3, so we divide by that. + expected.setQuick(0, 1.0 / 3); + expected.setQuick(1, 2.0 / 3); + expected.setQuick(2, 3.0 / 3); + assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + + expected.asFormatString(), norm.equals(expected)); norm = vec1.normalize(0); - //The max is 3, so we divide by that. - expected.setQuick(0, 1.0/3); - expected.setQuick(1, 2.0/3); - expected.setQuick(2, 3.0/3); - assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected)); + // The max is 3, so we divide by that. + expected.setQuick(0, 1.0 / 3); + expected.setQuick(1, 2.0 / 3); + expected.setQuick(2, 3.0 / 3); + assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + + expected.asFormatString(), norm.equals(expected)); try { vec1.normalize(-1); assertTrue(false); } catch (IllegalArgumentException e) { - //expected + // expected } } @@ -215,12 +231,12 @@ } public void testEnumeration() throws Exception { - double[] apriori = {0, 1, 2, 3, 4}; + double[] apriori = { 0, 1, 2, 3, 4 }; - doTestEnumeration(apriori, new VectorView(new DenseVector(new double[]{ - -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), 2, 5)); + doTestEnumeration(apriori, new VectorView(new DenseVector(new double[] { + -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }), 2, 5)); - doTestEnumeration(apriori, new DenseVector(new double[]{0, 1, 2, 3, 4})); + doTestEnumeration(apriori, new DenseVector(new double[] { 0, 1, 2, 3, 4 })); SparseVector sparse = new SparseVector(5); sparse.set(0, 0); @@ -250,7 +266,7 @@ long tRef = t1 - t0; assertTrue(tOpt < tRef); System.out.println("testSparseVectorTimesX tRef=tOpt=" + (tRef - tOpt) - + " ms for 10 iterations"); + + " ms for 10 iterations"); for (int i = 0; i < 50000; i++) assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i)); } @@ -274,7 +290,7 @@ long tRef = t1 - t0; assertTrue(tOpt < tRef); System.out.println("testSparseVectorTimesV tRef=tOpt=" + (tRef - tOpt) - + " ms for 10 iterations"); + + " ms for 10 iterations"); for (int i = 0; i < 50000; i++) assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i)); } @@ -286,4 +302,51 @@ return v1; } + public void testLabelSerializationDense() { + double[] values = { 1.1, 2.2, 3.3 }; + Vector test = new DenseVector(values); + Map<String, Integer> bindings = new HashMap<String, Integer>(); + bindings.put("Fee", 0); + bindings.put("Fie", 1); + bindings.put("Foe", 2); + test.setLabelBindings(bindings); + + Type vectorType = new TypeToken<Vector>() { + }.getType(); + + GsonBuilder builder = new GsonBuilder(); + builder.registerTypeAdapter(vectorType, new JsonVectorAdapter()); + Gson gson = builder.create(); + String json = gson.toJson(test, vectorType); + Vector test1 = gson.fromJson(json, vectorType); + assertEquals("Fee", test.get(0), test1.get("Fee")); + assertEquals("Fie", test.get(1), test1.get("Fie")); + assertEquals("Foe", test.get(2), test1.get("Foe")); + + } + + public void testLabelSerializationSparse() { + double[] values = { 1.1, 2.2, 3.3 }; + Vector test = new SparseVector(3); + for (int i = 0; i < values.length; i++) + test.set(i, values[i]); + Map<String, Integer> bindings = new HashMap<String, Integer>(); + bindings.put("Fee", 0); + bindings.put("Fie", 1); + bindings.put("Foe", 2); + test.setLabelBindings(bindings); + + Type vectorType = new TypeToken<Vector>() { + }.getType(); + + GsonBuilder builder = new GsonBuilder(); + builder.registerTypeAdapter(vectorType, new JsonVectorAdapter()); + Gson gson = builder.create(); + String json = gson.toJson(test, vectorType); + Vector test1 = gson.fromJson(json, vectorType); + assertEquals("Fee", test.get(0), test1.get("Fee")); + assertEquals("Fie", test.get(1), test1.get("Fie")); + assertEquals("Foe", test.get(2), test1.get("Foe")); + } + } Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java Tue Jun 16 13:28:14 2009 @@ -32,6 +32,7 @@ import org.apache.mahout.clustering.dirichlet.models.Model; import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution; import org.apache.mahout.clustering.kmeans.KMeansDriver; +import org.apache.mahout.matrix.AbstractVector; import org.apache.mahout.matrix.DenseVector; import org.apache.mahout.matrix.Vector; @@ -78,7 +79,7 @@ List<Vector> results = new ArrayList<Vector>(); String line; while ((line = r.readLine()) != null) - results.add(DenseVector.decodeFormat(line)); + results.add(AbstractVector.decodeVector(line)); return results; } finally { r.close(); Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java?rev=785206&r1=785205&r2=785206&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java Tue Jun 16 13:28:14 2009 @@ -32,6 +32,7 @@ import org.apache.mahout.clustering.dirichlet.models.NormalModel; import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution; import org.apache.mahout.clustering.kmeans.KMeansDriver; +import org.apache.mahout.matrix.AbstractVector; import org.apache.mahout.matrix.DenseVector; import org.apache.mahout.matrix.Vector; @@ -77,7 +78,7 @@ List<Vector> results = new ArrayList<Vector>(); String line; while ((line = r.readLine()) != null) - results.add(DenseVector.decodeFormat(line)); + results.add(AbstractVector.decodeVector(line)); return results; } finally { r.close();
