You are viewing a plain text version of this content. The canonical link for it is here.
Posted to solr-user@lucene.apache.org by damodar shetyo <ak...@gmail.com> on 2012/06/28 14:19:02 UTC

simple OnlineLogisticRegression classication example using mahout

I am trying to build a simple model that can group points in 2D space.Am
training the model by giving few examples.After that i am using the model
to predict the group in which the any other points may fall.But am not
getting answer as expected.Am i missing something in my code or am i doing
something wrong?

       public class SimpleClassifier {

    public static class Point{
        public int x;
        public int y;

        public Point(int x,int y){
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object arg0) {
            Point p = (Point)  arg0;
            return( (this.x == p.x) &&(this.y== p.y));
        }

        @Override
        public String toString() {
            // TODO Auto-generated method stub
            return  this.x + " , " + this.y ;
        }
    }
    public static void main(String[] args) {

        Map<Point,Integer> points = new HashMap<SimpleClassifier.Point,
Integer>();

        points.put(new Point(0,0), 0);
        points.put(new Point(1,1), 0);
        points.put(new Point(1,0), 0);
        points.put(new Point(0,1), 0);
        points.put(new Point(2,2), 0);


        points.put(new Point(8,8), 1);
        points.put(new Point(8,9), 1);
        points.put(new Point(9,8), 1);
        points.put(new Point(9,9), 1);


        OnlineLogisticRegression learningAlgo = new
OnlineLogisticRegression();
        learningAlgo =  new OnlineLogisticRegression(2, 2, new L1());
        learningAlgo.learningRate(50);

        //learningAlgo.alpha(1).stepOffset(1000);

        System.out.println("training model  \n" );
        for(Point point : points.keySet()){
            Vector v = getVector(point);
            System.out.println(point  + " belongs to " + points.get(point));
            learningAlgo.train(points.get(point), v);
        }

        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(2);
        v.set(0, 0.5);
        v.set(1, 0.5);

        Vector r = learningAlgo.classifyFull(v);
        System.out.println(r);

        System.out.println("ans = " );
        System.out.println("no of categories = " +
learningAlgo.numCategories());
        System.out.println("no of features = " +
learningAlgo.numFeatures());
        System.out.println("Probability of cluster 0 = " + r.get(0));
        System.out.println("Probability of cluster 1 = " + r.get(1));

    }

    public static Vector getVector(Point point){
        Vector v = new DenseVector(2);
        v.set(0, point.x);
        v.set(1, point.y);

        return v;
    }
}

OP
ans =
no of categories = 2
no of features = 2
Probability of cluster 0 = 3.9580985042775296E-4
Probability of cluster 1 = 0.9996041901495722

99 % of times the output show more probability for cluster 1.Why?



-- 
Regards,
Damodar Shetyo




-- 
Regards,
Damodar Shetyo