You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jz...@apache.org on 2022/08/06 12:47:22 UTC
[opennlp] branch master updated: OPENNLP-1375: Adding option for GPU inference. (#421)
This is an automated email from the ASF dual-hosted git repository.
jzemerick pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/master by this push:
new 8455ccbc OPENNLP-1375: Adding option for GPU inference. (#421)
8455ccbc is described below
commit 8455ccbccb1fb4121edd0eef6d5b59dba2c8a065
Author: Jeff Zemerick <13...@users.noreply.github.com>
AuthorDate: Sat Aug 6 08:47:17 2022 -0400
OPENNLP-1375: Adding option for GPU inference. (#421)
* OPENNLP-1375: Adding option for GPU inference.
---
opennlp-dl/pom.xml | 3 +-
opennlp-dl/src/main/java/opennlp/dl/Inference.java | 8 ++++-
.../src/main/java/opennlp/dl/InferenceOptions.java | 36 ++++++++++++++------
.../dl/doccat/DocumentCategorizerDLEval.java | 39 ++++++++++++++++++++--
4 files changed, 72 insertions(+), 14 deletions(-)
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index c3c4a81a..db9159a6 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -38,7 +38,8 @@
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
- <artifactId>onnxruntime</artifactId>
+ <!-- This dependency supports CPU and GPU -->
+ <artifactId>onnxruntime_gpu</artifactId>
<version>${onnxruntime.version}</version>
</dependency>
<dependency>
diff --git a/opennlp-dl/src/main/java/opennlp/dl/Inference.java b/opennlp-dl/src/main/java/opennlp/dl/Inference.java
index 66ac9b99..03122f0a 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/Inference.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/Inference.java
@@ -62,7 +62,13 @@ public abstract class Inference {
throws OrtException, IOException {
this.env = OrtEnvironment.getEnvironment();
- this.session = env.createSession(model.getPath(), new OrtSession.SessionOptions());
+
+ final OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
+ if (inferenceOptions.isGpu()) {
+ sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
+ }
+
+ this.session = env.createSession(model.getPath(), sessionOptions);
this.vocabulary = loadVocab(vocab);
this.tokenizer = new WordpieceTokenizer(vocabulary.keySet());
this.inferenceOptions = inferenceOptions;
diff --git a/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java b/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
index 99d3c833..9241a206 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
@@ -19,25 +19,41 @@ package opennlp.dl;
public class InferenceOptions {
- private boolean includeAttentionMask;
- private boolean includeTokenTypeIds;
+ private boolean includeAttentionMask = true;
+ private boolean includeTokenTypeIds = true;
+ private boolean gpu;
+ private int gpuDeviceId = 0;
- public InferenceOptions() {
- this.includeAttentionMask = true;
- this.includeTokenTypeIds = true;
+ public boolean isIncludeAttentionMask() {
+ return includeAttentionMask;
}
- public InferenceOptions(boolean includeAttentionMask, boolean includeTokenTypeIds) {
+ public void setIncludeAttentionMask(boolean includeAttentionMask) {
this.includeAttentionMask = includeAttentionMask;
+ }
+
+ public boolean isIncludeTokenTypeIds() {
+ return includeTokenTypeIds;
+ }
+
+ public void setIncludeTokenTypeIds(boolean includeTokenTypeIds) {
this.includeTokenTypeIds = includeTokenTypeIds;
}
- public boolean isIncludeAttentionMask() {
- return includeAttentionMask;
+ public boolean isGpu() {
+ return gpu;
}
- public boolean isIncludeTokenTypeIds() {
- return includeTokenTypeIds;
+ public void setGpu(boolean gpu) {
+ this.gpu = gpu;
+ }
+
+ public int getGpuDeviceId() {
+ return gpuDeviceId;
+ }
+
+ public void setGpuDeviceId(int gpuDeviceId) {
+ this.gpuDeviceId = gpuDeviceId;
}
}
diff --git a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index a2d58471..577ef2c8 100644
--- a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++ b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -24,6 +24,7 @@ import java.util.Map;
import java.util.Set;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import opennlp.dl.AbstactDLTest;
@@ -60,6 +61,40 @@ public class DocumentCategorizerDLEval extends AbstactDLTest {
}
+ @Ignore("This test will only run if a GPU device is present.")
+ @Test
+ public void categorizeWithGpu() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(),
+ "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx");
+ final File vocab = new File(getOpennlpDataDir(),
+ "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab");
+
+ final InferenceOptions inferenceOptions = new InferenceOptions();
+ inferenceOptions.setGpu(true);
+ inferenceOptions.setGpuDeviceId(0);
+
+ final DocumentCategorizerDL documentCategorizerDL =
+ new DocumentCategorizerDL(model, vocab, getCategories(), inferenceOptions);
+
+ final double[] result = documentCategorizerDL.categorize(new String[]{"I am happy"});
+ System.out.println(Arrays.toString(result));
+
+ final double[] expected = new double[]
+ {0.007819971069693565,
+ 0.006593209225684404,
+ 0.04995147883892059,
+ 0.3003573715686798,
+ 0.6352779865264893};
+
+ Assert.assertTrue(Arrays.equals(expected, result));
+ Assert.assertEquals(5, result.length);
+
+ final String category = documentCategorizerDL.getBestCategory(result);
+ Assert.assertEquals("very good", category);
+
+ }
+
@Test
public void categorizeWithInferenceOptions() throws Exception {
@@ -68,8 +103,8 @@ public class DocumentCategorizerDLEval extends AbstactDLTest {
final File vocab = new File(getOpennlpDataDir(),
"onnx/doccat/lvwerra_distilbert-imdb.vocab");
- final InferenceOptions inferenceOptions =
- new InferenceOptions(true, false);
+ final InferenceOptions inferenceOptions = new InferenceOptions();
+ inferenceOptions.setIncludeTokenTypeIds(false);
final Map<Integer, String> categories = new HashMap<>();
categories.put(0, "negative");