You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/02/20 12:51:53 UTC
svn commit: r1570160 - in
/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag:
POSModel.java POSTaggerME.java
Author: joern
Date: Thu Feb 20 11:51:52 2014
New Revision: 1570160
URL: http://svn.apache.org/r1570160
Log:
OPENNLP-641 Added initial sequence classification support
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java?rev=1570160&r1=1570159&r2=1570160&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java Thu Feb 20 11:51:52 2014
@@ -27,6 +27,7 @@ import java.util.Map;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.MaxentModel;
+import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.util.BaseToolFactory;
import opennlp.tools.util.InvalidFormatException;
import opennlp.tools.util.model.ArtifactSerializer;
@@ -66,6 +67,18 @@ public final class POSModel extends Base
this(languageCode, posModel, null, new POSTaggerFactory(ngramDict,
tagDictionary));
}
+
+ public POSModel(String languageCode, SequenceClassificationModel<String> posModel,
+ Map<String, String> manifestInfoEntries, POSTaggerFactory posFactory) {
+
+ super(COMPONENT_NAME, languageCode, manifestInfoEntries, posFactory);
+
+ if (posModel == null)
+ throw new IllegalArgumentException("The maxentPosModel param must not be null!");
+
+ artifactMap.put(POS_MODEL_ENTRY_NAME, posModel);
+ checkArtifactMap();
+ }
public POSModel(String languageCode, MaxentModel posModel,
Map<String, String> manifestInfoEntries, POSTaggerFactory posFactory) {
@@ -79,6 +92,10 @@ public final class POSModel extends Base
checkArtifactMap();
}
+ private void init() {
+
+ }
+
public POSModel(InputStream in) throws IOException, InvalidFormatException {
super(COMPONENT_NAME, in);
}
@@ -113,10 +130,25 @@ public final class POSModel extends Base
}
}
+ // TODO: This should be deprecated for the release ...
public MaxentModel getPosModel() {
- return (MaxentModel) artifactMap.get(POS_MODEL_ENTRY_NAME);
+ if (artifactMap.get(POS_MODEL_ENTRY_NAME) instanceof MaxentModel) {
+ return (MaxentModel) artifactMap.get(POS_MODEL_ENTRY_NAME);
+ }
+ else {
+ return null;
+ }
}
+ public SequenceClassificationModel<String> getPosSequenceModel() {
+ if (artifactMap.get(POS_MODEL_ENTRY_NAME) instanceof SequenceClassificationModel) {
+ return (SequenceClassificationModel) artifactMap.get(POS_MODEL_ENTRY_NAME);
+ }
+ else {
+ return null;
+ }
+ }
+
/**
* Retrieves the tag dictionary.
*
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java?rev=1570160&r1=1570159&r2=1570160&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java Thu Feb 20 11:51:52 2014
@@ -30,10 +30,13 @@ import java.util.concurrent.atomic.Atomi
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.ml.EventModelSequenceTrainer;
import opennlp.tools.ml.EventTrainer;
+import opennlp.tools.ml.SequenceTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.TrainerFactory.TrainerType;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
+import opennlp.tools.ml.model.SequenceClassificationModel;
+import opennlp.tools.namefind.NameSampleSequenceStream;
import opennlp.tools.ngram.NGramModel;
import opennlp.tools.util.BeamSearch;
import opennlp.tools.util.ObjectStream;
@@ -85,27 +88,9 @@ public class POSTaggerME implements POST
private Sequence bestSequence;
- /**
- * The search object used for search multiple sequences of tags.
- */
- protected BeamSearch<String> beam;
+ private SequenceClassificationModel<String> model;
- /**
- * Constructor that overrides the {@link SequenceValidator} from the model.
- *
- * @deprecated use {@link #POSTaggerME(POSModel, int, int)} instead. The model
- * knows which {@link SequenceValidator} to use.
- */
- public POSTaggerME(POSModel model, int beamSize, int cacheSize, SequenceValidator<String> sequenceValidator) {
- POSTaggerFactory factory = model.getFactory();
- posModel = model.getPosModel();
- model.getTagDictionary();
- contextGen = factory.getPOSContextGenerator(beamSize);
- tagDictionary = factory.getTagDictionary();
- size = beamSize;
- beam = new BeamSearch<String>(size, contextGen, posModel,
- sequenceValidator, cacheSize);
- }
+ private SequenceValidator<String> sequenceValidator;
/**
* Initializes the current instance with the provided
@@ -120,8 +105,16 @@ public class POSTaggerME implements POST
contextGen = factory.getPOSContextGenerator(beamSize);
tagDictionary = factory.getTagDictionary();
size = beamSize;
- beam = new BeamSearch<String>(size, contextGen, posModel,
- factory.getSequenceValidator(), cacheSize);
+
+ sequenceValidator = factory.getSequenceValidator();
+
+ if (model.getPosModel() != null) {
+ this.model = new opennlp.tools.ml.BeamSearch<String>(beamSize,
+ model.getPosModel(), cacheSize);
+ }
+ else {
+ this.model = model.getPosSequenceModel();
+ }
}
/**
@@ -145,7 +138,7 @@ public class POSTaggerME implements POST
@Deprecated
public List<String> tag(List<String> sentence) {
- bestSequence = beam.bestSequence(sentence.toArray(new String[sentence.size()]), null);
+ bestSequence = model.bestSequence(sentence.toArray(new String[sentence.size()]), null, contextGen, sequenceValidator);
return bestSequence.getOutcomes();
}
@@ -154,7 +147,7 @@ public class POSTaggerME implements POST
}
public String[] tag(String[] sentence, Object[] additionaContext) {
- bestSequence = beam.bestSequence(sentence, additionaContext);
+ bestSequence = model.bestSequence(sentence, additionaContext, contextGen, sequenceValidator);
List<String> t = bestSequence.getOutcomes();
return t.toArray(new String[t.size()]);
}
@@ -168,7 +161,8 @@ public class POSTaggerME implements POST
* @return At most the specified number of taggings for the specified sentence.
*/
public String[][] tag(int numTaggings, String[] sentence) {
- Sequence[] bestSequences = beam.bestSequences(numTaggings, sentence,null);
+ Sequence[] bestSequences = model.bestSequences(numTaggings, sentence, null,
+ contextGen, sequenceValidator);
String[][] tags = new String[bestSequences.length][];
for (int si=0;si<tags.length;si++) {
List<String> t = bestSequences[si].getOutcomes();
@@ -179,7 +173,8 @@ public class POSTaggerME implements POST
@Deprecated
public Sequence[] topKSequences(List<String> sentence) {
- return beam.bestSequences(size, sentence.toArray(new String[sentence.size()]), null);
+ return model.bestSequences(size, sentence.toArray(new String[sentence.size()]), null,
+ contextGen, sequenceValidator);
}
public Sequence[] topKSequences(String[] sentence) {
@@ -187,7 +182,7 @@ public class POSTaggerME implements POST
}
public Sequence[] topKSequences(String[] sentence, Object[] additionaContext) {
- return beam.bestSequences(size, sentence, additionaContext);
+ return model.bestSequences(size, sentence, additionaContext, contextGen, sequenceValidator);
}
/**
@@ -259,8 +254,8 @@ public class POSTaggerME implements POST
TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings());
- MaxentModel posModel;
-
+ MaxentModel posModel = null;
+ SequenceClassificationModel<String> seqPosModel = null;
if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
ObjectStream<Event> es = new POSSampleEventStream(samples, contextGenerator);
@@ -274,6 +269,15 @@ public class POSTaggerME implements POST
manifestInfoEntries);
posModel = trainer.train(ss);
}
+ else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
+ SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
+ trainParams.getSettings(), manifestInfoEntries);
+
+ // TODO: This will probably cause issue, since the feature generator uses the outcomes array
+
+ POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
+ seqPosModel = trainer.train(ss);
+ }
else {
throw new IllegalArgumentException("Trainer type is not supported: " + trainerType);
}