You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by cl...@apache.org on 2013/07/18 19:27:52 UTC
svn commit: r1504557 - in
/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal:
ae/TimeAnnotator.java eval/EvaluationOfTimeSpans.java
Author: clin
Date: Thu Jul 18 17:27:52 2013
New Revision: 1504557
URL: http://svn.apache.org/r1504557
Log:
Added Feature Selection for TimeAnnotator. The Feature Selection can be turned off by setting --featureSelectionThreshold 0.
Other positive value of --featureSelectionThreshold will enable the feature selection process.
Modified:
ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java
ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java
Modified: ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java?rev=1504557&r1=1504556&r2=1504557&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java (original)
+++ ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java Thu Jul 18 17:27:52 2013
@@ -19,11 +19,15 @@
package org.apache.ctakes.temporal.ae;
import java.io.File;
+import java.io.IOException;
+import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import org.apache.ctakes.temporal.ae.feature.ParseSpanFeatureExtractor;
import org.apache.ctakes.temporal.ae.feature.TimeWordTypeExtractor;
+import org.apache.ctakes.temporal.ae.feature.selection.Chi2FeatureSelection;
+import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
import org.apache.ctakes.typesystem.type.syntax.BaseToken;
import org.apache.ctakes.typesystem.type.textsem.TimeMention;
import org.apache.ctakes.typesystem.type.textspan.Segment;
@@ -35,7 +39,7 @@ import org.apache.uima.cas.CASException;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.classifier.CleartkAnnotator;
-import org.cleartk.classifier.DataWriter;
+//import org.cleartk.classifier.DataWriter;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.chunking.BIOChunking;
@@ -51,151 +55,203 @@ import org.cleartk.classifier.feature.ex
import org.cleartk.classifier.jar.DefaultDataWriterFactory;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.GenericJarClassifierFactory;
+import org.uimafit.descriptor.ConfigurationParameter;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.util.JCasUtil;
public class TimeAnnotator extends TemporalEntityAnnotator_ImplBase {
- public static final String TIMEX_VIEW = "TimexView";
-
- public static AnalysisEngineDescription createDataWriterDescription(
- Class<? extends DataWriter<String>> dataWriterClass,
- File outputDirectory) throws ResourceInitializationException {
- return AnalysisEngineFactory.createPrimitiveDescription(
- TimeAnnotator.class,
- CleartkAnnotator.PARAM_IS_TRAINING,
- true,
- DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
- dataWriterClass,
- DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
- outputDirectory);
- }
-
- public static AnalysisEngineDescription createAnnotatorDescription(File modelDirectory)
- throws ResourceInitializationException {
- return AnalysisEngineFactory.createPrimitiveDescription(
- TimeAnnotator.class,
- CleartkAnnotator.PARAM_IS_TRAINING,
- false,
- GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
- new File(modelDirectory, "model.jar"));
- }
-
- protected List<SimpleFeatureExtractor> tokenFeatureExtractors;
-
- protected List<CleartkExtractor> contextFeatureExtractors;
-
-// protected List<SimpleFeatureExtractor> parseFeatureExtractors;
- protected ParseSpanFeatureExtractor parseExtractor;
-
- private BIOChunking<BaseToken, TimeMention> timeChunking;
-
- @Override
- public void initialize(UimaContext context) throws ResourceInitializationException {
- super.initialize(context);
-
- // define chunking
- this.timeChunking = new BIOChunking<BaseToken, TimeMention>(BaseToken.class, TimeMention.class);
-
- CombinedExtractor allExtractors = new CombinedExtractor(
- new CoveredTextExtractor(),
- new CharacterCategoryPatternExtractor(PatternType.REPEATS_MERGED),
- new CharacterCategoryPatternExtractor(PatternType.ONE_PER_CHAR),
- new TypePathExtractor(BaseToken.class, "partOfSpeech"),
- new TimeWordTypeExtractor());
-
-// CombinedExtractor parseExtractors = new CombinedExtractor(
-// new ParseSpanFeatureExtractor()
-// );
- this.tokenFeatureExtractors = new ArrayList<SimpleFeatureExtractor>();
- this.tokenFeatureExtractors.add(allExtractors);
-
- this.contextFeatureExtractors = new ArrayList<CleartkExtractor>();
- this.contextFeatureExtractors.add(new CleartkExtractor(
- BaseToken.class,
- allExtractors,
- new Preceding(3),
- new Following(3)));
-// this.parseFeatureExtractors = new ArrayList<ParseSpanFeatureExtractor>();
-// this.parseFeatureExtractors.add(new ParseSpanFeatureExtractor());
- parseExtractor = new ParseSpanFeatureExtractor();
- }
-
- @Override
- public void process(JCas jCas, Segment segment) throws AnalysisEngineProcessException {
- // classify tokens within each sentence
- for (Sentence sentence : JCasUtil.selectCovered(jCas, Sentence.class, segment)) {
- List<BaseToken> tokens = JCasUtil.selectCovered(jCas, BaseToken.class, sentence);
-
- // during training, the list of all outcomes for the tokens
- List<String> outcomes;
- if (this.isTraining()) {
- List<TimeMention> times = JCasUtil.selectCovered(jCas, TimeMention.class, sentence);
- outcomes = this.timeChunking.createOutcomes(jCas, tokens, times);
- }
- // during prediction, the list of outcomes predicted so far
- else {
- outcomes = new ArrayList<String>();
- }
-
- // extract features for all tokens
- int tokenIndex = -1;
- for (BaseToken token : tokens) {
- ++tokenIndex;
-
- List<Feature> features = new ArrayList<Feature>();
- // features from token attributes
- for (SimpleFeatureExtractor extractor : this.tokenFeatureExtractors) {
- features.addAll(extractor.extract(jCas, token));
- }
- // features from surrounding tokens
- for (CleartkExtractor extractor : this.contextFeatureExtractors) {
- features.addAll(extractor.extractWithin(jCas, token, sentence));
- }
- // features from previous classifications
- int nPreviousClassifications = 2;
- for (int i = nPreviousClassifications; i > 0; --i) {
- int index = tokenIndex - i;
- String previousOutcome = index < 0 ? "O" : outcomes.get(index);
- features.add(new Feature("PreviousOutcome_" + i, previousOutcome));
- }
- //add segment ID as a features:
- features.add(new Feature("SegmentID", segment.getId()));
-
- // features from dominating parse tree
-// for(SimpleFeatureExtractor extractor : this.parseFeatureExtractors){
- BaseToken startToken = token;
- for(int i = tokenIndex-1; i >= 0; --i){
- String outcome = outcomes.get(i);
- if(outcome.equals("O")){
- break;
- }
- startToken = tokens.get(i);
- }
- features.addAll(parseExtractor.extract(jCas, startToken.getBegin(), token.getEnd()));
-// }
- // if training, write to data file
- if (this.isTraining()) {
- String outcome = outcomes.get(tokenIndex);
- this.dataWriter.write(new Instance<String>(outcome, features));
- }
-
- // if predicting, add prediction to outcomes
- else {
- outcomes.add(this.classifier.classify(features));
- }
- }
-
- // during prediction, convert chunk labels to times and add them to the CAS
- if (!this.isTraining()) {
- JCas timexCas;
- try {
- timexCas = jCas.getView(TIMEX_VIEW);
- } catch (CASException e) {
- throw new AnalysisEngineProcessException(e);
- }
- this.timeChunking.createChunks(timexCas, tokens, outcomes);
- }
- }
- }
+ public static final String PARAM_FEATURE_SELECTION_THRESHOLD = "WhetherToDoFeatureSelection";
+
+ @ConfigurationParameter(
+ name = PARAM_FEATURE_SELECTION_THRESHOLD,
+ mandatory = false,
+ description = "the Chi-squared threshold at which features should be removed")
+ protected Float featureSelectionThreshold = 0f;
+
+ public static final String PARAM_FEATURE_SELECTION_URI = "FeatureSelectionURI";
+
+ @ConfigurationParameter(
+ mandatory = false,
+ name = PARAM_FEATURE_SELECTION_URI,
+ description = "provides a URI where the feature selection data will be written")
+ protected URI featureSelectionURI;
+
+ public static final String TIMEX_VIEW = "TimexView";
+
+ public static AnalysisEngineDescription createDataWriterDescription(
+ Class<?> dataWriterClass,
+ File outputDirectory,
+ float featureSelect) throws ResourceInitializationException {
+ return AnalysisEngineFactory.createPrimitiveDescription(
+ TimeAnnotator.class,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ true,
+ DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+ dataWriterClass,
+ DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+ outputDirectory,
+ TimeAnnotator.PARAM_FEATURE_SELECTION_THRESHOLD,
+ featureSelect);
+ }
+
+ public static AnalysisEngineDescription createAnnotatorDescription(File modelDirectory)
+ throws ResourceInitializationException {
+ return AnalysisEngineFactory.createPrimitiveDescription(
+ TimeAnnotator.class,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ false,
+ GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+ new File(modelDirectory, "model.jar"),
+ TimeAnnotator.PARAM_FEATURE_SELECTION_URI,
+ TimeAnnotator.createFeatureSelectionURI(modelDirectory));
+ }
+
+ protected List<SimpleFeatureExtractor> tokenFeatureExtractors;
+
+ protected List<CleartkExtractor> contextFeatureExtractors;
+
+ // protected List<SimpleFeatureExtractor> parseFeatureExtractors;
+ protected ParseSpanFeatureExtractor parseExtractor;
+
+ private BIOChunking<BaseToken, TimeMention> timeChunking;
+
+ private FeatureSelection<String> featureSelection;
+
+ private static final String FEATURE_SELECTION_NAME = "SelectNeighborFeatures";
+
+ public static FeatureSelection<String> createFeatureSelection(double threshold) {
+ return new Chi2FeatureSelection<String>(TimeAnnotator.FEATURE_SELECTION_NAME, threshold);
+ }
+
+ public static URI createFeatureSelectionURI(File outputDirectoryName) {
+ return new File(outputDirectoryName, FEATURE_SELECTION_NAME + "_Chi2_extractor.dat").toURI();
+ }
+
+ @Override
+ public void initialize(UimaContext context) throws ResourceInitializationException {
+ super.initialize(context);
+
+ // define chunking
+ this.timeChunking = new BIOChunking<BaseToken, TimeMention>(BaseToken.class, TimeMention.class);
+
+ CombinedExtractor allExtractors = new CombinedExtractor(
+ new CoveredTextExtractor(),
+ new CharacterCategoryPatternExtractor(PatternType.REPEATS_MERGED),
+ new CharacterCategoryPatternExtractor(PatternType.ONE_PER_CHAR),
+ new TypePathExtractor(BaseToken.class, "partOfSpeech"),
+ new TimeWordTypeExtractor());
+
+ // CombinedExtractor parseExtractors = new CombinedExtractor(
+ // new ParseSpanFeatureExtractor()
+ // );
+ this.tokenFeatureExtractors = new ArrayList<SimpleFeatureExtractor>();
+ this.tokenFeatureExtractors.add(allExtractors);
+
+ this.contextFeatureExtractors = new ArrayList<CleartkExtractor>();
+ this.contextFeatureExtractors.add(new CleartkExtractor(
+ BaseToken.class,
+ allExtractors,
+ new Preceding(3),
+ new Following(3)));
+ // this.parseFeatureExtractors = new ArrayList<ParseSpanFeatureExtractor>();
+ // this.parseFeatureExtractors.add(new ParseSpanFeatureExtractor());
+ parseExtractor = new ParseSpanFeatureExtractor();
+
+ //initialize feature selection
+ if (featureSelectionThreshold == 0) {
+ this.featureSelection = null;
+ } else {
+ this.featureSelection = TimeAnnotator.createFeatureSelection(this.featureSelectionThreshold);
+
+ if (this.featureSelectionURI != null) {
+ try {
+ this.featureSelection.load(this.featureSelectionURI);
+ } catch (IOException e) {
+ throw new ResourceInitializationException(e);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void process(JCas jCas, Segment segment) throws AnalysisEngineProcessException {
+ // classify tokens within each sentence
+ for (Sentence sentence : JCasUtil.selectCovered(jCas, Sentence.class, segment)) {
+ List<BaseToken> tokens = JCasUtil.selectCovered(jCas, BaseToken.class, sentence);
+
+ // during training, the list of all outcomes for the tokens
+ List<String> outcomes;
+ if (this.isTraining()) {
+ List<TimeMention> times = JCasUtil.selectCovered(jCas, TimeMention.class, sentence);
+ outcomes = this.timeChunking.createOutcomes(jCas, tokens, times);
+ }
+ // during prediction, the list of outcomes predicted so far
+ else {
+ outcomes = new ArrayList<String>();
+ }
+
+ // extract features for all tokens
+ int tokenIndex = -1;
+ for (BaseToken token : tokens) {
+ ++tokenIndex;
+
+ List<Feature> features = new ArrayList<Feature>();
+ // features from token attributes
+ for (SimpleFeatureExtractor extractor : this.tokenFeatureExtractors) {
+ features.addAll(extractor.extract(jCas, token));
+ }
+ // features from surrounding tokens
+ for (CleartkExtractor extractor : this.contextFeatureExtractors) {
+ features.addAll(extractor.extractWithin(jCas, token, sentence));
+ }
+ // features from previous classifications
+ int nPreviousClassifications = 2;
+ for (int i = nPreviousClassifications; i > 0; --i) {
+ int index = tokenIndex - i;
+ String previousOutcome = index < 0 ? "O" : outcomes.get(index);
+ features.add(new Feature("PreviousOutcome_" + i, previousOutcome));
+ }
+ //add segment ID as a features:
+ features.add(new Feature("SegmentID", segment.getId()));
+
+ // features from dominating parse tree
+ // for(SimpleFeatureExtractor extractor : this.parseFeatureExtractors){
+ BaseToken startToken = token;
+ for(int i = tokenIndex-1; i >= 0; --i){
+ String outcome = outcomes.get(i);
+ if(outcome.equals("O")){
+ break;
+ }
+ startToken = tokens.get(i);
+ }
+ features.addAll(parseExtractor.extract(jCas, startToken.getBegin(), token.getEnd()));
+ // }
+
+ // apply feature selection, if necessary
+ if (this.featureSelection != null) {
+ features = this.featureSelection.transform(features);
+ }
+
+ // if training, write to data file
+ if (this.isTraining()) {
+ String outcome = outcomes.get(tokenIndex);
+ this.dataWriter.write(new Instance<String>(outcome, features));
+ }else {// if predicting, add prediction to outcomes
+ outcomes.add(this.classifier.classify(features));
+ }
+ }
+
+ // during prediction, convert chunk labels to times and add them to the CAS
+ if (!this.isTraining()) {
+ JCas timexCas;
+ try {
+ timexCas = jCas.getView(TIMEX_VIEW);
+ } catch (CASException e) {
+ throw new AnalysisEngineProcessException(e);
+ }
+ this.timeChunking.createChunks(timexCas, tokens, outcomes);
+ }
+ }
+ }
}
Modified: ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java?rev=1504557&r1=1504556&r2=1504557&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java (original)
+++ ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java Thu Jul 18 17:27:52 2013
@@ -29,6 +29,7 @@ import org.apache.ctakes.temporal.ae.CRF
import org.apache.ctakes.temporal.ae.ConstituencyBasedTimeAnnotator;
import org.apache.ctakes.temporal.ae.MetaTimeAnnotator;
import org.apache.ctakes.temporal.ae.TimeAnnotator;
+import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
import org.apache.ctakes.typesystem.type.textsem.TimeMention;
import org.apache.ctakes.typesystem.type.textspan.Segment;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
@@ -37,7 +38,11 @@ import org.apache.uima.jcas.tcas.Annotat
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.classifier.CleartkAnnotator;
import org.cleartk.classifier.CleartkSequenceAnnotator;
+import org.cleartk.classifier.Instance;
+//import org.cleartk.classifier.DataWriter;
import org.cleartk.classifier.crfsuite.CRFSuiteStringOutcomeDataWriter;
+import org.cleartk.classifier.feature.transform.InstanceDataWriter;
+import org.cleartk.classifier.feature.transform.InstanceStream;
import org.cleartk.classifier.jar.DefaultDataWriterFactory;
import org.cleartk.classifier.jar.DefaultSequenceDataWriterFactory;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
@@ -53,144 +58,183 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
public class EvaluationOfTimeSpans extends EvaluationOfAnnotationSpans_ImplBase {
- public static void main(String[] args) throws Exception {
- Options options = CliFactory.parseArguments(Options.class, args);
- List<Integer> patientSets = options.getPatients().getList();
- List<Integer> trainItems = THYMEData.getTrainPatientSets(patientSets);
- List<Integer> devItems = THYMEData.getDevPatientSets(patientSets);
-
- // specify the annotator classes to use
- List<Class<? extends JCasAnnotator_ImplBase>> annotatorClasses = Lists.newArrayList();
- annotatorClasses.add(BackwardsTimeAnnotator.class);
- annotatorClasses.add(TimeAnnotator.class);
- annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
- annotatorClasses.add(CRFTimeAnnotator.class);
- annotatorClasses.add(MetaTimeAnnotator.class);
- Map<Class<? extends JCasAnnotator_ImplBase>, String[]> annotatorTrainingArguments = Maps.newHashMap();
- annotatorTrainingArguments.put(BackwardsTimeAnnotator.class, new String[]{"-c", "0.1"});
- annotatorTrainingArguments.put(TimeAnnotator.class, new String[]{"-c", "0.1"});
- annotatorTrainingArguments.put(ConstituencyBasedTimeAnnotator.class, new String[]{"-c", "0.1"});
- annotatorTrainingArguments.put(CRFTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
- annotatorTrainingArguments.put(MetaTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
-
- // run one evaluation per annotator class
- final Map<Class<?>, AnnotationStatistics<?>> annotatorStats = Maps.newHashMap();
- for (Class<? extends JCasAnnotator_ImplBase> annotatorClass : annotatorClasses) {
- EvaluationOfTimeSpans evaluation = new EvaluationOfTimeSpans(
- new File("target/eval/time-spans"),
- options.getRawTextDirectory(),
- options.getXMLDirectory(),
- options.getXMLFormat(),
- options.getXMIDirectory(),
- options.getTreebankDirectory(),
- annotatorClass,
- options.getPrintOverlappingSpans(),
- annotatorTrainingArguments.get(annotatorClass));
- evaluation.prepareXMIsFor(patientSets);
- String name = String.format("%s.errors", annotatorClass.getSimpleName());
- evaluation.setLogging(Level.FINE, new File("target/eval", name));
- AnnotationStatistics<String> stats = evaluation.trainAndTest(trainItems, devItems);
- annotatorStats.put(annotatorClass, stats);
- }
-
- // allow ordering of models by F1
- Ordering<Class<? extends JCasAnnotator_ImplBase>> byF1 = Ordering.natural().onResultOf(
- new Function<Class<? extends JCasAnnotator_ImplBase>, Double>() {
- @Override
- public Double apply(
- Class<? extends JCasAnnotator_ImplBase> annotatorClass) {
- return annotatorStats.get(annotatorClass).f1();
- }
- });
-
- // print out models, ordered by F1
- for (Class<?> annotatorClass : byF1.sortedCopy(annotatorClasses)) {
- System.err.printf("===== %s =====\n", annotatorClass.getSimpleName());
- System.err.println(annotatorStats.get(annotatorClass));
- }
- }
-
- private Class<? extends JCasAnnotator_ImplBase> annotatorClass;
-
- private String[] trainingArguments;
-
- public EvaluationOfTimeSpans(
- File baseDirectory,
- File rawTextDirectory,
- File xmlDirectory,
- XMLFormat xmlFormat,
- File xmiDirectory,
- File treebankDirectory,
- Class<? extends JCasAnnotator_ImplBase> annotatorClass,
- boolean printOverlapping,
- String[] trainingArguments) {
- super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory, treebankDirectory, TimeMention.class);
- this.annotatorClass = annotatorClass;
- this.trainingArguments = trainingArguments;
- this.printOverlapping = printOverlapping;
- }
-
- @Override
- protected AnalysisEngineDescription getDataWriterDescription(File directory)
- throws ResourceInitializationException {
- if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
- return MetaTimeAnnotator.getDataWriterDescription(CRFSuiteStringOutcomeDataWriter.class, directory);
- }else if(CleartkAnnotator.class.isAssignableFrom(this.annotatorClass)){
- return AnalysisEngineFactory.createPrimitiveDescription(
- this.annotatorClass,
- CleartkAnnotator.PARAM_IS_TRAINING,
- true,
- DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
- LIBLINEARStringOutcomeDataWriter.class,
- DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
- this.getModelDirectory(directory));
- }else if(CleartkSequenceAnnotator.class.isAssignableFrom(this.annotatorClass)){
- return AnalysisEngineFactory.createPrimitiveDescription(
- this.annotatorClass,
- CleartkSequenceAnnotator.PARAM_IS_TRAINING,
- true,
- DefaultSequenceDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
- CRFSuiteStringOutcomeDataWriter.class,
- DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
- this.getModelDirectory(directory));
- }else{
- throw new ResourceInitializationException("Annotator class was not recognized as an acceptable class!", new Object[]{});
- }
- }
-
- @Override
- protected void trainAndPackage(File directory) throws Exception {
- JarClassifierBuilder.trainAndPackage(this.getModelDirectory(directory), this.trainingArguments);
- }
-
- @Override
- protected AnalysisEngineDescription getAnnotatorDescription(File directory)
- throws ResourceInitializationException {
- if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
- return MetaTimeAnnotator.getAnnotatorDescription(directory);
- }
- return AnalysisEngineFactory.createPrimitiveDescription(
- this.annotatorClass,
- CleartkAnnotator.PARAM_IS_TRAINING,
- false,
- GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
- new File(this.getModelDirectory(directory), "model.jar"));
- }
-
- @Override
- protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas, Segment segment) {
- return selectExact(jCas, TimeMention.class, segment);
- }
-
- @Override
- protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas, Segment segment) {
- return selectExact(jCas, TimeMention.class, segment);
- }
-
- private File getModelDirectory(File directory) {
- return new File(directory, this.annotatorClass.getSimpleName());
- }
+ static interface Options extends Evaluation_ImplBase.Options {
+
+ @Option(longName = "featureSelectionThreshold", defaultValue = "0")
+ public float getFeatureSelectionThreshold();
+ }
+
+ public static void main(String[] args) throws Exception {
+ Options options = CliFactory.parseArguments(Options.class, args);
+ List<Integer> patientSets = options.getPatients().getList();
+ List<Integer> trainItems = THYMEData.getTrainPatientSets(patientSets);
+ List<Integer> devItems = THYMEData.getDevPatientSets(patientSets);
+
+ // specify the annotator classes to use
+ List<Class<? extends JCasAnnotator_ImplBase>> annotatorClasses = Lists.newArrayList();
+ annotatorClasses.add(BackwardsTimeAnnotator.class);
+ annotatorClasses.add(TimeAnnotator.class);
+ annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
+ annotatorClasses.add(CRFTimeAnnotator.class);
+ annotatorClasses.add(MetaTimeAnnotator.class);
+ Map<Class<? extends JCasAnnotator_ImplBase>, String[]> annotatorTrainingArguments = Maps.newHashMap();
+ annotatorTrainingArguments.put(BackwardsTimeAnnotator.class, new String[]{"-c", "0.1"});
+ annotatorTrainingArguments.put(TimeAnnotator.class, new String[]{"-c", "0.1"});
+ annotatorTrainingArguments.put(ConstituencyBasedTimeAnnotator.class, new String[]{"-c", "0.1"});
+ annotatorTrainingArguments.put(CRFTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
+ annotatorTrainingArguments.put(MetaTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
+
+ // run one evaluation per annotator class
+ final Map<Class<?>, AnnotationStatistics<?>> annotatorStats = Maps.newHashMap();
+ for (Class<? extends JCasAnnotator_ImplBase> annotatorClass : annotatorClasses) {
+ EvaluationOfTimeSpans evaluation = new EvaluationOfTimeSpans(
+ new File("target/eval/time-spans"),
+ options.getRawTextDirectory(),
+ options.getXMLDirectory(),
+ options.getXMLFormat(),
+ options.getXMIDirectory(),
+ options.getTreebankDirectory(),
+ options.getFeatureSelectionThreshold(),
+ annotatorClass,
+ options.getPrintOverlappingSpans(),
+ annotatorTrainingArguments.get(annotatorClass));
+ evaluation.prepareXMIsFor(patientSets);
+ String name = String.format("%s.errors", annotatorClass.getSimpleName());
+ evaluation.setLogging(Level.FINE, new File("target/eval", name));
+ AnnotationStatistics<String> stats = evaluation.trainAndTest(trainItems, devItems);
+ annotatorStats.put(annotatorClass, stats);
+ }
+
+ // allow ordering of models by F1
+ Ordering<Class<? extends JCasAnnotator_ImplBase>> byF1 = Ordering.natural().onResultOf(
+ new Function<Class<? extends JCasAnnotator_ImplBase>, Double>() {
+ @Override
+ public Double apply(
+ Class<? extends JCasAnnotator_ImplBase> annotatorClass) {
+ return annotatorStats.get(annotatorClass).f1();
+ }
+ });
+
+ // print out models, ordered by F1
+ for (Class<?> annotatorClass : byF1.sortedCopy(annotatorClasses)) {
+ System.err.printf("===== %s =====\n", annotatorClass.getSimpleName());
+ System.err.println(annotatorStats.get(annotatorClass));
+ }
+ }
+
+ private Class<? extends JCasAnnotator_ImplBase> annotatorClass;
+
+ private String[] trainingArguments;
+
+ private float featureSelectionThreshold;
+
+ public EvaluationOfTimeSpans(
+ File baseDirectory,
+ File rawTextDirectory,
+ File xmlDirectory,
+ XMLFormat xmlFormat,
+ File xmiDirectory,
+ File treebankDirectory,
+ float featureSelectionThreshold,
+ Class<? extends JCasAnnotator_ImplBase> annotatorClass,
+ boolean printOverlapping,
+ String[] trainingArguments) {
+ super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory, treebankDirectory, TimeMention.class);
+ this.annotatorClass = annotatorClass;
+ this.featureSelectionThreshold = featureSelectionThreshold;
+ this.trainingArguments = trainingArguments;
+ this.printOverlapping = printOverlapping;
+ }
+
+ @Override
+ protected AnalysisEngineDescription getDataWriterDescription(File directory)
+ throws ResourceInitializationException {
+ if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
+ return MetaTimeAnnotator.getDataWriterDescription(CRFSuiteStringOutcomeDataWriter.class, directory);
+ }else if(CleartkAnnotator.class.isAssignableFrom(this.annotatorClass)){
+ //limit feature selection only to TimeAnnotator
+ if("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())){
+ Class<?> dataWriterClass = this.featureSelectionThreshold > 0f
+ ? InstanceDataWriter.class
+ : LIBLINEARStringOutcomeDataWriter.class;
+ return TimeAnnotator.createDataWriterDescription(
+ dataWriterClass,
+ this.getModelDirectory(directory),
+ this.featureSelectionThreshold);
+ }
+ return AnalysisEngineFactory.createPrimitiveDescription(
+ this.annotatorClass,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ true,
+ DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+ LIBLINEARStringOutcomeDataWriter.class,
+ DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+ this.getModelDirectory(directory));
+
+ }else if(CleartkSequenceAnnotator.class.isAssignableFrom(this.annotatorClass)){
+ return AnalysisEngineFactory.createPrimitiveDescription(
+ this.annotatorClass,
+ CleartkSequenceAnnotator.PARAM_IS_TRAINING,
+ true,
+ DefaultSequenceDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+ CRFSuiteStringOutcomeDataWriter.class,
+ DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+ this.getModelDirectory(directory));
+ }else{
+ throw new ResourceInitializationException("Annotator class was not recognized as an acceptable class!", new Object[]{});
+ }
+ }
+
+ @Override
+ protected void trainAndPackage(File directory) throws Exception {
+ if (this.featureSelectionThreshold > 0 && "org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName()) ) {
+ // Extracting features and writing instances
+ Iterable<Instance<String>> instances = InstanceStream.loadFromDirectory(this.getModelDirectory(directory));
+ // Collect MinMax stats for feature normalization
+ FeatureSelection<String> featureSelection = TimeAnnotator.createFeatureSelection(this.featureSelectionThreshold);
+ featureSelection.train(instances);
+ featureSelection.save(TimeAnnotator.createFeatureSelectionURI(this.getModelDirectory(directory)));
+ // now write in the libsvm format
+ LIBLINEARStringOutcomeDataWriter dataWriter = new LIBLINEARStringOutcomeDataWriter(this.getModelDirectory(directory));
+ for (Instance<String> instance : instances) {
+ dataWriter.write(featureSelection.transform(instance));
+ }
+ dataWriter.finish();
+ }
+ JarClassifierBuilder.trainAndPackage(this.getModelDirectory(directory), this.trainingArguments);
+ }
+
+ @Override
+ protected AnalysisEngineDescription getAnnotatorDescription(File directory)
+ throws ResourceInitializationException {
+ if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
+ return MetaTimeAnnotator.getAnnotatorDescription(directory);
+ }else if("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName() )){
+ return TimeAnnotator.createAnnotatorDescription(this.getModelDirectory(directory));
+ }
+ return AnalysisEngineFactory.createPrimitiveDescription(
+ this.annotatorClass,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ false,
+ GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+ new File(this.getModelDirectory(directory), "model.jar"));
+ }
+
+ @Override
+ protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas, Segment segment) {
+ return selectExact(jCas, TimeMention.class, segment);
+ }
+
+ @Override
+ protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas, Segment segment) {
+ return selectExact(jCas, TimeMention.class, segment);
+ }
+
+ private File getModelDirectory(File directory) {
+ return new File(directory, this.annotatorClass.getSimpleName());
+ }
}