You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@camel.apache.org by ac...@apache.org on 2020/08/28 05:32:40 UTC

[camel] branch master updated: CAMEL-15476:Upgrade to Deep Java Library 0.6.0 (#4140)

This is an automated email from the ASF dual-hosted git repository.

acosentino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/camel.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b2ba8c  CAMEL-15476:Upgrade to Deep Java Library 0.6.0 (#4140)
4b2ba8c is described below

commit 4b2ba8cc475df20c2a68d0ed4544bf379d50fb18
Author: Marat Gubaidullin <ma...@gmail.com>
AuthorDate: Fri Aug 28 01:32:26 2020 -0400

    CAMEL-15476:Upgrade to Deep Java Library 0.6.0 (#4140)
---
 camel-dependencies/pom.xml                         |  6 ++--
 .../model/CustomImageClassificationPredictor.java  | 22 +++++++--------
 .../djl/model/CustomObjectDetectionPredictor.java  | 22 +++++++--------
 .../djl/model/ZooImageClassificationPredictor.java | 29 +++++++++----------
 .../djl/model/ZooObjectDetectionPredictor.java     | 28 +++++++++---------
 .../djl/ImageClassificationLocalTest.java          | 17 ++++++-----
 .../component/djl/training/MnistTraining.java      | 33 ++--------------------
 .../src/test/resources/models/mnist/synset.txt     | 10 -------
 parent/pom.xml                                     |  6 ++--
 9 files changed, 66 insertions(+), 107 deletions(-)

diff --git a/camel-dependencies/pom.xml b/camel-dependencies/pom.xml
index 3f6fa69..f21e575 100644
--- a/camel-dependencies/pom.xml
+++ b/camel-dependencies/pom.xml
@@ -170,10 +170,10 @@
     <digitalocean-api-client-version>2.17</digitalocean-api-client-version>
     <directory-watcher-version>0.10.0</directory-watcher-version>
     <disruptor-version>3.4.2</disruptor-version>
-    <djl-mxnet-native-version>1.7.0-a</djl-mxnet-native-version>
+    <djl-mxnet-native-version>1.7.0-b</djl-mxnet-native-version>
     <djl-pytorch-native-version>1.5.0</djl-pytorch-native-version>
-    <djl-tensorflow-native-version>2.1.0</djl-tensorflow-native-version>
-    <djl-version>0.5.0</djl-version>
+    <djl-tensorflow-native-version>2.2.0</djl-tensorflow-native-version>
+    <djl-version>0.6.0</djl-version>
     <dnsjava-version>3.2.2</dnsjava-version>
     <docker-java-version>3.2.5</docker-java-version>
     <dozer-version>6.5.0</dozer-version>
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java
index 17fc27d..0856caa 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java
@@ -16,20 +16,16 @@
  */
 package org.apache.camel.component.djl.model;
 
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
+import java.io.*;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
-import javax.imageio.ImageIO;
-
 import ai.djl.Model;
 import ai.djl.inference.Predictor;
 import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
 import ai.djl.translate.TranslateException;
 import ai.djl.translate.Translator;
 import org.apache.camel.Exchange;
@@ -69,7 +65,8 @@ public class CustomImageClassificationPredictor extends AbstractPredictor {
 
     private Map<String, Float> classify(Model model, Translator translator, File input) throws Exception {
         try {
-            return classify(model, translator, ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+            return classify(model, translator, image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -78,16 +75,17 @@ public class CustomImageClassificationPredictor extends AbstractPredictor {
 
     private Map<String, Float> classify(Model model, Translator translator, InputStream input) throws Exception {
         try {
-            return classify(model, translator, ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(input);
+            return classify(model, translator, image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
         }
     }
 
-    private Map<String, Float> classify(Model model, Translator translator, BufferedImage input) throws Exception {
-        try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator)) {
-            Classifications classifications = predictor.predict(input);
+    private Map<String, Float> classify(Model model, Translator translator, Image image) throws Exception {
+        try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
+            Classifications classifications = predictor.predict(image);
             List<Classifications.Classification> list = classifications.items();
             return list.stream()
                     .collect(Collectors.toMap(Classifications.Classification::getClassName, x -> (float) x.getProbability()));
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java
index 6ad600b..b661b19 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java
@@ -16,16 +16,12 @@
  */
 package org.apache.camel.component.djl.model;
 
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-
-import javax.imageio.ImageIO;
+import java.io.*;
 
 import ai.djl.Model;
 import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
 import ai.djl.modality.cv.output.DetectedObjects;
 import ai.djl.translate.TranslateException;
 import ai.djl.translate.Translator;
@@ -65,9 +61,9 @@ public class CustomObjectDetectionPredictor extends AbstractPredictor {
         }
     }
 
-    public DetectedObjects classify(Model model, Translator translator, BufferedImage input) throws Exception {
-        try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor(translator)) {
-            DetectedObjects detectedObjects = predictor.predict(input);
+    public DetectedObjects classify(Model model, Translator translator, Image image) throws Exception {
+        try (Predictor<Image, DetectedObjects> predictor = model.newPredictor(translator)) {
+            DetectedObjects detectedObjects = predictor.predict(image);
             return detectedObjects;
         } catch (TranslateException e) {
             LOG.error("Could not process input or output", e);
@@ -77,7 +73,8 @@ public class CustomObjectDetectionPredictor extends AbstractPredictor {
 
     public DetectedObjects classify(Model model, Translator translator, File input) throws Exception {
         try {
-            return classify(model, translator, ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+            return classify(model, translator, image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -86,7 +83,8 @@ public class CustomObjectDetectionPredictor extends AbstractPredictor {
 
     public DetectedObjects classify(Model model, Translator translator, InputStream input) throws Exception {
         try {
-            return classify(model, translator, ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(input);
+            return classify(model, translator, image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java
index 8d1354c..843d255 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java
@@ -16,20 +16,16 @@
  */
 package org.apache.camel.component.djl.model;
 
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
+import java.io.*;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
-import javax.imageio.ImageIO;
-
 import ai.djl.Application;
 import ai.djl.inference.Predictor;
 import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
 import ai.djl.repository.zoo.Criteria;
 import ai.djl.repository.zoo.ModelZoo;
 import ai.djl.repository.zoo.ZooModel;
@@ -42,16 +38,17 @@ import org.slf4j.LoggerFactory;
 public class ZooImageClassificationPredictor extends AbstractPredictor {
     private static final Logger LOG = LoggerFactory.getLogger(ZooImageClassificationPredictor.class);
 
-    private final ZooModel<BufferedImage, Classifications> model;
+    private final ZooModel<Image, Classifications> model;
 
     public ZooImageClassificationPredictor(String artifactId) throws Exception {
-        Criteria<BufferedImage, Classifications> criteria = Criteria.builder()
+        Criteria<Image, Classifications> criteria = Criteria.builder()
                 .optApplication(Application.CV.IMAGE_CLASSIFICATION)
-                .setTypes(BufferedImage.class, Classifications.class)
+                .setTypes(Image.class, Classifications.class)
                 .optArtifactId(artifactId)
                 .optProgress(new ProgressBar())
                 .build();
         this.model = ModelZoo.loadModel(criteria);
+        //        model.save(Paths.get("src/test/resources/models/mnist"), "mlp");
     }
 
     @Override
@@ -73,7 +70,8 @@ public class ZooImageClassificationPredictor extends AbstractPredictor {
 
     public Map<String, Float> classify(File input) throws Exception {
         try {
-            return classify(ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+            return classify(image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -82,16 +80,17 @@ public class ZooImageClassificationPredictor extends AbstractPredictor {
 
     public Map<String, Float> classify(InputStream input) throws Exception {
         try {
-            return classify(ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(input);
+            return classify(image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
         }
     }
 
-    public Map<String, Float> classify(BufferedImage input) throws Exception {
-        try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor()) {
-            Classifications classifications = predictor.predict(input);
+    public Map<String, Float> classify(Image image) throws Exception {
+        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
+            Classifications classifications = predictor.predict(image);
             List<Classifications.Classification> list = classifications.items();
             return list.stream()
                     .collect(Collectors.toMap(Classifications.Classification::getClassName, x -> (float) x.getProbability()));
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java
index 05828ec..493a835 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java
@@ -16,16 +16,12 @@
  */
 package org.apache.camel.component.djl.model;
 
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-
-import javax.imageio.ImageIO;
+import java.io.*;
 
 import ai.djl.Application;
 import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
 import ai.djl.modality.cv.output.DetectedObjects;
 import ai.djl.repository.zoo.Criteria;
 import ai.djl.repository.zoo.ModelZoo;
@@ -40,12 +36,12 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
 
     private static final Logger LOG = LoggerFactory.getLogger(ZooObjectDetectionPredictor.class);
 
-    private final ZooModel<BufferedImage, DetectedObjects> model;
+    private final ZooModel<Image, DetectedObjects> model;
 
     public ZooObjectDetectionPredictor(String artifactId) throws Exception {
-        Criteria<BufferedImage, DetectedObjects> criteria = Criteria.builder()
+        Criteria<Image, DetectedObjects> criteria = Criteria.builder()
                 .optApplication(Application.CV.OBJECT_DETECTION)
-                .setTypes(BufferedImage.class, DetectedObjects.class)
+                .setTypes(Image.class, DetectedObjects.class)
                 .optArtifactId(artifactId)
                 .optProgress(new ProgressBar())
                 .build();
@@ -69,9 +65,9 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
         }
     }
 
-    public DetectedObjects classify(BufferedImage input) throws Exception {
-        try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()) {
-            DetectedObjects detectedObjects = predictor.predict(input);
+    public DetectedObjects classify(Image image) throws Exception {
+        try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
+            DetectedObjects detectedObjects = predictor.predict(image);
             return detectedObjects;
         } catch (TranslateException e) {
             throw new RuntimeException("Could not process input or output", e);
@@ -80,7 +76,8 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
 
     public DetectedObjects classify(File input) throws Exception {
         try {
-            return classify(ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+            return classify(image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -89,7 +86,8 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
 
     public DetectedObjects classify(InputStream input) throws Exception {
         try {
-            return classify(ImageIO.read(input));
+            Image image = ImageFactory.getInstance().fromInputStream(input);
+            return classify(image);
         } catch (IOException e) {
             LOG.error("Couldn't transform input into a BufferedImage");
             throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
diff --git a/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java b/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java
index 08b011e..5d33661 100644
--- a/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java
+++ b/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java
@@ -21,15 +21,20 @@ import java.io.IOException;
 import java.nio.file.Paths;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import ai.djl.MalformedModelException;
 import ai.djl.Model;
 import ai.djl.basicmodelzoo.basic.Mlp;
+import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
 import ai.djl.modality.cv.transform.ToTensor;
 import ai.djl.modality.cv.translator.ImageClassificationTranslator;
-import ai.djl.translate.Pipeline;
 import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
 import org.apache.camel.builder.RouteBuilder;
 import org.apache.camel.component.mock.MockEndpoint;
 import org.apache.camel.test.junit5.CamelTestSupport;
@@ -83,15 +88,13 @@ public class ImageClassificationLocalTest extends CamelTestSupport {
 
     private void loadLocalModel() throws IOException, MalformedModelException, TranslateException {
         // create deep learning model
-        Model model = Model.newInstance();
+        Model model = Model.newInstance(MODEL_NAME);
         model.setBlock(new Mlp(28 * 28, 10, new int[] { 128, 64 }));
         model.load(Paths.get(MODEL_DIR), MODEL_NAME);
         // create translator for pre-processing and postprocessing
-        ImageClassificationTranslator.Builder builder = ImageClassificationTranslator.builder();
-        builder.setSynsetArtifactName("synset.txt");
-        builder.setPipeline(new Pipeline(new ToTensor()));
-        builder.optApplySoftmax(true);
-        ImageClassificationTranslator translator = new ImageClassificationTranslator(builder);
+        List<String> classes = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
+        Translator<Image, Classifications> translator
+                = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optSynset(classes).build();
 
         // Bind model beans
         context.getRegistry().bind("MyModel", model);
diff --git a/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java b/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java
index 0222fdc..e8e8def 100644
--- a/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java
+++ b/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java
@@ -28,8 +28,8 @@ import ai.djl.metric.Metrics;
 import ai.djl.ndarray.types.Shape;
 import ai.djl.nn.Block;
 import ai.djl.training.DefaultTrainingConfig;
+import ai.djl.training.EasyTrain;
 import ai.djl.training.Trainer;
-import ai.djl.training.dataset.Batch;
 import ai.djl.training.dataset.Dataset;
 import ai.djl.training.dataset.RandomAccessDataset;
 import ai.djl.training.evaluator.Accuracy;
@@ -54,7 +54,7 @@ public final class MnistTraining {
         // Construct neural network
         Block block = new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES, new int[] { 128, 64 });
 
-        try (Model model = Model.newInstance()) {
+        try (Model model = Model.newInstance(MODEL_NAME)) {
             model.setBlock(block);
 
             // get training and validation dataset
@@ -73,39 +73,12 @@ public final class MnistTraining {
 
                 // initialize trainer with proper input shape
                 trainer.initialize(inputShape);
-                fit(trainer, 10, trainingSet, validateSet, MODEL_DIR, MODEL_NAME);
+                EasyTrain.fit(trainer, 10, trainingSet, validateSet);
             }
             model.save(Paths.get(MODEL_DIR), MODEL_NAME);
         }
     }
 
-    private static void fit(
-            Trainer trainer, int numEpoch, Dataset trainingSet, Dataset validateSet, String outputDir, String modelName)
-            throws IOException {
-        for (int epoch = 0; epoch < numEpoch; epoch++) {
-            for (Batch batch : trainer.iterateDataset(trainingSet)) {
-                trainer.trainBatch(batch);
-                trainer.step();
-                batch.close();
-            }
-
-            if (validateSet != null) {
-                for (Batch batch : trainer.iterateDataset(validateSet)) {
-                    trainer.validateBatch(batch);
-                    batch.close();
-                }
-            }
-            // reset training and validation evaluators at end of epoch
-            trainer.endEpoch();
-            // save model at end of each epoch
-            if (outputDir != null) {
-                Model model = trainer.getModel();
-                model.setProperty("Epoch", String.valueOf(epoch));
-                model.save(Paths.get(outputDir), modelName);
-            }
-        }
-    }
-
     private static RandomAccessDataset prepareDataset(Dataset.Usage usage, int batchSize, long limit) throws IOException {
         Mnist mnist = Mnist.builder().optUsage(usage).setSampling(batchSize, true).optLimit(limit).build();
         mnist.prepare(new ProgressBar());
diff --git a/components/camel-djl/src/test/resources/models/mnist/synset.txt b/components/camel-djl/src/test/resources/models/mnist/synset.txt
deleted file mode 100644
index f55b5c9..0000000
--- a/components/camel-djl/src/test/resources/models/mnist/synset.txt
+++ /dev/null
@@ -1,10 +0,0 @@
-0
-1
-2
-3
-4
-5
-6
-7
-8
-9
\ No newline at end of file
diff --git a/parent/pom.xml b/parent/pom.xml
index c92a21a..6ae8e64 100644
--- a/parent/pom.xml
+++ b/parent/pom.xml
@@ -150,10 +150,10 @@
         <directory-watcher-version>0.10.0</directory-watcher-version>
         <disruptor-version>3.4.2</disruptor-version>
         <dnsjava-version>3.2.2</dnsjava-version>
-        <djl-version>0.5.0</djl-version>
-        <djl-mxnet-native-version>1.7.0-a</djl-mxnet-native-version>
+        <djl-version>0.6.0</djl-version>
+        <djl-mxnet-native-version>1.7.0-b</djl-mxnet-native-version>
         <djl-pytorch-native-version>1.5.0</djl-pytorch-native-version>
-        <djl-tensorflow-native-version>2.1.0</djl-tensorflow-native-version>
+        <djl-tensorflow-native-version>2.2.0</djl-tensorflow-native-version>
         <docker-java-version>3.2.5</docker-java-version>
         <dozer-version>6.5.0</dozer-version>
         <drools-version>7.42.0.Final</drools-version>