You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by st...@apache.org on 2012/11/28 15:03:48 UTC
svn commit: r1414692 - in
/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor:
eval/RelationExtractorEvaluation.java pipelines/RelationExtractorTrain.java
Author: stevenbethard
Date: Wed Nov 28 14:03:47 2012
New Revision: 1414692
URL: http://svn.apache.org/viewvc?rev=1414692&view=rev
Log:
Removes irrelevant relations from relation extraction evaluation (e.g. don't include degree_of relations when evaluating location_of)
Modified:
incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java
Modified: incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java?rev=1414692&r1=1414691&r2=1414692&view=diff
==============================================================================
--- incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java (original)
+++ incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java Wed Nov 28 14:03:47 2012
@@ -51,6 +51,7 @@ import org.cleartk.eval.Evaluation_ImplB
import org.cleartk.util.Options_ImplBase;
import org.kohsuke.args4j.Option;
import org.uimafit.component.JCasAnnotator_ImplBase;
+import org.uimafit.descriptor.ConfigurationParameter;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.factory.CollectionReaderFactory;
@@ -103,10 +104,10 @@ public class RelationExtractorEvaluation
public boolean gridSearch = false;
@Option(
- name = "--run-degree-of",
- usage = "if true runs the degree of relation extractor otherwise "
- + "it uses the normal entity mention pair relation extractor")
- public boolean runDegreeOf = false;
+ name = "--relations",
+ usage = "determines which relations to evaluate on (separately)",
+ required = false)
+ public List<String> relations = Arrays.asList("degree_of", "location_of");
@Option(
name = "--test-on-ctakes",
@@ -116,124 +117,127 @@ public class RelationExtractorEvaluation
}
public static final String GOLD_VIEW_NAME = "GoldView";
-
+
public static void main(String[] args) throws Exception {
Options options = new Options();
options.parseOptions(args);
-
+
// error on invalid option combinations
if (options.testDirectory != null && options.gridSearch) {
throw new IllegalArgumentException("grid search can only be run on the train or dev sets");
}
-
+
List<File> trainFiles = Arrays.asList(options.trainDirectory.listFiles());
-
- // define the output directory for models
- File modelsDir = options.runDegreeOf
- ? new File("target/models/degree_of")
- : new File("target/models/em_pair");
-
- // determine class for the classifier annotator
- Class<? extends RelationExtractorAnnotator> annotatorClass = options.runDegreeOf
- ? DegreeOfRelationExtractorAnnotator.class
- : EntityMentionPairRelationExtractorAnnotator.class;
-
- // determine the type of classifier to be trained
- Class<? extends DataWriter<String>> dataWriterClass = LIBSVMStringOutcomeDataWriter.class;
-
- // define the set of possible training parameters
- List<ParameterSettings> possibleParams = options.runDegreeOf
- ? getDegreeOfParameterSpace(options.gridSearch)
- : getEMPairParameterSpace(options.gridSearch);
-
- // run an evaluation for each set of parameters
- Map<ParameterSettings, Double> scoredParams = new HashMap<ParameterSettings, Double>();
- for (ParameterSettings params : possibleParams) {
- System.err.println(params);
- System.err.println();
-
- // define additional configuration parameters for the annotator
- Object[] additionalParameters = new Object[] {
- RelationExtractorAnnotator.PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE,
- params.probabilityOfKeepingANegativeExample,
- EntityMentionPairRelationExtractorAnnotator.PARAM_CLASSIFY_BOTH_DIRECTIONS,
- params.classifyBothDirections,
- RelationExtractorAnnotator.PARAM_PRINT_ERRORS,
- false };
-
- // define arguments to be passed to the classifier
- String[] trainingArguments = new String[] {
- "-t",
- String.valueOf(params.svmKernelIndex),
- "-c",
- String.valueOf(params.svmCost),
- "-g",
- String.valueOf(params.svmGamma) };
-
- // create the evaluation
- RelationExtractorEvaluation evaluation = new RelationExtractorEvaluation(
- modelsDir,
- annotatorClass,
- dataWriterClass,
- additionalParameters,
- trainingArguments,
- options.testOnCTakes);
-
- if (options.devDirectory != null) {
- if (options.testDirectory != null) {
- // train on the training set + dev set and evaluate on the test set
- List<File> allTrainFiles = new ArrayList<File>();
- allTrainFiles.addAll(trainFiles);
- allTrainFiles.addAll(Arrays.asList(options.devDirectory.listFiles()));
- List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
- params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
+
+ for (String relationCategory : options.relations) {
+
+ // define the output directory for models
+ File modelsDir = new File("target/models/" + relationCategory);
+
+ // determine class for the classifier annotator
+ boolean isDegreeOf = relationCategory.equals("degree_of");
+ Class<? extends RelationExtractorAnnotator> annotatorClass = isDegreeOf
+ ? DegreeOfRelationExtractorAnnotator.class
+ : EntityMentionPairRelationExtractorAnnotator.class;
+
+ // determine the type of classifier to be trained
+ Class<? extends DataWriter<String>> dataWriterClass = LIBSVMStringOutcomeDataWriter.class;
+
+ // define the set of possible training parameters
+ List<ParameterSettings> possibleParams = isDegreeOf
+ ? getDegreeOfParameterSpace(options.gridSearch)
+ : getEMPairParameterSpace(options.gridSearch);
+
+ // run an evaluation for each set of parameters
+ Map<ParameterSettings, Double> scoredParams = new HashMap<ParameterSettings, Double>();
+ for (ParameterSettings params : possibleParams) {
+ System.err.println(relationCategory + ": " + params);
+ System.err.println();
+
+ // define additional configuration parameters for the annotator
+ Object[] additionalParameters = new Object[] {
+ RelationExtractorAnnotator.PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE,
+ params.probabilityOfKeepingANegativeExample,
+ EntityMentionPairRelationExtractorAnnotator.PARAM_CLASSIFY_BOTH_DIRECTIONS,
+ params.classifyBothDirections,
+ RelationExtractorAnnotator.PARAM_PRINT_ERRORS,
+ false };
+
+ // define arguments to be passed to the classifier
+ String[] trainingArguments = new String[] {
+ "-t",
+ String.valueOf(params.svmKernelIndex),
+ "-c",
+ String.valueOf(params.svmCost),
+ "-g",
+ String.valueOf(params.svmGamma) };
+
+ // create the evaluation
+ RelationExtractorEvaluation evaluation = new RelationExtractorEvaluation(
+ modelsDir,
+ relationCategory,
+ annotatorClass,
+ dataWriterClass,
+ additionalParameters,
+ trainingArguments,
+ options.testOnCTakes);
+
+ if (options.devDirectory != null) {
+ if (options.testDirectory != null) {
+ // train on the training set + dev set and evaluate on the test set
+ List<File> allTrainFiles = new ArrayList<File>();
+ allTrainFiles.addAll(trainFiles);
+ allTrainFiles.addAll(Arrays.asList(options.devDirectory.listFiles()));
+ List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+ params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
+ } else {
+ // train on the training set and evaluate on the dev set
+ List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
+ params.stats = evaluation.trainAndTest(trainFiles, devFiles);
+ }
} else {
- // train on the training set and evaluate on the dev set
- List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
- params.stats = evaluation.trainAndTest(trainFiles, devFiles);
+ if (options.testDirectory != null) {
+ // train on the training set and evaluate on the test set
+ List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+ params.stats = evaluation.trainAndTest(trainFiles, testFiles);
+ } else {
+ // run n-fold cross-validation on the training set
+ List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
+ params.stats = AnnotationStatistics.addAll(foldStats);
+ }
}
- } else {
- if (options.testDirectory != null) {
- // train on the training set and evaluate on the test set
- List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
- params.stats = evaluation.trainAndTest(trainFiles, testFiles);
- } else {
- // run n-fold cross-validation on the training set
- List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
- params.stats = AnnotationStatistics.addAll(foldStats);
+ scoredParams.put(params, params.stats.f1());
+ }
+
+ // print parameters sorted by F1
+ List<ParameterSettings> list = new ArrayList<ParameterSettings>(scoredParams.keySet());
+ Function<ParameterSettings, Double> getCount = Functions.forMap(scoredParams);
+ Collections.sort(list, Ordering.natural().onResultOf(getCount));
+
+ // print performance of each set of parameters
+ if (list.size() > 1) {
+ System.err.println(relationCategory + ": summary:");
+ for (ParameterSettings params : list) {
+ System.err.printf(
+ "F1=%.3f P=%.3f R=%.3f %s\n",
+ params.stats.f1(),
+ params.stats.precision(),
+ params.stats.recall(),
+ params);
}
+ System.err.println();
}
- scoredParams.put(params, params.stats.f1());
- }
- // print parameters sorted by F1
- List<ParameterSettings> list = new ArrayList<ParameterSettings>(scoredParams.keySet());
- Function<ParameterSettings, Double> getCount = Functions.forMap(scoredParams);
- Collections.sort(list, Ordering.natural().onResultOf(getCount));
-
- // print performance of each set of parameters
- if (list.size() > 1) {
- System.err.println("Summary:");
- for (ParameterSettings params : list) {
- System.err.printf(
- "F1=%.3f P=%.3f R=%.3f %s\n",
- params.stats.f1(),
- params.stats.precision(),
- params.stats.recall(),
- params);
- }
- System.err.println();
- }
-
- // print overall best model
- if (!list.isEmpty()) {
- ParameterSettings lastParams = list.get(list.size() - 1);
- System.err.println("Best model:");
- System.err.print(lastParams.stats);
- System.err.println(lastParams);
- System.err.println(lastParams.stats.confusions());
- System.err.println();
- System.err.println(lastParams.stats.confusions().toHTML());
+ // print overall best model
+ if (!list.isEmpty()) {
+ ParameterSettings lastParams = list.get(list.size() - 1);
+ System.err.println(relationCategory + ": best model:");
+ System.err.print(lastParams.stats);
+ System.err.println(lastParams);
+ System.err.println(lastParams.stats.confusions());
+ System.err.println();
+ System.err.println(lastParams.stats.confusions().toHTML());
+ }
}
}
@@ -255,18 +259,22 @@ public class RelationExtractorEvaluation
*/
public RelationExtractorEvaluation(
File baseDirectory,
+ String relationCategory,
Class<? extends RelationExtractorAnnotator> classifierAnnotatorClass,
Class<? extends DataWriter<String>> dataWriterClass,
Object[] additionalParameters,
String[] trainingArguments,
boolean testOnCTakes) {
super(baseDirectory);
+ this.relationCategory = relationCategory;
this.classifierAnnotatorClass = classifierAnnotatorClass;
this.dataWriterClass = dataWriterClass;
this.additionalParameters = additionalParameters;
this.trainingArguments = trainingArguments;
this.testOnCTakes = testOnCTakes;
}
+
+ private String relationCategory;
private Class<? extends RelationExtractorAnnotator> classifierAnnotatorClass;
@@ -297,6 +305,12 @@ public class RelationExtractorEvaluation
@Override
public void train(CollectionReader collectionReader, File directory) throws Exception {
AggregateBuilder builder = new AggregateBuilder();
+ // remove all but the relation of interest from the gold annotations
+ builder.add(AnalysisEngineFactory.createPrimitiveDescription(
+ RemoveOtherRelations.class,
+ RemoveOtherRelations.PARAM_RELATION_CATEGORY,
+ this.relationCategory),
+ CAS.NAME_DEFAULT_SOFA, GOLD_VIEW_NAME);
// replace cTAKES entity mentions and modifiers in the system view with the gold annotations
builder.add(AnalysisEngineFactory.createPrimitiveDescription(ReplaceCTakesEntityMentionsAndModifiersWithGold.class));
// add the relation extractor, configured for training mode
@@ -327,6 +341,12 @@ public class RelationExtractorEvaluation
protected AnnotationStatistics<String> test(CollectionReader collectionReader, File directory)
throws Exception {
AggregateBuilder builder = new AggregateBuilder();
+ // remove all but the relation of interest from the gold annotations
+ builder.add(AnalysisEngineFactory.createPrimitiveDescription(
+ RemoveOtherRelations.class,
+ RemoveOtherRelations.PARAM_RELATION_CATEGORY,
+ this.relationCategory),
+ CAS.NAME_DEFAULT_SOFA, GOLD_VIEW_NAME);
if (this.testOnCTakes) {
// add the modifier extractor
File file = new File("desc/analysis_engine/ModifierExtractorAnnotator.xml");
@@ -386,7 +406,7 @@ public class RelationExtractorEvaluation
getOutcome);
}
- System.err.println(directory.getName() + ":");
+ System.err.printf("%s: %s:\n", this.relationCategory, directory.getName());
System.err.print(stats);
System.err.println(stats.confusions());
System.err.println();
@@ -662,4 +682,23 @@ public class RelationExtractorEvaluation
return a == null ? null : String.format("\"%s\"(type=%d)", a.getCoveredText(), a.getTypeID());
}
}
+
+ public static class RemoveOtherRelations extends JCasAnnotator_ImplBase {
+
+ public static final String PARAM_RELATION_CATEGORY = "RelationCategory";
+ @ConfigurationParameter(name = PARAM_RELATION_CATEGORY)
+ private String relationCategory;
+
+
+ @Override
+ public void process(JCas jCas) throws AnalysisEngineProcessException {
+ List<BinaryTextRelation> relations = new ArrayList<BinaryTextRelation>();
+ relations.addAll(JCasUtil.select(jCas, BinaryTextRelation.class));
+ for (BinaryTextRelation relation : relations) {
+ if (!relation.getCategory().equals(this.relationCategory)) {
+ relation.removeFromIndexes();
+ }
+ }
+ }
+ }
}
Modified: incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java
URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java?rev=1414692&r1=1414691&r2=1414692&view=diff
==============================================================================
--- incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java (original)
+++ incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java Wed Nov 28 14:03:47 2012
@@ -95,6 +95,7 @@ public class RelationExtractorTrain {
public static AnalysisEngineDescription trainRelationExtractor(
File modelsDir,
List<File> trainFiles,
+ String relationCategory,
Class<? extends RelationExtractorAnnotator> annotatorClass,
Class<? extends DataWriter<String>> dataWriterClass,
ParameterSettings params) throws Exception {
@@ -119,6 +120,7 @@ public class RelationExtractorTrain {
RelationExtractorEvaluation evaluation = new RelationExtractorEvaluation(
modelsDir,
+ relationCategory,
annotatorClass,
dataWriterClass,
additionalParameters,
@@ -167,9 +169,9 @@ public class RelationExtractorTrain {
// Train and write models
AnalysisEngineDescription modifierExtractorDesc = trainModifierExtractor(modelsDirModExtractor, trainFiles);
writeDesc(options.descDir, "ModifierExtractorAnnotator", modifierExtractorDesc);
- AnalysisEngineDescription degreeOfRelationExtractorDesc = trainRelationExtractor(modelsDirDegreeOf, trainFiles, DegreeOfRelationExtractorAnnotator.class, dataWriterClass, degreeOfParams);
+ AnalysisEngineDescription degreeOfRelationExtractorDesc = trainRelationExtractor(modelsDirDegreeOf, trainFiles, "degree_of", DegreeOfRelationExtractorAnnotator.class, dataWriterClass, degreeOfParams);
writeDesc(options.descDir, "DegreeOfRelationExtractorAnnotator", degreeOfRelationExtractorDesc);
- AnalysisEngineDescription emPairRelationExtractorDesc = trainRelationExtractor(modelsDirEMPair, trainFiles, EntityMentionPairRelationExtractorAnnotator.class, dataWriterClass, emPairParams);
+ AnalysisEngineDescription emPairRelationExtractorDesc = trainRelationExtractor(modelsDirEMPair, trainFiles, "location_of", EntityMentionPairRelationExtractorAnnotator.class, dataWriterClass, emPairParams);
writeDesc(options.descDir, "EntityMentionPairRelationExtractorAnnotator", emPairRelationExtractorDesc);
// create the aggregate description