You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jz...@apache.org on 2022/12/20 16:23:17 UTC

[opennlp] branch master updated: OPENNLP-1417 Reduce unchecked assignments and raw types in TrainerFactory (#462)

This is an automated email from the ASF dual-hosted git repository.

jzemerick pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git


The following commit(s) were added to refs/heads/master by this push:
     new f04204ac OPENNLP-1417 Reduce unchecked assignments and raw types in TrainerFactory (#462)
f04204ac is described below

commit f04204ac07ae517d621781157ddb658b2ae8277b
Author: Martin Wiesner <ma...@users.noreply.github.com>
AuthorDate: Tue Dec 20 17:23:12 2022 +0100

    OPENNLP-1417 Reduce unchecked assignments and raw types in TrainerFactory (#462)
    
    - eliminates unchecked assignments and use of raw types in `opennlp.tools.ml.TrainerFactory`
    - simplifies existing, duplicated code
    - introduces a generic type for `EventModelSequenceTrainer` interface
    - adds missing JavaDoc
---
 .../main/java/opennlp/tools/chunker/ChunkerME.java |   2 +-
 .../opennlp/tools/lemmatizer/LemmatizerME.java     |   4 +-
 .../ml/AbstractEventModelSequenceTrainer.java      |   9 +-
 .../tools/ml/EventModelSequenceTrainer.java        |   5 +-
 .../java/opennlp/tools/ml/SequenceTrainer.java     |   4 +-
 .../main/java/opennlp/tools/ml/TrainerFactory.java | 132 +++++++++++++--------
 .../SimplePerceptronSequenceTrainer.java           |   3 +-
 .../java/opennlp/tools/namefind/NameFinderME.java  |   4 +-
 .../java/opennlp/tools/postag/POSTaggerME.java     |   2 +-
 .../java/opennlp/tools/ml/MockSequenceTrainer.java |   5 +-
 10 files changed, 104 insertions(+), 66 deletions(-)

diff --git a/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java b/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java
index e1246fa7..3656aeb9 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java
@@ -182,7 +182,7 @@ public class ChunkerME implements Chunker {
       chunkerModel = trainer.train(es);
     }
     else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
-      SequenceTrainer<ChunkSample> trainer = TrainerFactory.getSequenceModelTrainer(
+      SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
           mlParams, manifestInfoEntries);
 
       // TODO: This will probably cause issue, since the feature generator uses the outcomes array
diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java
index 4a19c516..33c34780 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java
@@ -267,12 +267,12 @@ public class LemmatizerME implements Lemmatizer {
     }
     else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
       LemmaSampleSequenceStream ss = new LemmaSampleSequenceStream(samples, contextGenerator);
-      EventModelSequenceTrainer trainer =
+      EventModelSequenceTrainer<LemmaSample> trainer =
           TrainerFactory.getEventModelSequenceTrainer(params, manifestInfoEntries);
       lemmatizerModel = trainer.train(ss);
     }
     else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
-      SequenceTrainer<LemmaSample> trainer = TrainerFactory.getSequenceModelTrainer(
+      SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
               params, manifestInfoEntries);
 
       // TODO: This will probably cause issue, since the feature generator uses the outcomes array
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java
index 362a0d69..0adad760 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java
@@ -19,19 +19,20 @@ package opennlp.tools.ml;
 
 import java.io.IOException;
 
+import opennlp.tools.ml.model.Event;
 import opennlp.tools.ml.model.MaxentModel;
 import opennlp.tools.ml.model.SequenceStream;
 
 public abstract class AbstractEventModelSequenceTrainer extends AbstractTrainer implements
-    EventModelSequenceTrainer {
+    EventModelSequenceTrainer<Event> {
 
   public AbstractEventModelSequenceTrainer() {
   }
 
-  public abstract MaxentModel doTrain(SequenceStream events)
-      throws IOException;
+  public abstract MaxentModel doTrain(SequenceStream<Event> events) throws IOException;
 
-  public final MaxentModel train(SequenceStream events) throws IOException {
+  @Override
+  public final MaxentModel train(SequenceStream<Event> events) throws IOException {
     validate();
 
     MaxentModel model = doTrain(events);
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java
index a6f07dee..caa98512 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java
@@ -20,19 +20,18 @@ package opennlp.tools.ml;
 import java.io.IOException;
 import java.util.Map;
 
-import opennlp.tools.commons.Sample;
 import opennlp.tools.commons.Trainer;
 import opennlp.tools.ml.model.MaxentModel;
 import opennlp.tools.ml.model.SequenceStream;
 import opennlp.tools.util.TrainingParameters;
 
-public interface EventModelSequenceTrainer extends Trainer {
+public interface EventModelSequenceTrainer<T> extends Trainer {
 
   String SEQUENCE_VALUE = "EventModelSequence";
 
   void init(Map<String, Object> trainParams, Map<String, String> reportMap);
   void init(TrainingParameters trainParams, Map<String, String> reportMap);
 
-  MaxentModel train(SequenceStream<? extends Sample> events) throws IOException;
+  MaxentModel train(SequenceStream<T> events) throws IOException;
 
 }
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
index 263682a0..d9e81b2f 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
@@ -25,7 +25,7 @@ import opennlp.tools.ml.model.SequenceClassificationModel;
 import opennlp.tools.ml.model.SequenceStream;
 import opennlp.tools.util.TrainingParameters;
 
-public interface SequenceTrainer<T> extends Trainer {
+public interface SequenceTrainer extends Trainer {
 
   String SEQUENCE_VALUE = "Sequence";
 
@@ -33,5 +33,5 @@ public interface SequenceTrainer<T> extends Trainer {
   void init(Map<String, String> trainParams, Map<String, String> reportMap);
   void init(TrainingParameters trainParams, Map<String, String> reportMap);
 
-  SequenceClassificationModel<String> train(SequenceStream<T> events) throws IOException;
+  <T> SequenceClassificationModel<String> train(SequenceStream<T> events) throws IOException;
 }
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
index a7604cb4..c1acba40 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
@@ -20,7 +20,6 @@ package opennlp.tools.ml;
 import java.lang.reflect.Constructor;
 import java.util.Map;
 
-import opennlp.tools.commons.Sample;
 import opennlp.tools.commons.Trainer;
 import opennlp.tools.ml.maxent.GISTrainer;
 import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
@@ -31,6 +30,10 @@ import opennlp.tools.util.TrainingParameters;
 import opennlp.tools.util.ext.ExtensionLoader;
 import opennlp.tools.util.ext.ExtensionNotLoadedException;
 
+/**
+ * A factory to initialize {@link Trainer} instances depending on a trainer type
+ * configured via {@link TrainingParameters}.
+ */
 public class TrainerFactory {
 
   public enum TrainerType {
@@ -40,8 +43,11 @@ public class TrainerFactory {
   }
 
   // built-in trainers
-  private static final Map<String, Class> BUILTIN_TRAINERS;
+  private static final Map<String, Class<? extends Trainer>> BUILTIN_TRAINERS;
 
+  /*
+   * Initialize the built-in trainers
+   */
   static {
     BUILTIN_TRAINERS = Map.of(
         GISTrainer.MAXENT_VALUE, GISTrainer.class,
@@ -52,10 +58,12 @@ public class TrainerFactory {
   }
 
   /**
-   * Determines the trainer type based on the ALGORITHM_PARAM value.
+   * Determines the {@link TrainerType} based on the
+   * {@link AbstractTrainer#ALGORITHM_PARAM} value.
    *
-   * @param trainParams - Map of training parameters
-   * @return the trainer type or null if type couldn't be determined.
+   * @param trainParams - A mapping of {@link TrainingParameters training parameters}.
+   *
+   * @return The {@link TrainerType} or {@code null} if the type couldn't be determined.
    */
   public static TrainerType getTrainerType(TrainingParameters trainParams) {
 
@@ -110,70 +118,101 @@ public class TrainerFactory {
     return null;
   }
 
-  public static SequenceTrainer getSequenceModelTrainer(TrainingParameters trainParams,
-      Map<String, String> reportMap) {
+  /**
+   * Retrieves a {@link SequenceTrainer} that fits the given parameters.
+   *
+   * @param trainParams The {@link TrainingParameters} to check for the trainer type.
+   *                    Note: The entry {@link AbstractTrainer#ALGORITHM_PARAM} is used
+   *                    to determine the type.
+   * @param reportMap A {@link Map} that shall be used during initialization of
+   *                  the {@link SequenceTrainer}.
+   *                  
+   * @return A valid {@link SequenceTrainer} for the configured {@code trainParams}.
+   * @throws IllegalArgumentException Thrown if the trainer type could not be determined.
+   */
+  public static SequenceTrainer getSequenceModelTrainer(
+          TrainingParameters trainParams, Map<String, String> reportMap) {
     String trainerType = trainParams.getStringParameter(AbstractTrainer.ALGORITHM_PARAM,null);
 
     if (trainerType != null) {
+      final SequenceTrainer trainer;
       if (BUILTIN_TRAINERS.containsKey(trainerType)) {
-        SequenceTrainer<? extends Sample> trainer = TrainerFactory.
-            <SequenceTrainer>createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
-        trainer.init(trainParams, reportMap);
-        return trainer;
+        trainer = TrainerFactory.createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
       } else {
-        SequenceTrainer<? extends Sample> trainer =
-            ExtensionLoader.instantiateExtension(SequenceTrainer.class, trainerType);
-        trainer.init(trainParams, reportMap);
-        return trainer;
+        trainer = ExtensionLoader.instantiateExtension(SequenceTrainer.class, trainerType);
       }
+      trainer.init(trainParams, reportMap);
+      return trainer;
     }
     else {
       throw new IllegalArgumentException("Trainer type couldn't be determined!");
     }
   }
 
-  public static EventModelSequenceTrainer getEventModelSequenceTrainer(TrainingParameters trainParams,
-      Map<String, String> reportMap) {
+  /**
+   * Retrieves an {@link EventModelSequenceTrainer} that fits the given parameters.
+   *
+   * @param trainParams The {@link TrainingParameters} to check for the trainer type.
+   *                    Note: The entry {@link AbstractTrainer#ALGORITHM_PARAM} is used
+   *                    to determine the type.
+   * @param reportMap A {@link Map} that shall be used during initialization of
+   *                  the {@link EventModelSequenceTrainer}.
+   *
+   * @return A valid {@link EventModelSequenceTrainer} for the configured {@code trainParams}.
+   * @throws IllegalArgumentException Thrown if the trainer type could not be determined.
+   */
+  public static <T> EventModelSequenceTrainer<T> getEventModelSequenceTrainer(
+          TrainingParameters trainParams, Map<String, String> reportMap) {
     String trainerType = trainParams.getStringParameter(AbstractTrainer.ALGORITHM_PARAM,null);
 
     if (trainerType != null) {
+      final EventModelSequenceTrainer<T> trainer;
       if (BUILTIN_TRAINERS.containsKey(trainerType)) {
-        EventModelSequenceTrainer trainer = TrainerFactory.
-            <EventModelSequenceTrainer>createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
-        trainer.init(trainParams, reportMap);
-        return trainer;
+        trainer = TrainerFactory.createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
       } else {
-        EventModelSequenceTrainer trainer =
-            ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class, trainerType);
-        trainer.init(trainParams, reportMap);
-        return trainer;
+        trainer = ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class, trainerType);
       }
+      trainer.init(trainParams, reportMap);
+      return trainer;
     }
     else {
       throw new IllegalArgumentException("Trainer type couldn't be determined!");
     }
   }
 
-  public static EventTrainer getEventTrainer(TrainingParameters trainParams,
-      Map<String, String> reportMap) {
+  /**
+   * Retrieves an {@link EventTrainer} that fits the given parameters.
+   *
+   * @param trainParams The {@link TrainingParameters} to check for the trainer type.
+   *                    Note: The entry {@link AbstractTrainer#ALGORITHM_PARAM} is used
+   *                    to determine the type. If the type is not defined, the
+   *                    {@link GISTrainer#MAXENT_VALUE} will be used.
+   * @param reportMap A {@link Map} that shall be used during initialization of
+   *                  the {@link EventTrainer}.
+   *
+   * @return A valid {@link EventTrainer} for the configured {@code trainParams}.
+   */
+  public static EventTrainer getEventTrainer(
+          TrainingParameters trainParams, Map<String, String> reportMap) {
 
     // if the trainerType is not defined -- use the GISTrainer.
-    String trainerType = 
-        trainParams.getStringParameter(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
+    String trainerType = trainParams.getStringParameter(
+            AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
 
+    final EventTrainer trainer;
     if (BUILTIN_TRAINERS.containsKey(trainerType)) {
-      EventTrainer trainer = TrainerFactory.
-              <EventTrainer>createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
-      trainer.init(trainParams, reportMap);
-      return trainer;
+      trainer = TrainerFactory.createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
     } else {
-      EventTrainer trainer = ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType);
-      trainer.init(trainParams, reportMap);
-      return trainer;
+      trainer = ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType);
     }
-
+    trainer.init(trainParams, reportMap);
+    return trainer;
   }
 
+  /**
+   * @param trainParams The {@link TrainingParameters} to validate. Must not be {@code null}.
+   * @return {@code true} if the {@code trainParams} could be validated, {@code false} otherwise.
+   */
   public static boolean isValid(TrainingParameters trainParams) {
 
     // TODO: Need to validate all parameters correctly ... error prone?!
@@ -202,22 +241,19 @@ public class TrainerFactory {
     return true;
   }
 
-  private static <T extends Trainer> T createBuiltinTrainer(Class<T> trainerClass) {
-    T theTrainer = null;
+  @SuppressWarnings("unchecked")
+  private static <T extends Trainer> T createBuiltinTrainer(Class<? extends Trainer> trainerClass) {
+    Trainer theTrainer = null;
     if (trainerClass != null) {
       try {
-        Constructor<T> contructor = trainerClass.getConstructor();
-        theTrainer = contructor.newInstance();
+        Constructor<? extends Trainer> c = trainerClass.getConstructor();
+        theTrainer = c.newInstance();
       } catch (Exception e) {
-        String msg = "Could not instantiate the "
-            + trainerClass.getCanonicalName()
-            + ". The initialization throw an exception.";
-        System.err.println(msg);
-        e.printStackTrace();
+        String msg = "Could not instantiate the " + trainerClass.getCanonicalName()
+            + ". The initialization threw an exception.";
         throw new IllegalArgumentException(msg, e);
       }
     }
-
-    return theTrainer;
+    return (T) theTrainer;
   }
 }
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
index f168553f..b7246c0a 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
@@ -107,7 +107,8 @@ public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceT
     }
   }
 
-  public AbstractModel doTrain(SequenceStream events) throws IOException {
+  @Override
+  public AbstractModel doTrain(SequenceStream<Event> events) throws IOException {
     int iterations = getIterations();
     int cutoff = getCutoff();
 
diff --git a/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java b/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
index 9417d6f0..a869d029 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
@@ -247,11 +247,11 @@ public class NameFinderME implements TokenNameFinder {
     else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
       NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, factory.createContextGenerator());
 
-      EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(
+      EventModelSequenceTrainer<NameSample> trainer = TrainerFactory.getEventModelSequenceTrainer(
               trainParams, manifestInfoEntries);
       nameFinderModel = trainer.train(ss);
     } else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
-      SequenceTrainer<NameSample> trainer = TrainerFactory.getSequenceModelTrainer(
+      SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
               trainParams, manifestInfoEntries);
 
       NameSampleSequenceStream ss =
diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
index f2ecc32f..ca45b0f2 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
@@ -255,7 +255,7 @@ public class POSTaggerME implements POSTagger {
     }
     else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
       POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
-      EventModelSequenceTrainer trainer =
+      EventModelSequenceTrainer<POSSample> trainer =
           TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries);
       posModel = trainer.train(ss);
     }
diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
index f3c848c6..502f79a1 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
@@ -24,9 +24,10 @@ import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.SequenceStream;
 import opennlp.tools.util.TrainingParameters;
 
-public class MockSequenceTrainer implements EventModelSequenceTrainer {
+public class MockSequenceTrainer implements EventModelSequenceTrainer<Sample> {
 
-  public AbstractModel train(SequenceStream<? extends Sample> events) {
+  @Override
+  public AbstractModel train(SequenceStream<Sample> events) {
     return null;
   }