In the interests of timeliness (I would have to figure out how to use
github to apache), and the fact that the code is not that long - I will
just post it here. It is pretty much the code from the examples and
Mahout in Action with a few modifications. Beware of non-standard
formatting though.
As to the number of items I have been running through the trainer -
about 200. I am just trying to get my first trainer, evaluator,
production end to end going before I start loading it up.
package com.zensa.spinn3r.mahout;
import com.google.common.collect.Maps;
import com.zensa.config.ConfigProperties;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.ModelDissector;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class SnomedSGDClassificationTrainer
{
private static final int FEATURES = 1000;
private AdaptiveLogisticRegression learningAlgorithm = null;;
private static final FeatureVectorEncoder encoder = new
StaticWordValueEncoder( "code" );
private static final FeatureVectorEncoder bias = new
ConstantValueEncoder( "Intercept" );
private Dictionary msgs = new Dictionary();
private double averageLL = 0;
private double averageCorrect = 0;
private int k = 0;
private double step = 0;
private int[] bumps = { 1, 2, 5 };
private String modelFile = null;
public void initialize()
{
encoder.setProbes( 2 );
learningAlgorithm = new AdaptiveLogisticRegression( 2, FEATURES, new
L1() );
learningAlgorithm.setInterval( 800 );
learningAlgorithm.setAveragingWindow( 500 );
modelFile = ConfigProperties.getInstance().getProperty(
"spinn3r.classifier.model_file" );
}
public void train( String msgId, Map<String, Integer> codes )
throws IOException
{
int actual = msgs.intern( msgId );
Vector v = encodeFeatureVector( codes, actual );
learningAlgorithm.train( actual, v );
k++;
int bump = bumps[(int)Math.floor( step ) % bumps.length];
int scale = (int)Math.pow( 10, Math.floor( step / bumps.length ) );
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best =
learningAlgorithm.getBest();
double maxBeta;
double norm;
double nonZeros;
double positive;
double lambda = 0;
double mu = 0;
if( best != null )
{
CrossFoldLearner state = best.getPayload().getLearner();
averageCorrect = state.percentCorrect();
averageLL = state.logLikelihood();
OnlineLogisticRegression model = state.getModels().get( 0 );
model.close();
Matrix beta = model.getBeta();
maxBeta = beta.aggregate( Functions.MAX, Functions.ABS );
norm = beta.aggregate( Functions.PLUS, Functions.ABS );
nonZeros = beta.aggregate( Functions.PLUS, new DoubleFunction()
{
@Override
public double apply(
double v )
{
return Math.abs(v)
> 1.0e-6 ? 1 : 0;
}
});
positive = beta.aggregate( Functions.PLUS, new DoubleFunction()
{
@Override
public double apply(
double v )
{
return v > 0 ? 1 :
0;
}
});
lambda = learningAlgorithm.getBest().getMappedParams()[0];
mu = learningAlgorithm.getBest().getMappedParams()[1];
}
else
{
maxBeta = 0;
nonZeros = 0;
positive = 0;
norm = 0;
}
if( k % ( bump * scale ) == 0 )
{
if( learningAlgorithm.getBest() != null )
{
ModelSerializer.writeBinary("/Users/tim/spinn3r_data/model/snomed-" + k
+ ".model", learningAlgorithm.getBest().getPayload().getLearner());
}
step += 0.25;
System.out.printf( "%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t",
maxBeta, nonZeros, positive, norm, lambda, mu );
System.out.printf( "%d\t%.3f\t%.2f\n", k, averageLL,
averageCorrect * 100 );
}
}
public void finishTraining( String msgId, Map<String, Integer>codes )
throws IOException
{
try
{
learningAlgorithm.close();
} catch( Exception e )
{
System.out.println( "SnomedClassificationTrainer.train() -
learningAlgorithm.close() Error = " + e.getMessage() );
e.printStackTrace();
}
dissect( msgId, learningAlgorithm, codes );
ModelSerializer.writeBinary( modelFile,
learningAlgorithm.getBest().getPayload().getLearner() );
}
private Vector encodeFeatureVector( Map<String, Integer> codes, int
actual )
{
Vector v = new RandomAccessSparseVector( FEATURES );
bias.addToVector( "", 1, v );
for( String code : codes.keySet() )
{
Integer count = codes.get( code );
encoder.addToVector( code, Math.log( 1 + count.intValue() ), v );
}
return v;
}
private void dissect( String msgId, AdaptiveLogisticRegression
learningAlgorithm, Map<String, Integer> codes )
{
CrossFoldLearner model =
learningAlgorithm.getBest().getPayload().getLearner();
model.close();
Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
ModelDissector md = new ModelDissector();
encoder.setTraceDictionary( traceDictionary );
bias.setTraceDictionary( traceDictionary );
int actual = msgs.intern( msgId );
traceDictionary.clear();
Vector v = encodeFeatureVector( codes, actual );
md.update( v, traceDictionary, model );
List<ModelDissector.Weight> weights = md.summary( 1000 );
for( ModelDissector.Weight w : weights )
{
System.out.printf( "%s\t%.1f\t%.1f\t%s\t%.1f\t%s\n",
w.getFeature(), w.getWeight(), w.getCategory( 1 ), w.getWeight( 1 ),
w.getCategory( 2 ), w.getWeight( 2 ) );
}
}
}
The End.
Tim
-------- Original Message --------
Subject: Re: AdaptiveLogisticRegression.close()
ArrayIndexOutOfBoundsException
From: Ted Dunning <[email protected]>
Date: Sun, May 01, 2011 7:37 pm
To: [email protected]
Can you put your code on github?
There is a detail that slipped somewhere and I can't guess where it is.
Your constructor is correct for a binary classifier, but I can't say
much
else.
How much data, btw, did you pour in?
On Sun, May 1, 2011 at 3:18 AM, Tim Snyder <[email protected]>
wrote:
> I am currently using trunk from April 30, 2011. The code is loosely
> following the SGD training example from Mahout in Action. I have
> instantiated the learner with the purpose of having a binary classifier
> :
>
> AdaptiveLogisticRegression learningAlgorithm = new
> AdaptiveLogisticRegression( 2, FEATURES, new L1() );
>
> Everything works fine (ie. the training) until I get to
> learningAlogrithm.close() where I get the following exception:
>
> learningAlgorithm.close() Error =
> java.lang.ArrayIndexOutOfBoundsException
> java.lang.IllegalStateException:
> java.lang.ArrayIndexOutOfBoundsException
> Exception = null
> at
>
> org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.trainWithBufferedExamples(AdaptiveLogisticRegression.java:144)
> at
>
> org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.close(AdaptiveLogisticRegression.java:196)
> at
>
> com.zensa.spinn3r.mahout.SnomedSDGClassificationTrainer.finishTraining(SnomedSDGClassificationTrainer.java:159)
> at
>
> com.spinn3r.sdg.trainer.Spinn3rSDGTrainer.process(Spinn3rSDGTrainer.java:170)
> at
> com.spinn3r.sdg.trainer.Spinn3rSDGTrainer.main(Spinn3rSDGTrainer.java:272)
> Caused by: java.lang.ArrayIndexOutOfBoundsException
>
> If I change the number of categories to 100, the close() works fine. Any
> ideas on how to get around this and have a working binary classifier?
>
> Thanks in advance.
>
> Tim
>
>