Hector, thank you very much for youir response, I adapted my example:

-- 8< --
public class OLRTest {

    private static final String[] animals = new String[] { "alligator",
"ant",
            "bear", "bee", "bird", "camel", "cat", "cheetah", "chicken",
            "chimpanzee", "cow", "crocodile", "deer", "dog", "dolphin",
"duck",
            "eagle", "elephant", "fish", "fly", "fox", "frog", "giraffe",
            "goat", "goldfish", "hamster", "hippopotamus", "horse",
"kangaroo",
            "kitten", "lion", "lobster", "monkey", "octopus", "owl",
"panda",
            "pig", "puppy", "rabbit", "rat", "scorpion", "seal", "shark",
            "sheep", "snail", "snake", "spider", "squirrel", "tiger",
"turtle",
            "wolf", "zebra" };
    private static final int PASSES = 1;
    private static final int FEATURES = animals.length;
    private static final int CATEGORIES = 2;

    public static void main(String[] args) {
        final OnlineLogisticRegression algorithm = new
OnlineLogisticRegression(
                CATEGORIES, FEATURES, new L1());

        for (int i = 0; i < PASSES; i++) {
            int idx = 0;
            for (String animal : animals) {
                algorithm.train(0, generateVector(animal, idx++));
            }
        }

        algorithm.close();

        testClassify(algorithm, "lion");
        testClassify(algorithm, "rabbit");
        testClassify(algorithm, "xyz");
        testClassify(algorithm, "something");
    }

    private static int findIndex(String[] arr, String value) {
        int index = -1;
        for (int i = 0; i < arr.length; i++) {
            if (arr[i].equals(value)) {
                index = i;
            }
        }
        return index;
    }

    private static void testClassify(final OnlineLogisticRegression
algorithm,
            final String allegedAnimal) {
        System.out.println(allegedAnimal
                + " is an animal with a probability of "
                + algorithm.classifyScalar(generateVector(allegedAnimal,
                        findIndex(animals, allegedAnimal))) * 100 + "%");
    }

    private static Vector generateVector(String animal, int index) {
        final Vector v = new RandomAccessSparseVector(FEATURES);
        for (int i = 0; i < FEATURES; i++) {
            v.set(i, i == index ? 1 : 0);
        }
        return v;
    }
}
-- 8< --

however I still get this output:
-- 8< --
lion is an animal with a probability of 48.08362706423329%
rabbit is an animal with a probability of 48.26422253903691%
xyz is an animal with a probability of 50.0%
something is an animal with a probability of 50.0%
-- 8< --

any ideas or mistakes you can see?

Kind regards,
Joscha

Reply via email to