Author: ssc
Date: Mon Mar 11 11:04:58 2013
New Revision: 1455095
URL: http://svn.apache.org/r1455095
Log:
MAHOUT-1019 VectorDistanceSimilarityJob
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java?rev=1455095&r1=1455094&r2=1455095&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
Mon Mar 11 11:04:58 2013
@@ -36,6 +36,8 @@ public final class VectorDistanceMapper
private DistanceMeasure measure;
private List<NamedVector> seedVectors;
+ private boolean usesThreshold = false;
+ private double maxDistance;
@Override
protected void map(WritableComparable<?> key, VectorWritable value, Context
context)
@@ -47,12 +49,15 @@ public final class VectorDistanceMapper
} else {
keyName = key.toString();
}
+
for (NamedVector seedVector : seedVectors) {
double distance = measure.distance(seedVector, valVec);
- StringTuple outKey = new StringTuple();
- outKey.add(seedVector.getName());
- outKey.add(keyName);
- context.write(outKey, new DoubleWritable(distance));
+ if (!usesThreshold || distance <= maxDistance) {
+ StringTuple outKey = new StringTuple();
+ outKey.add(seedVector.getName());
+ outKey.add(keyName);
+ context.write(outKey, new DoubleWritable(distance));
+ }
}
}
@@ -60,8 +65,15 @@ public final class VectorDistanceMapper
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
- measure =
-
ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY),
DistanceMeasure.class);
+
+ String maxDistanceParam =
conf.get(VectorDistanceSimilarityJob.MAX_DISTANCE);
+ if (maxDistanceParam != null) {
+ usesThreshold = true;
+ maxDistance = Double.parseDouble(maxDistanceParam);
+ }
+
+ measure =
ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY),
+ DistanceMeasure.class);
measure.configure(conf);
seedVectors = SeedVectorUtil.loadSeedVectors(conf);
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java?rev=1455095&r1=1455094&r2=1455095&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
Mon Mar 11 11:04:58 2013
@@ -36,6 +36,8 @@ import org.apache.mahout.common.distance
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.math.VectorWritable;
+import com.google.common.base.Preconditions;
+
import java.io.IOException;
/**
@@ -48,6 +50,7 @@ public class VectorDistanceSimilarityJob
public static final String SEEDS_PATH_KEY = "seedsPath";
public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
public static final String OUT_TYPE_KEY = "outType";
+ public static final String MAX_DISTANCE = "maxDistance";
public static void main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(),
args);
@@ -60,11 +63,13 @@ public class VectorDistanceSimilarityJob
addOutputOption();
addOption(DefaultOptionCreator.distanceMeasureOption().create());
addOption(SEEDS, "s", "The set of vectors to compute distances against.
Must fit in memory on the mapper");
+ addOption(MAX_DISTANCE, "mx", "set an upper-bound on distance (double)
such that any pair of vectors with a" +
+ " distance greater than this value is ignored in the output. Ignored
for non pairwise output!");
addOption(DefaultOptionCreator.overwriteOption().create());
- addOption(OUT_TYPE_KEY, "ot",
- "[pw|v] -- Define the output style: pairwise, the default, (pw)
or vector (v). Pairwise is a "
- + "tuple of <seed, other, distance>, vector is <other,
<Vector of size the number of seeds>>.",
- "pw");
+ addOption(OUT_TYPE_KEY, "ot", "[pw|v] -- Define the output style:
pairwise, the default, (pw) or vector (v). " +
+ "Pairwise is a tuple of <seed, other, distance>, vector is <other,
<Vector of size the number of seeds>>.",
+ "pw");
+
if (parseArguments(args) == null) {
return -1;
}
@@ -83,12 +88,19 @@ public class VectorDistanceSimilarityJob
if (getConf() == null) {
setConf(new Configuration());
}
- String outType = getOption(OUT_TYPE_KEY);
- if (outType == null) {
- outType = "pw";
+ String outType = getOption(OUT_TYPE_KEY, "pw");
+
+ Double maxDistance = null;
+
+ if ("pw".equals(outType)) {
+ String maxDistanceArg = getOption(MAX_DISTANCE);
+ if (maxDistanceArg != null) {
+ maxDistance = Double.parseDouble(maxDistanceArg);
+ Preconditions.checkArgument(maxDistance > 0d, "value for " +
MAX_DISTANCE + " must be greater than zero");
+ }
}
- run(getConf(), input, seeds, output, measure, outType);
+ run(getConf(), input, seeds, output, measure, outType, maxDistance);
return 0;
}
@@ -98,6 +110,18 @@ public class VectorDistanceSimilarityJob
Path output,
DistanceMeasure measure, String outType)
throws IOException, ClassNotFoundException, InterruptedException {
+ run(conf, input, seeds, output, measure, outType, null);
+ }
+
+ public static void run(Configuration conf,
+ Path input,
+ Path seeds,
+ Path output,
+ DistanceMeasure measure, String outType, Double maxDistance)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ if (maxDistance != null) {
+ conf.set(MAX_DISTANCE, String.valueOf(maxDistance));
+ }
conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
conf.set(SEEDS_PATH_KEY, seeds.toString());
Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + "
input: " + input);
@@ -119,7 +143,6 @@ public class VectorDistanceSimilarityJob
throw new IllegalArgumentException("Invalid outType specified: " +
outType);
}
-
job.setNumReduceTasks(0);
FileInputFormat.addInputPath(job, input);
FileOutputFormat.setOutputPath(job, output);
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java?rev=1455095&r1=1455094&r2=1455095&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
Mon Mar 11 11:04:58 2013
@@ -17,6 +17,7 @@
package org.apache.mahout.math.hadoop.similarity;
+import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
@@ -50,14 +51,19 @@ import java.util.List;
import java.util.Map;
public class TestVectorDistanceSimilarityJob extends MahoutTestCase {
+
private FileSystem fs;
+ private static final double[][] REFERENCE = { { 1, 1 }, { 2, 1 }, { 1, 2 },
{ 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 },
+ { 4, 5 }, { 5, 5 } };
+
+ private static final double[][] SEEDS = { { 1, 1 }, { 10, 10 } };
+
@Override
@Before
public void setUp() throws Exception {
super.setUp();
- Configuration conf = new Configuration();
- fs = FileSystem.get(conf);
+ fs = FileSystem.get(new Configuration());
}
@Test
@@ -96,7 +102,6 @@ public class TestVectorDistanceSimilarit
mapper.map(new IntWritable(123), new VectorWritable(vector), context);
EasyMock.verify(context);
-
}
@Test
@@ -130,39 +135,65 @@ public class TestVectorDistanceSimilarit
}
- private static final double[][] REFERENCE = {
- {1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5,
5}
- };
-
- private static final double[][] SEEDS = {
- {1, 1}, {10, 10}
- };
-
@Test
public void testRun() throws Exception {
Path input = getTestTempDirPath("input");
Path output = getTestTempDirPath("output");
Path seedsPath = getTestTempDirPath("seeds");
+
List<VectorWritable> points = getPointsWritable(REFERENCE);
List<VectorWritable> seeds = getPointsWritable(SEEDS);
+
Configuration conf = new Configuration();
ClusteringTestUtils.writePointsToFile(points, true, new Path(input,
"file1"), fs, conf);
ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath,
"part-seeds"), fs, conf);
- String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION),
input.toString(),
- optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(),
optKey(DefaultOptionCreator.OUTPUT_OPTION),
- output.toString(),
optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
EuclideanDistanceMeasure.class.getName()
- };
+
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
input.toString(),
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(),
optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(),
optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName() };
+
ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(),
args);
- int expect = SEEDS.length * REFERENCE.length;
- DummyOutputCollector<StringTuple, DoubleWritable> collector =
- new DummyOutputCollector<StringTuple, DoubleWritable>();
- //
- for (Pair<StringTuple, DoubleWritable> record :
- new SequenceFileIterable<StringTuple, DoubleWritable>(
- new Path(output, "part-m-00000"), conf)) {
- collector.collect(record.getFirst(), record.getSecond());
+
+ int expectedOutputSize = SEEDS.length * REFERENCE.length;
+ int outputSize = Iterables.size(new SequenceFileIterable<StringTuple,
DoubleWritable>(new Path(output,
+ "part-m-00000"), conf));
+ assertEquals(expectedOutputSize, outputSize);
+ }
+
+ @Test
+ public void testMaxDistance() throws Exception {
+
+ Path input = getTestTempDirPath("input");
+ Path output = getTestTempDirPath("output");
+ Path seedsPath = getTestTempDirPath("seeds");
+
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+ List<VectorWritable> seeds = getPointsWritable(SEEDS);
+
+ Configuration conf = new Configuration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(input,
"file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath,
"part-seeds"), fs, conf);
+
+ double maxDistance = 10;
+
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
input.toString(),
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(),
optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(),
optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(VectorDistanceSimilarityJob.MAX_DISTANCE),
String.valueOf(maxDistance) };
+
+ ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(),
args);
+
+ int outputSize = 0;
+
+ for (Pair<StringTuple, DoubleWritable> record : new
SequenceFileIterable<StringTuple, DoubleWritable>(
+ new Path(output, "part-m-00000"), conf)) {
+ outputSize++;
+ assertTrue(record.getSecond().get() <= maxDistance);
}
- assertEquals(expect, collector.getData().size());
+
+ assertEquals(14, outputSize);
}
@Test
@@ -176,18 +207,17 @@ public class TestVectorDistanceSimilarit
ClusteringTestUtils.writePointsToFile(points, true, new Path(input,
"file1"), fs, conf);
ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath,
"part-seeds"), fs, conf);
String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION),
input.toString(),
- optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(),
optKey(DefaultOptionCreator.OUTPUT_OPTION),
- output.toString(),
optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
EuclideanDistanceMeasure.class.getName(),
- optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v"
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(),
optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(),
optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v"
};
ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(),
args);
- DummyOutputCollector<Text, VectorWritable> collector =
- new DummyOutputCollector<Text, VectorWritable>();
- //
- for (Pair<Text, VectorWritable> record :
- new SequenceFileIterable<Text, VectorWritable>(
- new Path(output, "part-m-00000"), conf)) {
+ DummyOutputCollector<Text, VectorWritable> collector = new
DummyOutputCollector<Text, VectorWritable>();
+
+ for (Pair<Text, VectorWritable> record : new SequenceFileIterable<Text,
VectorWritable>(
+ new Path(output, "part-m-00000"), conf)) {
collector.collect(record.getFirst(), record.getSecond());
}
assertEquals(REFERENCE.length, collector.getData().size());
@@ -196,7 +226,7 @@ public class TestVectorDistanceSimilarit
}
}
- public static List<VectorWritable> getPointsWritable(double[][] raw) {
+ private List<VectorWritable> getPointsWritable(double[][] raw) {
List<VectorWritable> points = Lists.newArrayList();
for (double[] fr : raw) {
Vector vec = new RandomAccessSparseVector(fr.length);