You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by tm...@apache.org on 2015/03/13 17:23:46 UTC
svn commit: r1666502 - in
/ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval:
EvaluationOfEventCoreference.java EvaluationOfMarkableSalience.java
EvaluationOfMarkableSpans.java
Author: tmill
Date: Fri Mar 13 16:23:46 2015
New Revision: 1666502
URL: http://svn.apache.org/r1666502
Log:
Evaluation code refactored from temporal
Added:
ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java
ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java
ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java
Added: ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java
URL: http://svn.apache.org/viewvc/ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java?rev=1666502&view=auto
==============================================================================
--- ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java (added)
+++ ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java Fri Mar 13 16:23:46 2015
@@ -0,0 +1,700 @@
+package org.apache.ctakes.coreference.eval;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.apache.ctakes.assertion.medfacts.cleartk.PolarityCleartkAnalysisEngine;
+import org.apache.ctakes.core.resource.FileLocator;
+import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
+import org.apache.ctakes.coreference.ae.CoreferenceChainScoringOutput;
+import org.apache.ctakes.coreference.ae.DeterministicMarkableAnnotator;
+import org.apache.ctakes.coreference.ae.EventCoreferenceAnnotator;
+import org.apache.ctakes.coreference.ae.MarkableSalienceAnnotator;
+import org.apache.ctakes.coreference.ae.MentionClusterCoreferenceAnnotator;
+import org.apache.ctakes.coreference.ae.PersonChainAnnotator;
+import org.apache.ctakes.dependency.parser.util.DependencyUtility;
+import org.apache.ctakes.relationextractor.eval.RelationExtractorEvaluation.HashableArguments;
+import org.apache.ctakes.temporal.ae.DocTimeRelAnnotator;
+import org.apache.ctakes.temporal.ae.EventAnnotator;
+import org.apache.ctakes.temporal.eval.EvaluationOfEventTimeRelations;
+import org.apache.ctakes.temporal.eval.EvaluationOfTemporalRelations_ImplBase;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
+import org.apache.ctakes.temporal.eval.EvaluationOfEventTimeRelations.ParameterSettings;
+import org.apache.ctakes.temporal.eval.EvaluationOfTemporalRelations_ImplBase.TempRelOptions;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.Subcorpus;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat;
+import org.apache.ctakes.typesystem.type.constants.CONST;
+import org.apache.ctakes.typesystem.type.relation.BinaryTextRelation;
+import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
+import org.apache.ctakes.typesystem.type.relation.CoreferenceRelation;
+import org.apache.ctakes.typesystem.type.relation.RelationArgument;
+import org.apache.ctakes.typesystem.type.syntax.BaseToken;
+import org.apache.ctakes.typesystem.type.syntax.ConllDependencyNode;
+import org.apache.ctakes.typesystem.type.syntax.NewlineToken;
+import org.apache.ctakes.typesystem.type.syntax.WordToken;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.ctakes.typesystem.type.textspan.Paragraph;
+import org.apache.ctakes.utils.distsem.WordEmbeddings;
+import org.apache.ctakes.utils.distsem.WordVector;
+import org.apache.ctakes.utils.distsem.WordVectorReader;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.uima.UimaContext;
+import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
+import org.apache.uima.analysis_engine.metadata.FixedFlow;
+import org.apache.uima.analysis_engine.metadata.FlowConstraints;
+import org.apache.uima.cas.CAS;
+import org.apache.uima.cas.CASException;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.fit.component.ViewCreatorAnnotator;
+import org.apache.uima.fit.descriptor.ConfigurationParameter;
+import org.apache.uima.fit.factory.AggregateBuilder;
+import org.apache.uima.fit.factory.AnalysisEngineFactory;
+import org.apache.uima.fit.factory.FlowControllerFactory;
+import org.apache.uima.fit.pipeline.JCasIterator;
+import org.apache.uima.fit.pipeline.SimplePipeline;
+import org.apache.uima.fit.util.JCasUtil;
+import org.apache.uima.flow.FinalStep;
+import org.apache.uima.flow.Flow;
+import org.apache.uima.flow.FlowControllerContext;
+import org.apache.uima.flow.FlowControllerDescription;
+import org.apache.uima.flow.JCasFlow_ImplBase;
+import org.apache.uima.flow.SimpleStep;
+import org.apache.uima.flow.Step;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.cas.EmptyFSList;
+import org.apache.uima.jcas.cas.FSArray;
+import org.apache.uima.jcas.cas.FSList;
+import org.apache.uima.jcas.cas.FloatArray;
+import org.apache.uima.jcas.cas.NonEmptyFSList;
+import org.apache.uima.jcas.tcas.Annotation;
+import org.apache.uima.resource.ResourceInitializationException;
+import org.apache.uima.util.FileUtils;
+import org.cleartk.eval.AnnotationStatistics;
+import org.cleartk.ml.jar.JarClassifierBuilder;
+import org.cleartk.ml.liblinear.LibLinearStringOutcomeDataWriter;
+import org.cleartk.ml.libsvm.tk.TkLibSvmStringOutcomeDataWriter;
+import org.cleartk.ml.tksvmlight.model.CompositeKernel.ComboOperator;
+import org.cleartk.util.ViewUriUtil;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
+
+public class EvaluationOfEventCoreference extends EvaluationOfTemporalRelations_ImplBase {
+
+
+ static interface CoreferenceOptions extends TempRelOptions{
+ @Option
+ public String getOutputDirectory();
+
+ @Option
+ public boolean getUseTmp();
+
+ @Option
+ public boolean getTestOnTrain();
+
+ @Option(longName="external")
+ public boolean getUseExternalScorer();
+
+ @Option(shortName="t", defaultValue={"MENTION_PAIR"})
+ public EVAL_SYSTEM getEvalSystem();
+ }
+
+ private static Logger logger = Logger.getLogger(EvaluationOfEventCoreference.class);
+ public static float COREF_DOWNSAMPLE = 0.5f;
+ protected static ParameterSettings allParams = new ParameterSettings(DEFAULT_BOTH_DIRECTIONS, COREF_DOWNSAMPLE, "tk",
+ 1.0, 1.0, "linear", ComboOperator.SUM, 0.1, 0.5);
+
+ public static void main(String[] args) throws Exception {
+ CoreferenceOptions options = CliFactory.parseArguments(CoreferenceOptions.class, args);
+
+ List<Integer> patientSets = options.getPatients().getList();
+ List<Integer> trainItems = getTrainItems(options);
+ List<Integer> testItems = options.getTestOnTrain() ? getTrainItems(options) : getTestItems(options);
+
+ ParameterSettings params = allParams;
+ File workingDir = new File("target/eval/temporal-relations/coreference");
+ if(!workingDir.exists()) workingDir.mkdirs();
+ if(options.getUseTmp()){
+ File tempModelDir = File.createTempFile("temporal", null, workingDir);
+ tempModelDir.delete();
+ tempModelDir.mkdir();
+ workingDir = tempModelDir;
+ }
+ EvaluationOfEventCoreference eval = new EvaluationOfEventCoreference(
+ workingDir,
+ options.getRawTextDirectory(),
+ options.getXMLDirectory(),
+ options.getXMLFormat(),
+ options.getSubcorpus(),
+ options.getXMIDirectory(),
+ options.getTreebankDirectory(),
+ options.getPrintErrors(),
+ options.getPrintFormattedRelations(),
+ params,
+ options.getKernelParams(),
+ options.getOutputDirectory());
+
+ if(options.getSkipTrain()){
+ eval.skipTrain = true;
+ }
+ if(options.getSkipDataWriting()){
+ eval.skipWrite = true;
+ }
+ eval.evalType = options.getEvalSystem();
+ eval.prepareXMIsFor(patientSets);
+
+ params.stats = eval.trainAndTest(trainItems, testItems);//training);//
+ // System.err.println(options.getKernelParams() == null ? params : options.getKernelParams());
+ System.err.println(params.stats);
+
+ if(options.getUseTmp()){
+ FileUtils.deleteRecursive(workingDir);
+ }
+
+ if(options.getUseExternalScorer()){
+ Pattern patt = Pattern.compile("(?:Coreference|BLANC): Recall: \\([^\\)]*\\) (\\S+)%.*Precision: \\([^\\)]*\\) (\\S+)%.*F1: (\\S+)%");
+ Runtime runtime = Runtime.getRuntime();
+ Process p = runtime.exec(new String[]{
+ "perl",
+ "/home/tmill/soft/reference-coreference-scorers-read-only/scorer.pl",
+ "all",
+ options.getOutputDirectory() + "gold.chains",
+ options.getOutputDirectory() + "system.chains",
+ "none"});
+ BufferedReader reader = new BufferedReader(new InputStreamReader(p.getInputStream()));
+ String line, metric=null;
+ System.out.println(String.format("%10s%7s%7s%7s", "Metric", "Rec", "Prec", "F1"));
+ Map<String,Double> scores = new HashMap<>();
+ while((line = reader.readLine()) != null){
+ line = line.trim();
+ if(line.startsWith("METRIC")){
+ metric = line.substring(7); // everything after "METRIC"
+ metric = metric.substring(0, metric.length()-1); // remove colon from the end
+ }else if(line.startsWith("Coreference")){
+ Matcher m = patt.matcher(line);
+ if(m.matches()){
+ System.out.println(String.format("%10s%7.2f%7.2f%7.2f", metric, Double.parseDouble(m.group(1)), Double.parseDouble(m.group(2)), Double.parseDouble(m.group(3))));
+ scores.put(metric, Double.parseDouble(m.group(3)));
+ }
+ }
+ }
+
+ if(scores.containsKey("muc") && scores.containsKey("bcub") && scores.containsKey("ceafe")){
+ double conll = (scores.get("muc") + scores.get("bcub") + scores.get("ceafe")) / 3.0;
+ System.out.println(String.format("%10s %7.2f", "Conll", conll));
+ }
+ }
+ }
+
+ boolean skipTrain=false;
+ boolean skipWrite=false;
+ public enum EVAL_SYSTEM { BASELINE, MENTION_PAIR, MENTION_CLUSTER };
+ EVAL_SYSTEM evalType;
+
+ private String outputDirectory;
+
+ public EvaluationOfEventCoreference(File baseDirectory,
+ File rawTextDirectory, File xmlDirectory,
+ org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat, Subcorpus subcorpus,
+ File xmiDirectory, File treebankDirectory, boolean printErrors,
+ boolean printRelations, ParameterSettings params, String cmdParams, String outputDirectory) {
+ super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, subcorpus, xmiDirectory,
+ treebankDirectory, printErrors, printRelations, params);
+ this.outputDirectory = outputDirectory;
+ this.kernelParams = cmdParams == null ? null : cmdParams.split(" ");
+ }
+
+ @Override
+ protected void train(CollectionReader collectionReader, File directory)
+ throws Exception {
+ if(skipTrain) return;
+ if(!skipWrite){
+ AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+ aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class, ViewCreatorAnnotator.PARAM_VIEW_NAME, "Baseline"));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphAnnotator.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphVectorAnnotator.class));
+ aggregateBuilder.add(EventAnnotator.createAnnotatorDescription());
+ aggregateBuilder.add(DocTimeRelAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/doctimerel/model.jar"));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+ // aggregateBuilder.add(CopyFromGold.getDescription(/*Markable.class,*/ CoreferenceRelation.class, CollectionTextRelation.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CopyCoreferenceRelations.class, CopyCoreferenceRelations.PARAM_GOLD_VIEW, GOLD_VIEW_NAME));
+ aggregateBuilder.add(MarkableSalienceAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/salience/model.jar"));
+ if(this.evalType == EVAL_SYSTEM.MENTION_PAIR){
+ aggregateBuilder.add(EventCoreferenceAnnotator.createDataWriterDescription(
+ // TKSVMlightStringOutcomeDataWriter.class,
+ // LibLinearStringOutcomeDataWriter.class,
+ TkLibSvmStringOutcomeDataWriter.class,
+ directory,
+ params.probabilityOfKeepingANegativeExample
+ ));
+ }else if(this.evalType == EVAL_SYSTEM.MENTION_CLUSTER){
+ aggregateBuilder.add(MentionClusterCoreferenceAnnotator.createDataWriterDescription(
+ LibLinearStringOutcomeDataWriter.class,
+// TkLibSvmStringOutcomeDataWriter.class,
+ directory,
+ params.probabilityOfKeepingANegativeExample
+ ));
+ }
+ Logger.getLogger(EventCoreferenceAnnotator.class).setLevel(Level.WARN);
+ // create gold chains for writing out which we can then use for our scoring tool
+ // aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CoreferenceChainScoringOutput.class,
+ // CoreferenceChainScoringOutput.PARAM_OUTPUT_DIR,
+ // this.outputDirectory + "train"));
+ FlowControllerDescription corefFlowControl = FlowControllerFactory.createFlowControllerDescription(CorefEvalFlowController.class);
+ aggregateBuilder.setFlowControllerDescription(corefFlowControl);
+
+ SimplePipeline.runPipeline(collectionReader, aggregateBuilder.createAggregate());
+ }
+ String[] optArray;
+
+ if(this.kernelParams == null){
+ ArrayList<String> svmOptions = new ArrayList<>();
+ svmOptions.add("-c"); svmOptions.add(""+params.svmCost); // svm cost
+ svmOptions.add("-t"); svmOptions.add(""+params.svmKernelIndex); // kernel index
+ svmOptions.add("-d"); svmOptions.add("3"); // degree parameter for polynomial
+ svmOptions.add("-g"); svmOptions.add(""+params.svmGamma);
+ if(params.svmKernelIndex==ParameterSettings.SVM_KERNELS.indexOf("tk")){
+ svmOptions.add("-S"); svmOptions.add(""+params.secondKernelIndex); // second kernel index (similar to -t) for composite kernel
+ String comboFlag = (params.comboOperator == ComboOperator.SUM ? "+" : params.comboOperator == ComboOperator.PRODUCT ? "*" : params.comboOperator == ComboOperator.TREE_ONLY ? "T" : "V");
+ svmOptions.add("-C"); svmOptions.add(comboFlag);
+ svmOptions.add("-L"); svmOptions.add(""+params.lambda);
+ svmOptions.add("-T"); svmOptions.add(""+params.tkWeight);
+ svmOptions.add("-N"); svmOptions.add("3"); // normalize trees and features
+ }
+ optArray = svmOptions.toArray(new String[]{});
+ }else{
+ optArray = this.kernelParams;
+ for(int i = 0; i < optArray.length; i+=2){
+ optArray[i] = "-" + optArray[i];
+ }
+ }
+ JarClassifierBuilder.trainAndPackage(directory, optArray);
+ }
+
+ @Override
+ protected AnnotationStatistics<String> test(
+ CollectionReader collectionReader, File directory) throws Exception {
+ AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+ aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphAnnotator.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphVectorAnnotator.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+ aggregateBuilder.add(EventAnnotator.createAnnotatorDescription());
+ aggregateBuilder.add(DocTimeRelAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/doctimerel/model.jar"));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CoreferenceChainScoringOutput.class,
+ CoreferenceChainScoringOutput.PARAM_OUTPUT_FILENAME,
+ this.outputDirectory + "gold.chains",
+ CoreferenceChainScoringOutput.PARAM_GOLD_VIEW_NAME,
+ GOLD_VIEW_NAME));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+ aggregateBuilder.add(MarkableSalienceAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/salience/model.jar"));
+ if(this.evalType == EVAL_SYSTEM.MENTION_PAIR){
+ aggregateBuilder.add(EventCoreferenceAnnotator.createAnnotatorDescription(directory.getAbsolutePath() + File.separator + "model.jar"));
+ }else if(this.evalType == EVAL_SYSTEM.MENTION_CLUSTER){
+ aggregateBuilder.add(MentionClusterCoreferenceAnnotator.createAnnotatorDescription(directory.getAbsolutePath() + File.separator + "model.jar"));
+ }
+// aggregateBuilder.add(CoreferenceChainAnnotator.createAnnotatorDescription());
+ aggregateBuilder.add(PersonChainAnnotator.createAnnotatorDescription());
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CoreferenceChainScoringOutput.class,
+ CoreferenceChainScoringOutput.PARAM_OUTPUT_FILENAME,
+ this.outputDirectory + "system.chains"));
+
+ FlowControllerDescription corefFlowControl = FlowControllerFactory.createFlowControllerDescription(CorefEvalFlowController.class);
+ aggregateBuilder.setFlowControllerDescription(corefFlowControl);
+// aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(XMIWriter.class));
+ Function<CoreferenceRelation, ?> getSpan = new Function<CoreferenceRelation, HashableArguments>() {
+ public HashableArguments apply(CoreferenceRelation relation) {
+ return new HashableArguments(relation);
+ }
+ };
+ Function<CoreferenceRelation, String> getOutcome = new Function<CoreferenceRelation,String>() {
+ public String apply(CoreferenceRelation relation){
+ return "Coreference";
+ }
+ };
+
+ AnnotationStatistics<String> corefStats = new AnnotationStatistics<>();
+
+ for(Iterator<JCas> casIter =new JCasIterator(collectionReader, aggregateBuilder.createAggregate()); casIter.hasNext();){
+ JCas jCas = casIter.next();
+ JCas goldView = jCas.getView(GOLD_VIEW_NAME);
+ JCas systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA);
+ Collection<CoreferenceRelation> goldRelations = JCasUtil.select(
+ goldView,
+ CoreferenceRelation.class);
+ Collection<CoreferenceRelation> systemRelations = JCasUtil.select(
+ systemView,
+ CoreferenceRelation.class);
+ corefStats.add(goldRelations, systemRelations, getSpan, getOutcome);
+ if(this.printErrors){
+ Map<HashableArguments, BinaryTextRelation> goldMap = Maps.newHashMap();
+ for (BinaryTextRelation relation : goldRelations) {
+ goldMap.put(new HashableArguments(relation), relation);
+ }
+ Map<HashableArguments, BinaryTextRelation> systemMap = Maps.newHashMap();
+ for (BinaryTextRelation relation : systemRelations) {
+ systemMap.put(new HashableArguments(relation), relation);
+ }
+ Set<HashableArguments> all = Sets.union(goldMap.keySet(), systemMap.keySet());
+ List<HashableArguments> sorted = Lists.newArrayList(all);
+ Collections.sort(sorted);
+ for (HashableArguments key : sorted) {
+ BinaryTextRelation goldRelation = goldMap.get(key);
+ BinaryTextRelation systemRelation = systemMap.get(key);
+ if (goldRelation == null) {
+ System.out.println("System added: " + formatRelation(systemRelation));
+ } else if (systemRelation == null) {
+ System.out.println("System dropped: " + formatRelation(goldRelation));
+ } else if (!systemRelation.getCategory().equals(goldRelation.getCategory())) {
+ String label = systemRelation.getCategory();
+ System.out.printf("System labeled %s for %s\n", label, formatRelation(goldRelation));
+ } else{
+ System.out.println("Nailed it! " + formatRelation(systemRelation));
+ }
+ }
+ }
+ }
+
+ return corefStats;
+ }
+
+ public static class AnnotationComparator implements Comparator<Annotation> {
+
+ @Override
+ public int compare(Annotation o1, Annotation o2) {
+ if(o1.getBegin() < o2.getBegin()){
+ return -1;
+ }else if(o1.getBegin() == o2.getBegin() && o1.getEnd() < o2.getEnd()){
+ return -1;
+ }else if(o1.getBegin() == o2.getBegin() && o1.getEnd() > o2.getEnd()){
+ return 1;
+ }else if(o2.getBegin() < o1.getBegin()){
+ return 1;
+ }else{
+ return 0;
+ }
+ }
+ }
+ public static class DocumentIDPrinter extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+ static Logger logger = Logger.getLogger(DocumentIDPrinter.class);
+ @Override
+ public void process(JCas jCas) throws AnalysisEngineProcessException {
+ String docId = DocumentIDAnnotationUtil.getDocumentID(jCas);
+ if(docId == null){
+ docId = new File(ViewUriUtil.getURI(jCas)).getName();
+ }
+ logger.info(String.format("Processing %s\n", docId));
+ }
+
+ }
+
+ public static class ParagraphAnnotator extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+ List<BaseToken> tokens = new ArrayList<>(JCasUtil.select(jcas, BaseToken.class));
+ BaseToken lastToken = null;
+ int parStart = 0;
+
+ for(int i = 0; i < tokens.size(); i++){
+ BaseToken token = tokens.get(i);
+ if(parStart == i && token instanceof NewlineToken){
+ // we've just created a pargraph ending but there were multiple newlines -- don't want to start the
+ // new paragraph until we are past the newlines -- increment the parStart index and move forward
+ parStart++;
+ }else if(lastToken != null && token instanceof NewlineToken){
+ Paragraph par = new Paragraph(jcas, tokens.get(parStart).getBegin(), lastToken.getEnd());
+ par.addToIndexes();
+ parStart = i+1;
+ }
+ lastToken = token;
+ }
+
+ }
+
+ }
+
+ public static class ParagraphVectorAnnotator extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+ WordEmbeddings words = null;
+
+ @Override
+ public void initialize(final UimaContext context) throws ResourceInitializationException{
+ try {
+ words = WordVectorReader.getEmbeddings(FileLocator.getAsStream("org/apache/ctakes/coreference/distsem/mimic_vectors.txt"));
+ } catch (IOException e) {
+ e.printStackTrace();
+ throw new ResourceInitializationException(e);
+ }
+ }
+
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+ List<Paragraph> pars = new ArrayList<>(JCasUtil.select(jcas, Paragraph.class));
+ FSArray parVecs = new FSArray(jcas, pars.size());
+ for(int parNum = 0; parNum < pars.size(); parNum++){
+ Paragraph par = pars.get(parNum);
+ float[] parVec = new float[words.getDimensionality()];
+
+ List<BaseToken> tokens = JCasUtil.selectCovered(BaseToken.class, par);
+ for(int i = 0; i < tokens.size(); i++){
+ BaseToken token = tokens.get(i);
+ if(token instanceof WordToken){
+ String word = token.getCoveredText().toLowerCase();
+ if(words.containsKey(word)){
+ WordVector wv = words.getVector(word);
+ for(int j = 0; j < parVec.length; j++){
+ parVec[j] += wv.getValue(j);
+ }
+ }
+ }
+ }
+ normalize(parVec);
+ FloatArray vec = new FloatArray(jcas, words.getDimensionality());
+ vec.copyFromArray(parVec, 0, 0, parVec.length);
+ vec.addToIndexes();
+ parVecs.set(parNum, vec);
+ }
+ parVecs.addToIndexes();
+ }
+
+ private static final void normalize(float[] vec) {
+ double sum = 0.0;
+ for(int i = 0; i < vec.length; i++){
+ sum += (vec[i]*vec[i]);
+ }
+ sum = Math.sqrt(sum);
+ for(int i = 0; i < vec.length; i++){
+ vec[i] /= sum;
+ }
+ }
+ }
+
+ public static class CopyCoreferenceRelations extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+ public static final String PARAM_GOLD_VIEW = "GoldViewName";
+ @ConfigurationParameter(name=PARAM_GOLD_VIEW, mandatory=true, description="View containing gold standard annotations")
+ private String goldViewName;
+
+ @SuppressWarnings("synthetic-access")
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+ JCas goldView = null;
+ try {
+ goldView = jcas.getView(goldViewName);
+ } catch (CASException e) {
+ e.printStackTrace();
+ throw new AnalysisEngineProcessException(e);
+ }
+
+ HashMap<Markable,Markable> gold2sys = new HashMap<>();
+ Map<ConllDependencyNode,Collection<Markable>> depIndex = JCasUtil.indexCovering(jcas, ConllDependencyNode.class, Markable.class);
+ // remove those with removed markables (person mentions)
+ List<CollectionTextRelation> toRemove = new ArrayList<>();
+
+ for(CollectionTextRelation goldChain : JCasUtil.select(goldView, CollectionTextRelation.class)){
+ FSList head = goldChain.getMembers();
+ NonEmptyFSList sysList = new NonEmptyFSList(jcas);
+ NonEmptyFSList listEnd = sysList;
+ boolean removeChain = false;
+
+ // first one is guaranteed to be nonempty otherwise it would not be in cas
+ do{
+ NonEmptyFSList element = (NonEmptyFSList) head;
+ // if this is not first time through move listEnd to end.
+ if(listEnd.getHead() != null){
+ listEnd.setTail(new NonEmptyFSList(jcas));
+ listEnd.addToIndexes();
+ listEnd = (NonEmptyFSList) listEnd.getTail();
+ }
+ Markable goldMarkable = (Markable) element.getHead();
+ if(!(goldMarkable.getBegin() < 0 || goldMarkable.getEnd() >= jcas.getDocumentText().length())){
+
+
+ ConllDependencyNode headNode = DependencyUtility.getNominalHeadNode(jcas, goldMarkable);
+
+ for(Markable sysMarkable : depIndex.get(headNode)){
+ ConllDependencyNode markNode = DependencyUtility.getNominalHeadNode(jcas, sysMarkable);
+ if(markNode == headNode){
+ gold2sys.put(goldMarkable, sysMarkable);
+ break;
+ }
+ }
+ if(!gold2sys.containsKey(goldMarkable)){
+ Markable mappedGold = new Markable(jcas, goldMarkable.getBegin(), goldMarkable.getEnd());
+ mappedGold.addToIndexes();
+ }
+ }else{
+ // Have seen some instances where anafora writes a span that is not possible, log them
+ // so they can be found and fixed:
+ logger.warn(String.format("There is a markable with span [%d, %d] in a document with length %d\n",
+ goldMarkable.getBegin(), goldMarkable.getEnd(), jcas.getDocumentText().length()));
+ }
+
+ // add markable to end of list:
+ if(gold2sys.get(goldMarkable) == null){
+ logger.warn(String.format("There is a gold markable [%d, %d] which could not map to a system markable.",
+ goldMarkable.getBegin(), goldMarkable.getEnd()));
+ removeChain = true;
+ break;
+ }
+ listEnd.setHead(gold2sys.get(goldMarkable));
+
+ head = element.getTail();
+ }while(head instanceof NonEmptyFSList);
+
+ // don't bother copying over -- the gold chain was of person mentions
+ if(!removeChain){
+ listEnd.setTail(new EmptyFSList(jcas));
+ listEnd.addToIndexes();
+ listEnd.getTail().addToIndexes();
+ sysList.addToIndexes();
+ CollectionTextRelation sysRel = new CollectionTextRelation(jcas);
+ sysRel.setMembers(sysList);
+ sysRel.addToIndexes();
+ }
+ }
+
+ for(CoreferenceRelation goldRel : JCasUtil.select(goldView, CoreferenceRelation.class)){
+ if((gold2sys.containsKey(goldRel.getArg1().getArgument()) && gold2sys.containsKey(goldRel.getArg2().getArgument()))){
+ CoreferenceRelation sysRel = new CoreferenceRelation(jcas);
+ sysRel.setCategory(goldRel.getCategory());
+ sysRel.setDiscoveryTechnique(CONST.REL_DISCOVERY_TECH_GOLD_ANNOTATION);
+
+ RelationArgument arg1 = new RelationArgument(jcas);
+ arg1.setArgument(gold2sys.get(goldRel.getArg1().getArgument()));
+ sysRel.setArg1(arg1);
+ arg1.addToIndexes();
+
+ RelationArgument arg2 = new RelationArgument(jcas);
+ arg2.setArgument(gold2sys.get(goldRel.getArg2().getArgument()));
+ sysRel.setArg2(arg2);
+ arg2.addToIndexes();
+
+ sysRel.addToIndexes();
+ }
+ }
+ }
+ }
+ public static class RemovePersonMarkables extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+// JCas systemView=null, goldView=null;
+// try{
+// systemView = jcas.getView(CAS.NAME_DEFAULT_SOFA);
+// goldView = jcas.getView(GOLD_VIEW_NAME);
+// }catch(Exception e){
+// throw new AnalysisEngineProcessException(e);
+// }
+ List<Markable> toRemove = new ArrayList<>();
+ for(Markable markable : JCasUtil.select(jcas, Markable.class)){
+ List<BaseToken> coveredTokens = JCasUtil.selectCovered(jcas, BaseToken.class, markable);
+ if(coveredTokens.size() == 1 && coveredTokens.get(0).getPartOfSpeech().startsWith("PRP")){
+ toRemove.add(markable);
+ }else if(coveredTokens.size() == 2 &&
+ (coveredTokens.get(0).getCoveredText().startsWith("Mr.") || coveredTokens.get(0).getCoveredText().startsWith("Dr.") ||
+ coveredTokens.get(0).getCoveredText().startsWith("Mrs.") || coveredTokens.get(0).getCoveredText().startsWith("Ms."))){
+ toRemove.add(markable);
+ }else if(markable.getCoveredText().toLowerCase().equals("patient")){
+ toRemove.add(markable);
+ }
+ }
+
+ for(Markable markable : toRemove){
+ markable.removeFromIndexes();
+ }
+ }
+ }
+
+ /* This flow control section borrows from the UIMA implementation of FixedFlowController
+ * and its internal Flow object. Simple change to check if there are any gold
+ * coref annotations inside the cas, and if not skip out so we don't waste
+ * time running coref code on those (since we're not going to print out the answers
+ * anyways)
+ */
+ public static class CorefEvalFlowController extends org.apache.uima.flow.JCasFlowController_ImplBase {
+ List<String> mSequence;
+
+
+ @Override
+ public void initialize(FlowControllerContext context)
+ throws ResourceInitializationException {
+ super.initialize(context);
+
+ FlowConstraints flowConstraints = context.getAggregateMetadata().getFlowConstraints();
+ mSequence = new ArrayList<>();
+ if (flowConstraints instanceof FixedFlow) {
+ String[] sequence = ((FixedFlow) flowConstraints).getFixedFlow();
+ mSequence.addAll(Arrays.asList(sequence));
+ } else {
+ throw new ResourceInitializationException(ResourceInitializationException.FLOW_CONTROLLER_REQUIRES_FLOW_CONSTRAINTS,
+ new Object[]{this.getClass().getName(), "fixedFlow", context.getAggregateMetadata().getSourceUrlString()});
+ }
+ }
+
+ @Override
+ public Flow computeFlow(JCas jcas) throws AnalysisEngineProcessException {
+ return new CorefEvalFlow(jcas, 0);
+ }
+
+ class CorefEvalFlow extends JCasFlow_ImplBase {
+
+ private JCas jcas;
+ private int currentStep;
+
+ public CorefEvalFlow(JCas jcas, int step){
+ this.jcas = jcas;
+ this.currentStep = step;
+ }
+
+ @Override
+ public Step next() {
+ // if we are past the last annotator finish
+ if (currentStep >= mSequence.size()) {
+ return new FinalStep();
+ }
+
+ // if we have gold standard relations, continue
+ if(currentStep > 0 && mSequence.get(currentStep-1).equals(DocumentIDPrinter.class.getName())){
+ JCas goldView;
+ try {
+ goldView = jcas.getView(GOLD_VIEW_NAME);
+ if(JCasUtil.select(goldView, CoreferenceRelation.class).size() == 0){
+ System.out.println("Skipping this document with no coreference relations.");
+ return new FinalStep();
+ }
+ } catch (CASException e) {
+ // no need to stop flow -- just go ahead to default simple step.
+ e.printStackTrace();
+ }
+ }
+
+ // otherwise finish
+ return new SimpleStep(mSequence.get(currentStep++));
+ }
+ }
+ }
+}
Added: ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java
URL: http://svn.apache.org/viewvc/ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java?rev=1666502&view=auto
==============================================================================
--- ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java (added)
+++ ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java Fri Mar 13 16:23:46 2015
@@ -0,0 +1,239 @@
+package org.apache.ctakes.coreference.eval;
+
+import java.io.File;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.ctakes.assertion.medfacts.cleartk.PolarityCleartkAnalysisEngine;
+import org.apache.ctakes.coreference.ae.DeterministicMarkableAnnotator;
+import org.apache.ctakes.coreference.ae.MarkableSalienceAnnotator;
+import org.apache.ctakes.coreference.eval.EvaluationOfEventCoreference.DocumentIDPrinter;
+import org.apache.ctakes.coreference.eval.EvaluationOfEventCoreference.RemovePersonMarkables;
+import org.apache.ctakes.dependency.parser.util.DependencyUtility;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
+import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
+import org.apache.ctakes.typesystem.type.syntax.ConllDependencyNode;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
+import org.apache.uima.cas.CAS;
+import org.apache.uima.cas.CASException;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.fit.component.ViewCreatorAnnotator;
+import org.apache.uima.fit.descriptor.ConfigurationParameter;
+import org.apache.uima.fit.factory.AggregateBuilder;
+import org.apache.uima.fit.factory.AnalysisEngineFactory;
+import org.apache.uima.fit.pipeline.JCasIterator;
+import org.apache.uima.fit.pipeline.SimplePipeline;
+import org.apache.uima.fit.util.JCasUtil;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.cas.FSList;
+import org.apache.uima.jcas.cas.NonEmptyFSList;
+import org.cleartk.eval.AnnotationStatistics;
+import org.cleartk.ml.jar.JarClassifierBuilder;
+import org.cleartk.ml.liblinear.LibLinearBooleanOutcomeDataWriter;
+
+import com.google.common.base.Function;
+import com.lexicalscope.jewel.cli.CliFactory;
+
+public class EvaluationOfMarkableSalience extends Evaluation_ImplBase<AnnotationStatistics<Boolean>> {
+
+ public static void main(String[] args) throws Exception {
+ Options options = CliFactory.parseArguments(Options.class, args);
+ List<Integer> patientSets = options.getPatients().getList();
+ List<Integer> trainItems = getTrainItems(options);
+ List<Integer> testItems = getTestItems(options);
+
+ EvaluationOfMarkableSalience eval =
+ new EvaluationOfMarkableSalience(new File("target/eval/salience"),
+ options.getRawTextDirectory(),
+ options.getXMLDirectory(),
+ options.getXMLFormat(),
+ options.getSubcorpus(),
+ options.getXMIDirectory(), null);
+ eval.prepareXMIsFor(patientSets);
+
+ AnnotationStatistics<Boolean> stats = eval.trainAndTest(trainItems, testItems);
+ System.out.println(stats);
+ System.out.println(stats.confusions());
+ }
+
+ public EvaluationOfMarkableSalience(File baseDirectory,
+ File rawTextDirectory, File xmlDirectory,
+ org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat,
+ org.apache.ctakes.temporal.eval.Evaluation_ImplBase.Subcorpus subcorpus,
+ File xmiDirectory, File treebankDirectory) {
+ super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, subcorpus,
+ xmiDirectory, treebankDirectory);
+ }
+
+ @Override
+ protected void train(CollectionReader collectionReader, File directory)
+ throws Exception {
+ AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+ aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(SetGoldConfidence.class, SetGoldConfidence.PARAM_GOLD_VIEW, GOLD_VIEW_NAME));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(MarkableSalienceAnnotator.createDataWriterDescription(
+ LibLinearBooleanOutcomeDataWriter.class,
+ directory
+ )));
+ SimplePipeline.runPipeline(collectionReader, aggregateBuilder.createAggregate());
+ // s=0 -> logistic regression with L2-norm (gives probabilistic outputs)
+ String[] optArray = new String[]{ "-s", "0", "-c", "1", "-w1", "1"};
+ JarClassifierBuilder.trainAndPackage(directory, optArray);
+ }
+
+ @Override
+ protected AnnotationStatistics<Boolean> test(
+ CollectionReader collectionReader, File directory) throws Exception {
+ AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+ aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class, ViewCreatorAnnotator.PARAM_VIEW_NAME, "PseudoGold"));
+ aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CreatePseudoGoldMarkables.class, CreatePseudoGoldMarkables.PARAM_GOLD_VIEW, GOLD_VIEW_NAME, CreatePseudoGoldMarkables.PARAM_PSEUDO_GOLD_VIEW, "PseudoGold"));
+ aggregateBuilder.add(MarkableSalienceAnnotator.createAnnotatorDescription(directory.getAbsolutePath() + File.separator + "model.jar"));
+ AnnotationStatistics<Boolean> stats = new AnnotationStatistics<>();
+
+ for(Iterator<JCas> casIter = new JCasIterator(collectionReader, aggregateBuilder.createAggregate()); casIter.hasNext();){
+ JCas jCas = casIter.next();
+ JCas goldView = jCas.getView("PseudoGold");
+ JCas systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA);
+
+ stats.add(JCasUtil.select(goldView, Markable.class),
+ JCasUtil.select(systemView, Markable.class),
+ AnnotationStatistics.<Markable>annotationToSpan(),
+ mapConfidenceToBoolean());
+ }
+
+
+ return stats;
+ }
+
+ public static class SetGoldConfidence extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+ public static final String PARAM_GOLD_VIEW = "GoldViewName";
+ @ConfigurationParameter(name=PARAM_GOLD_VIEW, mandatory=true, description="View containing gold standard annotations")
+ private String goldViewName;
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+ JCas goldView = null;
+ try {
+ goldView = jcas.getView(goldViewName);
+ } catch (CASException e) {
+ e.printStackTrace();
+ throw new AnalysisEngineProcessException(e);
+ }
+
+ Map<ConllDependencyNode,Collection<Markable>> depIndex = JCasUtil.indexCovering(jcas, ConllDependencyNode.class, Markable.class);
+
+ // iterate over every gold coreference chain
+ for(CollectionTextRelation goldChain : JCasUtil.select(goldView, CollectionTextRelation.class)){
+ FSList head = goldChain.getMembers();
+
+ // iterate over every gold markable in the chain
+ // first one is guaranteed to be nonempty otherwise it would not be in cas
+ do{
+ NonEmptyFSList element = (NonEmptyFSList) head;
+ Markable goldMarkable = (Markable) element.getHead();
+ if(!(goldMarkable.getBegin() < 0 || goldMarkable.getEnd() >= jcas.getDocumentText().length())){
+ // get the head of this markable, then check if there are any system markables with the same
+ // head, and if so, that markable is "true" for being coreferent, AKA high confidence.
+ ConllDependencyNode headNode = DependencyUtility.getNominalHeadNode(jcas, goldMarkable);
+
+ for(Markable sysMarkable : depIndex.get(headNode)){
+ ConllDependencyNode markNode = DependencyUtility.getNominalHeadNode(jcas, sysMarkable);
+ if(markNode == headNode){
+ sysMarkable.setConfidence(1.0f);
+ break;
+ }
+ }
+ }
+ head = element.getTail();
+ }while(head instanceof NonEmptyFSList);
+ }
+ }
+ }
+
+ public static class CreatePseudoGoldMarkables extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+ public static final String PARAM_PSEUDO_GOLD_VIEW = "PseudoViewName";
+ @ConfigurationParameter(name = PARAM_PSEUDO_GOLD_VIEW)
+ private String fakeGoldName;
+
+ public static final String PARAM_GOLD_VIEW = "GoldViewName";
+ @ConfigurationParameter(name = PARAM_GOLD_VIEW)
+ private String goldViewName;
+
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+ JCas fakeView = null;
+ JCas goldView = null;
+
+ try{
+ fakeView = jcas.getView(fakeGoldName);
+ goldView = jcas.getView(goldViewName);
+ }catch(CASException e){
+ throw new AnalysisEngineProcessException(e);
+ }
+ // create a set of markables that map to gold
+ Set<Markable> sys = new HashSet<>();
+ Map<ConllDependencyNode,Collection<Markable>> depIndex = JCasUtil.indexCovering(jcas, ConllDependencyNode.class, Markable.class);
+
+ // iterate over every gold coreference chain
+ for(CollectionTextRelation goldChain : JCasUtil.select(goldView, CollectionTextRelation.class)){
+ FSList head = goldChain.getMembers();
+
+ // iterate over every gold markable in the chain
+ // first one is guaranteed to be nonempty otherwise it would not be in cas
+ do{
+ NonEmptyFSList element = (NonEmptyFSList) head;
+ Markable goldMarkable = (Markable) element.getHead();
+ if(!(goldMarkable.getBegin() < 0 || goldMarkable.getEnd() >= jcas.getDocumentText().length())){
+ // get the head of this markable, then check if there are any system markables with the same
+ // head, and if so, that markable is "true" for being coreferent, AKA high confidence.
+ ConllDependencyNode headNode = DependencyUtility.getNominalHeadNode(jcas, goldMarkable);
+
+ for(Markable sysMarkable : depIndex.get(headNode)){
+ ConllDependencyNode markNode = DependencyUtility.getNominalHeadNode(jcas, sysMarkable);
+ if(markNode == headNode){
+ sys.add(sysMarkable);
+ break;
+ }
+ }
+ }
+ head = element.getTail();
+ }while(head instanceof NonEmptyFSList);
+ }
+
+ // add all system markables to psuedo-gold and with confidence based on whether they map
+ for(Markable markable : JCasUtil.select(jcas, Markable.class)){
+ Markable fakeMarkable = new Markable(fakeView, markable.getBegin(), markable.getEnd());
+
+ if(sys.contains(markable)){
+ fakeMarkable.setConfidence(1.0f);
+ }else{
+ fakeMarkable.setConfidence(0.0f);
+ }
+ fakeMarkable.addToIndexes();
+ }
+ }
+ }
+
+ // this is predicting non-singletons rather than singletons
+ public static Function<Markable,Boolean> mapConfidenceToBoolean(){
+ return new Function<Markable,Boolean>() {
+ public Boolean apply(Markable markable) {
+ return markable.getConfidence() > 0.5;
+ }
+ };
+ }
+}
Added: ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java
URL: http://svn.apache.org/viewvc/ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java?rev=1666502&view=auto
==============================================================================
--- ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java (added)
+++ ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java Fri Mar 13 16:23:46 2015
@@ -0,0 +1,208 @@
+package org.apache.ctakes.coreference.eval;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.logging.Level;
+
+import org.apache.ctakes.coreference.ae.DeterministicMarkableAnnotator;
+import org.apache.ctakes.coreference.ae.MarkableAnnotator;
+import org.apache.ctakes.temporal.eval.EvaluationOfAnnotationSpans_ImplBase;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
+import org.apache.ctakes.temporal.eval.THYMEData;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.CopyFromGold;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.Options;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.ctakes.typesystem.type.textspan.Segment;
+import org.apache.uima.analysis_engine.AnalysisEngineDescription;
+import org.apache.uima.cas.CAS;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
+import org.apache.uima.fit.factory.AggregateBuilder;
+import org.apache.uima.fit.factory.AnalysisEngineFactory;
+import org.apache.uima.fit.pipeline.SimplePipeline;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.tcas.Annotation;
+import org.apache.uima.resource.ResourceInitializationException;
+import org.cleartk.eval.AnnotationStatistics;
+import org.cleartk.ml.CleartkAnnotator;
+import org.cleartk.ml.jar.DefaultDataWriterFactory;
+import org.cleartk.ml.jar.DirectoryDataWriterFactory;
+import org.cleartk.ml.jar.GenericJarClassifierFactory;
+import org.cleartk.ml.jar.JarClassifierBuilder;
+import org.cleartk.ml.liblinear.LibLinearStringOutcomeDataWriter;
+
+import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
+
+public class EvaluationOfMarkableSpans extends EvaluationOfAnnotationSpans_ImplBase {
+ static interface Options extends Evaluation_ImplBase.Options {
+ @Option
+ public boolean getUseTmp();
+
+ @Option
+ public boolean getPul();
+
+ @Option(shortName="m")
+ public boolean getUseMachineLearning();
+ }
+
+
+ public static void main(String[] args) throws Exception {
+ Options options = CliFactory.parseArguments(Options.class, args);
+ List<Integer> trainItems = null;
+ List<Integer> devItems = null;
+ List<Integer> testItems = null;
+
+ List<Integer> patientSets = options.getPatients().getList();
+ trainItems = THYMEData.getTrainPatientSets(patientSets);
+ devItems = THYMEData.getDevPatientSets(patientSets);
+ testItems = THYMEData.getTestPatientSets(patientSets);
+ File workingDir = new File("target/eval/markable-spans/");
+ if(!workingDir.exists()) workingDir.mkdirs();
+ if(options.getUseTmp()){
+ File tempModelDir = File.createTempFile("temporal", null, workingDir);
+ tempModelDir.delete();
+ tempModelDir.mkdir();
+ workingDir = tempModelDir;
+ }
+
+ List<Integer> allTrain = new ArrayList<>(trainItems);
+ List<Integer> allTest = null;
+
+ if(options.getTest()){
+ allTrain.addAll(devItems);
+ allTest = new ArrayList<>(testItems);
+ }else{
+ allTest = new ArrayList<>(devItems);
+ }
+
+ EvaluationOfMarkableSpans eval = new EvaluationOfMarkableSpans(
+ workingDir,
+ options.getRawTextDirectory(),
+ options.getXMLDirectory(),
+ options.getXMLFormat(),
+ options.getXMIDirectory(),
+ options.getTreebankDirectory(),
+ options.getPrintErrors());
+
+
+ eval.trainingArguments = new String[]{ "-c", "1.0", "-s", "0"};
+ eval.annotatorClass = options.getUseMachineLearning() ? MarkableAnnotator.class : DeterministicMarkableAnnotator.class;
+ String name = String.format("%s.errors", eval.annotatorClass.getSimpleName());
+ eval.setLogging(Level.FINE, new File("target/eval", name));
+
+ AnnotationStatistics<String> stats = null;
+ if(options.getPul()){
+ stats = eval.trainAndRetrainAndTest(allTrain, allTest);
+ }else{
+ stats = eval.trainAndTest(allTrain, allTest);
+ }
+ System.out.println(stats);
+ }
+
+ protected String[] trainingArguments;
+
+ protected Class<? extends JCasAnnotator_ImplBase> annotatorClass = null;
+
+ public EvaluationOfMarkableSpans(File workingDir, File rawTextDirectory,
+ File xmlDirectory,
+ org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat,
+ File xmiDirectory, File treebankDirectory,
+ boolean printErrors) {
+ super(workingDir, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory, treebankDirectory, Markable.class);
+ this.printErrors = printErrors;
+ }
+
+ public AnnotationStatistics<String> trainAndRetrainAndTest(List<Integer> trainItems, List<Integer> testItems)
+ throws Exception {
+ File subDirectory = new File(this.baseDirectory, "train_and_test");
+ subDirectory.mkdirs();
+ this.train(this.getCollectionReader(trainItems), subDirectory);
+ this.retrain(this.getCollectionReader(trainItems), subDirectory);
+ return this.test(this.getCollectionReader(testItems), subDirectory);
+ }
+
+
+
+ @Override
+ protected void train(CollectionReader collectionReader, File directory)
+ throws Exception {
+ if(this.annotatorClass == MarkableAnnotator.class){
+ super.train(collectionReader, directory);
+ }
+ }
+
+
+ protected void retrain(CollectionReader collectionReader, File directory) throws Exception{
+ AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+ aggregateBuilder.add(CopyFromGold.getDescription(Markable.class));
+ aggregateBuilder.add(this.getDataRewriterDescription(directory), "TimexView", CAS.NAME_DEFAULT_SOFA);
+ SimplePipeline.runPipeline(collectionReader, aggregateBuilder.createAggregate());
+ this.trainAndPackage(directory);
+ }
+
+ @Override
+ protected void trainAndPackage(File directory) throws Exception {
+ JarClassifierBuilder.trainAndPackage(getModelDirectory(directory), this.trainingArguments);
+ }
+
+ @Override
+ protected AnalysisEngineDescription getDataWriterDescription(File directory)
+ throws ResourceInitializationException {
+ return AnalysisEngineFactory.createEngineDescription(
+ MarkableAnnotator.class,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ true,
+ DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+ LibLinearStringOutcomeDataWriter.class,
+ DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+ getModelDirectory(directory));
+ }
+
+ protected AnalysisEngineDescription getDataRewriterDescription(File directory)
+ throws ResourceInitializationException {
+ return AnalysisEngineFactory.createEngineDescription(
+ MarkableAnnotator.class,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ false,
+ MarkableAnnotator.PARAM_IS_RETRAINING,
+ true,
+ DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+ LibLinearStringOutcomeDataWriter.class,
+ DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+ getModelDirectory(directory),
+ GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+ new File(getModelDirectory(directory), "model.jar"));
+ }
+
+ @Override
+ protected AnalysisEngineDescription getAnnotatorDescription(File directory)
+ throws ResourceInitializationException {
+ return AnalysisEngineFactory.createEngineDescription(
+ annotatorClass,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ false,
+ GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+ new File(getModelDirectory(directory), "model.jar"));
+ }
+
+ @Override
+ protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas,
+ Segment segment) {
+ return selectExact(jCas, Markable.class, segment);
+ }
+
+ @Override
+ protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas,
+ Segment segment) {
+ return selectExact(jCas, Markable.class, segment);
+ }
+
+ private static File getModelDirectory(File directory) {
+ return new File(directory, MarkableAnnotator.class.getSimpleName());
+ }
+
+}