You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by st...@apache.org on 2012/11/27 13:43:19 UTC
svn commit: r1414158 -
/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
Author: stevenbethard
Date: Tue Nov 27 12:43:18 2012
New Revision: 1414158
URL: http://svn.apache.org/viewvc?rev=1414158&view=rev
Log:
Fixes handling of different possible train/dev/test combinations in relation extraction evaluation
Modified:
incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
Modified: incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java?rev=1414158&r1=1414157&r2=1414158&view=diff
==============================================================================
--- incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java (original)
+++ incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java Tue Nov 27 12:43:18 2012
@@ -121,6 +121,11 @@ public class RelationExtractorEvaluation
Options options = new Options();
options.parseOptions(args);
+ // error on invalid option combinations
+ if (options.testDirectory != null && options.gridSearch) {
+ throw new IllegalArgumentException("grid search can only be run on the train or dev sets");
+ }
+
List<File> trainFiles = Arrays.asList(options.trainDirectory.listFiles());
// define the output directory for models
@@ -174,40 +179,31 @@ public class RelationExtractorEvaluation
trainingArguments,
options.testOnCTakes);
- if(options.testDirectory == null && options.devDirectory == null) {
- // run n-fold cross-validation on the training set
-
- List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
- params.stats = AnnotationStatistics.addAll(foldStats);
-
- System.err.println("overall:");
- System.err.print(params.stats);
- System.err.println(params.stats.confusions());
- System.err.println();
-
- // store these parameter settings and the respective performance
- scoredParams.put(params, params.stats.f1());
- } else if(options.devDirectory != null && options.gridSearch) {
- // tune parameters on the development set
-
- List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
- params.stats = evaluation.trainAndTest(trainFiles, devFiles);
-
- System.err.println("overall:");
- System.err.print(params.stats);
- System.err.println(params.stats.confusions());
- System.err.println();
-
- // store these parameter settings and the respective performance
- scoredParams.put(params, params.stats.f1());
+ if (options.devDirectory != null) {
+ if (options.testDirectory != null) {
+ // train on the training set + dev set and evaluate on the test set
+ List<File> allTrainFiles = new ArrayList<File>();
+ allTrainFiles.addAll(trainFiles);
+ allTrainFiles.addAll(Arrays.asList(options.devDirectory.listFiles()));
+ List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+ params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
+ } else {
+ // train on the training set and evaluate on the dev set
+ List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
+ params.stats = evaluation.trainAndTest(trainFiles, devFiles);
+ }
} else {
- // train on the entire training set and evaluate on the test set
-
- List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
- AnnotationStatistics<String> stats = evaluation.trainAndTest(trainFiles, testFiles);
- System.err.print(stats);
- return;
+ if (options.testDirectory != null) {
+ // train on the training set and evaluate on the test set
+ List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+ params.stats = evaluation.trainAndTest(trainFiles, testFiles);
+ } else {
+ // run n-fold cross-validation on the training set
+ List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
+ params.stats = AnnotationStatistics.addAll(foldStats);
+ }
}
+ scoredParams.put(params, params.stats.f1());
}
// print parameters sorted by F1
@@ -216,21 +212,22 @@ public class RelationExtractorEvaluation
Collections.sort(list, Ordering.natural().onResultOf(getCount));
// print performance of each set of parameters
- System.err.println("Summary:");
- ParameterSettings lastParams = null;
- for (ParameterSettings params : list) {
- System.err.printf(
- "F1=%.3f P=%.3f R=%.3f %s\n",
- params.stats.f1(),
- params.stats.precision(),
- params.stats.recall(),
- params);
- lastParams = params;
+ if (list.size() > 1) {
+ System.err.println("Summary:");
+ for (ParameterSettings params : list) {
+ System.err.printf(
+ "F1=%.3f P=%.3f R=%.3f %s\n",
+ params.stats.f1(),
+ params.stats.precision(),
+ params.stats.recall(),
+ params);
+ }
+ System.err.println();
}
// print overall best model
- if (lastParams != null) {
- System.err.println();
+ if (!list.isEmpty()) {
+ ParameterSettings lastParams = list.get(list.size() - 1);
System.err.println("Best model:");
System.err.print(lastParams.stats);
System.err.println(lastParams);