You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jz...@apache.org on 2022/12/20 16:23:17 UTC
[opennlp] branch master updated: OPENNLP-1417 Reduce unchecked assignments and raw types in TrainerFactory (#462)
This is an automated email from the ASF dual-hosted git repository.
jzemerick pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/master by this push:
new f04204ac OPENNLP-1417 Reduce unchecked assignments and raw types in TrainerFactory (#462)
f04204ac is described below
commit f04204ac07ae517d621781157ddb658b2ae8277b
Author: Martin Wiesner <ma...@users.noreply.github.com>
AuthorDate: Tue Dec 20 17:23:12 2022 +0100
OPENNLP-1417 Reduce unchecked assignments and raw types in TrainerFactory (#462)
- eliminates unchecked assignments and use of raw types in `opennlp.tools.ml.TrainerFactory`
- simplifies existing, duplicated code
- introduces a generic type for `EventModelSequenceTrainer` interface
- adds missing JavaDoc
---
.../main/java/opennlp/tools/chunker/ChunkerME.java | 2 +-
.../opennlp/tools/lemmatizer/LemmatizerME.java | 4 +-
.../ml/AbstractEventModelSequenceTrainer.java | 9 +-
.../tools/ml/EventModelSequenceTrainer.java | 5 +-
.../java/opennlp/tools/ml/SequenceTrainer.java | 4 +-
.../main/java/opennlp/tools/ml/TrainerFactory.java | 132 +++++++++++++--------
.../SimplePerceptronSequenceTrainer.java | 3 +-
.../java/opennlp/tools/namefind/NameFinderME.java | 4 +-
.../java/opennlp/tools/postag/POSTaggerME.java | 2 +-
.../java/opennlp/tools/ml/MockSequenceTrainer.java | 5 +-
10 files changed, 104 insertions(+), 66 deletions(-)
diff --git a/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java b/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java
index e1246fa7..3656aeb9 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/chunker/ChunkerME.java
@@ -182,7 +182,7 @@ public class ChunkerME implements Chunker {
chunkerModel = trainer.train(es);
}
else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
- SequenceTrainer<ChunkSample> trainer = TrainerFactory.getSequenceModelTrainer(
+ SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
mlParams, manifestInfoEntries);
// TODO: This will probably cause issue, since the feature generator uses the outcomes array
diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java
index 4a19c516..33c34780 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java
@@ -267,12 +267,12 @@ public class LemmatizerME implements Lemmatizer {
}
else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
LemmaSampleSequenceStream ss = new LemmaSampleSequenceStream(samples, contextGenerator);
- EventModelSequenceTrainer trainer =
+ EventModelSequenceTrainer<LemmaSample> trainer =
TrainerFactory.getEventModelSequenceTrainer(params, manifestInfoEntries);
lemmatizerModel = trainer.train(ss);
}
else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
- SequenceTrainer<LemmaSample> trainer = TrainerFactory.getSequenceModelTrainer(
+ SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
params, manifestInfoEntries);
// TODO: This will probably cause issue, since the feature generator uses the outcomes array
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java
index 362a0d69..0adad760 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java
@@ -19,19 +19,20 @@ package opennlp.tools.ml;
import java.io.IOException;
+import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceStream;
public abstract class AbstractEventModelSequenceTrainer extends AbstractTrainer implements
- EventModelSequenceTrainer {
+ EventModelSequenceTrainer<Event> {
public AbstractEventModelSequenceTrainer() {
}
- public abstract MaxentModel doTrain(SequenceStream events)
- throws IOException;
+ public abstract MaxentModel doTrain(SequenceStream<Event> events) throws IOException;
- public final MaxentModel train(SequenceStream events) throws IOException {
+ @Override
+ public final MaxentModel train(SequenceStream<Event> events) throws IOException {
validate();
MaxentModel model = doTrain(events);
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java
index a6f07dee..caa98512 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/EventModelSequenceTrainer.java
@@ -20,19 +20,18 @@ package opennlp.tools.ml;
import java.io.IOException;
import java.util.Map;
-import opennlp.tools.commons.Sample;
import opennlp.tools.commons.Trainer;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.util.TrainingParameters;
-public interface EventModelSequenceTrainer extends Trainer {
+public interface EventModelSequenceTrainer<T> extends Trainer {
String SEQUENCE_VALUE = "EventModelSequence";
void init(Map<String, Object> trainParams, Map<String, String> reportMap);
void init(TrainingParameters trainParams, Map<String, String> reportMap);
- MaxentModel train(SequenceStream<? extends Sample> events) throws IOException;
+ MaxentModel train(SequenceStream<T> events) throws IOException;
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
index 263682a0..d9e81b2f 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
@@ -25,7 +25,7 @@ import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.util.TrainingParameters;
-public interface SequenceTrainer<T> extends Trainer {
+public interface SequenceTrainer extends Trainer {
String SEQUENCE_VALUE = "Sequence";
@@ -33,5 +33,5 @@ public interface SequenceTrainer<T> extends Trainer {
void init(Map<String, String> trainParams, Map<String, String> reportMap);
void init(TrainingParameters trainParams, Map<String, String> reportMap);
- SequenceClassificationModel<String> train(SequenceStream<T> events) throws IOException;
+ <T> SequenceClassificationModel<String> train(SequenceStream<T> events) throws IOException;
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
index a7604cb4..c1acba40 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
@@ -20,7 +20,6 @@ package opennlp.tools.ml;
import java.lang.reflect.Constructor;
import java.util.Map;
-import opennlp.tools.commons.Sample;
import opennlp.tools.commons.Trainer;
import opennlp.tools.ml.maxent.GISTrainer;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
@@ -31,6 +30,10 @@ import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.ext.ExtensionLoader;
import opennlp.tools.util.ext.ExtensionNotLoadedException;
+/**
+ * A factory to initialize {@link Trainer} instances depending on a trainer type
+ * configured via {@link TrainingParameters}.
+ */
public class TrainerFactory {
public enum TrainerType {
@@ -40,8 +43,11 @@ public class TrainerFactory {
}
// built-in trainers
- private static final Map<String, Class> BUILTIN_TRAINERS;
+ private static final Map<String, Class<? extends Trainer>> BUILTIN_TRAINERS;
+ /*
+ * Initialize the built-in trainers
+ */
static {
BUILTIN_TRAINERS = Map.of(
GISTrainer.MAXENT_VALUE, GISTrainer.class,
@@ -52,10 +58,12 @@ public class TrainerFactory {
}
/**
- * Determines the trainer type based on the ALGORITHM_PARAM value.
+ * Determines the {@link TrainerType} based on the
+ * {@link AbstractTrainer#ALGORITHM_PARAM} value.
*
- * @param trainParams - Map of training parameters
- * @return the trainer type or null if type couldn't be determined.
+ * @param trainParams - A mapping of {@link TrainingParameters training parameters}.
+ *
+ * @return The {@link TrainerType} or {@code null} if the type couldn't be determined.
*/
public static TrainerType getTrainerType(TrainingParameters trainParams) {
@@ -110,70 +118,101 @@ public class TrainerFactory {
return null;
}
- public static SequenceTrainer getSequenceModelTrainer(TrainingParameters trainParams,
- Map<String, String> reportMap) {
+ /**
+ * Retrieves a {@link SequenceTrainer} that fits the given parameters.
+ *
+ * @param trainParams The {@link TrainingParameters} to check for the trainer type.
+ * Note: The entry {@link AbstractTrainer#ALGORITHM_PARAM} is used
+ * to determine the type.
+ * @param reportMap A {@link Map} that shall be used during initialization of
+ * the {@link SequenceTrainer}.
+ *
+ * @return A valid {@link SequenceTrainer} for the configured {@code trainParams}.
+ * @throws IllegalArgumentException Thrown if the trainer type could not be determined.
+ */
+ public static SequenceTrainer getSequenceModelTrainer(
+ TrainingParameters trainParams, Map<String, String> reportMap) {
String trainerType = trainParams.getStringParameter(AbstractTrainer.ALGORITHM_PARAM,null);
if (trainerType != null) {
+ final SequenceTrainer trainer;
if (BUILTIN_TRAINERS.containsKey(trainerType)) {
- SequenceTrainer<? extends Sample> trainer = TrainerFactory.
- <SequenceTrainer>createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
- trainer.init(trainParams, reportMap);
- return trainer;
+ trainer = TrainerFactory.createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
} else {
- SequenceTrainer<? extends Sample> trainer =
- ExtensionLoader.instantiateExtension(SequenceTrainer.class, trainerType);
- trainer.init(trainParams, reportMap);
- return trainer;
+ trainer = ExtensionLoader.instantiateExtension(SequenceTrainer.class, trainerType);
}
+ trainer.init(trainParams, reportMap);
+ return trainer;
}
else {
throw new IllegalArgumentException("Trainer type couldn't be determined!");
}
}
- public static EventModelSequenceTrainer getEventModelSequenceTrainer(TrainingParameters trainParams,
- Map<String, String> reportMap) {
+ /**
+ * Retrieves an {@link EventModelSequenceTrainer} that fits the given parameters.
+ *
+ * @param trainParams The {@link TrainingParameters} to check for the trainer type.
+ * Note: The entry {@link AbstractTrainer#ALGORITHM_PARAM} is used
+ * to determine the type.
+ * @param reportMap A {@link Map} that shall be used during initialization of
+ * the {@link EventModelSequenceTrainer}.
+ *
+ * @return A valid {@link EventModelSequenceTrainer} for the configured {@code trainParams}.
+ * @throws IllegalArgumentException Thrown if the trainer type could not be determined.
+ */
+ public static <T> EventModelSequenceTrainer<T> getEventModelSequenceTrainer(
+ TrainingParameters trainParams, Map<String, String> reportMap) {
String trainerType = trainParams.getStringParameter(AbstractTrainer.ALGORITHM_PARAM,null);
if (trainerType != null) {
+ final EventModelSequenceTrainer<T> trainer;
if (BUILTIN_TRAINERS.containsKey(trainerType)) {
- EventModelSequenceTrainer trainer = TrainerFactory.
- <EventModelSequenceTrainer>createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
- trainer.init(trainParams, reportMap);
- return trainer;
+ trainer = TrainerFactory.createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
} else {
- EventModelSequenceTrainer trainer =
- ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class, trainerType);
- trainer.init(trainParams, reportMap);
- return trainer;
+ trainer = ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class, trainerType);
}
+ trainer.init(trainParams, reportMap);
+ return trainer;
}
else {
throw new IllegalArgumentException("Trainer type couldn't be determined!");
}
}
- public static EventTrainer getEventTrainer(TrainingParameters trainParams,
- Map<String, String> reportMap) {
+ /**
+ * Retrieves an {@link EventTrainer} that fits the given parameters.
+ *
+ * @param trainParams The {@link TrainingParameters} to check for the trainer type.
+ * Note: The entry {@link AbstractTrainer#ALGORITHM_PARAM} is used
+ * to determine the type. If the type is not defined, the
+ * {@link GISTrainer#MAXENT_VALUE} will be used.
+ * @param reportMap A {@link Map} that shall be used during initialization of
+ * the {@link EventTrainer}.
+ *
+ * @return A valid {@link EventTrainer} for the configured {@code trainParams}.
+ */
+ public static EventTrainer getEventTrainer(
+ TrainingParameters trainParams, Map<String, String> reportMap) {
// if the trainerType is not defined -- use the GISTrainer.
- String trainerType =
- trainParams.getStringParameter(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
+ String trainerType = trainParams.getStringParameter(
+ AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
+ final EventTrainer trainer;
if (BUILTIN_TRAINERS.containsKey(trainerType)) {
- EventTrainer trainer = TrainerFactory.
- <EventTrainer>createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
- trainer.init(trainParams, reportMap);
- return trainer;
+ trainer = TrainerFactory.createBuiltinTrainer(BUILTIN_TRAINERS.get(trainerType));
} else {
- EventTrainer trainer = ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType);
- trainer.init(trainParams, reportMap);
- return trainer;
+ trainer = ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType);
}
-
+ trainer.init(trainParams, reportMap);
+ return trainer;
}
+ /**
+ * @param trainParams The {@link TrainingParameters} to validate. Must not be {@code null}.
+ * @return {@code true} if the {@code trainParams} could be validated, {@code false} otherwise.
+ */
public static boolean isValid(TrainingParameters trainParams) {
// TODO: Need to validate all parameters correctly ... error prone?!
@@ -202,22 +241,19 @@ public class TrainerFactory {
return true;
}
- private static <T extends Trainer> T createBuiltinTrainer(Class<T> trainerClass) {
- T theTrainer = null;
+ @SuppressWarnings("unchecked")
+ private static <T extends Trainer> T createBuiltinTrainer(Class<? extends Trainer> trainerClass) {
+ Trainer theTrainer = null;
if (trainerClass != null) {
try {
- Constructor<T> contructor = trainerClass.getConstructor();
- theTrainer = contructor.newInstance();
+ Constructor<? extends Trainer> c = trainerClass.getConstructor();
+ theTrainer = c.newInstance();
} catch (Exception e) {
- String msg = "Could not instantiate the "
- + trainerClass.getCanonicalName()
- + ". The initialization throw an exception.";
- System.err.println(msg);
- e.printStackTrace();
+ String msg = "Could not instantiate the " + trainerClass.getCanonicalName()
+ + ". The initialization threw an exception.";
throw new IllegalArgumentException(msg, e);
}
}
-
- return theTrainer;
+ return (T) theTrainer;
}
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
index f168553f..b7246c0a 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
@@ -107,7 +107,8 @@ public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceT
}
}
- public AbstractModel doTrain(SequenceStream events) throws IOException {
+ @Override
+ public AbstractModel doTrain(SequenceStream<Event> events) throws IOException {
int iterations = getIterations();
int cutoff = getCutoff();
diff --git a/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java b/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
index 9417d6f0..a869d029 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
@@ -247,11 +247,11 @@ public class NameFinderME implements TokenNameFinder {
else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, factory.createContextGenerator());
- EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(
+ EventModelSequenceTrainer<NameSample> trainer = TrainerFactory.getEventModelSequenceTrainer(
trainParams, manifestInfoEntries);
nameFinderModel = trainer.train(ss);
} else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
- SequenceTrainer<NameSample> trainer = TrainerFactory.getSequenceModelTrainer(
+ SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
trainParams, manifestInfoEntries);
NameSampleSequenceStream ss =
diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
index f2ecc32f..ca45b0f2 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
@@ -255,7 +255,7 @@ public class POSTaggerME implements POSTagger {
}
else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
- EventModelSequenceTrainer trainer =
+ EventModelSequenceTrainer<POSSample> trainer =
TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries);
posModel = trainer.train(ss);
}
diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
index f3c848c6..502f79a1 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
@@ -24,9 +24,10 @@ import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.util.TrainingParameters;
-public class MockSequenceTrainer implements EventModelSequenceTrainer {
+public class MockSequenceTrainer implements EventModelSequenceTrainer<Sample> {
- public AbstractModel train(SequenceStream<? extends Sample> events) {
+ @Override
+ public AbstractModel train(SequenceStream<Sample> events) {
return null;
}