You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by tm...@apache.org on 2014/07/16 16:18:42 UTC
svn commit: r1611018 - in
/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal:
ae/EventCoreferenceAnnotator.java eval/EvaluationOfEventCoreference.java
Author: tmill
Date: Wed Jul 16 14:18:42 2014
New Revision: 1611018
URL: http://svn.apache.org/r1611018
Log:
CTAKES-199: Event coreference annotator and evaluator.
Added:
ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/EventCoreferenceAnnotator.java
ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfEventCoreference.java
Added: ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/EventCoreferenceAnnotator.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/EventCoreferenceAnnotator.java?rev=1611018&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/EventCoreferenceAnnotator.java (added)
+++ ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/EventCoreferenceAnnotator.java Wed Jul 16 14:18:42 2014
@@ -0,0 +1,225 @@
+package org.apache.ctakes.temporal.ae;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.ctakes.coreference.ae.features.DistSemFeatureExtractor;
+import org.apache.ctakes.coreference.ae.features.DistanceFeatureExtractor;
+import org.apache.ctakes.coreference.ae.features.StringMatchingFeatureExtractor;
+import org.apache.ctakes.coreference.ae.features.TokenFeatureExtractor;
+import org.apache.ctakes.coreference.ae.features.UMLSFeatureExtractor;
+import org.apache.ctakes.relationextractor.ae.RelationExtractorAnnotator;
+import org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor;
+import org.apache.ctakes.typesystem.type.relation.BinaryTextRelation;
+import org.apache.ctakes.typesystem.type.relation.CoreferenceRelation;
+import org.apache.ctakes.typesystem.type.relation.RelationArgument;
+import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.ctakes.typesystem.type.textspan.Paragraph;
+import org.apache.ctakes.typesystem.type.textspan.Sentence;
+import org.apache.uima.analysis_engine.AnalysisEngineDescription;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.cas.FSArray;
+import org.apache.uima.jcas.cas.FloatArray;
+import org.apache.uima.jcas.tcas.Annotation;
+import org.apache.uima.jcas.tcas.DocumentAnnotation;
+import org.apache.uima.resource.ResourceInitializationException;
+import org.cleartk.classifier.CleartkAnnotator;
+import org.cleartk.classifier.DataWriter;
+import org.cleartk.classifier.jar.DefaultDataWriterFactory;
+import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
+import org.cleartk.classifier.jar.GenericJarClassifierFactory;
+import org.uimafit.descriptor.ConfigurationParameter;
+import org.uimafit.factory.AnalysisEngineFactory;
+import org.uimafit.util.JCasUtil;
+
+public class EventCoreferenceAnnotator extends RelationExtractorAnnotator {
+
+ public static final int DEFAULT_SENT_DIST = 5;
+ public static final String PARAM_SENT_DIST = "SentenceDistance";
+ @ConfigurationParameter(name = PARAM_SENT_DIST, mandatory = false, description = "Number of sentences allowed between coreferent mentions")
+ private int maxSentDist = DEFAULT_SENT_DIST;
+
+ public static final double DEFAULT_PAR_SIM = 0.5;
+ public static final String PARAM_PAR_SIM = "PararaphSimilarity";
+ @ConfigurationParameter(name = PARAM_PAR_SIM, mandatory = false, description = "Similarity required to pair paragraphs for coreference")
+ private double simThreshold = DEFAULT_PAR_SIM;
+
+ public static AnalysisEngineDescription createDataWriterDescription(
+ Class<? extends DataWriter<String>> dataWriterClass,
+ File outputDirectory,
+ float downsamplingRate) throws ResourceInitializationException {
+ return AnalysisEngineFactory.createPrimitiveDescription(
+ EventCoreferenceAnnotator.class,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ true,
+ RelationExtractorAnnotator.PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE,
+ downsamplingRate,
+ DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+ dataWriterClass,
+ DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+ outputDirectory);
+ }
+
+ public static AnalysisEngineDescription createAnnotatorDescription(File modelDirectory)
+ throws ResourceInitializationException {
+ return AnalysisEngineFactory.createPrimitiveDescription(
+ EventCoreferenceAnnotator.class,
+ CleartkAnnotator.PARAM_IS_TRAINING,
+ false,
+ GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+ new File(modelDirectory, "model.jar"));
+ }
+
+ @Override
+ protected List<RelationFeaturesExtractor> getFeatureExtractors() {
+ List<RelationFeaturesExtractor> featureList = new ArrayList<>();
+
+ featureList.addAll(super.getFeatureExtractors());
+
+ featureList.add(new DistanceFeatureExtractor());
+ featureList.add(new StringMatchingFeatureExtractor());
+ featureList.add(new TokenFeatureExtractor());
+ featureList.add(new UMLSFeatureExtractor());
+ try{
+ featureList.add(new DistSemFeatureExtractor());
+ }catch(IOException e){
+ e.printStackTrace();
+ }
+ return featureList;
+ }
+
+
+ @Override
+ protected List<IdentifiedAnnotationPair> getCandidateRelationArgumentPairs(
+ JCas jcas, Annotation docAnnotation) {
+ List<Markable> markables = new ArrayList<>(JCasUtil.select(jcas, Markable.class));
+// List<Markable> markables = JCasUtil.selectCovered(Markable.class, docAnnotation);
+ List<IdentifiedAnnotationPair> pairs = new ArrayList<>();
+
+// // CODE FOR SENTENCE-DISTANCE-LIMITED PAIR MATCHING
+// for(int i = 1; i < markables.size(); i++){
+// for(int j = i-1; j >= 0; j--){
+// IdentifiedAnnotation ante = markables.get(j);
+// IdentifiedAnnotation ana = markables.get(i);
+// int sentdist = sentDist(jcas, ante, ana);
+// if(sentdist > maxSentDist) break;
+// pairs.add(new IdentifiedAnnotationPair(ante, ana));
+// }
+// }
+
+ FSArray parVecs = JCasUtil.selectSingle(jcas, FSArray.class);
+
+ // CODE FOR PARAGRAPH-BASED MATCHING
+ List<Paragraph> pars = new ArrayList<>(JCasUtil.select(jcas, Paragraph.class));
+ double[][] sims = new double[pars.size()][pars.size()];
+ for(int i = 0; i < sims.length; i++){
+ Arrays.fill(sims[i], 0.0);
+ }
+
+ for(int i = 0; i < pars.size(); i++){
+ // get all pairs within this paragraph
+ List<Markable> curParMarkables = JCasUtil.selectCovered(Markable.class, pars.get(i));
+ for(int anaId = 1; anaId < curParMarkables.size(); anaId++){
+ for(int anteId = anaId-1; anteId >= 0; anteId--){
+ Markable ana = curParMarkables.get(anaId);
+ Markable ante = curParMarkables.get(anteId);
+ int sentdist = sentDist(jcas, ante, ana);
+ if(sentdist > maxSentDist) break;
+ pairs.add(new IdentifiedAnnotationPair(ante, ana));
+ }
+ }
+
+ // now get all pairs between markables in this paragraph and others
+ FloatArray parVec = (FloatArray) parVecs.get(i);
+ for(int j = i-1; j >= 0; j--){
+ if(sims[i][j] == 0.0){
+ // compute the sim explicitly
+ FloatArray prevParVec = (FloatArray) parVecs.get(j);
+ sims[i][j] = calculateSimilarity(parVec, prevParVec);
+ }
+
+ if(sims[i][j] > simThreshold){
+ // pair up all markables in each paragraph
+ List<Markable> prevParMarkables = JCasUtil.selectCovered(Markable.class, pars.get(j));
+ for(int anaId = 0; anaId < curParMarkables.size(); anaId++){
+ for(int anteId = prevParMarkables.size()-1; anteId >= 0; anteId--){
+ Markable ana = curParMarkables.get(anaId);
+ Markable ante = prevParMarkables.get(anteId);
+ int sentdist = sentDist(jcas, ante, ana);
+ if(sentdist > maxSentDist) break;
+ pairs.add(new IdentifiedAnnotationPair(ante, ana));
+ }
+ }
+ }
+ }
+ }
+ return pairs;
+ }
+
+ @Override
+ protected Class<? extends Annotation> getCoveringClass() {
+ return DocumentAnnotation.class;
+ }
+
+ @Override
+ protected Class<? extends BinaryTextRelation> getRelationClass() {
+ return CoreferenceRelation.class;
+ }
+
+ protected HashSet<IdentifiedAnnotation> foundAnaphors = new HashSet<>();
+
+ @Override
+ protected void createRelation(
+ JCas jCas,
+ IdentifiedAnnotation ante,
+ IdentifiedAnnotation ana,
+ String predictedCategory) {
+ // check if its already been linked
+ if(!foundAnaphors.contains(ana)){
+ // add the relation to the CAS
+ RelationArgument relArg1 = new RelationArgument(jCas);
+ relArg1.setArgument(ante);
+ relArg1.setRole("Antecedent");
+ relArg1.addToIndexes();
+ RelationArgument relArg2 = new RelationArgument(jCas);
+ relArg2.setArgument(ana);
+ relArg2.setRole("Anaphor");
+ relArg2.addToIndexes();
+ CoreferenceRelation relation = new CoreferenceRelation(jCas);
+ relation.setArg1(relArg1);
+ relation.setArg2(relArg2);
+ relation.setCategory(predictedCategory);
+ relation.addToIndexes();
+ foundAnaphors.add(ana);
+ }
+ }
+
+ private static int sentDist(JCas jcas, IdentifiedAnnotation arg1,
+ IdentifiedAnnotation arg2) {
+ Collection<Sentence> sents = JCasUtil.selectCovered(jcas, Sentence.class, arg1.getBegin(), arg2.getEnd());
+ return sents.size();
+ }
+
+ private static double calculateSimilarity(FloatArray f1, FloatArray f2){
+ double sim = 0.0f;
+ double f1len = 0.0;
+ double f2len = 0.0;
+
+ for(int i = 0; i < f1.size(); i++){
+ sim += (f1.get(i) * f2.get(i));
+ f1len += (f1.get(i) * f1.get(i));
+ f2len += (f2.get(i) * f2.get(i));
+ }
+ f1len = Math.sqrt(f1len);
+ f2len = Math.sqrt(f2len);
+ sim = sim / (f1len * f2len);
+
+ return sim;
+ }
+}
Added: ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfEventCoreference.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfEventCoreference.java?rev=1611018&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfEventCoreference.java (added)
+++ ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfEventCoreference.java Wed Jul 16 14:18:42 2014
@@ -0,0 +1,505 @@
+package org.apache.ctakes.temporal.eval;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+
+import org.apache.ctakes.core.resource.FileLocator;
+import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
+import org.apache.ctakes.coreference.ae.CoreferenceChainAnnotator;
+import org.apache.ctakes.relationextractor.eval.RelationExtractorEvaluation.HashableArguments;
+import org.apache.ctakes.temporal.ae.EventCoreferenceAnnotator;
+import org.apache.ctakes.temporal.eval.EvaluationOfEventTimeRelations.ParameterSettings;
+import org.apache.ctakes.typesystem.type.relation.BinaryTextRelation;
+import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
+import org.apache.ctakes.typesystem.type.relation.CoreferenceRelation;
+import org.apache.ctakes.typesystem.type.syntax.BaseToken;
+import org.apache.ctakes.typesystem.type.syntax.NewlineToken;
+import org.apache.ctakes.typesystem.type.syntax.WordToken;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.ctakes.typesystem.type.textspan.Paragraph;
+import org.apache.ctakes.utils.distsem.WordEmbeddings;
+import org.apache.ctakes.utils.distsem.WordVector;
+import org.apache.ctakes.utils.distsem.WordVectorReader;
+import org.apache.log4j.Logger;
+import org.apache.uima.UimaContext;
+import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
+import org.apache.uima.cas.CAS;
+import org.apache.uima.cas.CASException;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.cas.FSArray;
+import org.apache.uima.jcas.cas.FSList;
+import org.apache.uima.jcas.cas.FloatArray;
+import org.apache.uima.jcas.cas.NonEmptyFSList;
+import org.apache.uima.jcas.tcas.Annotation;
+import org.apache.uima.resource.ResourceInitializationException;
+import org.apache.uima.util.FileUtils;
+import org.cleartk.classifier.jar.JarClassifierBuilder;
+import org.cleartk.classifier.liblinear.LIBLINEARStringOutcomeDataWriter;
+import org.cleartk.classifier.tksvmlight.model.CompositeKernel.ComboOperator;
+import org.cleartk.eval.AnnotationStatistics;
+import org.cleartk.util.ViewURIUtil;
+import org.uimafit.component.JCasAnnotator_ImplBase;
+import org.uimafit.descriptor.ConfigurationParameter;
+import org.uimafit.factory.AggregateBuilder;
+import org.uimafit.factory.AnalysisEngineFactory;
+import org.uimafit.pipeline.JCasIterable;
+import org.uimafit.pipeline.SimplePipeline;
+import org.uimafit.util.JCasUtil;
+
+import com.google.common.base.Function;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Sets;
+import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
+
+public class EvaluationOfEventCoreference extends EvaluationOfTemporalRelations_ImplBase {
+
+ static interface CoreferenceOptions extends TempRelOptions{
+ @Option
+ public String getOutputDirectory();
+
+ @Option
+ public boolean getUseTmp();
+
+ }
+
+ public static float COREF_DOWNSAMPLE = 0.5f;
+ protected static ParameterSettings allParams = new ParameterSettings(DEFAULT_BOTH_DIRECTIONS, COREF_DOWNSAMPLE, "tk",
+ 1.0, 1.0, "linear", ComboOperator.SUM, 0.1, 0.5); // (0.3, 0.4 for tklibsvm)
+
+ public static void main(String[] args) throws Exception {
+ CoreferenceOptions options = CliFactory.parseArguments(CoreferenceOptions.class, args);
+ List<Integer> trainItems = null;
+ List<Integer> devItems = null;
+ List<Integer> testItems = null;
+
+ List<Integer> patientSets = options.getPatients().getList();
+ trainItems = THYMEData.getTrainPatientSets(patientSets);
+ devItems = THYMEData.getDevPatientSets(patientSets);
+ testItems = THYMEData.getTestPatientSets(patientSets);
+ ParameterSettings params = allParams;
+ File workingDir = new File("target/eval/temporal-relations/coreference");
+ if(!workingDir.exists()) workingDir.mkdirs();
+ if(options.getUseTmp()){
+ File tempModelDir = File.createTempFile("temporal", null, workingDir);
+ tempModelDir.delete();
+ tempModelDir.mkdir();
+ workingDir = tempModelDir;
+ }
+ EvaluationOfEventCoreference eval = new EvaluationOfEventCoreference(
+ workingDir,
+ options.getRawTextDirectory(),
+ options.getXMLDirectory(),
+ options.getXMLFormat(),
+ options.getXMIDirectory(),
+ options.getTreebankDirectory(),
+ options.getCoreferenceDirectory(),
+ options.getPrintErrors(),
+ options.getPrintFormattedRelations(),
+ params,
+ options.getKernelParams(),
+ options.getOutputDirectory());
+
+ eval.prepareXMIsFor(patientSets);
+ List<Integer> training = trainItems;
+ List<Integer> testing = null;
+ if(options.getTest()){
+ training.addAll(devItems);
+ testing = testItems;
+ }else{
+ testing = devItems;
+ }
+ params.stats = eval.trainAndTest(training, testing);//training);//
+ // System.err.println(options.getKernelParams() == null ? params : options.getKernelParams());
+ System.err.println(params.stats);
+
+ if(options.getUseTmp()){
+ FileUtils.deleteRecursive(workingDir);
+ }
+ }
+
+ private String outputDirectory;
+
+ public EvaluationOfEventCoreference(File baseDirectory,
+ File rawTextDirectory, File xmlDirectory,
+ org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat,
+ File xmiDirectory, File treebankDirectory, File coreferenceDirectory, boolean printErrors,
+ boolean printRelations, ParameterSettings params, String cmdParams, String outputDirectory) {
+ super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory,
+ treebankDirectory, coreferenceDirectory, printErrors, printRelations, params);
+ this.outputDirectory = outputDirectory;
+ this.kernelParams = cmdParams == null ? null : cmdParams.split(" ");
+ }
+
+ @Override
+ protected void train(CollectionReader collectionReader, File directory)
+ throws Exception {
+ AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(DocumentIDPrinter.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(ParagraphAnnotator.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(ParagraphVectorAnnotator.class));
+ aggregateBuilder.add(CopyFromGold.getDescription(Markable.class, CoreferenceRelation.class, CollectionTextRelation.class));
+ aggregateBuilder.add(EventCoreferenceAnnotator.createDataWriterDescription(
+// TKSVMlightStringOutcomeDataWriter.class,
+ LIBLINEARStringOutcomeDataWriter.class,
+ directory,
+ params.probabilityOfKeepingANegativeExample
+ ));
+ // create gold chains for writing out which we can then use for our scoring tool
+// aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(CoreferenceChainScoringOutput.class,
+// CoreferenceChainScoringOutput.PARAM_OUTPUT_DIR,
+// this.outputDirectory + "train"));
+ SimplePipeline.runPipeline(collectionReader, aggregateBuilder.createAggregate());
+ String[] optArray;
+
+ if(this.kernelParams == null){
+ ArrayList<String> svmOptions = new ArrayList<>();
+ svmOptions.add("-c"); svmOptions.add(""+params.svmCost); // svm cost
+ svmOptions.add("-t"); svmOptions.add(""+params.svmKernelIndex); // kernel index
+ svmOptions.add("-d"); svmOptions.add("3"); // degree parameter for polynomial
+ svmOptions.add("-g"); svmOptions.add(""+params.svmGamma);
+ if(params.svmKernelIndex==ParameterSettings.SVM_KERNELS.indexOf("tk")){
+ svmOptions.add("-S"); svmOptions.add(""+params.secondKernelIndex); // second kernel index (similar to -t) for composite kernel
+ String comboFlag = (params.comboOperator == ComboOperator.SUM ? "+" : params.comboOperator == ComboOperator.PRODUCT ? "*" : params.comboOperator == ComboOperator.TREE_ONLY ? "T" : "V");
+ svmOptions.add("-C"); svmOptions.add(comboFlag);
+ svmOptions.add("-L"); svmOptions.add(""+params.lambda);
+ svmOptions.add("-T"); svmOptions.add(""+params.tkWeight);
+ svmOptions.add("-N"); svmOptions.add("3"); // normalize trees and features
+ }
+ optArray = svmOptions.toArray(new String[]{});
+ }else{
+ optArray = this.kernelParams;
+ for(int i = 0; i < optArray.length; i+=2){
+ optArray[i] = "-" + optArray[i];
+ }
+ }
+ JarClassifierBuilder.trainAndPackage(directory, optArray);
+ }
+
+ @Override
+ protected AnnotationStatistics<String> test(
+ CollectionReader collectionReader, File directory) throws Exception {
+ AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(DocumentIDPrinter.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(ParagraphAnnotator.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(ParagraphVectorAnnotator.class));
+ aggregateBuilder.add(CopyFromGold.getDescription(Markable.class));
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(CoreferenceChainScoringOutput.class,
+ CoreferenceChainScoringOutput.PARAM_OUTPUT_FILENAME,
+ this.outputDirectory + "gold.chains",
+ CoreferenceChainScoringOutput.PARAM_USE_GOLD_CHAINS,
+ true));
+ aggregateBuilder.add(EventCoreferenceAnnotator.createAnnotatorDescription(directory));
+ aggregateBuilder.add(CoreferenceChainAnnotator.createAnnotatorDescription());
+ aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(CoreferenceChainScoringOutput.class,
+ CoreferenceChainScoringOutput.PARAM_OUTPUT_FILENAME,
+ this.outputDirectory + "system.chains"));
+
+ Function<CoreferenceRelation, ?> getSpan = new Function<CoreferenceRelation, HashableArguments>() {
+ public HashableArguments apply(CoreferenceRelation relation) {
+ return new HashableArguments(relation);
+ }
+ };
+ Function<CoreferenceRelation, String> getOutcome = new Function<CoreferenceRelation,String>() {
+ public String apply(CoreferenceRelation relation){
+ return "Coreference";
+ }
+ };
+ AnnotationStatistics<String> stats = new AnnotationStatistics<>();
+
+ JCasIterable jcasIter =new JCasIterable(collectionReader, aggregateBuilder.createAggregate());
+ JCas jCas = null;
+ while(jcasIter.hasNext()) {
+ jCas = jcasIter.next();
+ JCas goldView = jCas.getView(GOLD_VIEW_NAME);
+ JCas systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA);
+ Collection<CoreferenceRelation> goldRelations = JCasUtil.select(
+ goldView,
+ CoreferenceRelation.class);
+ Collection<CoreferenceRelation> systemRelations = JCasUtil.select(
+ systemView,
+ CoreferenceRelation.class);
+ stats.add(goldRelations, systemRelations, getSpan, getOutcome);
+ if(this.printErrors){
+ Map<HashableArguments, BinaryTextRelation> goldMap = Maps.newHashMap();
+ for (BinaryTextRelation relation : goldRelations) {
+ goldMap.put(new HashableArguments(relation), relation);
+ }
+ Map<HashableArguments, BinaryTextRelation> systemMap = Maps.newHashMap();
+ for (BinaryTextRelation relation : systemRelations) {
+ systemMap.put(new HashableArguments(relation), relation);
+ }
+ Set<HashableArguments> all = Sets.union(goldMap.keySet(), systemMap.keySet());
+ List<HashableArguments> sorted = Lists.newArrayList(all);
+ Collections.sort(sorted);
+ for (HashableArguments key : sorted) {
+ BinaryTextRelation goldRelation = goldMap.get(key);
+ BinaryTextRelation systemRelation = systemMap.get(key);
+ if (goldRelation == null) {
+ System.out.println("System added: " + formatRelation(systemRelation));
+ } else if (systemRelation == null) {
+ System.out.println("System dropped: " + formatRelation(goldRelation));
+ } else if (!systemRelation.getCategory().equals(goldRelation.getCategory())) {
+ String label = systemRelation.getCategory();
+ System.out.printf("System labeled %s for %s\n", label, formatRelation(goldRelation));
+ } else{
+ System.out.println("Nailed it! " + formatRelation(systemRelation));
+ }
+ }
+ }
+ }
+
+ return stats;
+ }
+
+ public static class CoreferenceChainScoringOutput extends JCasAnnotator_ImplBase {
+ public static final String PARAM_OUTPUT_FILENAME = "OutputDirectory";
+ @ConfigurationParameter(
+ name = PARAM_OUTPUT_FILENAME,
+ mandatory = true,
+ description = "Directory to write output"
+ )
+ private String outputFilename;
+ private PrintWriter out = null;
+
+ public static final String PARAM_USE_GOLD_CHAINS = "UseGoldChains";
+ @ConfigurationParameter(
+ name = PARAM_USE_GOLD_CHAINS,
+ mandatory = false,
+ description = "Whether to use gold chains for writing output"
+ )
+ private boolean useGoldChains = false;
+
+ @Override
+ public void initialize(final UimaContext context) throws ResourceInitializationException{
+ super.initialize(context);
+
+ try {
+ out = new PrintWriter(outputFilename);
+ } catch (FileNotFoundException e) {
+ e.printStackTrace();
+ throw new ResourceInitializationException(e);
+ }
+ }
+
+ @Override
+ public void process(JCas jCas) throws AnalysisEngineProcessException {
+ File filename = new File(ViewURIUtil.getURI(jCas));
+ JCas chainsCas = null;
+ try {
+ chainsCas = useGoldChains? jCas.getView(GOLD_VIEW_NAME) : jCas;
+ } catch (CASException e) {
+ e.printStackTrace();
+ throw new AnalysisEngineProcessException(e);
+ }
+ int chainNum = 1;
+ HashMap<Annotation, Integer> ent2chain = new HashMap<>();
+ if(useGoldChains) System.out.println("Gold chains:");
+ else System.out.println("System chains:");
+ for(CollectionTextRelation chain : JCasUtil.select(chainsCas, CollectionTextRelation.class)){
+ FSList members = chain.getMembers();
+ while(members instanceof NonEmptyFSList){
+ Annotation mention = (Annotation) ((NonEmptyFSList) members).getHead();
+ ent2chain.put(mention, chainNum);
+ members = ((NonEmptyFSList)members).getTail();
+ System.out.print("Mention: " + mention.getCoveredText());
+ System.out.print(" (" + mention.getBegin() + ", " + mention.getEnd() + ")");
+ System.out.print(" -----> ");
+ }
+ System.out.println();
+ chainNum++;
+ }
+
+ out.println("#begin document " + filename.getPath());
+ List<BaseToken> tokens = new ArrayList<>(JCasUtil.select(jCas, BaseToken.class));
+ Stack<Integer> endStack = new Stack<>();
+ for(int i = 0; i < tokens.size(); i++){
+ BaseToken token = tokens.get(i);
+ List<Markable> markables = new ArrayList<>(JCasUtil.selectCovering(chainsCas, Markable.class, token.getBegin(), token.getEnd()));
+ List<Integer> startMention = new ArrayList<>();
+ Multiset<Integer> endMention = HashMultiset.create();
+ List<Integer> wholeMention = new ArrayList<>();
+
+ for(Annotation markable : markables){
+ if(ent2chain.containsKey(markable)){
+ if(markable.getBegin() == token.getBegin()){
+ if(markable.getEnd() == token.getEnd()){
+ wholeMention.add(ent2chain.get(markable));
+ }else{
+ startMention.add(ent2chain.get(markable));
+ }
+ }else if(markable.getEnd() <= token.getEnd()){
+ if(endMention.contains(ent2chain.get(markable))){
+ System.err.println("There is a duplicate element -- should be handled by multiset");
+ }
+ if(markable.getEnd() < token.getEnd()){
+ System.err.println("There is a markable that ends in the middle of a token!");
+ }
+ endMention.add(ent2chain.get(markable));
+ }
+ }
+ }
+ out.print(i+1);
+ out.print('\t');
+ StringBuffer buff = new StringBuffer();
+ while(endStack.size() > 0 && endMention.contains(endStack.peek())){
+ int ind = endStack.pop();
+ buff.append(ind);
+ buff.append(')');
+ buff.append('|');
+ endMention.remove(ind);
+ }
+ for(int ind : wholeMention){
+ buff.append('(');
+ buff.append(ind);
+ buff.append(')');
+ buff.append('|');
+ }
+ for(int ind : startMention){
+ buff.append('(');
+ buff.append(ind);
+ buff.append('|');
+ endStack.push(ind);
+ }
+// for(int ind : endMention){
+// buff.append(ind);
+// buff.append(')');
+// buff.append('|');
+// }
+ if(buff.length() > 0){
+ out.println(buff.substring(0, buff.length()-1));
+ }else{
+ out.println("_");
+ }
+ }
+ out.println("#end document " + filename.getPath());
+ out.println();
+ }
+ }
+
+ public static class AnnotationComparator implements Comparator<Annotation> {
+
+ @Override
+ public int compare(Annotation o1, Annotation o2) {
+ if(o1.getBegin() < o2.getBegin()){
+ return -1;
+ }else if(o1.getBegin() == o2.getBegin() && o1.getEnd() < o2.getEnd()){
+ return -1;
+ }else if(o1.getBegin() == o2.getBegin() && o1.getEnd() > o2.getEnd()){
+ return 1;
+ }else if(o2.getBegin() < o1.getBegin()){
+ return 1;
+ }else{
+ return 0;
+ }
+ }
+ }
+ public static class DocumentIDPrinter extends JCasAnnotator_ImplBase {
+ static Logger logger = Logger.getLogger(DocumentIDPrinter.class);
+ @Override
+ public void process(JCas jCas) throws AnalysisEngineProcessException {
+ String docId = DocumentIDAnnotationUtil.getDocumentID(jCas);
+ if(docId == null){
+ docId = new File(ViewURIUtil.getURI(jCas)).getName();
+ }
+ logger.info(String.format("Processing %s\n", docId));
+ }
+
+ }
+
+ public static class ParagraphAnnotator extends JCasAnnotator_ImplBase {
+
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+ List<BaseToken> tokens = new ArrayList<>(JCasUtil.select(jcas, BaseToken.class));
+ BaseToken lastToken = null;
+ int parStart = 0;
+
+ for(int i = 0; i < tokens.size(); i++){
+ BaseToken token = tokens.get(i);
+ if(parStart == i && token instanceof NewlineToken){
+ // we've just created a pargraph ending but there were multiple newlines -- don't want to start the
+ // new paragraph until we are past the newlines -- increment the parStart index and move forward
+ parStart++;
+ }else if(lastToken != null && token instanceof NewlineToken){
+ Paragraph par = new Paragraph(jcas, tokens.get(parStart).getBegin(), lastToken.getEnd());
+ par.addToIndexes();
+ parStart = i+1;
+ }
+ lastToken = token;
+ }
+
+ }
+
+ }
+
+ public static class ParagraphVectorAnnotator extends JCasAnnotator_ImplBase {
+ WordEmbeddings words = null;
+
+ @Override
+ public void initialize(final UimaContext context) throws ResourceInitializationException{
+ try {
+ words = WordVectorReader.getEmbeddings(FileLocator.getAsStream("org/apache/ctakes/coreference/distsem/mimic_vectors.txt"));
+ } catch (IOException e) {
+ e.printStackTrace();
+ throw new ResourceInitializationException(e);
+ }
+ }
+
+ @Override
+ public void process(JCas jcas) throws AnalysisEngineProcessException {
+ List<Paragraph> pars = new ArrayList<>(JCasUtil.select(jcas, Paragraph.class));
+ FSArray parVecs = new FSArray(jcas, pars.size());
+ for(int parNum = 0; parNum < pars.size(); parNum++){
+ Paragraph par = pars.get(parNum);
+ float[] parVec = new float[words.getDimensionality()];
+
+ List<BaseToken> tokens = JCasUtil.selectCovered(BaseToken.class, par);
+ for(int i = 0; i < tokens.size(); i++){
+ BaseToken token = tokens.get(i);
+ if(token instanceof WordToken){
+ String word = token.getCoveredText().toLowerCase();
+ if(words.containsKey(word)){
+ WordVector wv = words.getVector(word);
+ for(int j = 0; j < parVec.length; j++){
+ parVec[j] += wv.getValue(j);
+ }
+ }
+ }
+ }
+ normalize(parVec);
+ FloatArray vec = new FloatArray(jcas, words.getDimensionality());
+ vec.copyFromArray(parVec, 0, 0, parVec.length);
+ vec.addToIndexes();
+ parVecs.set(parNum, vec);
+ }
+ parVecs.addToIndexes();
+ }
+
+ private static final void normalize(float[] vec) {
+ double sum = 0.0;
+ for(int i = 0; i < vec.length; i++){
+ sum += (vec[i]*vec[i]);
+ }
+ sum = Math.sqrt(sum);
+ for(int i = 0; i < vec.length; i++){
+ vec[i] /= sum;
+ }
+ }
+ }
+}