Good point Sam.

Here is modified code that has:

   - comments
   - bias term added per Sam's cogent suggestion
   - code warnings fixed
   - multiple passes through the data to get to convergence
   - randomized ordering of examples
   - better diagnostic output

The best thing about this code is that it works.

import com.google.common.collect.Lists;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ClassifierExample {

    public static class Point {
        public int x;
        public int y;

        public Point(int x, int y) {
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object arg0) {
            Point p = (Point) arg0;
            return ((this.x == p.x) && (this.y == p.y));
        }

        @Override
        public String toString() {
            // TODO Auto-generated method stub
            return this.x + " , " + this.y;
        }
    }

    public static void main(String[] args) {

        Map<Point, Integer> points = new HashMap<Point,
                Integer>();

        points.put(new Point(0, 0), 0);
        points.put(new Point(1, 1), 0);
        points.put(new Point(1, 0), 0);
        points.put(new Point(0, 1), 0);
        points.put(new Point(2, 2), 0);


        points.put(new Point(8, 8), 1);
        points.put(new Point(8, 9), 1);
        points.put(new Point(9, 8), 1);
        points.put(new Point(9, 9), 1);


        OnlineLogisticRegression learningAlgo = new
OnlineLogisticRegression(2, 3, new L1());
        // this is a really big value which will make the model very
cautious
        // for lambda = 0.1, the first example below should be about .83
certain
        // for lambda = 0.01, the first example below should be about 0.98
certain
        learningAlgo.lambda(0.1);
        learningAlgo.learningRate(4);

        System.out.println("training model  \n");
        final List<Point> keys = Lists.newArrayList(points.keySet());
        // 200 times through the training data is probably over-kill.  It
doesn't matter
        // for tiny data.  The key here is total number of points seen, not
number of passes.
        for (int i = 0; i < 200; i++) {
            // randomize training data on each iteration
            Collections.shuffle(keys);
            for (Point point : keys) {
                Vector v = getVector(point);
                learningAlgo.train(points.get(point), v);
            }
        }
        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(3);
        v.set(0, 0.5);
        v.set(1, 0.5);
        v.set(2, 1);

        Vector r = learningAlgo.classifyFull(v);
        System.out.println(r);

        System.out.println("ans = ");
        System.out.printf("no of categories = %d\n",
learningAlgo.numCategories());
        System.out.printf("no of features = %d\n",
learningAlgo.numFeatures());
        System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0));
        System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1));

        v.set(0, 4.5);
        v.set(1, 6.5);
        v.set(2, 1);

        r = learningAlgo.classifyFull(v);

        System.out.println("ans = ");
        System.out.printf("no of categories = %d\n",
learningAlgo.numCategories());
        System.out.printf("no of features = %d\n",
learningAlgo.numFeatures());
        System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0));
        System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1));

        // show how the score varies along a line from 0,0 to 1,1
        System.out.printf("\nx\tscore\n");
        for (int i = 0; i < 100; i++) {
            final double x = 0.0 + i / 10.0;
            v.set(0, x);
            v.set(1, x);
            v.set(2, 1);

            r = learningAlgo.classifyFull(v);

            System.out.printf("%.2f\t%.3f\n", x, r.get(1));
        }

    }

    public static Vector getVector(Point point) {
        Vector v = new DenseVector(3);
        v.set(0, point.x);
        v.set(1, point.y);
        v.set(2, 1);

        return v;
    }
}

On Wed, Jun 27, 2012 at 11:07 AM, sam wu <[email protected]> wrote:

