You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@camel.apache.org by ac...@apache.org on 2020/08/28 05:32:40 UTC
[camel] branch master updated: CAMEL-15476:Upgrade to Deep Java
Library 0.6.0 (#4140)
This is an automated email from the ASF dual-hosted git repository.
acosentino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/camel.git
The following commit(s) were added to refs/heads/master by this push:
new 4b2ba8c CAMEL-15476:Upgrade to Deep Java Library 0.6.0 (#4140)
4b2ba8c is described below
commit 4b2ba8cc475df20c2a68d0ed4544bf379d50fb18
Author: Marat Gubaidullin <ma...@gmail.com>
AuthorDate: Fri Aug 28 01:32:26 2020 -0400
CAMEL-15476:Upgrade to Deep Java Library 0.6.0 (#4140)
---
camel-dependencies/pom.xml | 6 ++--
.../model/CustomImageClassificationPredictor.java | 22 +++++++--------
.../djl/model/CustomObjectDetectionPredictor.java | 22 +++++++--------
.../djl/model/ZooImageClassificationPredictor.java | 29 +++++++++----------
.../djl/model/ZooObjectDetectionPredictor.java | 28 +++++++++---------
.../djl/ImageClassificationLocalTest.java | 17 ++++++-----
.../component/djl/training/MnistTraining.java | 33 ++--------------------
.../src/test/resources/models/mnist/synset.txt | 10 -------
parent/pom.xml | 6 ++--
9 files changed, 66 insertions(+), 107 deletions(-)
diff --git a/camel-dependencies/pom.xml b/camel-dependencies/pom.xml
index 3f6fa69..f21e575 100644
--- a/camel-dependencies/pom.xml
+++ b/camel-dependencies/pom.xml
@@ -170,10 +170,10 @@
<digitalocean-api-client-version>2.17</digitalocean-api-client-version>
<directory-watcher-version>0.10.0</directory-watcher-version>
<disruptor-version>3.4.2</disruptor-version>
- <djl-mxnet-native-version>1.7.0-a</djl-mxnet-native-version>
+ <djl-mxnet-native-version>1.7.0-b</djl-mxnet-native-version>
<djl-pytorch-native-version>1.5.0</djl-pytorch-native-version>
- <djl-tensorflow-native-version>2.1.0</djl-tensorflow-native-version>
- <djl-version>0.5.0</djl-version>
+ <djl-tensorflow-native-version>2.2.0</djl-tensorflow-native-version>
+ <djl-version>0.6.0</djl-version>
<dnsjava-version>3.2.2</dnsjava-version>
<docker-java-version>3.2.5</docker-java-version>
<dozer-version>6.5.0</dozer-version>
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java
index 17fc27d..0856caa 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomImageClassificationPredictor.java
@@ -16,20 +16,16 @@
*/
package org.apache.camel.component.djl.model;
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
+import java.io.*;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
-import javax.imageio.ImageIO;
-
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import org.apache.camel.Exchange;
@@ -69,7 +65,8 @@ public class CustomImageClassificationPredictor extends AbstractPredictor {
private Map<String, Float> classify(Model model, Translator translator, File input) throws Exception {
try {
- return classify(model, translator, ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+ return classify(model, translator, image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -78,16 +75,17 @@ public class CustomImageClassificationPredictor extends AbstractPredictor {
private Map<String, Float> classify(Model model, Translator translator, InputStream input) throws Exception {
try {
- return classify(model, translator, ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(input);
+ return classify(model, translator, image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
}
}
- private Map<String, Float> classify(Model model, Translator translator, BufferedImage input) throws Exception {
- try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator)) {
- Classifications classifications = predictor.predict(input);
+ private Map<String, Float> classify(Model model, Translator translator, Image image) throws Exception {
+ try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
+ Classifications classifications = predictor.predict(image);
List<Classifications.Classification> list = classifications.items();
return list.stream()
.collect(Collectors.toMap(Classifications.Classification::getClassName, x -> (float) x.getProbability()));
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java
index 6ad600b..b661b19 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/CustomObjectDetectionPredictor.java
@@ -16,16 +16,12 @@
*/
package org.apache.camel.component.djl.model;
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-
-import javax.imageio.ImageIO;
+import java.io.*;
import ai.djl.Model;
import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
@@ -65,9 +61,9 @@ public class CustomObjectDetectionPredictor extends AbstractPredictor {
}
}
- public DetectedObjects classify(Model model, Translator translator, BufferedImage input) throws Exception {
- try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor(translator)) {
- DetectedObjects detectedObjects = predictor.predict(input);
+ public DetectedObjects classify(Model model, Translator translator, Image image) throws Exception {
+ try (Predictor<Image, DetectedObjects> predictor = model.newPredictor(translator)) {
+ DetectedObjects detectedObjects = predictor.predict(image);
return detectedObjects;
} catch (TranslateException e) {
LOG.error("Could not process input or output", e);
@@ -77,7 +73,8 @@ public class CustomObjectDetectionPredictor extends AbstractPredictor {
public DetectedObjects classify(Model model, Translator translator, File input) throws Exception {
try {
- return classify(model, translator, ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+ return classify(model, translator, image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -86,7 +83,8 @@ public class CustomObjectDetectionPredictor extends AbstractPredictor {
public DetectedObjects classify(Model model, Translator translator, InputStream input) throws Exception {
try {
- return classify(model, translator, ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(input);
+ return classify(model, translator, image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java
index 8d1354c..843d255 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooImageClassificationPredictor.java
@@ -16,20 +16,16 @@
*/
package org.apache.camel.component.djl.model;
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
+import java.io.*;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
-import javax.imageio.ImageIO;
-
import ai.djl.Application;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
@@ -42,16 +38,17 @@ import org.slf4j.LoggerFactory;
public class ZooImageClassificationPredictor extends AbstractPredictor {
private static final Logger LOG = LoggerFactory.getLogger(ZooImageClassificationPredictor.class);
- private final ZooModel<BufferedImage, Classifications> model;
+ private final ZooModel<Image, Classifications> model;
public ZooImageClassificationPredictor(String artifactId) throws Exception {
- Criteria<BufferedImage, Classifications> criteria = Criteria.builder()
+ Criteria<Image, Classifications> criteria = Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
- .setTypes(BufferedImage.class, Classifications.class)
+ .setTypes(Image.class, Classifications.class)
.optArtifactId(artifactId)
.optProgress(new ProgressBar())
.build();
this.model = ModelZoo.loadModel(criteria);
+ // model.save(Paths.get("src/test/resources/models/mnist"), "mlp");
}
@Override
@@ -73,7 +70,8 @@ public class ZooImageClassificationPredictor extends AbstractPredictor {
public Map<String, Float> classify(File input) throws Exception {
try {
- return classify(ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+ return classify(image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -82,16 +80,17 @@ public class ZooImageClassificationPredictor extends AbstractPredictor {
public Map<String, Float> classify(InputStream input) throws Exception {
try {
- return classify(ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(input);
+ return classify(image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
}
}
- public Map<String, Float> classify(BufferedImage input) throws Exception {
- try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor()) {
- Classifications classifications = predictor.predict(input);
+ public Map<String, Float> classify(Image image) throws Exception {
+ try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
+ Classifications classifications = predictor.predict(image);
List<Classifications.Classification> list = classifications.items();
return list.stream()
.collect(Collectors.toMap(Classifications.Classification::getClassName, x -> (float) x.getProbability()));
diff --git a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java
index 05828ec..493a835 100644
--- a/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java
+++ b/components/camel-djl/src/main/java/org/apache/camel/component/djl/model/ZooObjectDetectionPredictor.java
@@ -16,16 +16,12 @@
*/
package org.apache.camel.component.djl.model;
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayInputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-
-import javax.imageio.ImageIO;
+import java.io.*;
import ai.djl.Application;
import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
@@ -40,12 +36,12 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
private static final Logger LOG = LoggerFactory.getLogger(ZooObjectDetectionPredictor.class);
- private final ZooModel<BufferedImage, DetectedObjects> model;
+ private final ZooModel<Image, DetectedObjects> model;
public ZooObjectDetectionPredictor(String artifactId) throws Exception {
- Criteria<BufferedImage, DetectedObjects> criteria = Criteria.builder()
+ Criteria<Image, DetectedObjects> criteria = Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
- .setTypes(BufferedImage.class, DetectedObjects.class)
+ .setTypes(Image.class, DetectedObjects.class)
.optArtifactId(artifactId)
.optProgress(new ProgressBar())
.build();
@@ -69,9 +65,9 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
}
}
- public DetectedObjects classify(BufferedImage input) throws Exception {
- try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()) {
- DetectedObjects detectedObjects = predictor.predict(input);
+ public DetectedObjects classify(Image image) throws Exception {
+ try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
+ DetectedObjects detectedObjects = predictor.predict(image);
return detectedObjects;
} catch (TranslateException e) {
throw new RuntimeException("Could not process input or output", e);
@@ -80,7 +76,8 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
public DetectedObjects classify(File input) throws Exception {
try {
- return classify(ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(new FileInputStream(input));
+ return classify(image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
@@ -89,7 +86,8 @@ public class ZooObjectDetectionPredictor extends AbstractPredictor {
public DetectedObjects classify(InputStream input) throws Exception {
try {
- return classify(ImageIO.read(input));
+ Image image = ImageFactory.getInstance().fromInputStream(input);
+ return classify(image);
} catch (IOException e) {
LOG.error("Couldn't transform input into a BufferedImage");
throw new RuntimeException("Couldn't transform input into a BufferedImage", e);
diff --git a/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java b/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java
index 08b011e..5d33661 100644
--- a/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java
+++ b/components/camel-djl/src/test/java/org/apache/camel/component/djl/ImageClassificationLocalTest.java
@@ -21,15 +21,20 @@ import java.io.IOException;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.Comparator;
+import java.util.List;
import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
+import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
-import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
import org.apache.camel.builder.RouteBuilder;
import org.apache.camel.component.mock.MockEndpoint;
import org.apache.camel.test.junit5.CamelTestSupport;
@@ -83,15 +88,13 @@ public class ImageClassificationLocalTest extends CamelTestSupport {
private void loadLocalModel() throws IOException, MalformedModelException, TranslateException {
// create deep learning model
- Model model = Model.newInstance();
+ Model model = Model.newInstance(MODEL_NAME);
model.setBlock(new Mlp(28 * 28, 10, new int[] { 128, 64 }));
model.load(Paths.get(MODEL_DIR), MODEL_NAME);
// create translator for pre-processing and postprocessing
- ImageClassificationTranslator.Builder builder = ImageClassificationTranslator.builder();
- builder.setSynsetArtifactName("synset.txt");
- builder.setPipeline(new Pipeline(new ToTensor()));
- builder.optApplySoftmax(true);
- ImageClassificationTranslator translator = new ImageClassificationTranslator(builder);
+ List<String> classes = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
+ Translator<Image, Classifications> translator
+ = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optSynset(classes).build();
// Bind model beans
context.getRegistry().bind("MyModel", model);
diff --git a/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java b/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java
index 0222fdc..e8e8def 100644
--- a/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java
+++ b/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java
@@ -28,8 +28,8 @@ import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
+import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
-import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
@@ -54,7 +54,7 @@ public final class MnistTraining {
// Construct neural network
Block block = new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES, new int[] { 128, 64 });
- try (Model model = Model.newInstance()) {
+ try (Model model = Model.newInstance(MODEL_NAME)) {
model.setBlock(block);
// get training and validation dataset
@@ -73,39 +73,12 @@ public final class MnistTraining {
// initialize trainer with proper input shape
trainer.initialize(inputShape);
- fit(trainer, 10, trainingSet, validateSet, MODEL_DIR, MODEL_NAME);
+ EasyTrain.fit(trainer, 10, trainingSet, validateSet);
}
model.save(Paths.get(MODEL_DIR), MODEL_NAME);
}
}
- private static void fit(
- Trainer trainer, int numEpoch, Dataset trainingSet, Dataset validateSet, String outputDir, String modelName)
- throws IOException {
- for (int epoch = 0; epoch < numEpoch; epoch++) {
- for (Batch batch : trainer.iterateDataset(trainingSet)) {
- trainer.trainBatch(batch);
- trainer.step();
- batch.close();
- }
-
- if (validateSet != null) {
- for (Batch batch : trainer.iterateDataset(validateSet)) {
- trainer.validateBatch(batch);
- batch.close();
- }
- }
- // reset training and validation evaluators at end of epoch
- trainer.endEpoch();
- // save model at end of each epoch
- if (outputDir != null) {
- Model model = trainer.getModel();
- model.setProperty("Epoch", String.valueOf(epoch));
- model.save(Paths.get(outputDir), modelName);
- }
- }
- }
-
private static RandomAccessDataset prepareDataset(Dataset.Usage usage, int batchSize, long limit) throws IOException {
Mnist mnist = Mnist.builder().optUsage(usage).setSampling(batchSize, true).optLimit(limit).build();
mnist.prepare(new ProgressBar());
diff --git a/components/camel-djl/src/test/resources/models/mnist/synset.txt b/components/camel-djl/src/test/resources/models/mnist/synset.txt
deleted file mode 100644
index f55b5c9..0000000
--- a/components/camel-djl/src/test/resources/models/mnist/synset.txt
+++ /dev/null
@@ -1,10 +0,0 @@
-0
-1
-2
-3
-4
-5
-6
-7
-8
-9
\ No newline at end of file
diff --git a/parent/pom.xml b/parent/pom.xml
index c92a21a..6ae8e64 100644
--- a/parent/pom.xml
+++ b/parent/pom.xml
@@ -150,10 +150,10 @@
<directory-watcher-version>0.10.0</directory-watcher-version>
<disruptor-version>3.4.2</disruptor-version>
<dnsjava-version>3.2.2</dnsjava-version>
- <djl-version>0.5.0</djl-version>
- <djl-mxnet-native-version>1.7.0-a</djl-mxnet-native-version>
+ <djl-version>0.6.0</djl-version>
+ <djl-mxnet-native-version>1.7.0-b</djl-mxnet-native-version>
<djl-pytorch-native-version>1.5.0</djl-pytorch-native-version>
- <djl-tensorflow-native-version>2.1.0</djl-tensorflow-native-version>
+ <djl-tensorflow-native-version>2.2.0</djl-tensorflow-native-version>
<docker-java-version>3.2.5</docker-java-version>
<dozer-version>6.5.0</dozer-version>
<drools-version>7.42.0.Final</drools-version>