You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jz...@apache.org on 2022/08/06 12:47:22 UTC

[opennlp] branch master updated: OPENNLP-1375: Adding option for GPU inference. (#421)

This is an automated email from the ASF dual-hosted git repository.

jzemerick pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git


The following commit(s) were added to refs/heads/master by this push:
     new 8455ccbc OPENNLP-1375: Adding option for GPU inference. (#421)
8455ccbc is described below

commit 8455ccbccb1fb4121edd0eef6d5b59dba2c8a065
Author: Jeff Zemerick <13...@users.noreply.github.com>
AuthorDate: Sat Aug 6 08:47:17 2022 -0400

    OPENNLP-1375: Adding option for GPU inference. (#421)
    
    * OPENNLP-1375: Adding option for GPU inference.
---
 opennlp-dl/pom.xml                                 |  3 +-
 opennlp-dl/src/main/java/opennlp/dl/Inference.java |  8 ++++-
 .../src/main/java/opennlp/dl/InferenceOptions.java | 36 ++++++++++++++------
 .../dl/doccat/DocumentCategorizerDLEval.java       | 39 ++++++++++++++++++++--
 4 files changed, 72 insertions(+), 14 deletions(-)

diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index c3c4a81a..db9159a6 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -38,7 +38,8 @@
     </dependency>
     <dependency>
       <groupId>com.microsoft.onnxruntime</groupId>
-      <artifactId>onnxruntime</artifactId>
+      <!-- This dependency supports CPU and GPU -->
+      <artifactId>onnxruntime_gpu</artifactId>
       <version>${onnxruntime.version}</version>
     </dependency>
     <dependency>
diff --git a/opennlp-dl/src/main/java/opennlp/dl/Inference.java b/opennlp-dl/src/main/java/opennlp/dl/Inference.java
index 66ac9b99..03122f0a 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/Inference.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/Inference.java
@@ -62,7 +62,13 @@ public abstract class Inference {
       throws OrtException, IOException {
 
     this.env = OrtEnvironment.getEnvironment();
-    this.session = env.createSession(model.getPath(), new OrtSession.SessionOptions());
+
+    final OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
+    if (inferenceOptions.isGpu()) {
+      sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
+    }
+
+    this.session = env.createSession(model.getPath(), sessionOptions);
     this.vocabulary = loadVocab(vocab);
     this.tokenizer = new WordpieceTokenizer(vocabulary.keySet());
     this.inferenceOptions = inferenceOptions;
diff --git a/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java b/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
index 99d3c833..9241a206 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
@@ -19,25 +19,41 @@ package opennlp.dl;
 
 public class InferenceOptions {
 
-  private boolean includeAttentionMask;
-  private boolean includeTokenTypeIds;
+  private boolean includeAttentionMask = true;
+  private boolean includeTokenTypeIds = true;
+  private boolean gpu;
+  private int gpuDeviceId = 0;
 
-  public InferenceOptions() {
-    this.includeAttentionMask = true;
-    this.includeTokenTypeIds = true;
+  public boolean isIncludeAttentionMask() {
+    return includeAttentionMask;
   }
 
-  public InferenceOptions(boolean includeAttentionMask, boolean includeTokenTypeIds) {
+  public void setIncludeAttentionMask(boolean includeAttentionMask) {
     this.includeAttentionMask = includeAttentionMask;
+  }
+
+  public boolean isIncludeTokenTypeIds() {
+    return includeTokenTypeIds;
+  }
+
+  public void setIncludeTokenTypeIds(boolean includeTokenTypeIds) {
     this.includeTokenTypeIds = includeTokenTypeIds;
   }
 
-  public boolean isIncludeAttentionMask() {
-    return includeAttentionMask;
+  public boolean isGpu() {
+    return gpu;
   }
 
-  public boolean isIncludeTokenTypeIds() {
-    return includeTokenTypeIds;
+  public void setGpu(boolean gpu) {
+    this.gpu = gpu;
+  }
+
+  public int getGpuDeviceId() {
+    return gpuDeviceId;
+  }
+
+  public void setGpuDeviceId(int gpuDeviceId) {
+    this.gpuDeviceId = gpuDeviceId;
   }
 
 }
diff --git a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index a2d58471..577ef2c8 100644
--- a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++ b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -24,6 +24,7 @@ import java.util.Map;
 import java.util.Set;
 
 import org.junit.Assert;
+import org.junit.Ignore;
 import org.junit.Test;
 
 import opennlp.dl.AbstactDLTest;
@@ -60,6 +61,40 @@ public class DocumentCategorizerDLEval extends AbstactDLTest {
 
   }
 
+  @Ignore("This test will only run if a GPU device is present.")
+  @Test
+  public void categorizeWithGpu() throws Exception {
+
+    final File model = new File(getOpennlpDataDir(),
+        "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx");
+    final File vocab = new File(getOpennlpDataDir(),
+        "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab");
+
+    final InferenceOptions inferenceOptions = new InferenceOptions();
+    inferenceOptions.setGpu(true);
+    inferenceOptions.setGpuDeviceId(0);
+
+    final DocumentCategorizerDL documentCategorizerDL =
+        new DocumentCategorizerDL(model, vocab, getCategories(), inferenceOptions);
+
+    final double[] result = documentCategorizerDL.categorize(new String[]{"I am happy"});
+    System.out.println(Arrays.toString(result));
+
+    final double[] expected = new double[]
+        {0.007819971069693565,
+            0.006593209225684404,
+            0.04995147883892059,
+            0.3003573715686798,
+            0.6352779865264893};
+
+    Assert.assertTrue(Arrays.equals(expected, result));
+    Assert.assertEquals(5, result.length);
+
+    final String category = documentCategorizerDL.getBestCategory(result);
+    Assert.assertEquals("very good", category);
+
+  }
+
   @Test
   public void categorizeWithInferenceOptions() throws Exception {
 
@@ -68,8 +103,8 @@ public class DocumentCategorizerDLEval extends AbstactDLTest {
     final File vocab = new File(getOpennlpDataDir(),
         "onnx/doccat/lvwerra_distilbert-imdb.vocab");
 
-    final InferenceOptions inferenceOptions =
-        new InferenceOptions(true, false);
+    final InferenceOptions inferenceOptions = new InferenceOptions();
+    inferenceOptions.setIncludeTokenTypeIds(false);
 
     final Map<Integer, String> categories = new HashMap<>();
     categories.put(0, "negative");