You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by dl...@apache.org on 2015/09/25 20:29:05 UTC

svn commit: r1705338 - /ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java

Author: dligach
Date: Fri Sep 25 18:29:05 2015
New Revision: 1705338

URL: http://svn.apache.org/viewvc?rev=1705338&view=rev
Log:
added word between the arguments feature represented as the average of the word embeddings

Modified:
    ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java

Modified: ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java?rev=1705338&r1=1705337&r2=1705338&view=diff
==============================================================================
--- ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java (original)
+++ ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java Fri Sep 25 18:29:05 2015
@@ -19,12 +19,15 @@
 package org.apache.ctakes.relationextractor.ae.features;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
 import org.apache.ctakes.relationextractor.data.analysis.Utils;
+import org.apache.ctakes.typesystem.type.syntax.WordToken;
 import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
 import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
+import org.apache.uima.fit.util.JCasUtil;
 import org.apache.uima.jcas.JCas;
 import org.cleartk.ml.Feature;
 
@@ -48,7 +51,6 @@ public class EmbeddingFeatureExtractor i
 
     List<Feature> features = new ArrayList<>();
 
-    // get head words
     String arg1LastWord = Utils.getLastWord(jCas, arg1).toLowerCase();
     String arg2LastWord = Utils.getLastWord(jCas, arg2).toLowerCase();
 
@@ -64,10 +66,8 @@ public class EmbeddingFeatureExtractor i
     } else {
       arg2Vector = wordVectors.get("oov");
     }
-
-    double similarity = computeCosineSimilarity(arg1Vector, arg2Vector); 
-    features.add(new Feature("arg_cos_sim", similarity));
-
+    
+    // head word feataures
     for(int dim = 0; dim < numberOfDimensions; dim++) {
       String featureName = String.format("arg1_dim_%d", dim);
       features.add(new Feature(featureName, arg1Vector.get(dim)));
@@ -77,6 +77,32 @@ public class EmbeddingFeatureExtractor i
       features.add(new Feature(featureName, arg2Vector.get(dim)));
     }    
 
+    // head word similarity features
+    double similarity = computeCosineSimilarity(arg1Vector, arg2Vector); 
+    features.add(new Feature("arg_cos_sim", similarity));
+    
+    // words between argument features
+    List<WordToken> wordsBetweenArgs = JCasUtil.selectBetween(jCas, WordToken.class, arg1, arg2);
+    if(wordsBetweenArgs.size() < 1) {
+      return features;  
+    }
+    
+    List<Double> sum = new ArrayList<>(Collections.nCopies(numberOfDimensions, 0.0));
+    for(WordToken wordToken : wordsBetweenArgs) {
+      List<Double> wordVector;
+      if(wordVectors.containsKey(wordToken.getCoveredText().toLowerCase())) {
+        wordVector = wordVectors.get(wordToken.getCoveredText().toLowerCase());
+      } else {
+        wordVector = wordVectors.get("oov");
+      }
+      sum = addVectors(sum, wordVector);      
+    }
+
+    for(int dim = 0; dim < numberOfDimensions; dim++) {
+      String featureName = String.format("average_dim_%d", dim);
+      features.add(new Feature(featureName, sum.get(dim) / wordsBetweenArgs.size()));
+    }
+
     return features;
   }
 
@@ -97,4 +123,17 @@ public class EmbeddingFeatureExtractor i
 
     return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
   }
+  
+  /**
+   * Add two vectors. Return the sum vector.
+   */
+  public List<Double> addVectors(List<Double> vector1, List<Double> vector2) {
+    
+    List<Double> sum = new ArrayList<>();
+    for(int dim = 0; dim < numberOfDimensions; dim++) {
+      sum.add(vector1.get(dim) + vector2.get(dim));
+    }
+    
+    return sum;
+  }
 }