http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java b/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java new file mode 100644 index 0000000..7ccd6a7 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +public final class TestStableFixedSizeSampler extends SamplerCase { + + @Override + protected Iterator<Integer> createSampler(int n, Iterator<Integer> source) { + return new StableFixedSizeSamplingIterator<Integer>(n, source); + } + + @Override + protected boolean isSorted() { + return true; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java b/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java new file mode 100644 index 0000000..f94d63e --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.lucene; + +import org.apache.lucene.analysis.cjk.CJKAnalyzer; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.junit.Test; + +import static org.junit.Assert.assertNotNull; + +public class AnalyzerUtilsTest { + + @Test + public void createStandardAnalyzer() throws Exception { + assertNotNull(AnalyzerUtils.createAnalyzer(StandardAnalyzer.class.getName())); + } + + @Test + public void createCJKAnalyzer() throws Exception { + assertNotNull(AnalyzerUtils.createAnalyzer(CJKAnalyzer.class.getName())); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java b/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java new file mode 100644 index 0000000..e0bdc98 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.driver; + +import org.junit.Test; + +/** + * Tests if MahoutDriver can be run directly through its main method. + */ +public final class MahoutDriverTest { + + @Test + public void testMain() throws Throwable { + MahoutDriver.main(new String[] {"canopy", "help"}); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java b/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java new file mode 100644 index 0000000..e53db7e --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +public final class EvolutionaryProcessTest extends MahoutTestCase { + + @Test + public void testConverges() throws Exception { + State<Foo, Double> s0 = new State<Foo, Double>(new double[5], 1); + s0.setPayload(new Foo()); + EvolutionaryProcess<Foo, Double> ep = new EvolutionaryProcess<Foo, Double>(10, 100, s0); + + State<Foo, Double> best = null; + for (int i = 0; i < 20; i++) { + best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<Double>>() { + @Override + public double apply(Payload<Double> payload, double[] params) { + int i = 1; + double sum = 0; + for (double x : params) { + sum += i * (x - i) * (x - i); + i++; + } + return -sum; + } + }); + + ep.mutatePopulation(3); + + System.out.printf("%10.3f %.3f\n", best.getValue(), best.getOmni()); + } + + ep.close(); + assertNotNull(best); + assertEquals(0.0, best.getValue(), 0.02); + } + + private static class Foo implements Payload<Double> { + @Override + public Foo copy() { + return this; + } + + @Override + public void update(double[] params) { + // ignore + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + // no-op + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + // no-op + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java b/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java new file mode 100644 index 0000000..226d4b1 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Map; + +import com.google.common.collect.Maps; +import com.google.common.io.Closeables; +import org.apache.hadoop.io.Writable; +import org.junit.Test; + +public final class MatrixWritableTest extends MahoutTestCase { + + @Test + public void testSparseMatrixWritable() throws Exception { + Matrix m = new SparseMatrix(5, 5); + m.set(1, 2, 3.0); + m.set(3, 4, 5.0); + Map<String, Integer> bindings = Maps.newHashMap(); + bindings.put("A", 0); + bindings.put("B", 1); + bindings.put("C", 2); + bindings.put("D", 3); + bindings.put("default", 4); + m.setRowLabelBindings(bindings); + m.setColumnLabelBindings(bindings); + doTestMatrixWritableEquals(m); + } + + @Test + public void testSparseRowMatrixWritable() throws Exception { + Matrix m = new SparseRowMatrix(5, 5); + m.set(1, 2, 3.0); + m.set(3, 4, 5.0); + Map<String, Integer> bindings = Maps.newHashMap(); + bindings.put("A", 0); + bindings.put("B", 1); + bindings.put("C", 2); + bindings.put("D", 3); + bindings.put("default", 4); + m.setRowLabelBindings(bindings); + m.setColumnLabelBindings(bindings); + doTestMatrixWritableEquals(m); + } + + @Test + public void testDenseMatrixWritable() throws Exception { + Matrix m = new DenseMatrix(5,5); + m.set(1, 2, 3.0); + m.set(3, 4, 5.0); + Map<String, Integer> bindings = Maps.newHashMap(); + bindings.put("A", 0); + bindings.put("B", 1); + bindings.put("C", 2); + bindings.put("D", 3); + bindings.put("default", 4); + m.setRowLabelBindings(bindings); + m.setColumnLabelBindings(bindings); + doTestMatrixWritableEquals(m); + } + + private static void doTestMatrixWritableEquals(Matrix m) throws IOException { + Writable matrixWritable = new MatrixWritable(m); + MatrixWritable matrixWritable2 = new MatrixWritable(); + writeAndRead(matrixWritable, matrixWritable2); + Matrix m2 = matrixWritable2.get(); + compareMatrices(m, m2); + doCheckBindings(m2.getRowLabelBindings()); + doCheckBindings(m2.getColumnLabelBindings()); + } + + private static void compareMatrices(Matrix m, Matrix m2) { + assertEquals(m.numRows(), m2.numRows()); + assertEquals(m.numCols(), m2.numCols()); + for (int r = 0; r < m.numRows(); r++) { + for (int c = 0; c < m.numCols(); c++) { + assertEquals(m.get(r, c), m2.get(r, c), EPSILON); + } + } + Map<String,Integer> bindings = m.getRowLabelBindings(); + Map<String, Integer> bindings2 = m2.getRowLabelBindings(); + assertEquals(bindings == null, bindings2 == null); + if (bindings != null) { + assertEquals(bindings.size(), m.numRows()); + assertEquals(bindings.size(), bindings2.size()); + for (Map.Entry<String,Integer> entry : bindings.entrySet()) { + assertEquals(entry.getValue(), bindings2.get(entry.getKey())); + } + } + bindings = m.getColumnLabelBindings(); + bindings2 = m2.getColumnLabelBindings(); + assertEquals(bindings == null, bindings2 == null); + if (bindings != null) { + assertEquals(bindings.size(), bindings2.size()); + for (Map.Entry<String,Integer> entry : bindings.entrySet()) { + assertEquals(entry.getValue(), bindings2.get(entry.getKey())); + } + } + } + + private static void doCheckBindings(Map<String,Integer> labels) { + assertTrue("Missing label", labels.keySet().contains("A")); + assertTrue("Missing label", labels.keySet().contains("B")); + assertTrue("Missing label", labels.keySet().contains("C")); + assertTrue("Missing label", labels.keySet().contains("D")); + assertTrue("Missing label", labels.keySet().contains("default")); + } + + private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + try { + toWrite.write(dos); + } finally { + Closeables.close(dos, false); + } + + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + DataInputStream dis = new DataInputStream(bais); + try { + toRead.readFields(dis); + } finally { + Closeables.close(dis, true); + } + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/VarintTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/VarintTest.java b/mr/src/test/java/org/apache/mahout/math/VarintTest.java new file mode 100644 index 0000000..0b1a664 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/VarintTest.java @@ -0,0 +1,189 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; + +/** + * Tests {@link Varint}. + */ +public final class VarintTest extends MahoutTestCase { + + @Test + public void testUnsignedLong() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeUnsignedVarLong(0L, out); + for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) { + Varint.writeUnsignedVarLong(i-1, out); + Varint.writeUnsignedVarLong(i, out); + } + Varint.writeUnsignedVarLong(Long.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0L, Varint.readUnsignedVarLong(in)); + for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) { + assertEquals(i-1, Varint.readUnsignedVarLong(in)); + assertEquals(i, Varint.readUnsignedVarLong(in)); + } + assertEquals(Long.MAX_VALUE, Varint.readUnsignedVarLong(in)); + } + + @Test + public void testSignedPositiveLong() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeSignedVarLong(0L, out); + for (long i = 1L; i <= (1L << 61); i <<= 1) { + Varint.writeSignedVarLong(i-1, out); + Varint.writeSignedVarLong(i, out); + } + Varint.writeSignedVarLong((1L << 62) - 1, out); + Varint.writeSignedVarLong((1L << 62), out); + Varint.writeSignedVarLong(Long.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0L, Varint.readSignedVarLong(in)); + for (long i = 1L; i <= (1L << 61); i <<= 1) { + assertEquals(i-1, Varint.readSignedVarLong(in)); + assertEquals(i, Varint.readSignedVarLong(in)); + } + assertEquals((1L << 62) - 1, Varint.readSignedVarLong(in)); + assertEquals((1L << 62), Varint.readSignedVarLong(in)); + assertEquals(Long.MAX_VALUE, Varint.readSignedVarLong(in)); + } + + @Test + public void testSignedNegativeLong() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + for (long i = -1L; i >= -(1L << 62); i <<= 1) { + Varint.writeSignedVarLong(i, out); + Varint.writeSignedVarLong(i+1, out); + } + Varint.writeSignedVarLong(Long.MIN_VALUE, out); + Varint.writeSignedVarLong(Long.MIN_VALUE+1, out); + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + for (long i = -1L; i >= -(1L << 62); i <<= 1) { + assertEquals(i, Varint.readSignedVarLong(in)); + assertEquals(i+1, Varint.readSignedVarLong(in)); + } + assertEquals(Long.MIN_VALUE, Varint.readSignedVarLong(in)); + assertEquals(Long.MIN_VALUE+1, Varint.readSignedVarLong(in)); + } + + @Test + public void testUnsignedInt() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeUnsignedVarInt(0, out); + for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) { + Varint.writeUnsignedVarLong(i-1, out); + Varint.writeUnsignedVarLong(i, out); + } + Varint.writeUnsignedVarLong(Integer.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0, Varint.readUnsignedVarInt(in)); + for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) { + assertEquals(i-1, Varint.readUnsignedVarInt(in)); + assertEquals(i, Varint.readUnsignedVarInt(in)); + } + assertEquals(Integer.MAX_VALUE, Varint.readUnsignedVarInt(in)); + } + + @Test + public void testSignedPositiveInt() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeSignedVarInt(0, out); + for (int i = 1; i <= (1 << 29); i <<= 1) { + Varint.writeSignedVarLong(i-1, out); + Varint.writeSignedVarLong(i, out); + } + Varint.writeSignedVarInt((1 << 30) - 1, out); + Varint.writeSignedVarInt((1 << 30), out); + Varint.writeSignedVarInt(Integer.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0, Varint.readSignedVarInt(in)); + for (int i = 1; i <= (1 << 29); i <<= 1) { + assertEquals(i-1, Varint.readSignedVarInt(in)); + assertEquals(i, Varint.readSignedVarInt(in)); + } + assertEquals((1L << 30) - 1, Varint.readSignedVarInt(in)); + assertEquals((1L << 30), Varint.readSignedVarInt(in)); + assertEquals(Integer.MAX_VALUE, Varint.readSignedVarInt(in)); + } + + @Test + public void testSignedNegativeInt() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + for (int i = -1; i >= -(1 << 30); i <<= 1) { + Varint.writeSignedVarInt(i, out); + Varint.writeSignedVarInt(i+1, out); + } + Varint.writeSignedVarInt(Integer.MIN_VALUE, out); + Varint.writeSignedVarInt(Integer.MIN_VALUE+1, out); + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + for (int i = -1; i >= -(1 << 30); i <<= 1) { + assertEquals(i, Varint.readSignedVarInt(in)); + assertEquals(i+1, Varint.readSignedVarInt(in)); + } + assertEquals(Integer.MIN_VALUE, Varint.readSignedVarInt(in)); + assertEquals(Integer.MIN_VALUE+1, Varint.readSignedVarInt(in)); + } + + @Test + public void testUnsignedSize() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + int expectedSize = 0; + for (int exponent = 0; exponent <= 62; exponent++) { + Varint.writeUnsignedVarLong(1L << exponent, out); + expectedSize += 1 + exponent / 7; + assertEquals(expectedSize, baos.size()); + } + } + + @Test + public void testSignedSize() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + int expectedSize = 0; + for (int exponent = 0; exponent <= 61; exponent++) { + Varint.writeSignedVarLong(1L << exponent, out); + expectedSize += 1 + ((exponent + 1) / 7); + assertEquals(expectedSize, baos.size()); + } + for (int exponent = 0; exponent <= 61; exponent++) { + Varint.writeSignedVarLong(-(1L << exponent)-1, out); + expectedSize += 1 + ((exponent + 1) / 7); + assertEquals(expectedSize, baos.size()); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java b/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java new file mode 100644 index 0000000..60fb8b4 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java @@ -0,0 +1,123 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.mahout.math; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.math.Vector.Element; +import org.junit.Test; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.Repeat; +import com.google.common.io.Closeables; + +public final class VectorWritableTest extends RandomizedTest { + private static final int MAX_VECTOR_SIZE = 100; + + public void createRandom(Vector v) { + int size = randomInt(v.size() - 1); + for (int i = 0; i < size; ++i) { + v.set(randomInt(v.size() - 1), randomDouble()); + } + + int zeros = Math.max(2, size / 4); + for (Element e : v.nonZeroes()) { + if (e.index() % zeros == 0) { + e.set(0.0); + } + } + } + + @Test + @Repeat(iterations = 20) + public void testViewSequentialAccessSparseVectorWritable() throws Exception { + Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE); + createRandom(v); + Vector view = new VectorView(v, 0, v.size()); + doTestVectorWritableEquals(view); + } + + @Test + @Repeat(iterations = 20) + public void testSequentialAccessSparseVectorWritable() throws Exception { + Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE); + createRandom(v); + doTestVectorWritableEquals(v); + } + + @Test + @Repeat(iterations = 20) + public void testRandomAccessSparseVectorWritable() throws Exception { + Vector v = new RandomAccessSparseVector(MAX_VECTOR_SIZE); + createRandom(v); + doTestVectorWritableEquals(v); + } + + @Test + @Repeat(iterations = 20) + public void testDenseVectorWritable() throws Exception { + Vector v = new DenseVector(MAX_VECTOR_SIZE); + createRandom(v); + doTestVectorWritableEquals(v); + } + + @Test + @Repeat(iterations = 20) + public void testNamedVectorWritable() throws Exception { + Vector v = new DenseVector(MAX_VECTOR_SIZE); + v = new NamedVector(v, "Victor"); + createRandom(v); + doTestVectorWritableEquals(v); + } + + private static void doTestVectorWritableEquals(Vector v) throws IOException { + Writable vectorWritable = new VectorWritable(v); + VectorWritable vectorWritable2 = new VectorWritable(); + writeAndRead(vectorWritable, vectorWritable2); + Vector v2 = vectorWritable2.get(); + if (v instanceof NamedVector) { + assertTrue(v2 instanceof NamedVector); + NamedVector nv = (NamedVector) v; + NamedVector nv2 = (NamedVector) v2; + assertEquals(nv.getName(), nv2.getName()); + assertEquals("Victor", nv.getName()); + } + assertEquals(v, v2); + } + + private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + try { + toWrite.write(dos); + } finally { + Closeables.close(dos, false); + } + + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + DataInputStream dis = new DataInputStream(bais); + try { + toRead.readFields(dis); + } finally { + Closeables.close(dos, true); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java b/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java new file mode 100644 index 0000000..a23f7b4 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java @@ -0,0 +1,236 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop; + +import java.io.IOException; +import java.text.DecimalFormat; +import java.text.DecimalFormatSymbols; +import java.util.Locale; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.easymock.EasyMock; +import org.easymock.IArgumentMatcher; +import org.junit.Assert; + +/** + * a collection of small helper methods useful for unit-testing mathematical operations + */ +public final class MathHelper { + + private MathHelper() {} + + /** + * convenience method to create a {@link Vector.Element} + */ + public static Vector.Element elem(int index, double value) { + return new ElementToCheck(index, value); + } + + /** + * a simple implementation of {@link Vector.Element} + */ + static class ElementToCheck implements Vector.Element { + private final int index; + private double value; + + ElementToCheck(int index, double value) { + this.index = index; + this.value = value; + } + @Override + public double get() { + return value; + } + @Override + public int index() { + return index; + } + @Override + public void set(double value) { + this.value = value; + } + } + + /** + * applies an {@link IArgumentMatcher} to a {@link VectorWritable} that checks whether all elements are included + */ + public static VectorWritable vectorMatches(final Vector.Element... elements) { + EasyMock.reportMatcher(new IArgumentMatcher() { + @Override + public boolean matches(Object argument) { + if (argument instanceof VectorWritable) { + Vector v = ((VectorWritable) argument).get(); + return consistsOf(v, elements); + } + return false; + } + + @Override + public void appendTo(StringBuffer buffer) {} + }); + return null; + } + + /** + * checks whether the {@link Vector} is equivalent to the set of {@link Vector.Element}s + */ + public static boolean consistsOf(Vector vector, Vector.Element... elements) { + if (elements.length != numberOfNoNZeroNonNaNElements(vector)) { + return false; + } + for (Vector.Element element : elements) { + if (Math.abs(element.get() - vector.get(element.index())) > MahoutTestCase.EPSILON) { + return false; + } + } + return true; + } + + /** + * returns the number of elements in the {@link Vector} that are neither 0 nor NaN + */ + public static int numberOfNoNZeroNonNaNElements(Vector vector) { + int elementsInVector = 0; + for (Element currentElement : vector.nonZeroes()) { + if (!Double.isNaN(currentElement.get())) { + elementsInVector++; + } + } + return elementsInVector; + } + + /** + * read a {@link Matrix} from a SequenceFile<IntWritable,VectorWritable> + */ + public static Matrix readMatrix(Configuration conf, Path path, int rows, int columns) { + boolean readOneRow = false; + Matrix matrix = new DenseMatrix(rows, columns); + for (Pair<IntWritable,VectorWritable> record : + new SequenceFileIterable<IntWritable,VectorWritable>(path, true, conf)) { + IntWritable key = record.getFirst(); + VectorWritable value = record.getSecond(); + readOneRow = true; + int row = key.get(); + for (Element element : value.get().nonZeroes()) { + matrix.set(row, element.index(), element.get()); + } + } + if (!readOneRow) { + throw new IllegalStateException("Not a single row read!"); + } + return matrix; + } + + /** + * read a {@link Matrix} from a SequenceFile<IntWritable,VectorWritable> + */ + public static OpenIntObjectHashMap<Vector> readMatrixRows(Configuration conf, Path path) { + boolean readOneRow = false; + OpenIntObjectHashMap<Vector> rows = new OpenIntObjectHashMap<Vector>(); + for (Pair<IntWritable,VectorWritable> record : + new SequenceFileIterable<IntWritable,VectorWritable>(path, true, conf)) { + IntWritable key = record.getFirst(); + readOneRow = true; + rows.put(key.get(), record.getSecond().get()); + } + if (!readOneRow) { + throw new IllegalStateException("Not a single row read!"); + } + return rows; + } + + /** + * write a two-dimensional double array to an SequenceFile<IntWritable,VectorWritable> + */ + public static void writeDistributedRowMatrix(double[][] entries, FileSystem fs, Configuration conf, Path path) + throws IOException { + SequenceFile.Writer writer = null; + try { + writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class); + for (int n = 0; n < entries.length; n++) { + Vector v = new RandomAccessSparseVector(entries[n].length); + for (int m = 0; m < entries[n].length; m++) { + v.setQuick(m, entries[n][m]); + } + writer.append(new IntWritable(n), new VectorWritable(v)); + } + } finally { + Closeables.close(writer, false); + } + } + + public static void assertMatrixEquals(Matrix expected, Matrix actual) { + Assert.assertEquals(expected.numRows(), actual.numRows()); + Assert.assertEquals(actual.numCols(), actual.numCols()); + for (int row = 0; row < expected.numRows(); row++) { + for (int col = 0; col < expected.numCols(); col ++) { + Assert.assertEquals("Non-matching values in [" + row + ',' + col + ']', + expected.get(row, col), actual.get(row, col), MahoutTestCase.EPSILON); + } + } + } + + public static String nice(Vector v) { + if (!v.isSequentialAccess()) { + v = new DenseVector(v); + } + + DecimalFormat df = new DecimalFormat("0.00", DecimalFormatSymbols.getInstance(Locale.ENGLISH)); + + StringBuilder buffer = new StringBuilder("["); + String separator = ""; + for (Vector.Element e : v.all()) { + buffer.append(separator); + if (Double.isNaN(e.get())) { + buffer.append(" - "); + } else { + if (e.get() >= 0) { + buffer.append(' '); + } + buffer.append(df.format(e.get())); + } + separator = "\t"; + } + buffer.append(" ]"); + return buffer.toString(); + } + + public static String nice(Matrix matrix) { + StringBuilder info = new StringBuilder(); + for (int n = 0; n < matrix.numRows(); n++) { + info.append(nice(matrix.viewRow(n))).append('\n'); + } + return info.toString(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java b/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java new file mode 100644 index 0000000..13da38a --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java @@ -0,0 +1,395 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop; + +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.Job; +import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.decomposer.SolverTest; +import org.apache.mahout.math.function.Functions; +import org.junit.Test; + +import com.google.common.base.Function; +import com.google.common.collect.Iterators; +import com.google.common.collect.Maps; + +public final class TestDistributedRowMatrix extends MahoutTestCase { + public static final String TEST_PROPERTY_KEY = "test.property.key"; + public static final String TEST_PROPERTY_VALUE = "test.property.value"; + + private static void assertEquals(VectorIterable m, VectorIterable mtt, double errorTolerance) { + Iterator<MatrixSlice> mIt = m.iterateAll(); + Iterator<MatrixSlice> mttIt = mtt.iterateAll(); + Map<Integer, Vector> mMap = Maps.newHashMap(); + Map<Integer, Vector> mttMap = Maps.newHashMap(); + while (mIt.hasNext() && mttIt.hasNext()) { + MatrixSlice ms = mIt.next(); + mMap.put(ms.index(), ms.vector()); + MatrixSlice mtts = mttIt.next(); + mttMap.put(mtts.index(), mtts.vector()); + } + for (Map.Entry<Integer, Vector> entry : mMap.entrySet()) { + Integer key = entry.getKey(); + Vector value = entry.getValue(); + if (value == null || mttMap.get(key) == null) { + assertTrue(value == null || value.norm(2) == 0); + assertTrue(mttMap.get(key) == null || mttMap.get(key).norm(2) == 0); + } else { + assertTrue( + value.getDistanceSquared(mttMap.get(key)) < errorTolerance); + } + } + } + + @Test + public void testTranspose() throws Exception { + DistributedRowMatrix m = randomDistributedMatrix(10, 9, 5, 4, 1.0, false); + m.setConf(getConfiguration()); + DistributedRowMatrix mt = m.transpose(); + mt.setConf(getConfiguration()); + + Path tmpPath = getTestTempDirPath(); + m.setOutputTempPathString(tmpPath.toString()); + Path tmpOutPath = new Path(tmpPath, "/tmpOutTranspose"); + mt.setOutputTempPathString(tmpOutPath.toString()); + HadoopUtil.delete(getConfiguration(), tmpOutPath); + DistributedRowMatrix mtt = mt.transpose(); + assertEquals(m, mtt, EPSILON); + } + + @Test + public void testMatrixColumnMeansJob() throws Exception { + Matrix m = + SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0); + DistributedRowMatrix dm = + randomDistributedMatrix(100, 90, 50, 20, 1.0, false); + dm.setConf(getConfiguration()); + + Vector expected = new DenseVector(50); + for (int i = 0; i < m.numRows(); i++) { + expected.assign(m.viewRow(i), Functions.PLUS); + } + expected.assign(Functions.DIV, m.numRows()); + Vector actual = dm.columnMeans("DenseVector"); + assertEquals(0.0, expected.getDistanceSquared(actual), EPSILON); + } + + @Test + public void testNullMatrixColumnMeansJob() throws Exception { + Matrix m = + SolverTest.randomSequentialAccessSparseMatrix(100, 90, 0, 0, 1.0); + DistributedRowMatrix dm = + randomDistributedMatrix(100, 90, 0, 0, 1.0, false); + dm.setConf(getConfiguration()); + + Vector expected = new DenseVector(0); + for (int i = 0; i < m.numRows(); i++) { + expected.assign(m.viewRow(i), Functions.PLUS); + } + expected.assign(Functions.DIV, m.numRows()); + Vector actual = dm.columnMeans(); + assertEquals(0.0, expected.getDistanceSquared(actual), EPSILON); + } + + @Test + public void testMatrixTimesVector() throws Exception { + Vector v = new RandomAccessSparseVector(50); + v.assign(1.0); + Matrix m = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0); + DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false); + dm.setConf(getConfiguration()); + + Vector expected = m.times(v); + Vector actual = dm.times(v); + assertEquals(0.0, expected.getDistanceSquared(actual), EPSILON); + } + + @Test + public void testMatrixTimesSquaredVector() throws Exception { + Vector v = new RandomAccessSparseVector(50); + v.assign(1.0); + Matrix m = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0); + DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false); + dm.setConf(getConfiguration()); + + Vector expected = m.timesSquared(v); + Vector actual = dm.timesSquared(v); + assertEquals(0.0, expected.getDistanceSquared(actual), 1.0e-9); + } + + @Test + public void testMatrixTimesMatrix() throws Exception { + Matrix inputA = SolverTest.randomSequentialAccessSparseMatrix(20, 19, 15, 5, 10.0); + Matrix inputB = SolverTest.randomSequentialAccessSparseMatrix(20, 13, 25, 10, 5.0); + Matrix expected = inputA.transpose().times(inputB); + + DistributedRowMatrix distA = randomDistributedMatrix(20, 19, 15, 5, 10.0, false, "distA"); + distA.setConf(getConfiguration()); + DistributedRowMatrix distB = randomDistributedMatrix(20, 13, 25, 10, 5.0, false, "distB"); + distB.setConf(getConfiguration()); + DistributedRowMatrix product = distA.times(distB); + + assertEquals(expected, product, EPSILON); + } + + @Test + public void testMatrixMultiplactionJobConfBuilder() throws Exception { + Configuration initialConf = createInitialConf(); + + Path baseTmpDirPath = getTestTempDirPath("testpaths"); + Path aPath = new Path(baseTmpDirPath, "a"); + Path bPath = new Path(baseTmpDirPath, "b"); + Path outPath = new Path(baseTmpDirPath, "out"); + + Configuration mmJobConf = MatrixMultiplicationJob.createMatrixMultiplyJobConf(aPath, bPath, outPath, 10); + Configuration mmCustomJobConf = MatrixMultiplicationJob.createMatrixMultiplyJobConf(initialConf, + aPath, + bPath, + outPath, + 10); + + assertNull(mmJobConf.get(TEST_PROPERTY_KEY)); + assertEquals(TEST_PROPERTY_VALUE, mmCustomJobConf.get(TEST_PROPERTY_KEY)); + } + + @Test + public void testTransposeJobConfBuilder() throws Exception { + Configuration initialConf = createInitialConf(); + + Path baseTmpDirPath = getTestTempDirPath("testpaths"); + Path inputPath = new Path(baseTmpDirPath, "input"); + Path outputPath = new Path(baseTmpDirPath, "output"); + + Configuration transposeJobConf = TransposeJob.buildTransposeJob(inputPath, outputPath, 10).getConfiguration(); + + Configuration transposeCustomJobConf = TransposeJob.buildTransposeJob(initialConf, inputPath, outputPath, 10) + .getConfiguration(); + + assertNull(transposeJobConf.get(TEST_PROPERTY_KEY)); + assertEquals(TEST_PROPERTY_VALUE, transposeCustomJobConf.get(TEST_PROPERTY_KEY)); + } + + @Test public void testTimesSquaredJobConfBuilders() throws Exception { + Configuration initialConf = createInitialConf(); + + Path baseTmpDirPath = getTestTempDirPath("testpaths"); + Path inputPath = new Path(baseTmpDirPath, "input"); + Path outputPath = new Path(baseTmpDirPath, "output"); + + Vector v = new RandomAccessSparseVector(50); + v.assign(1.0); + + Job timesSquaredJob1 = TimesSquaredJob.createTimesSquaredJob(v, inputPath, outputPath); + Job customTimesSquaredJob1 = TimesSquaredJob.createTimesSquaredJob(initialConf, v, inputPath, outputPath); + + assertNull(timesSquaredJob1.getConfiguration().get(TEST_PROPERTY_KEY)); + assertEquals(TEST_PROPERTY_VALUE, customTimesSquaredJob1.getConfiguration().get(TEST_PROPERTY_KEY)); + + Job timesJob = TimesSquaredJob.createTimesJob(v, 50, inputPath, outputPath); + Job customTimesJob = TimesSquaredJob.createTimesJob(initialConf, v, 50, inputPath, outputPath); + + assertNull(timesJob.getConfiguration().get(TEST_PROPERTY_KEY)); + assertEquals(TEST_PROPERTY_VALUE, customTimesJob.getConfiguration().get(TEST_PROPERTY_KEY)); + + Job timesSquaredJob2 = TimesSquaredJob.createTimesSquaredJob(v, inputPath, outputPath, + TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class); + + Job customTimesSquaredJob2 = TimesSquaredJob.createTimesSquaredJob(initialConf, v, inputPath, + outputPath, TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class); + + assertNull(timesSquaredJob2.getConfiguration().get(TEST_PROPERTY_KEY)); + assertEquals(TEST_PROPERTY_VALUE, customTimesSquaredJob2.getConfiguration().get(TEST_PROPERTY_KEY)); + + Job timesSquaredJob3 = TimesSquaredJob.createTimesSquaredJob(v, 50, inputPath, outputPath, + TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class); + + Job customTimesSquaredJob3 = TimesSquaredJob.createTimesSquaredJob(initialConf, + v, 50, inputPath, outputPath, TimesSquaredJob.TimesSquaredMapper.class, + TimesSquaredJob.VectorSummingReducer.class); + + assertNull(timesSquaredJob3.getConfiguration().get(TEST_PROPERTY_KEY)); + assertEquals(TEST_PROPERTY_VALUE, customTimesSquaredJob3.getConfiguration().get(TEST_PROPERTY_KEY)); + } + + @Test + public void testTimesVectorTempDirDeletion() throws Exception { + Configuration conf = getConfiguration(); + Vector v = new RandomAccessSparseVector(50); + v.assign(1.0); + DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false); + dm.setConf(conf); + + Path outputPath = dm.getOutputTempPath(); + FileSystem fs = outputPath.getFileSystem(conf); + + deleteContentsOfPath(conf, outputPath); + + assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length); + + Vector result1 = dm.times(v); + + assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length); + + deleteContentsOfPath(conf, outputPath); + assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length); + + conf.setBoolean(DistributedRowMatrix.KEEP_TEMP_FILES, true); + dm.setConf(conf); + + Vector result2 = dm.times(v); + + FileStatus[] outputStatuses = fs.listStatus(outputPath); + assertEquals(1, outputStatuses.length); + Path outputTempPath = outputStatuses[0].getPath(); + Path inputVectorPath = new Path(outputTempPath, TimesSquaredJob.INPUT_VECTOR); + Path outputVectorPath = new Path(outputTempPath, TimesSquaredJob.OUTPUT_VECTOR_FILENAME); + assertEquals(1, fs.listStatus(inputVectorPath, PathFilters.logsCRCFilter()).length); + assertEquals(1, fs.listStatus(outputVectorPath, PathFilters.logsCRCFilter()).length); + + assertEquals(0.0, result1.getDistanceSquared(result2), EPSILON); + } + + @Test + public void testTimesSquaredVectorTempDirDeletion() throws Exception { + Configuration conf = getConfiguration(); + Vector v = new RandomAccessSparseVector(50); + v.assign(1.0); + DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false); + dm.setConf(getConfiguration()); + + Path outputPath = dm.getOutputTempPath(); + FileSystem fs = outputPath.getFileSystem(conf); + + deleteContentsOfPath(conf, outputPath); + + assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length); + + Vector result1 = dm.timesSquared(v); + + assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length); + + deleteContentsOfPath(conf, outputPath); + assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length); + + conf.setBoolean(DistributedRowMatrix.KEEP_TEMP_FILES, true); + dm.setConf(conf); + + Vector result2 = dm.timesSquared(v); + + FileStatus[] outputStatuses = fs.listStatus(outputPath); + assertEquals(1, outputStatuses.length); + Path outputTempPath = outputStatuses[0].getPath(); + Path inputVectorPath = new Path(outputTempPath, TimesSquaredJob.INPUT_VECTOR); + Path outputVectorPath = new Path(outputTempPath, TimesSquaredJob.OUTPUT_VECTOR_FILENAME); + assertEquals(1, fs.listStatus(inputVectorPath, PathFilters.logsCRCFilter()).length); + assertEquals(1, fs.listStatus(outputVectorPath, PathFilters.logsCRCFilter()).length); + + assertEquals(0.0, result1.getDistanceSquared(result2), EPSILON); + } + + public Configuration createInitialConf() throws IOException { + Configuration initialConf = getConfiguration(); + initialConf.set(TEST_PROPERTY_KEY, TEST_PROPERTY_VALUE); + return initialConf; + } + + private static void deleteContentsOfPath(Configuration conf, Path path) throws Exception { + FileSystem fs = path.getFileSystem(conf); + + FileStatus[] statuses = HadoopUtil.listStatus(fs, path); + for (FileStatus status : statuses) { + fs.delete(status.getPath(), true); + } + } + + public DistributedRowMatrix randomDistributedMatrix(int numRows, + int nonNullRows, + int numCols, + int entriesPerRow, + double entryMean, + boolean isSymmetric) throws IOException { + return randomDistributedMatrix(numRows, nonNullRows, numCols, entriesPerRow, entryMean, isSymmetric, "testdata"); + } + + public DistributedRowMatrix randomDenseHierarchicalDistributedMatrix(int numRows, + int numCols, + boolean isSymmetric, + String baseTmpDirSuffix) + throws IOException { + Path baseTmpDirPath = getTestTempDirPath(baseTmpDirSuffix); + Matrix c = SolverTest.randomHierarchicalMatrix(numRows, numCols, isSymmetric); + return saveToFs(c, baseTmpDirPath); + } + + public DistributedRowMatrix randomDistributedMatrix(int numRows, + int nonNullRows, + int numCols, + int entriesPerRow, + double entryMean, + boolean isSymmetric, + String baseTmpDirSuffix) throws IOException { + Path baseTmpDirPath = getTestTempDirPath(baseTmpDirSuffix); + Matrix c = SolverTest.randomSequentialAccessSparseMatrix(numRows, nonNullRows, numCols, entriesPerRow, entryMean); + if (isSymmetric) { + c = c.times(c.transpose()); + } + return saveToFs(c, baseTmpDirPath); + } + + private DistributedRowMatrix saveToFs(final Matrix m, Path baseTmpDirPath) throws IOException { + Configuration conf = getConfiguration(); + FileSystem fs = FileSystem.get(baseTmpDirPath.toUri(), conf); + + ClusteringTestUtils.writePointsToFile(new Iterable<VectorWritable>() { + @Override + public Iterator<VectorWritable> iterator() { + return Iterators.transform(m.iterator(), new Function<MatrixSlice,VectorWritable>() { + @Override + public VectorWritable apply(MatrixSlice input) { + return new VectorWritable(input.vector()); + } + }); + } + }, true, new Path(baseTmpDirPath, "distMatrix/part-00000"), fs, conf); + + DistributedRowMatrix distMatrix = new DistributedRowMatrix(new Path(baseTmpDirPath, "distMatrix"), + new Path(baseTmpDirPath, "tmpOut"), + m.numRows(), + m.numCols()); + distMatrix.setConf(new Configuration(conf)); + + return distMatrix; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java new file mode 100644 index 0000000..ac01c28 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java @@ -0,0 +1,132 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.decomposer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.decomposer.SolverTest; +import org.apache.mahout.math.decomposer.lanczos.LanczosState; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; +import org.apache.mahout.math.hadoop.TestDistributedRowMatrix; +import org.junit.Before; + +import java.io.File; +import java.io.IOException; + +@Deprecated +public final class TestDistributedLanczosSolver extends MahoutTestCase { + + private int counter = 0; + private DistributedRowMatrix symCorpus; + private DistributedRowMatrix asymCorpus; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + File symTestData = getTestTempDir("symTestData"); + File asymTestData = getTestTempDir("asymTestData"); + symCorpus = new TestDistributedRowMatrix().randomDistributedMatrix(100, + 90, 80, 2, 10.0, true, symTestData.getAbsolutePath()); + asymCorpus = new TestDistributedRowMatrix().randomDistributedMatrix(100, + 90, 80, 2, 10.0, false, asymTestData.getAbsolutePath()); + } + + private static String suf(boolean symmetric) { + return symmetric ? "_sym" : "_asym"; + } + + private DistributedRowMatrix getCorpus(boolean symmetric) { + return symmetric ? symCorpus : asymCorpus; + } + + /* + private LanczosState doTestDistributedLanczosSolver(boolean symmetric, + int desiredRank) throws IOException { + return doTestDistributedLanczosSolver(symmetric, desiredRank, true); + } + */ + + private LanczosState doTestDistributedLanczosSolver(boolean symmetric, + int desiredRank, boolean hdfsBackedState) + throws IOException { + DistributedRowMatrix corpus = getCorpus(symmetric); + Configuration conf = getConfiguration(); + corpus.setConf(conf); + DistributedLanczosSolver solver = new DistributedLanczosSolver(); + Vector intitialVector = DistributedLanczosSolver.getInitialVector(corpus); + LanczosState state; + if (hdfsBackedState) { + HdfsBackedLanczosState hState = new HdfsBackedLanczosState(corpus, + desiredRank, intitialVector, new Path(getTestTempDirPath(), + "lanczosStateDir" + suf(symmetric) + counter)); + hState.setConf(conf); + state = hState; + } else { + state = new LanczosState(corpus, desiredRank, intitialVector); + } + solver.solve(state, desiredRank, symmetric); + SolverTest.assertOrthonormal(state); + for (int i = 0; i < desiredRank/2; i++) { + SolverTest.assertEigen(i, state.getRightSingularVector(i), corpus, 0.1, symmetric); + } + counter++; + return state; + } + + public void doTestResumeIteration(boolean symmetric) throws IOException { + DistributedRowMatrix corpus = getCorpus(symmetric); + Configuration conf = getConfiguration(); + corpus.setConf(conf); + DistributedLanczosSolver solver = new DistributedLanczosSolver(); + int rank = 10; + Vector intitialVector = DistributedLanczosSolver.getInitialVector(corpus); + HdfsBackedLanczosState state = new HdfsBackedLanczosState(corpus, rank, + intitialVector, new Path(getTestTempDirPath(), "lanczosStateDir" + suf(symmetric) + counter)); + solver.solve(state, rank, symmetric); + + rank *= 2; + state = new HdfsBackedLanczosState(corpus, rank, + intitialVector, new Path(getTestTempDirPath(), "lanczosStateDir" + suf(symmetric) + counter)); + solver = new DistributedLanczosSolver(); + solver.solve(state, rank, symmetric); + + LanczosState allAtOnceState = doTestDistributedLanczosSolver(symmetric, rank, false); + for (int i=0; i<state.getIterationNumber(); i++) { + Vector v = state.getBasisVector(i).normalize(); + Vector w = allAtOnceState.getBasisVector(i).normalize(); + double diff = v.minus(w).norm(2); + assertTrue("basis " + i + " is too long: " + diff, diff < 0.1); + } + counter++; + } + + // TODO when this can be made to run in under 20 minutes, re-enable + /* + @Test + public void testDistributedLanczosSolver() throws Exception { + doTestDistributedLanczosSolver(true, 30); + doTestDistributedLanczosSolver(false, 30); + doTestResumeIteration(true); + doTestResumeIteration(false); + } + */ + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java new file mode 100644 index 0000000..5dfb328 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java @@ -0,0 +1,190 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.decomposer; + +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; +import org.apache.mahout.math.hadoop.TestDistributedRowMatrix; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Arrays; + +@Deprecated +public final class TestDistributedLanczosSolverCLI extends MahoutTestCase { + private static final Logger log = LoggerFactory.getLogger(TestDistributedLanczosSolverCLI.class); + + @Test + public void testDistributedLanczosSolverCLI() throws Exception { + Path testData = getTestTempDirPath("testdata"); + DistributedRowMatrix corpus = + new TestDistributedRowMatrix().randomDenseHierarchicalDistributedMatrix(10, 9, false, + testData.toString()); + corpus.setConf(getConfiguration()); + Path output = getTestTempDirPath("output"); + Path tmp = getTestTempDirPath("tmp"); + Path workingDir = getTestTempDirPath("working"); + String[] args = { + "-i", new Path(testData, "distMatrix").toString(), + "-o", output.toString(), + "--tempDir", tmp.toString(), + "--numRows", "10", + "--numCols", "9", + "--rank", "6", + "--symmetric", "false", + "--workingDir", workingDir.toString() + }; + ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args); + + output = getTestTempDirPath("output2"); + tmp = getTestTempDirPath("tmp2"); + args = new String[] { + "-i", new Path(testData, "distMatrix").toString(), + "-o", output.toString(), + "--tempDir", tmp.toString(), + "--numRows", "10", + "--numCols", "9", + "--rank", "7", + "--symmetric", "false", + "--workingDir", workingDir.toString() + }; + ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args); + + Path rawEigenvectors = new Path(output, DistributedLanczosSolver.RAW_EIGENVECTORS); + Matrix eigenVectors = new DenseMatrix(7, corpus.numCols()); + Configuration conf = getConfiguration(); + + int i = 0; + for (VectorWritable value : new SequenceFileValueIterable<VectorWritable>(rawEigenvectors, conf)) { + Vector v = value.get(); + eigenVectors.assignRow(i, v); + i++; + } + assertEquals("number of eigenvectors", 7, i); + } + + @Test + public void testDistributedLanczosSolverEVJCLI() throws Exception { + Path testData = getTestTempDirPath("testdata"); + DistributedRowMatrix corpus = new TestDistributedRowMatrix() + .randomDenseHierarchicalDistributedMatrix(10, 9, false, testData.toString()); + corpus.setConf(getConfiguration()); + Path output = getTestTempDirPath("output"); + Path tmp = getTestTempDirPath("tmp"); + String[] args = { + "-i", new Path(testData, "distMatrix").toString(), + "-o", output.toString(), + "--tempDir", tmp.toString(), + "--numRows", "10", + "--numCols", "9", + "--rank", "6", + "--symmetric", "false", + "--cleansvd", "true" + }; + ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args); + + Path cleanEigenvectors = new Path(output, EigenVerificationJob.CLEAN_EIGENVECTORS); + Matrix eigenVectors = new DenseMatrix(6, corpus.numCols()); + Collection<Double> eigenvalues = Lists.newArrayList(); + + output = getTestTempDirPath("output2"); + tmp = getTestTempDirPath("tmp2"); + args = new String[] { + "-i", new Path(testData, "distMatrix").toString(), + "-o", output.toString(), + "--tempDir", tmp.toString(), + "--numRows", "10", + "--numCols", "9", + "--rank", "7", + "--symmetric", "false", + "--cleansvd", "true" + }; + ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args); + Path cleanEigenvectors2 = new Path(output, EigenVerificationJob.CLEAN_EIGENVECTORS); + Matrix eigenVectors2 = new DenseMatrix(7, corpus.numCols()); + Configuration conf = getConfiguration(); + Collection<Double> newEigenValues = Lists.newArrayList(); + + int i = 0; + for (VectorWritable value : new SequenceFileValueIterable<VectorWritable>(cleanEigenvectors, conf)) { + NamedVector v = (NamedVector) value.get(); + eigenVectors.assignRow(i, v); + log.info(v.getName()); + if (EigenVector.getCosAngleError(v.getName()) < 1.0e-3) { + eigenvalues.add(EigenVector.getEigenValue(v.getName())); + } + i++; + } + assertEquals("number of clean eigenvectors", 3, i); + + i = 0; + for (VectorWritable value : new SequenceFileValueIterable<VectorWritable>(cleanEigenvectors2, conf)) { + NamedVector v = (NamedVector) value.get(); + log.info(v.getName()); + eigenVectors2.assignRow(i, v); + newEigenValues.add(EigenVector.getEigenValue(v.getName())); + i++; + } + + Collection<Integer> oldEigensFound = Lists.newArrayList(); + for (int row = 0; row < eigenVectors.numRows(); row++) { + Vector oldEigen = eigenVectors.viewRow(row); + if (oldEigen == null) { + break; + } + for (int newRow = 0; newRow < eigenVectors2.numRows(); newRow++) { + Vector newEigen = eigenVectors2.viewRow(newRow); + if (newEigen != null && oldEigen.dot(newEigen) > 0.9) { + oldEigensFound.add(row); + break; + } + } + } + assertEquals("the number of new eigenvectors", 5, i); + + Collection<Double> oldEigenValuesNotFound = Lists.newArrayList(); + for (double d : eigenvalues) { + boolean found = false; + for (double newD : newEigenValues) { + if (Math.abs((d - newD)/d) < 0.1) { + found = true; + } + } + if (!found) { + oldEigenValuesNotFound.add(d); + } + } + assertEquals("number of old eigenvalues not found: " + + Arrays.toString(oldEigenValuesNotFound.toArray(new Double[oldEigenValuesNotFound.size()])), + 0, oldEigenValuesNotFound.size()); + assertEquals("did not find enough old eigenvectors", 3, oldEigensFound.size()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java new file mode 100644 index 0000000..bb2c373 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java @@ -0,0 +1,238 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.common.DummyOutputCollector; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.StringTuple; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.EuclideanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +public class TestVectorDistanceSimilarityJob extends MahoutTestCase { + + private FileSystem fs; + + private static final double[][] REFERENCE = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, + { 4, 5 }, { 5, 5 } }; + + private static final double[][] SEEDS = { { 1, 1 }, { 10, 10 } }; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + fs = FileSystem.get(getConfiguration()); + } + + @Test + public void testVectorDistanceMapper() throws Exception { + Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable>.Context context = + EasyMock.createMock(Mapper.Context.class); + StringTuple tuple = new StringTuple(); + tuple.add("foo"); + tuple.add("123"); + context.write(tuple, new DoubleWritable(Math.sqrt(2.0))); + tuple = new StringTuple(); + tuple.add("foo2"); + tuple.add("123"); + context.write(tuple, new DoubleWritable(1)); + + EasyMock.replay(context); + + Vector vector = new RandomAccessSparseVector(2); + vector.set(0, 2); + vector.set(1, 2); + + VectorDistanceMapper mapper = new VectorDistanceMapper(); + setField(mapper, "measure", new EuclideanDistanceMeasure()); + Collection<NamedVector> seedVectors = Lists.newArrayList(); + Vector seed1 = new RandomAccessSparseVector(2); + seed1.set(0, 1); + seed1.set(1, 1); + Vector seed2 = new RandomAccessSparseVector(2); + seed2.set(0, 2); + seed2.set(1, 1); + + seedVectors.add(new NamedVector(seed1, "foo")); + seedVectors.add(new NamedVector(seed2, "foo2")); + setField(mapper, "seedVectors", seedVectors); + + mapper.map(new IntWritable(123), new VectorWritable(vector), context); + + EasyMock.verify(context); + } + + @Test + public void testVectorDistanceInvertedMapper() throws Exception { + Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = + EasyMock.createMock(Mapper.Context.class); + Vector expectVec = new DenseVector(new double[]{Math.sqrt(2.0), 1.0}); + context.write(new Text("other"), new VectorWritable(expectVec)); + EasyMock.replay(context); + Vector vector = new NamedVector(new RandomAccessSparseVector(2), "other"); + vector.set(0, 2); + vector.set(1, 2); + + VectorDistanceInvertedMapper mapper = new VectorDistanceInvertedMapper(); + setField(mapper, "measure", new EuclideanDistanceMeasure()); + Collection<NamedVector> seedVectors = Lists.newArrayList(); + Vector seed1 = new RandomAccessSparseVector(2); + seed1.set(0, 1); + seed1.set(1, 1); + Vector seed2 = new RandomAccessSparseVector(2); + seed2.set(0, 2); + seed2.set(1, 1); + + seedVectors.add(new NamedVector(seed1, "foo")); + seedVectors.add(new NamedVector(seed2, "foo2")); + setField(mapper, "seedVectors", seedVectors); + + mapper.map(new IntWritable(123), new VectorWritable(vector), context); + + EasyMock.verify(context); + + } + + @Test + public void testRun() throws Exception { + Path input = getTestTempDirPath("input"); + Path output = getTestTempDirPath("output"); + Path seedsPath = getTestTempDirPath("seeds"); + + List<VectorWritable> points = getPointsWritable(REFERENCE); + List<VectorWritable> seeds = getPointsWritable(SEEDS); + + Configuration conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf); + + String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(), + optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION), + output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), + EuclideanDistanceMeasure.class.getName() }; + + ToolRunner.run(getConfiguration(), new VectorDistanceSimilarityJob(), args); + + int expectedOutputSize = SEEDS.length * REFERENCE.length; + int outputSize = Iterables.size(new SequenceFileIterable<StringTuple, DoubleWritable>(new Path(output, + "part-m-00000"), conf)); + assertEquals(expectedOutputSize, outputSize); + } + + @Test + public void testMaxDistance() throws Exception { + + Path input = getTestTempDirPath("input"); + Path output = getTestTempDirPath("output"); + Path seedsPath = getTestTempDirPath("seeds"); + + List<VectorWritable> points = getPointsWritable(REFERENCE); + List<VectorWritable> seeds = getPointsWritable(SEEDS); + + Configuration conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf); + + double maxDistance = 10; + + String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(), + optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION), + output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), + EuclideanDistanceMeasure.class.getName(), + optKey(VectorDistanceSimilarityJob.MAX_DISTANCE), String.valueOf(maxDistance) }; + + ToolRunner.run(getConfiguration(), new VectorDistanceSimilarityJob(), args); + + int outputSize = 0; + + for (Pair<StringTuple, DoubleWritable> record : new SequenceFileIterable<StringTuple, DoubleWritable>( + new Path(output, "part-m-00000"), conf)) { + outputSize++; + assertTrue(record.getSecond().get() <= maxDistance); + } + + assertEquals(14, outputSize); + } + + @Test + public void testRunInverted() throws Exception { + Path input = getTestTempDirPath("input"); + Path output = getTestTempDirPath("output"); + Path seedsPath = getTestTempDirPath("seeds"); + List<VectorWritable> points = getPointsWritable(REFERENCE); + List<VectorWritable> seeds = getPointsWritable(SEEDS); + Configuration conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf); + String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(), + optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION), + output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), + EuclideanDistanceMeasure.class.getName(), + optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v" + }; + ToolRunner.run(getConfiguration(), new VectorDistanceSimilarityJob(), args); + + DummyOutputCollector<Text, VectorWritable> collector = new DummyOutputCollector<Text, VectorWritable>(); + + for (Pair<Text, VectorWritable> record : new SequenceFileIterable<Text, VectorWritable>( + new Path(output, "part-m-00000"), conf)) { + collector.collect(record.getFirst(), record.getSecond()); + } + assertEquals(REFERENCE.length, collector.getData().size()); + for (Map.Entry<Text, List<VectorWritable>> entry : collector.getData().entrySet()) { + assertEquals(SEEDS.length, entry.getValue().iterator().next().get().size()); + } + } + + private static List<VectorWritable> getPointsWritable(double[][] raw) { + List<VectorWritable> points = Lists.newArrayList(); + for (double[] fr : raw) { + Vector vec = new RandomAccessSparseVector(fr.length); + vec.assign(fr); + points.add(new VectorWritable(vec)); + } + return points; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java new file mode 100644 index 0000000..5d64f90 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java @@ -0,0 +1,214 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity.cooccurrence; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.hadoop.MathHelper; +import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.TanimotoCoefficientSimilarity; +import org.apache.mahout.math.map.OpenIntIntHashMap; +import org.junit.Test; + +import java.io.File; + +public class RowSimilarityJobTest extends MahoutTestCase { + + /** + * integration test with a tiny data set + * + * <pre> + * + * input matrix: + * + * 1, 0, 1, 1, 0 + * 0, 0, 1, 1, 0 + * 0, 0, 0, 0, 1 + * + * similarity matrix (via tanimoto): + * + * 1, 0.666, 0 + * 0.666, 1, 0 + * 0, 0, 1 + * </pre> + * @throws Exception + */ + @Test + public void toyIntegration() throws Exception { + + File inputFile = getTestTempFile("rows"); + File outputDir = getTestTempDir("output"); + outputDir.delete(); + File tmpDir = getTestTempDir("tmp"); + + Configuration conf = getConfiguration(); + Path inputPath = new Path(inputFile.getAbsolutePath()); + FileSystem fs = FileSystem.get(inputPath.toUri(), conf); + + MathHelper.writeDistributedRowMatrix(new double[][] { + new double[] { 1, 0, 1, 1, 0 }, + new double[] { 0, 0, 1, 1, 0 }, + new double[] { 0, 0, 0, 0, 1 } }, + fs, conf, inputPath); + + RowSimilarityJob rowSimilarityJob = new RowSimilarityJob(); + rowSimilarityJob.setConf(conf); + rowSimilarityJob.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--numberOfColumns", String.valueOf(5), "--similarityClassname", TanimotoCoefficientSimilarity.class.getName(), + "--tempDir", tmpDir.getAbsolutePath() }); + + + OpenIntIntHashMap observationsPerColumn = + Vectors.readAsIntMap(new Path(tmpDir.getAbsolutePath(), "observationsPerColumn.bin"), conf); + assertEquals(4, observationsPerColumn.size()); + assertEquals(1, observationsPerColumn.get(0)); + assertEquals(2, observationsPerColumn.get(2)); + assertEquals(2, observationsPerColumn.get(3)); + assertEquals(1, observationsPerColumn.get(4)); + + Matrix similarityMatrix = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "part-r-00000"), 3, 3); + + assertNotNull(similarityMatrix); + assertEquals(3, similarityMatrix.numCols()); + assertEquals(3, similarityMatrix.numRows()); + + assertEquals(1.0, similarityMatrix.get(0, 0), EPSILON); + assertEquals(1.0, similarityMatrix.get(1, 1), EPSILON); + assertEquals(1.0, similarityMatrix.get(2, 2), EPSILON); + assertEquals(0.0, similarityMatrix.get(2, 0), EPSILON); + assertEquals(0.0, similarityMatrix.get(2, 1), EPSILON); + assertEquals(0.0, similarityMatrix.get(0, 2), EPSILON); + assertEquals(0.0, similarityMatrix.get(1, 2), EPSILON); + assertEquals(0.666666, similarityMatrix.get(0, 1), EPSILON); + assertEquals(0.666666, similarityMatrix.get(1, 0), EPSILON); + } + + @Test + public void toyIntegrationMaxSimilaritiesPerRow() throws Exception { + + File inputFile = getTestTempFile("rows"); + File outputDir = getTestTempDir("output"); + outputDir.delete(); + File tmpDir = getTestTempDir("tmp"); + + Configuration conf = getConfiguration(); + Path inputPath = new Path(inputFile.getAbsolutePath()); + FileSystem fs = FileSystem.get(inputPath.toUri(), conf); + + MathHelper.writeDistributedRowMatrix(new double[][]{ + new double[] { 1, 0, 1, 1, 0, 1 }, + new double[] { 0, 1, 1, 1, 1, 1 }, + new double[] { 1, 1, 0, 1, 0, 0 } }, + fs, conf, inputPath); + + RowSimilarityJob rowSimilarityJob = new RowSimilarityJob(); + rowSimilarityJob.setConf(conf); + rowSimilarityJob.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--numberOfColumns", String.valueOf(6), "--similarityClassname", TanimotoCoefficientSimilarity.class.getName(), + "--maxSimilaritiesPerRow", String.valueOf(1), "--excludeSelfSimilarity", String.valueOf(true), + "--tempDir", tmpDir.getAbsolutePath() }); + + Matrix similarityMatrix = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "part-r-00000"), 3, 3); + + assertNotNull(similarityMatrix); + assertEquals(3, similarityMatrix.numCols()); + assertEquals(3, similarityMatrix.numRows()); + + assertEquals(0.0, similarityMatrix.get(0, 0), EPSILON); + assertEquals(0.5, similarityMatrix.get(0, 1), EPSILON); + assertEquals(0.0, similarityMatrix.get(0, 2), EPSILON); + + assertEquals(0.5, similarityMatrix.get(1, 0), EPSILON); + assertEquals(0.0, similarityMatrix.get(1, 1), EPSILON); + assertEquals(0.0, similarityMatrix.get(1, 2), EPSILON); + + assertEquals(0.4, similarityMatrix.get(2, 0), EPSILON); + assertEquals(0.0, similarityMatrix.get(2, 1), EPSILON); + assertEquals(0.0, similarityMatrix.get(2, 2), EPSILON); + } + + @Test + public void toyIntegrationWithThreshold() throws Exception { + + + File inputFile = getTestTempFile("rows"); + File outputDir = getTestTempDir("output"); + outputDir.delete(); + File tmpDir = getTestTempDir("tmp"); + + Configuration conf = getConfiguration(); + Path inputPath = new Path(inputFile.getAbsolutePath()); + FileSystem fs = FileSystem.get(inputPath.toUri(), conf); + + MathHelper.writeDistributedRowMatrix(new double[][]{ + new double[] { 1, 0, 1, 1, 0, 1 }, + new double[] { 0, 1, 1, 1, 1, 1 }, + new double[] { 1, 1, 0, 1, 0, 0 } }, + fs, conf, inputPath); + + RowSimilarityJob rowSimilarityJob = new RowSimilarityJob(); + rowSimilarityJob.setConf(conf); + rowSimilarityJob.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--numberOfColumns", String.valueOf(6), "--similarityClassname", TanimotoCoefficientSimilarity.class.getName(), + "--excludeSelfSimilarity", String.valueOf(true), "--threshold", String.valueOf(0.5), + "--tempDir", tmpDir.getAbsolutePath() }); + + Matrix similarityMatrix = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "part-r-00000"), 3, 3); + + assertNotNull(similarityMatrix); + assertEquals(3, similarityMatrix.numCols()); + assertEquals(3, similarityMatrix.numRows()); + + assertEquals(0.0, similarityMatrix.get(0, 0), EPSILON); + assertEquals(0.5, similarityMatrix.get(0, 1), EPSILON); + assertEquals(0.0, similarityMatrix.get(0, 2), EPSILON); + + assertEquals(0.5, similarityMatrix.get(1, 0), EPSILON); + assertEquals(0.0, similarityMatrix.get(1, 1), EPSILON); + assertEquals(0.0, similarityMatrix.get(1, 2), EPSILON); + + assertEquals(0.0, similarityMatrix.get(2, 0), EPSILON); + assertEquals(0.0, similarityMatrix.get(2, 1), EPSILON); + assertEquals(0.0, similarityMatrix.get(2, 2), EPSILON); + } + + @Test + public void testVectorDimensions() throws Exception { + + File inputFile = getTestTempFile("rows"); + + Configuration conf = getConfiguration(); + Path inputPath = new Path(inputFile.getAbsolutePath()); + FileSystem fs = FileSystem.get(inputPath.toUri(), conf); + + MathHelper.writeDistributedRowMatrix(new double[][] { + new double[] { 1, 0, 1, 1, 0, 1 }, + new double[] { 0, 1, 1, 1, 1, 1 }, + new double[] { 1, 1, 0, 1, 0, 0 } }, + fs, conf, inputPath); + + RowSimilarityJob rowSimilarityJob = new RowSimilarityJob(); + rowSimilarityJob.setConf(conf); + + int numberOfColumns = rowSimilarityJob.getDimensions(inputPath); + + assertEquals(6, numberOfColumns); + } +}
