You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by tm...@apache.org on 2014/07/15 22:36:53 UTC

svn commit: r1610842 - /ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/data/analysis/CoreferenceLinkDistanceAnalyzer.java

Author: tmill
Date: Tue Jul 15 20:36:53 2014
New Revision: 1610842

URL: http://svn.apache.org/r1610842
Log:
Added class for analyzing coref links relative to paragraph similarity.

Added:
    ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/data/analysis/CoreferenceLinkDistanceAnalyzer.java

Added: ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/data/analysis/CoreferenceLinkDistanceAnalyzer.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/data/analysis/CoreferenceLinkDistanceAnalyzer.java?rev=1610842&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/data/analysis/CoreferenceLinkDistanceAnalyzer.java (added)
+++ ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/data/analysis/CoreferenceLinkDistanceAnalyzer.java Tue Jul 15 20:36:53 2014
@@ -0,0 +1,302 @@
+package org.apache.ctakes.temporal.data.analysis;
+
+import java.io.File;
+import java.io.FilenameFilter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.ctakes.core.resource.FileLocator;
+import org.apache.ctakes.temporal.eval.EvaluationOfEventCoreference.ParagraphAnnotator;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMIReader;
+import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
+import org.apache.ctakes.typesystem.type.syntax.BaseToken;
+import org.apache.ctakes.typesystem.type.syntax.NewlineToken;
+import org.apache.ctakes.typesystem.type.syntax.WordToken;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.ctakes.typesystem.type.textspan.Paragraph;
+import org.apache.ctakes.utils.distsem.WordEmbeddings;
+import org.apache.ctakes.utils.distsem.WordVector;
+import org.apache.ctakes.utils.distsem.WordVectorReader;
+import org.apache.uima.UIMAException;
+import org.apache.uima.analysis_engine.AnalysisEngine;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.cas.FSList;
+import org.apache.uima.jcas.cas.NonEmptyFSList;
+import org.apache.uima.resource.metadata.TypeSystemDescription;
+import org.cleartk.util.ViewURIUtil;
+import org.cleartk.util.ae.UriToDocumentTextAnnotator;
+import org.cleartk.util.cr.UriCollectionReader;
+import org.uimafit.factory.AggregateBuilder;
+import org.uimafit.factory.AnalysisEngineFactory;
+import org.uimafit.pipeline.JCasIterable;
+import org.uimafit.util.JCasUtil;
+
+import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
+
+public class CoreferenceLinkDistanceAnalyzer {
+  static interface Options {
+
+    @Option(
+        shortName = "i",
+        description = "specify the path to the directory containing the text files")
+    public File getInputDirectory();
+    
+    @Option(
+        shortName = "x",
+        description = "Specify the path to the directory containing the xmis")
+    public File getXMIDirectory();
+  }
+  
+  public static final String GOLD_VIEW_NAME = "GoldView";
+  
+  public static void main(String[] args) throws UIMAException, IOException {
+    Options options = CliFactory.parseArguments(Options.class, args);
+    CollectionReader reader = UriCollectionReader.getCollectionReaderFromFiles(getFiles(options.getInputDirectory(), options.getXMIDirectory()));
+    AggregateBuilder aggregateBuilder = new AggregateBuilder();
+    aggregateBuilder.add(UriToDocumentTextAnnotator.getDescription());
+    aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(
+        XMIReader.class,
+        XMIReader.PARAM_XMI_DIRECTORY,
+        options.getXMIDirectory()));
+    aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(ParagraphAnnotator.class));
+    
+    WordEmbeddings words = WordVectorReader.getEmbeddings(FileLocator.getAsStream("org/apache/ctakes/coreference/distsem/mimic_vectors.txt"));
+
+    double[] parVec = new double[words.getDimensionality()];
+    Arrays.fill(parVec, 0.0);
+    int numWords = 0;
+    double[] thresholds = {0.1, 0.25, 0.5, 0.75};
+    int[][] thresholdSavings = new int[thresholds.length][2];
+    double[] recalls = new double[thresholds.length];
+    int numDocs = 0;
+    
+    // compute paragraph vectors for every paragraph
+    AnalysisEngine ae = aggregateBuilder.createAggregate();
+    
+    for(JCas jcas : new JCasIterable(reader, ae)){
+      numDocs++;
+      // print out document name
+      System.out.println("######### Document id: " + ViewURIUtil.getURI(jcas).toString());
+      JCas goldView = jcas.getView(GOLD_VIEW_NAME);
+      
+      Map<Markable,Integer> markable2par = new HashMap<>();
+      List<double[]> vectors = new ArrayList<>();
+      
+      for(Paragraph par : JCasUtil.select(jcas, Paragraph.class)){
+        // map markables to paragraph numbers
+        Collection<Markable> markables = JCasUtil.selectCovered(goldView, Markable.class, par);
+        for(Markable markable : markables){
+          markable2par.put(markable, vectors.size());
+        }
+
+        // build embedding vector for this paragraph
+        List<BaseToken> tokens = JCasUtil.selectCovered(BaseToken.class, par);
+        for(int i = 0; i < tokens.size(); i++){
+          BaseToken token = tokens.get(i);
+          if(token instanceof WordToken){
+            String word = token.getCoveredText().toLowerCase();
+            if(words.containsKey(word)){
+              WordVector wv = words.getVector(word);
+              for(int j = 0; j < parVec.length; j++){
+                parVec[j] += wv.getValue(j);
+              }
+            }          
+          }
+        }
+        normalize(parVec);
+        vectors.add(parVec);
+        parVec = new double[words.getDimensionality()];
+        Arrays.fill(parVec, 0.0);        
+      }
+        
+        
+//      List<BaseToken> tokens = new ArrayList<>(JCasUtil.select(jcas, BaseToken.class));
+      
+//      BaseToken lastToken = null;
+//      int parStart = 0;
+//      for(int i = 0; i < tokens.size(); i++){
+//        BaseToken token = tokens.get(i);
+//        if(token instanceof WordToken){
+//          String word = token.getCoveredText().toLowerCase();
+//          if(words.containsKey(word)){
+//            numWords++;
+//            WordVector wv = words.getVector(word);
+//            for(int j = 0; j < parVec.length; j++){
+//              parVec[j] += wv.getValue(j);
+//            }
+//          }
+//        }else if(lastToken != null && lastToken instanceof NewlineToken && token instanceof NewlineToken){
+//          if(numWords > 0){
+//            int parEnd = token.getEnd();
+//            Collection<Markable> markables = JCasUtil.selectCovered(goldView, Markable.class, parStart, parEnd);
+//            for(Markable markable : markables){
+//              markable2par.put(markable, vectors.size());
+//            }
+//            Paragraph par = new Paragraph(jcas, parStart, parEnd);
+//            normalize(parVec);
+//            vectors.add(parVec);
+//            parVec = new double[words.getDimensionality()];
+//            Arrays.fill(parVec, 0.0);
+//            numWords = 0;
+//            parStart = parEnd;
+//          }
+//        }
+//        lastToken = token;
+//      }
+
+      double[][] sims = new double[vectors.size()][vectors.size()];
+      // compute similarities between every pair of vectors
+      for(int i = 0; i < vectors.size(); i++){
+        sims[i][i] = 1.0;
+        for(int j = i+1; j< vectors.size(); j++){
+          double sim = getSimilarity(vectors.get(i), vectors.get(j));
+          sims[i][j] = sim;
+          for(int ind = 0; ind < thresholds.length; ind++){
+            if(sim < thresholds[ind]){
+              thresholdSavings[ind][0]++;
+            }
+            thresholdSavings[ind][1]++;
+          }
+          System.out.printf("Similarity between paragraphs %d and %d = %f\n", i, j, sim);
+        }
+      }
+
+      // build markable chains in easier to access way
+      List<List<Integer>> parChains = new ArrayList<>();
+      for(CollectionTextRelation chain : JCasUtil.select(goldView, CollectionTextRelation.class)){
+        Set<Integer> pars = new HashSet<>();
+        
+        FSList list = chain.getMembers();
+        while(list instanceof NonEmptyFSList){
+          Markable member = (Markable) ((NonEmptyFSList) list).getHead();
+          if(markable2par.containsKey(member)){
+            pars.add(markable2par.get(member));
+          }else{
+            System.err.println("Markable not found in any paragraph: " + member.getCoveredText() + " [" + member.getBegin() + "," + member.getEnd() + "]");
+          }
+          list = ((NonEmptyFSList) list).getTail();
+        }
+        if(pars.size() > 1){
+          List<Integer> parList = new ArrayList<>(pars);
+          Collections.sort(parList);
+          parChains.add(parList);
+        }
+      }
+      
+      for(int i = 0; i < thresholds.length; i++){
+        double threshold = thresholds[i];
+        int tps = 0;
+        int fns = 0;
+        
+        // figure out our leakage rate:
+        for(List<Integer> chain : parChains){
+          // for any paragraph with an anaphor, look at all the earlier paragraphs
+          // with antecedents
+          for(int anaParInd = 1; anaParInd < chain.size(); anaParInd++){
+            int anteParInd = 0;
+            for(anteParInd = 0; anteParInd < anaParInd; anteParInd++){
+              int anaPar = chain.get(anaParInd);
+              int antePar = chain.get(anteParInd);
+              // if any of the previous paragraphs has an antecedent we are ok
+              if(sims[antePar][anaPar] > threshold){
+                tps++;
+                break;
+              }
+            }
+            // if we got to the exit condition of the for-loop we didn't
+            // have any matching paragraphs with high enough similarity
+            if(anteParInd == anaParInd){
+              fns++;
+            }
+          }
+//        for(int focusPar = vectors.size()-1; focusPar >= 0; focusPar--){
+//          for(int otherPar = 0; otherPar < focusPar; otherPar++){
+//            double sim = sims[otherPar][focusPar];
+//            for(List<Integer> chain : parChains){
+//              for(int ind = chain.size()-1; ind > 0; ind--){
+//                int anaPar = chain.get(ind);
+//                if(focusPar == anaPar){
+//                  // see if there are antecedents in any of the threshold-passing paragraphs
+//                  int prev;
+//                  for(prev = ind-1; prev >= 0; prev--){
+//                    int antePar = chain.get(prev);
+//                    if(sim > threshold && antePar == otherPar){
+//                      hits++;
+//                      break;
+//                    }
+//                  }
+//                  if(prev < 0){
+//                    misses++;
+//                  }
+//                }
+//              }
+//            }
+//          }
+        }
+        double recall = (double) tps / (tps + fns);
+        recalls[i] += recall;
+        System.out.printf("With threshold %f, recall is %f with %d hits and %d misses\n", threshold, recall, tps, fns);
+      }
+      
+      System.out.println("\n\n");
+    }
+    
+    for(int i = 0; i < thresholds.length; i++){
+      System.out.printf("Threshold %f has average recall %f\n", thresholds[i], recalls[i] / numDocs);
+      System.out.printf("Was able to ignore %d pairs out of %d possible pairs\n", thresholdSavings[i][0], thresholdSavings[i][1]);
+    }
+  }
+  
+  public static final void normalize(double[] vec){
+    double sum = 0.0;
+    for(int i = 0; i < vec.length; i++){
+      sum += (vec[i]*vec[i]);
+    }
+    sum = Math.sqrt(sum);
+    for(int i = 0; i < vec.length; i++){
+      vec[i] /= sum;
+    }
+  }
+  
+  private static final double getSimilarity(double[] v1, double[] v2){
+    assert v1.length == v2.length;
+    double sim = 0;
+    double v1norm=0, v2norm=0;
+    for(int i = 0; i < v1.length; i++){
+      sim += (v1[i] * v2[i]);
+      v1norm += (v1[i]*v1[i]);
+      v2norm += (v2[i]*v2[i]);
+    }
+    v1norm = Math.sqrt(v1norm);
+    v2norm = Math.sqrt(v2norm);
+    
+    sim = sim / (v1norm * v2norm);
+    return sim;
+  }
+  
+  public static Collection<File> getFiles(File textDir, File xmiDir){
+    Collection<File> files = new HashSet<>();
+    
+    File[] xmiFiles = xmiDir.listFiles(new FilenameFilter(){
+
+      public boolean accept(File dir, String name) {
+        return name.endsWith("xmi");
+      }});
+    
+    for(File xmiFile : xmiFiles){
+      String name = xmiFile.getName();
+      files.add(new File(textDir, name.substring(0, name.length()-4)));
+    }
+    return files;
+  }
+}