You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@opennlp.apache.org by GitBox <gi...@apache.org> on 2022/01/04 05:23:55 UTC

[GitHub] [opennlp] kinow commented on a change in pull request #400: OPENNLP-1351: Add ONNX model support for doccat and namefinder

kinow commented on a change in pull request #400:
URL: https://github.com/apache/opennlp/pull/400#discussion_r777801110



##########
File path: NOTICE
##########
@@ -1,12 +1,66 @@
 Apache OpenNLP
-Copyright 2017 The Apache Software Foundation
+Copyright 2021 The Apache Software Foundation
 
 This product includes software developed at
 The Apache Software Foundation (http://www.apache.org/).
 
+============================================================================
 
 The snowball stemmers in
 opennlp-tools/src/main/java/opennlp/tools/stemmer/snowball
 were developed by Martin Porter and Richard Boulton.
 The full snowball package is available from
 http://snowball.tartarus.org/
+
+============================================================================
+
+Wordpiece tokenizer
+https://github.com/robrua/easy-bert
+
+The MIT License (MIT)
+
+Copyright (c) 2019 Rob Rua
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+============================================================================
+
+ONNX Runtime
+
+MIT License
+
+Copyright (c) Microsoft Corporation
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

Review comment:
       Missing newline?

##########
File path: pom.xml
##########
@@ -17,16 +16,14 @@
    KIND, either express or implied.  See the License for the
    specific language governing permissions and limitations
    under the License.
--->
-
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+--><project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">

Review comment:
       I think the `<project>` tag is now in the same line as the end-comment token. Maybe movie it to the next line?

##########
File path: opennlp-dl/README.md
##########
@@ -0,0 +1,43 @@
+# OpenNLP DL
+
+This module provides OpenNLP interface implementations for ONNX models using the `onnxruntime` dependency.
+
+**Important**: This does not provide the ability to train models. Model training is done outside of OpenNLP. This code provides the ability to use ONNX models from OpenNLP.

Review comment:
       :+1: 

##########
File path: opennlp-tools/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.tokenize;
+
+import opennlp.tools.util.Span;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * A WordPiece tokenizer.
+ *
+ * Adapted from https://github.com/robrua/easy-bert under the MIT license.
+ */
+public class WordpieceTokenizer implements Tokenizer {
+
+    private static final String CLASSIFICATION_TOKEN = "[CLS]";
+    private static final String SEPARATOR_TOKEN = "[SEP]";
+    private static final String UNKNOWN_TOKEN = "[UNK]";
+
+    private Set<String> vocabulary;
+    private int maxTokenLength = 50;
+
+    public WordpieceTokenizer(Set<String> vocabulary) {
+        this.vocabulary = vocabulary;

Review comment:
       Should we prevent modification, making `this.vocabulary` a unmodifiable/immutable copy of the given `vocabulary`? :point_down: :point_up: 
   
   Then `this.vocabulary` can be final too.

##########
File path: opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
##########
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.namefinder;
+
+import opennlp.tools.namefind.TokenNameFinder;
+import opennlp.tools.util.Span;
+
+import java.io.File;
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An implementation of {@link TokenNameFinder} that uses ONNX models.
+ */
+public class NameFinderDL implements TokenNameFinder {
+
+    public static final String I_PER = "I-PER";
+    public static final String B_PER = "B-PER";
+
+    private final TokenNameFinderInference inference;
+    private final Map<Integer, String> ids2Labels;
+
+    /**
+     * Creates a new NameFinderDL for entity recognition using ONNX models.
+     * @param model The ONNX model file.
+     * @param vocab The model's vocabulary file.
+     * @param doLowerCase Whether or not to lowercase the text prior to inference.
+     * @param ids2Labels A map of values and their assigned labels used to train the model.
+     * @throws Exception Thrown if the models cannot be loaded.
+     */
+    public NameFinderDL(File model, File vocab, boolean doLowerCase, Map<Integer, String> ids2Labels) throws Exception {
+
+        this.ids2Labels = ids2Labels;
+        this.inference = new TokenNameFinderInference(model, vocab, doLowerCase);
+
+    }
+
+    @Override
+    public Span[] find(String[] tokens) {
+
+        final List<Span> spans = new LinkedList<>();
+        final String text = String.join(" ", tokens);
+
+        try {
+
+            final double[][] v = inference.infer(text);
+
+            // Find consecutive B-PER and I-PER labels and combine the spans where necessary.
+            // There are also B-LOC and I-LOC tags for locations that might be useful at some point.
+
+            // Keep track of where the last span was so when there are multiple/duplicate
+            // spans we can get the next one instead of the first one each time.
+            int characterStart = 0;
+
+            // We are looping over the vector for each word,
+            // finding the index of the array that has the maximum value,
+            // and then finding the token classification that corresponds to that index.
+            for(int x = 0; x < v.length; x++) {
+
+                final double[] arr = v[x];
+                final int maxIndex = maxIndex(arr);
+                final String label = ids2Labels.get(maxIndex);
+
+                // TODO: Need to make sure this value is between 0 and 1?
+                final double probability = arr[maxIndex] / 10;

Review comment:
       I think the `arr` has values returned from running ONNX inference/session. Not sure, though, why the probability is calculated dividing he value by `10`, nor whether it will be between `0` and `1`. Had a look at the existing code, but couldn't find any class that we could re-use here. Maybe create an issue to follow this up after this PR?

##########
File path: opennlp-dl/pom.xml
##########
@@ -0,0 +1,35 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+  <modelVersion>4.0.0</modelVersion>
+  <parent>
+    <groupId>org.apache.opennlp</groupId>
+    <artifactId>opennlp</artifactId>
+    <version>1.9.5-SNAPSHOT</version>
+    <relativePath>../pom.xml</relativePath>
+  </parent>
+  <groupId>org.apache.opennlp</groupId>
+  <artifactId>opennlp-dl</artifactId>
+  <name>opennlp-dl</name>
+  <dependencies>
+    <dependency>
+      <groupId>org.apache.opennlp</groupId>
+      <artifactId>opennlp-tools</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>com.microsoft.onnxruntime</groupId>
+      <artifactId>onnxruntime</artifactId>
+      <version>${onnxruntime.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>junit</groupId>
+      <artifactId>junit</artifactId>
+      <version>${junit.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.opennlp</groupId>
+      <artifactId>opennlp-tools</artifactId>
+    </dependency>

Review comment:
       This is duplicated, so running `mvn` is failing at the moment.

##########
File path: opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
##########
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.namefinder;
+
+import opennlp.tools.namefind.TokenNameFinder;
+import opennlp.tools.util.Span;
+
+import java.io.File;
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An implementation of {@link TokenNameFinder} that uses ONNX models.
+ */
+public class NameFinderDL implements TokenNameFinder {
+
+    public static final String I_PER = "I-PER";
+    public static final String B_PER = "B-PER";
+
+    private final TokenNameFinderInference inference;
+    private final Map<Integer, String> ids2Labels;
+
+    /**
+     * Creates a new NameFinderDL for entity recognition using ONNX models.
+     * @param model The ONNX model file.
+     * @param vocab The model's vocabulary file.
+     * @param doLowerCase Whether or not to lowercase the text prior to inference.
+     * @param ids2Labels A map of values and their assigned labels used to train the model.
+     * @throws Exception Thrown if the models cannot be loaded.
+     */
+    public NameFinderDL(File model, File vocab, boolean doLowerCase, Map<Integer, String> ids2Labels) throws Exception {
+
+        this.ids2Labels = ids2Labels;
+        this.inference = new TokenNameFinderInference(model, vocab, doLowerCase);
+
+    }
+
+    @Override
+    public Span[] find(String[] tokens) {
+
+        final List<Span> spans = new LinkedList<>();
+        final String text = String.join(" ", tokens);
+
+        try {
+
+            final double[][] v = inference.infer(text);
+
+            // Find consecutive B-PER and I-PER labels and combine the spans where necessary.
+            // There are also B-LOC and I-LOC tags for locations that might be useful at some point.
+
+            // Keep track of where the last span was so when there are multiple/duplicate
+            // spans we can get the next one instead of the first one each time.
+            int characterStart = 0;
+
+            // We are looping over the vector for each word,
+            // finding the index of the array that has the maximum value,
+            // and then finding the token classification that corresponds to that index.
+            for(int x = 0; x < v.length; x++) {
+
+                final double[] arr = v[x];
+                final int maxIndex = maxIndex(arr);
+                final String label = ids2Labels.get(maxIndex);
+
+                // TODO: Need to make sure this value is between 0 and 1?
+                final double probability = arr[maxIndex] / 10;
+
+                if (B_PER.equalsIgnoreCase(label)) {
+
+                    // This is the start of a person entity.
+                    final String spanText;
+
+                    // Find the end index of the span in the array (where the label is not I-PER).
+                    final int endIndex = findSpanEnd(v, x, ids2Labels);
+
+                    // If the end is -1 it means this is a single-span token.
+                    // If the end is != -1 it means this is a multi-span token.
+                    if(endIndex != -1) {
+
+                        // Subtract one for the beginning token not part of the text.
+                        spanText = String.join(" ", Arrays.copyOfRange(tokens, x - 1, endIndex));
+
+                        spans.add(new Span(x - 1, endIndex, spanText, probability));
+
+                        x = endIndex;
+
+                    } else {
+
+                        // This is a single-token span so there is nothing else to do except grab the token.
+                        spanText = tokens[x];
+
+                        // Subtract one for the beginning token not part of the text.
+                        spans.add(new Span(x - 1, endIndex, spanText, probability));
+
+                    }
+
+                }
+
+            }
+
+        } catch (Exception ex) {
+            System.err.println("Error performing namefinder inference: " + ex.getMessage());
+        }
+
+        return spans.toArray(new Span[0]);
+
+    }
+
+    @Override
+    public void clearAdaptiveData() {
+        // No use for this in this implementation.
+    }
+
+    private int findSpanEnd(double[][] v, int startIndex, Map<Integer, String> id2Labels) {
+
+        // This will be the index of the last token in the span.
+        // -1 means there is no follow-up token, so it is a single-token span.
+        int index = -1;
+
+        // Starts at the span start in the vector.
+        // Looks at the next token to see if it is an I-PER.
+        // Go until the next token is something other than I-PER.
+        // When the next token is not I-PER, return the previous index.
+
+        for(int x = startIndex + 1; x < v[0].length; x++) {
+
+            // Get the next item.
+            final double[] arr = v[x];
+
+            // See if the next token has an I-PER label.
+            final String nextTokenClassification = id2Labels.get(maxIndex(arr));
+
+            if(!I_PER.equalsIgnoreCase(nextTokenClassification)) {
+                index = x - 1;
+                break;
+            }
+
+        }
+
+        return index;
+
+    }
+
+    private int maxIndex(double[] arr) {
+
+        double max = Double.NEGATIVE_INFINITY;
+        int index = -1;
+
+        for(int x = 0; x < arr.length; x++) {
+            if(arr[x] > max) {
+                index = x;
+                max = arr[x];
+            }
+        }
+
+        return index;
+
+    }

Review comment:
       ```suggestion
       private int maxIndex(double[] arr) {
           return IntStream.range(0, arr.length)
                   .reduce((i, j) -> arr[i] > arr[j] ? i : j)
                   .orElse(-1);
       }
   ```

##########
File path: opennlp-dl/src/main/java/opennlp/dl/Tokens.java
##########
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+/**
+ * Holds the tokens for input to an ONNX model.
+ */
+public class Tokens {
+
+    private String[] tokens;
+    private long[] ids;
+    private long[] mask;
+    private long[] types;
+
+    /**
+     * Creates a new instance to hold the tokens for input to an ONNX model.
+     * @param tokens The tokens themselves.
+     * @param ids The token IDs as retrieved from the vocabulary.
+     * @param mask The token mask. (Typically all 1.)
+     * @param types The token types. (Typically all 1.)
+     */
+    public Tokens(String[] tokens, long[] ids, long[] mask, long[] types) {
+
+        this.tokens = tokens;
+        this.ids = ids;
+        this.mask = mask;
+        this.types = types;
+
+    }
+
+    public String[] getTokens() {
+        return tokens;
+    }
+
+    public long[] getIds() {
+        return ids;
+    }
+
+    public long[] getMask() {
+        return mask;
+    }
+
+    public long[] getTypes() {
+        return types;
+    }
+
+}

Review comment:
       Missing newline.

##########
File path: pom.xml
##########
@@ -506,6 +504,7 @@
 		<module>opennlp-morfologik-addon</module>
 		<module>opennlp-docs</module>
 		<module>opennlp-distr</module>
-	</modules>
+		<module>opennlp-dl</module>
+  </modules>
 
-</project>
+</project>

Review comment:
       Missing newline?

##########
File path: opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.namefinder;
+
+import ai.onnxruntime.OrtException;
+import opennlp.tools.util.Span;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+
+public class NameFinderDLEval {
+
+    @Test
+    public void tokenNameFinder1Test() throws Exception {
+
+        // This test was written using the dslim/bert-base-NER model.
+        // You will need to update the ids2Labels and assertions if you use a different model.
+
+        final File model = new File(getClass().getClassLoader().getResource("namefinder/model.onnx").toURI());
+        final File vocab = new File(getClass().getClassLoader().getResource("namefinder/vocab.txt").toURI());
+
+        final String[] tokens = new String[]{"George", "Washington", "was", "president", "of", "the", "United", "States", "."};
+
+        final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, false, getIds2Labels());
+        final Span[] spans = nameFinderDL.find(tokens);
+
+        for(Span span : spans) {
+            System.out.println(span.toString());
+        }

Review comment:
       Remove to reduce the noise in the logs? Unless it's useful to see the spans when running tests.

##########
File path: opennlp-tools/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.tokenize;
+
+import opennlp.tools.util.Span;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * A WordPiece tokenizer.
+ *
+ * Adapted from https://github.com/robrua/easy-bert under the MIT license.
+ */
+public class WordpieceTokenizer implements Tokenizer {
+
+    private static final String CLASSIFICATION_TOKEN = "[CLS]";
+    private static final String SEPARATOR_TOKEN = "[SEP]";
+    private static final String UNKNOWN_TOKEN = "[UNK]";
+
+    private Set<String> vocabulary;
+    private int maxTokenLength = 50;
+
+    public WordpieceTokenizer(Set<String> vocabulary) {
+        this.vocabulary = vocabulary;
+    }
+
+    public WordpieceTokenizer(Set<String> vocabulary, int maxTokenLength) {
+        this.vocabulary = vocabulary;
+        this.maxTokenLength = maxTokenLength;
+    }
+
+    // https://www.tensorflow.org/text/guide/subwords_tokenizer#applying_wordpiece
+    // https://cran.r-project.org/web/packages/wordpiece/vignettes/basic_usage.html

Review comment:
       Move these two comments up to the class javadoc?

##########
File path: opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
##########
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.namefinder;
+
+import opennlp.tools.namefind.TokenNameFinder;
+import opennlp.tools.util.Span;
+
+import java.io.File;
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An implementation of {@link TokenNameFinder} that uses ONNX models.
+ */
+public class NameFinderDL implements TokenNameFinder {
+
+    public static final String I_PER = "I-PER";
+    public static final String B_PER = "B-PER";
+
+    private final TokenNameFinderInference inference;
+    private final Map<Integer, String> ids2Labels;
+
+    /**
+     * Creates a new NameFinderDL for entity recognition using ONNX models.
+     * @param model The ONNX model file.
+     * @param vocab The model's vocabulary file.
+     * @param doLowerCase Whether or not to lowercase the text prior to inference.
+     * @param ids2Labels A map of values and their assigned labels used to train the model.
+     * @throws Exception Thrown if the models cannot be loaded.
+     */
+    public NameFinderDL(File model, File vocab, boolean doLowerCase, Map<Integer, String> ids2Labels) throws Exception {
+
+        this.ids2Labels = ids2Labels;
+        this.inference = new TokenNameFinderInference(model, vocab, doLowerCase);
+
+    }
+
+    @Override
+    public Span[] find(String[] tokens) {
+
+        final List<Span> spans = new LinkedList<>();
+        final String text = String.join(" ", tokens);
+
+        try {
+
+            final double[][] v = inference.infer(text);
+
+            // Find consecutive B-PER and I-PER labels and combine the spans where necessary.
+            // There are also B-LOC and I-LOC tags for locations that might be useful at some point.
+
+            // Keep track of where the last span was so when there are multiple/duplicate
+            // spans we can get the next one instead of the first one each time.
+            int characterStart = 0;
+
+            // We are looping over the vector for each word,
+            // finding the index of the array that has the maximum value,
+            // and then finding the token classification that corresponds to that index.
+            for(int x = 0; x < v.length; x++) {
+
+                final double[] arr = v[x];
+                final int maxIndex = maxIndex(arr);
+                final String label = ids2Labels.get(maxIndex);
+
+                // TODO: Need to make sure this value is between 0 and 1?
+                final double probability = arr[maxIndex] / 10;
+
+                if (B_PER.equalsIgnoreCase(label)) {
+
+                    // This is the start of a person entity.
+                    final String spanText;
+
+                    // Find the end index of the span in the array (where the label is not I-PER).
+                    final int endIndex = findSpanEnd(v, x, ids2Labels);
+
+                    // If the end is -1 it means this is a single-span token.
+                    // If the end is != -1 it means this is a multi-span token.
+                    if(endIndex != -1) {
+
+                        // Subtract one for the beginning token not part of the text.
+                        spanText = String.join(" ", Arrays.copyOfRange(tokens, x - 1, endIndex));
+
+                        spans.add(new Span(x - 1, endIndex, spanText, probability));
+
+                        x = endIndex;
+
+                    } else {
+
+                        // This is a single-token span so there is nothing else to do except grab the token.
+                        spanText = tokens[x];
+
+                        // Subtract one for the beginning token not part of the text.
+                        spans.add(new Span(x - 1, endIndex, spanText, probability));
+
+                    }
+
+                }
+
+            }
+
+        } catch (Exception ex) {
+            System.err.println("Error performing namefinder inference: " + ex.getMessage());
+        }
+
+        return spans.toArray(new Span[0]);
+
+    }
+
+    @Override
+    public void clearAdaptiveData() {
+        // No use for this in this implementation.
+    }
+
+    private int findSpanEnd(double[][] v, int startIndex, Map<Integer, String> id2Labels) {
+
+        // This will be the index of the last token in the span.
+        // -1 means there is no follow-up token, so it is a single-token span.
+        int index = -1;
+
+        // Starts at the span start in the vector.
+        // Looks at the next token to see if it is an I-PER.
+        // Go until the next token is something other than I-PER.
+        // When the next token is not I-PER, return the previous index.
+
+        for(int x = startIndex + 1; x < v[0].length; x++) {
+
+            // Get the next item.
+            final double[] arr = v[x];
+
+            // See if the next token has an I-PER label.
+            final String nextTokenClassification = id2Labels.get(maxIndex(arr));
+
+            if(!I_PER.equalsIgnoreCase(nextTokenClassification)) {
+                index = x - 1;
+                break;
+            }
+
+        }
+
+        return index;
+
+    }
+
+    private int maxIndex(double[] arr) {
+
+        double max = Double.NEGATIVE_INFINITY;
+        int index = -1;
+
+        for(int x = 0; x < arr.length; x++) {
+            if(arr[x] > max) {
+                index = x;
+                max = arr[x];
+            }
+        }
+
+        return index;
+
+    }

Review comment:
       I think it does the same. Or use the `max` method in streams.

##########
File path: opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
##########
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.doccat;
+
+import opennlp.tools.doccat.DocumentCategorizer;
+
+import java.io.File;
+import java.util.*;
+
+/**
+ * An implementation of {@link DocumentCategorizer} that performs document classification
+ * using ONNX models.
+ */
+public class DocumentCategorizerDL implements DocumentCategorizer {
+
+    private final File model;
+    private final File vocab;
+    private final Map<Integer, String> categories;
+
+    /**
+     * Creates a new document categorizer using ONNX models.
+     * @param model The ONNX model file.
+     * @param vocab The model's vocabulary file.
+     * @param categories The categories.
+     */
+    public DocumentCategorizerDL(File model, File vocab, Map<Integer, String> categories) {
+
+        this.model = model;
+        this.vocab = vocab;
+        this.categories = categories;
+
+    }
+
+    @Override
+    public double[] categorize(String[] strings) {
+
+        try {
+
+            final DocumentCategorizerInference inference = new DocumentCategorizerInference(model, vocab);
+
+            final double[][] vectors = inference.infer(strings[0]);
+            final double[] results = inference.softmax(vectors[0]);
+
+            return results;
+
+        } catch (Exception ex) {
+            System.err.println("Unload to perform document classification inference: " + ex.getMessage());
+        }
+
+        return new double[]{};
+
+    }
+
+    @Override
+    public double[] categorize(String[] strings, Map<String, Object> map) {
+        return categorize(strings);
+    }
+
+    @Override
+    public String getBestCategory(double[] doubles) {
+        return categories.get(maxIndex(doubles));
+    }
+
+    @Override
+    public int getIndex(String s) {
+        return getKey(s);
+    }
+
+    @Override
+    public String getCategory(int i) {
+        return categories.get(i);
+    }
+
+    @Override
+    public int getNumberOfCategories() {
+       return categories.size();
+    }
+
+    @Override
+    public String getAllResults(double[] doubles) {
+        return null;
+    }
+
+    @Override
+    public Map<String, Double> scoreMap(String[] strings) {
+
+        final double[] scores = categorize(strings);
+
+        final Map<String, Double> scoreMap = new HashMap<>();
+
+        for(int x : categories.keySet()) {
+            scoreMap.put(categories.get(x), scores[x]);
+        }
+
+        return scoreMap;
+
+    }
+
+    @Override
+    public SortedMap<Double, Set<String>> sortedScoreMap(String[] strings) {
+
+        final double[] scores = categorize(strings);
+
+        final SortedMap<Double, Set<String>> scoreMap = new TreeMap<>();
+
+        for(int x : categories.keySet()) {
+
+            if(scoreMap.get(scores[x]) == null) {
+                scoreMap.put(scores[x], new HashSet<>());
+            }
+
+            scoreMap.get(scores[x]).add(categories.get(x));
+
+        }
+
+        return scoreMap;
+
+    }
+
+    private int getKey(String value) {
+
+        for (Map.Entry<Integer, String> entry : categories.entrySet()) {
+
+            if (entry.getValue().equals(value)) {
+                return entry.getKey();
+            }
+
+        }
+
+        // The String wasn't found as a value in the map.
+        return -1;
+
+    }
+
+    private int maxIndex(double[] arr) {
+
+        double max = Double.NEGATIVE_INFINITY;
+        int index = -1;
+
+        for(int x = 0; x < arr.length; x++) {
+            if(arr[x] > max) {
+                index = x;
+                max = arr[x];
+            }
+        }
+
+        return index;

Review comment:
       Ditto about calculating max here. Maybe this could even be a common static method somewhere, and re-use it instead of maintaining the duplicated code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@opennlp.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org