You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/02/20 12:51:53 UTC

svn commit: r1570160 - in /opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag: POSModel.java POSTaggerME.java

Author: joern
Date: Thu Feb 20 11:51:52 2014
New Revision: 1570160

URL: http://svn.apache.org/r1570160
Log:
OPENNLP-641 Added initial sequence classification support

Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java?rev=1570160&r1=1570159&r2=1570160&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java Thu Feb 20 11:51:52 2014
@@ -27,6 +27,7 @@ import java.util.Map;
 import opennlp.tools.dictionary.Dictionary;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.MaxentModel;
+import opennlp.tools.ml.model.SequenceClassificationModel;
 import opennlp.tools.util.BaseToolFactory;
 import opennlp.tools.util.InvalidFormatException;
 import opennlp.tools.util.model.ArtifactSerializer;
@@ -66,6 +67,18 @@ public final class POSModel extends Base
     this(languageCode, posModel, null, new POSTaggerFactory(ngramDict,
         tagDictionary));
   }
+
+  public POSModel(String languageCode, SequenceClassificationModel<String> posModel,
+      Map<String, String> manifestInfoEntries, POSTaggerFactory posFactory) {
+
+    super(COMPONENT_NAME, languageCode, manifestInfoEntries, posFactory);
+
+    if (posModel == null)
+        throw new IllegalArgumentException("The maxentPosModel param must not be null!");
+
+    artifactMap.put(POS_MODEL_ENTRY_NAME, posModel);
+    checkArtifactMap();
+  }
   
   public POSModel(String languageCode, MaxentModel posModel,
       Map<String, String> manifestInfoEntries, POSTaggerFactory posFactory) {
@@ -79,6 +92,10 @@ public final class POSModel extends Base
     checkArtifactMap();
   }
   
+  private void init() {
+    
+  }
+  
   public POSModel(InputStream in) throws IOException, InvalidFormatException {
     super(COMPONENT_NAME, in);
   }
@@ -113,10 +130,25 @@ public final class POSModel extends Base
     }
   }
 
+  // TODO: This should be deprecated for the release ...
   public MaxentModel getPosModel() {
-    return (MaxentModel) artifactMap.get(POS_MODEL_ENTRY_NAME);
+    if (artifactMap.get(POS_MODEL_ENTRY_NAME) instanceof MaxentModel) {
+      return (MaxentModel) artifactMap.get(POS_MODEL_ENTRY_NAME);
+    }
+    else {
+      return null;
+    }
   }
 
+  public SequenceClassificationModel<String> getPosSequenceModel() {
+    if (artifactMap.get(POS_MODEL_ENTRY_NAME) instanceof SequenceClassificationModel) {
+      return (SequenceClassificationModel) artifactMap.get(POS_MODEL_ENTRY_NAME);
+    }
+    else {
+      return null;
+    }
+  }
+  
   /**
    * Retrieves the tag dictionary.
    * 

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java?rev=1570160&r1=1570159&r2=1570160&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java Thu Feb 20 11:51:52 2014
@@ -30,10 +30,13 @@ import java.util.concurrent.atomic.Atomi
 import opennlp.tools.dictionary.Dictionary;
 import opennlp.tools.ml.EventModelSequenceTrainer;
 import opennlp.tools.ml.EventTrainer;
+import opennlp.tools.ml.SequenceTrainer;
 import opennlp.tools.ml.TrainerFactory;
 import opennlp.tools.ml.TrainerFactory.TrainerType;
 import opennlp.tools.ml.model.Event;
 import opennlp.tools.ml.model.MaxentModel;
+import opennlp.tools.ml.model.SequenceClassificationModel;
+import opennlp.tools.namefind.NameSampleSequenceStream;
 import opennlp.tools.ngram.NGramModel;
 import opennlp.tools.util.BeamSearch;
 import opennlp.tools.util.ObjectStream;
@@ -85,27 +88,9 @@ public class POSTaggerME implements POST
 
   private Sequence bestSequence;
 
-  /**
-   * The search object used for search multiple sequences of tags.
-   */
-  protected BeamSearch<String> beam;
+  private SequenceClassificationModel<String> model;
 
