Good point Sam.
Here is modified code that has:
- comments
- bias term added per Sam's cogent suggestion
- code warnings fixed
- multiple passes through the data to get to convergence
- randomized ordering of examples
- better diagnostic output
The best thing about this code is that it works.
import com.google.common.collect.Lists;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ClassifierExample {
public static class Point {
public int x;
public int y;
public Point(int x, int y) {
this.x = x;
this.y = y;
}
@Override
public boolean equals(Object arg0) {
Point p = (Point) arg0;
return ((this.x == p.x) && (this.y == p.y));
}
@Override
public String toString() {
// TODO Auto-generated method stub
return this.x + " , " + this.y;
}
}
public static void main(String[] args) {
Map<Point, Integer> points = new HashMap<Point,
Integer>();
points.put(new Point(0, 0), 0);
points.put(new Point(1, 1), 0);
points.put(new Point(1, 0), 0);
points.put(new Point(0, 1), 0);
points.put(new Point(2, 2), 0);
points.put(new Point(8, 8), 1);
points.put(new Point(8, 9), 1);
points.put(new Point(9, 8), 1);
points.put(new Point(9, 9), 1);
OnlineLogisticRegression learningAlgo = new
OnlineLogisticRegression(2, 3, new L1());
// this is a really big value which will make the model very
cautious
// for lambda = 0.1, the first example below should be about .83
certain
// for lambda = 0.01, the first example below should be about 0.98
certain
learningAlgo.lambda(0.1);
learningAlgo.learningRate(4);
System.out.println("training model \n");
final List<Point> keys = Lists.newArrayList(points.keySet());
// 200 times through the training data is probably over-kill. It
doesn't matter
// for tiny data. The key here is total number of points seen, not
number of passes.
for (int i = 0; i < 200; i++) {
// randomize training data on each iteration
Collections.shuffle(keys);
for (Point point : keys) {
Vector v = getVector(point);
learningAlgo.train(points.get(point), v);
}
}
learningAlgo.close();
//now classify real data
Vector v = new RandomAccessSparseVector(3);
v.set(0, 0.5);
v.set(1, 0.5);
v.set(2, 1);
Vector r = learningAlgo.classifyFull(v);
System.out.println(r);
System.out.println("ans = ");
System.out.printf("no of categories = %d\n",
learningAlgo.numCategories());
System.out.printf("no of features = %d\n",
learningAlgo.numFeatures());
System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0));
System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1));
v.set(0, 4.5);
v.set(1, 6.5);
v.set(2, 1);
r = learningAlgo.classifyFull(v);
System.out.println("ans = ");
System.out.printf("no of categories = %d\n",
learningAlgo.numCategories());
System.out.printf("no of features = %d\n",
learningAlgo.numFeatures());
System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0));
System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1));
// show how the score varies along a line from 0,0 to 1,1
System.out.printf("\nx\tscore\n");
for (int i = 0; i < 100; i++) {
final double x = 0.0 + i / 10.0;
v.set(0, x);
v.set(1, x);
v.set(2, 1);
r = learningAlgo.classifyFull(v);
System.out.printf("%.2f\t%.3f\n", x, r.get(1));
}
}
public static Vector getVector(Point point) {
Vector v = new DenseVector(3);
v.set(0, point.x);
v.set(1, point.y);
v.set(2, 1);
return v;
}
}
On Wed, Jun 27, 2012 at 11:07 AM, sam wu <[email protected]> wrote:
> I don't think the problem is due to the sample size, if you use 2 features,
> 10 samples might be OK.
>
> The problem is that you didn't include bias(intercept) term, which is
> always 1.
> If you add bias term(1) to your point class, you'll get 0.9999 for cluster
> 0 probability.
>
>
> Sam
>
> On Wed, Jun 27, 2012 at 1:23 AM, Sean Owen <[email protected]> wrote:
>
> > Those are both true; they may not be the issue here.
> >
> > The test point definitely belongs in the first of the two groups you
> > created. Why is the result surprising?
> >
> > On Wed, Jun 27, 2012 at 9:15 AM, Lance Norskog <[email protected]>
> wrote:
> >
> > > Not enough samples. Machine learning algorithms in general do well if
> > > you have large sample sets (hundreds or thousands) from "real" data
> > > sources. The data should have a strong signal but be a little noisy.
> > >
> > > Also: your Point class needs a hashCode() since it does equals(). The
> > > Map class won't work at scale.
> > >
> > > On Wed, Jun 27, 2012 at 1:00 AM, damodar shetyo <
> [email protected]
> > >
> > > wrote:
> > > > I am trying to build a simple model that can group points in 2D
> > space.Am
> > > > training the model by giving few examples.After that i am using the
> > model
> > > > to predict the group in which the any other points may fall.But am
> not
> > > > getting answer as expected.Am i missing something in my code or am i
> > > doing
> > > > something wrong?
> > > >
> > > > public class SimpleClassifier {
> > > >
> > > > public static class Point{
> > > > public int x;
> > > > public int y;
> > > >
> > > > public Point(int x,int y){
> > > > this.x = x;
> > > > this.y = y;
> > > > }
> > > >
> > > > @Override
> > > > public boolean equals(Object arg0) {
> > > > Point p = (Point) arg0;
> > > > return( (this.x == p.x) &&(this.y== p.y));
> > > > }
> > > >
> > > > @Override
> > > > public String toString() {
> > > > // TODO Auto-generated method stub
> > > > return this.x + " , " + this.y ;
> > > > }
> > > > }
> > > > public static void main(String[] args) {
> > > >
> > > > Map<Point,Integer> points = new
> HashMap<SimpleClassifier.Point,
> > > > Integer>();
> > > >
> > > > points.put(new Point(0,0), 0);
> > > > points.put(new Point(1,1), 0);
> > > > points.put(new Point(1,0), 0);
> > > > points.put(new Point(0,1), 0);
> > > > points.put(new Point(2,2), 0);
> > > >
> > > >
> > > > points.put(new Point(8,8), 1);
> > > > points.put(new Point(8,9), 1);
> > > > points.put(new Point(9,8), 1);
> > > > points.put(new Point(9,9), 1);
> > > >
> > > >
> > > > OnlineLogisticRegression learningAlgo = new
> > > > OnlineLogisticRegression();
> > > > learningAlgo = new OnlineLogisticRegression(2, 2, new L1());
> > > > learningAlgo.learningRate(50);
> > > >
> > > > //learningAlgo.alpha(1).stepOffset(1000);
> > > >
> > > > System.out.println("training model \n" );
> > > > for(Point point : points.keySet()){
> > > > Vector v = getVector(point);
> > > > System.out.println(point + " belongs to " +
> > > points.get(point));
> > > > learningAlgo.train(points.get(point), v);
> > > > }
> > > >
> > > > learningAlgo.close();
> > > >
> > > >
> > > > //now classify real data
> > > > Vector v = new RandomAccessSparseVector(2);
> > > > v.set(0, 0.5);
> > > > v.set(1, 0.5);
> > > >
> > > > Vector r = learningAlgo.classifyFull(v);
> > > > System.out.println(r);
> > > >
> > > > System.out.println("ans = " );
> > > > System.out.println("no of categories = " +
> > > > learningAlgo.numCategories());
> > > > System.out.println("no of features = " +
> > > > learningAlgo.numFeatures());
> > > > System.out.println("Probability of cluster 0 = " + r.get(0));
> > > > System.out.println("Probability of cluster 1 = " + r.get(1));
> > > >
> > > > }
> > > >
> > > > public static Vector getVector(Point point){
> > > > Vector v = new DenseVector(2);
> > > > v.set(0, point.x);
> > > > v.set(1, point.y);
> > > >
> > > > return v;
> > > > }
> > > > }
> > > >
> > > > OP
> > > > ans =
> > > > no of categories = 2
> > > > no of features = 2
> > > > Probability of cluster 0 = 3.9580985042775296E-4
> > > > Probability of cluster 1 = 0.9996041901495722
> > > >
> > > > 99 % of times the output show more probability for cluster 1.Why?
> > > >
> > > >
> > > >
> > > > --
> > > > Regards,
> > > > Damodar Shetyo
> > >
> > >
> > >
> > > --
> > > Lance Norskog
> > > [email protected]
> > >
> >
>