> I don't think the problem is due to the sample size, if you use 2 features,
> 10 samples might be OK.
>
> The problem is that you didn't include bias(intercept) term, which is
> always 1.
> If you add bias term(1) to your point class, you'll get 0.9999 for cluster
> 0 probability.
>
>
> Sam
>
> On Wed, Jun 27, 2012 at 1:23 AM, Sean Owen <[email protected]> wrote:
>
> > Those are both true; they may not be the issue here.
> >
> > The test point definitely belongs in the first of the two groups you
> > created. Why is the result surprising?
> >
> > On Wed, Jun 27, 2012 at 9:15 AM, Lance Norskog <[email protected]>
> wrote:
> >
> > > Not enough samples. Machine learning algorithms in general do well if
> > > you have large sample sets (hundreds or thousands) from "real" data
> > > sources. The data should have a strong signal but be a little noisy.
> > >
> > > Also: your Point class needs a hashCode() since it does equals(). The
> > > Map class won't work at scale.
> > >
> > > On Wed, Jun 27, 2012 at 1:00 AM, damodar shetyo <
> [email protected]
> > >
> > > wrote:
> > > > I am trying to build a simple model that can group points in 2D
> > space.Am
> > > > training the model by giving few examples.After that i am using the
> > model
> > > > to predict the group in which the any other points may fall.But am
> not
> > > > getting answer as expected.Am i missing something in my code or am i
> > > doing
> > > > something wrong?
> > > >
> > > >       public class SimpleClassifier {
> > > >
> > > >    public static class Point{
> > > >        public int x;
> > > >        public int y;
> > > >
> > > >        public Point(int x,int y){
> > > >            this.x = x;
> > > >            this.y = y;
> > > >        }
> > > >
> > > >        @Override
> > > >        public boolean equals(Object arg0) {
> > > >            Point p = (Point)  arg0;
> > > >            return( (this.x == p.x) &&(this.y== p.y));
> > > >        }
> > > >
> > > >        @Override
> > > >        public String toString() {
> > > >            // TODO Auto-generated method stub
> > > >            return  this.x + " , " + this.y ;
> > > >        }
> > > >    }
> > > >    public static void main(String[] args) {
> > > >
> > > >        Map<Point,Integer> points = new
> HashMap<SimpleClassifier.Point,
> > > > Integer>();
> > > >
> > > >        points.put(new Point(0,0), 0);
> > > >        points.put(new Point(1,1), 0);
> > > >        points.put(new Point(1,0), 0);
> > > >        points.put(new Point(0,1), 0);
> > > >        points.put(new Point(2,2), 0);
> > > >
> > > >
> > > >        points.put(new Point(8,8), 1);
> > > >        points.put(new Point(8,9), 1);
> > > >        points.put(new Point(9,8), 1);
> > > >        points.put(new Point(9,9), 1);
> > > >
> > > >
> > > >        OnlineLogisticRegression learningAlgo = new
> > > > OnlineLogisticRegression();
> > > >        learningAlgo =  new OnlineLogisticRegression(2, 2, new L1());
> > > >        learningAlgo.learningRate(50);
> > > >
> > > >        //learningAlgo.alpha(1).stepOffset(1000);
> > > >
> > > >        System.out.println("training model  \n" );
> > > >        for(Point point : points.keySet()){
> > > >            Vector v = getVector(point);
> > > >            System.out.println(point  + " belongs to " +
> > > points.get(point));
> > > >            learningAlgo.train(points.get(point), v);
> > > >        }
> > > >
> > > >        learningAlgo.close();
> > > >
> > > >
> > > >        //now classify real data
> > > >        Vector v = new RandomAccessSparseVector(2);
> > > >        v.set(0, 0.5);
> > > >        v.set(1, 0.5);
> > > >
> > > >        Vector r = learningAlgo.classifyFull(v);
> > > >        System.out.println(r);
> > > >
> > > >        System.out.println("ans = " );
> > > >        System.out.println("no of categories = " +
> > > > learningAlgo.numCategories());
> > > >        System.out.println("no of features = " +
> > > > learningAlgo.numFeatures());
> > > >        System.out.println("Probability of cluster 0 = " + r.get(0));
> > > >        System.out.println("Probability of cluster 1 = " + r.get(1));
> > > >
> > > >    }
> > > >
> > > >    public static Vector getVector(Point point){
> > > >        Vector v = new DenseVector(2);
> > > >        v.set(0, point.x);
> > > >        v.set(1, point.y);
> > > >
> > > >        return v;
> > > >    }
> > > > }
> > > >
> > > > OP
> > > > ans =
> > > > no of categories = 2
> > > > no of features = 2
> > > > Probability of cluster 0 = 3.9580985042775296E-4
> > > > Probability of cluster 1 = 0.9996041901495722
> > > >
> > > > 99 % of times the output show more probability for cluster 1.Why?
> > > >
> > > >
> > > >
> > > > --
> > > > Regards,
> > > > Damodar Shetyo
> > >
> > >
> > >
> > > --
> > > Lance Norskog
> > > [email protected]
> > >
> >
>

Reply via email to