-  /**
-   * Constructor that overrides the {@link SequenceValidator} from the model.
-   * 
-   * @deprecated use {@link #POSTaggerME(POSModel, int, int)} instead. The model
-   *             knows which {@link SequenceValidator} to use.
-   */
-  public POSTaggerME(POSModel model, int beamSize, int cacheSize, SequenceValidator<String> sequenceValidator) {
-    POSTaggerFactory factory = model.getFactory();
-    posModel = model.getPosModel();
-    model.getTagDictionary();
-    contextGen = factory.getPOSContextGenerator(beamSize);
-    tagDictionary = factory.getTagDictionary();
-    size = beamSize;
-    beam = new BeamSearch<String>(size, contextGen, posModel,
-        sequenceValidator, cacheSize);
-  }
+  private SequenceValidator<String> sequenceValidator;
   
   /**
    * Initializes the current instance with the provided
@@ -120,8 +105,16 @@ public class POSTaggerME implements POST
     contextGen = factory.getPOSContextGenerator(beamSize);
     tagDictionary = factory.getTagDictionary();
     size = beamSize;
-    beam = new BeamSearch<String>(size, contextGen, posModel,
-        factory.getSequenceValidator(), cacheSize);
+    
+    sequenceValidator = factory.getSequenceValidator();
+    
+    if (model.getPosModel() != null) {
+      this.model = new opennlp.tools.ml.BeamSearch<String>(beamSize,
+          model.getPosModel(), cacheSize);
+    }
+    else {
+      this.model = model.getPosSequenceModel();
+    }
   }
   
   /**
@@ -145,7 +138,7 @@ public class POSTaggerME implements POST
 
   @Deprecated
   public List<String> tag(List<String> sentence) {
-    bestSequence = beam.bestSequence(sentence.toArray(new String[sentence.size()]), null);
+    bestSequence = model.bestSequence(sentence.toArray(new String[sentence.size()]), null, contextGen, sequenceValidator);
     return bestSequence.getOutcomes();
   }
 
@@ -154,7 +147,7 @@ public class POSTaggerME implements POST
   }
 
   public String[] tag(String[] sentence, Object[] additionaContext) {
-    bestSequence = beam.bestSequence(sentence, additionaContext);
+    bestSequence = model.bestSequence(sentence, additionaContext, contextGen, sequenceValidator);
     List<String> t = bestSequence.getOutcomes();
     return t.toArray(new String[t.size()]);
   }
@@ -168,7 +161,8 @@ public class POSTaggerME implements POST
    * @return At most the specified number of taggings for the specified sentence.
    */
   public String[][] tag(int numTaggings, String[] sentence) {
-    Sequence[] bestSequences = beam.bestSequences(numTaggings, sentence,null);
+    Sequence[] bestSequences = model.bestSequences(numTaggings, sentence, null,
+        contextGen, sequenceValidator);
     String[][] tags = new String[bestSequences.length][];
     for (int si=0;si<tags.length;si++) {
       List<String> t = bestSequences[si].getOutcomes();
@@ -179,7 +173,8 @@ public class POSTaggerME implements POST
 
   @Deprecated
   public Sequence[] topKSequences(List<String> sentence) {
-    return beam.bestSequences(size, sentence.toArray(new String[sentence.size()]), null);
+    return model.bestSequences(size, sentence.toArray(new String[sentence.size()]), null,
+        contextGen, sequenceValidator);
   }
 
   public Sequence[] topKSequences(String[] sentence) {
@@ -187,7 +182,7 @@ public class POSTaggerME implements POST
   }
 
   public Sequence[] topKSequences(String[] sentence, Object[] additionaContext) {
-    return beam.bestSequences(size, sentence, additionaContext);
+    return model.bestSequences(size, sentence, additionaContext, contextGen, sequenceValidator);
   }
 
   /**
@@ -259,8 +254,8 @@ public class POSTaggerME implements POST
     
     TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings());
     
-    MaxentModel posModel;
-    
+    MaxentModel posModel = null;
+    SequenceClassificationModel<String> seqPosModel = null;
     if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
       ObjectStream<Event> es = new POSSampleEventStream(samples, contextGenerator);
       
@@ -274,6 +269,15 @@ public class POSTaggerME implements POST
           manifestInfoEntries);
       posModel = trainer.train(ss);
     }
+    else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
+      SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
+          trainParams.getSettings(), manifestInfoEntries);
+      
+      // TODO: This will probably cause issue, since the feature generator uses the outcomes array
+      
+      POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
+      seqPosModel = trainer.train(ss);
+    }
     else {
       throw new IllegalArgumentException("Trainer type is not supported: " + trainerType);  
     }