Author: dfilimon
Date: Fri May  3 11:07:01 2013
New Revision: 1478723

URL: http://svn.apache.org/r1478723
Log:
Fixes MAHOUT-1180

Modified:
    
mahout/trunk/math/src/main/java/org/apache/mahout/math/random/Multinomial.java
    
mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java

Modified: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/random/Multinomial.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/random/Multinomial.java?rev=1478723&r1=1478722&r2=1478723&view=diff
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/random/Multinomial.java 
(original)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/random/Multinomial.java 
Fri May  3 11:07:01 2013
@@ -17,18 +17,20 @@
 
 package org.apache.mahout.math.random;
 
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
 import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Multiset;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.list.DoubleArrayList;
 
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-
 /**
  * Multinomial sampler that allows updates to element probabilities.  The 
basic idea is that sampling is
  * done by using a simple balanced tree.  Probabilities are kept in the tree 
so that we can navigate to
@@ -66,6 +68,7 @@ public final class Multinomial<T> implem
   }
 
   public void add(T value, double w) {
+    Preconditions.checkNotNull(value);
     Preconditions.checkArgument(!items.containsKey(value));
 
     int n = this.weight.size();
@@ -182,6 +185,18 @@ public final class Multinomial<T> implem
 
   @Override
   public Iterator<T> iterator() {
-    return items.keySet().iterator();
+    return new AbstractIterator<T>() {
+      Iterator<T> valuesIterator = Iterables.skip(values, 1).iterator();
+      @Override
+      protected T computeNext() {
+        while (valuesIterator.hasNext()) {
+          T next = valuesIterator.next();
+          if (items.containsKey(next)) {
+            return next;
+          }
+        }
+        return endOfData();
+      }
+    };
   }
 }

Modified: 
mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java?rev=1478723&r1=1478722&r2=1478723&view=diff
==============================================================================
--- 
mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java
 (original)
+++ 
mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java
 Fri May  3 11:07:01 2013
@@ -17,6 +17,10 @@
 
 package org.apache.mahout.math.random;
 
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
@@ -26,10 +30,6 @@ import org.apache.mahout.math.MahoutTest
 import org.junit.Before;
 import org.junit.Test;
 
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-
 public class MultinomialTest extends MahoutTestCase {
     @Override
     @Before
@@ -166,6 +166,27 @@ public class MultinomialTest extends Mah
     }
 
     @Test
+  public void testSetZeroWhileIterating() {
+    Multinomial<Integer> table = new Multinomial<Integer>();
+    for (int i = 0; i < 10000; ++i) {
+      table.add(i, i);
+    }
+    // Setting a sample's weight to 0 removes from the items map.
+    // If that map is used when iterating (it used to be), it will
+    // trigger a ConcurrentModificationException.
+    for (Integer sample : table) {
+      table.set(sample, 0);
+    }
+  }
+
+  @Test(expected=NullPointerException.class)
+  public void testNoNullValuesAllowed() {
+    Multinomial<Integer> table = new Multinomial<Integer>();
+    // No null values should be allowed.
+    table.add(null, 1);
+  }
+
+  @Test
     public void testDeleteAndUpdate() {
         Random rand = RandomUtils.getRandom();
         Multinomial<Integer> table = new Multinomial<Integer>();


Reply via email to