You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by st...@apache.org on 2012/11/27 13:43:19 UTC

svn commit: r1414158 - /incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java

Author: stevenbethard
Date: Tue Nov 27 12:43:18 2012
New Revision: 1414158

URL: http://svn.apache.org/viewvc?rev=1414158&view=rev
Log:
Fixes handling of different possible train/dev/test combinations in relation extraction evaluation

Modified:
    incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java

Modified: incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java?rev=1414158&r1=1414157&r2=1414158&view=diff
==============================================================================
--- incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java (original)
+++ incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java Tue Nov 27 12:43:18 2012
@@ -121,6 +121,11 @@ public class RelationExtractorEvaluation
     Options options = new Options();
     options.parseOptions(args);
     
+    // error on invalid option combinations
+    if (options.testDirectory != null && options.gridSearch) {
+      throw new IllegalArgumentException("grid search can only be run on the train or dev sets");
+    }
+    
     List<File> trainFiles = Arrays.asList(options.trainDirectory.listFiles());
     
     // define the output directory for models
@@ -174,40 +179,31 @@ public class RelationExtractorEvaluation
           trainingArguments,
           options.testOnCTakes);
       
-      if(options.testDirectory == null && options.devDirectory == null) {
-      	// run n-fold cross-validation on the training set
-        
-      	List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
-      	params.stats = AnnotationStatistics.addAll(foldStats);
-        
-      	System.err.println("overall:");
-        System.err.print(params.stats);
-        System.err.println(params.stats.confusions());
-        System.err.println();
-
-        // store these parameter settings and the respective performance
-        scoredParams.put(params, params.stats.f1());
-      } else if(options.devDirectory != null && options.gridSearch) {
-        // tune parameters on the development set
-        
-        List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
-        params.stats = evaluation.trainAndTest(trainFiles, devFiles);
-        
-        System.err.println("overall:");
-        System.err.print(params.stats);
-        System.err.println(params.stats.confusions());
-        System.err.println();
-        
-        // store these parameter settings and the respective performance
-        scoredParams.put(params, params.stats.f1());
+      if (options.devDirectory != null) {
+        if (options.testDirectory != null) {
+          // train on the training set + dev set and evaluate on the test set
+          List<File> allTrainFiles = new ArrayList<File>();
+          allTrainFiles.addAll(trainFiles);
+          allTrainFiles.addAll(Arrays.asList(options.devDirectory.listFiles()));
+          List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+          params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
+        } else {
+          // train on the training set and evaluate on the dev set
+          List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
+          params.stats = evaluation.trainAndTest(trainFiles, devFiles);
+        }
       } else {
-      	// train on the entire training set and evaluate on the test set
-        
-      	List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
-      	AnnotationStatistics<String> stats = evaluation.trainAndTest(trainFiles, testFiles);
-      	System.err.print(stats);
-      	return;
+        if (options.testDirectory != null) {
+          // train on the training set and evaluate on the test set
+          List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+          params.stats = evaluation.trainAndTest(trainFiles, testFiles);
+        } else {
+          // run n-fold cross-validation on the training set
+          List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
+          params.stats = AnnotationStatistics.addAll(foldStats);
+        }
       }
+      scoredParams.put(params, params.stats.f1());
     }
 
     // print parameters sorted by F1
@@ -216,21 +212,22 @@ public class RelationExtractorEvaluation
     Collections.sort(list, Ordering.natural().onResultOf(getCount));
 
     // print performance of each set of parameters
-    System.err.println("Summary:");
-    ParameterSettings lastParams = null;
-    for (ParameterSettings params : list) {
-      System.err.printf(
-          "F1=%.3f P=%.3f R=%.3f %s\n",
-          params.stats.f1(),
-          params.stats.precision(),
-          params.stats.recall(),
-          params);
-      lastParams = params;
+    if (list.size() > 1) {
+      System.err.println("Summary:");
+      for (ParameterSettings params : list) {
+        System.err.printf(
+            "F1=%.3f P=%.3f R=%.3f %s\n",
+            params.stats.f1(),
+            params.stats.precision(),
+            params.stats.recall(),
+            params);
+      }
+      System.err.println();
     }
 
     // print overall best model
-    if (lastParams != null) {
-      System.err.println();
+    if (!list.isEmpty()) {
+      ParameterSettings lastParams = list.get(list.size() - 1);
       System.err.println("Best model:");
       System.err.print(lastParams.stats);
       System.err.println(lastParams);