This is an automated email from the ASF dual-hosted git repository.
tsato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/camel.git
The following commit(s) were added to refs/heads/main by this push:
new 0ff88a0a302 CAMEL-20905: camel-djl - Support more applications for
custom models
0ff88a0a302 is described below
commit 0ff88a0a302a74aaac0394f95cc57d0b63dc2f43
Author: Tadayoshi Sato <[email protected]>
AuthorDate: Fri Jul 5 18:09:33 2024 +0900
CAMEL-20905: camel-djl - Support more applications for custom models
---
.../djl/model/ModelPredictorProducer.java | 83 +++++++++++++--
.../CustomAudioPredictor.java} | 78 +++++++-------
.../djl/model/audio/ZooAudioPredictor.java | 107 +++++++++++++++++++
...cationPredictor.java => CustomCvPredictor.java} | 60 ++++++-----
.../model/cv/CustomImageGenerationPredictor.java | 63 +++++++++++
.../model/cv/CustomObjectDetectionPredictor.java | 94 -----------------
.../djl/model/cv/ZooImageEnhancementPredictor.java | 2 +-
.../djl/model/nlp/CustomNlpPredictor.java | 62 +++++++++++
.../model/nlp/CustomQuestionAnswerPredictor.java | 72 +++++++++++++
.../CustomWordEmbeddingPredictor.java} | 18 ++--
.../model/nlp/ZooMachineTranslationPredictor.java | 2 +-
.../djl/model/tabular/CustomTabularPredictor.java | 58 +++++++++++
.../tabular/ZooLinearRegressionPredictor.java | 2 +-
.../tabular/ZooSoftmaxRegressionPredictor.java | 2 +-
.../timeseries/CustomForecastingPredictor.java | 64 ++++++++++++
.../apache/camel/component/djl/AudioLocalTest.java | 92 ++++++++++++++++
.../component/djl/CvImageEnhancementLocalTest.java | 116 +++++++++++++++++++++
.../camel/component/djl/CvImageGenerationTest.java | 2 +-
.../djl/model/ModelPredictorProducerTest.java | 71 +++++++++++++
.../src/test/resources/data/enhance/fox.png | Bin 0 -> 32045 bytes
20 files changed, 872 insertions(+), 176 deletions(-)
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/ModelPredictorProducer.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/ModelPredictorProducer.java
index 5c3f2db8930..bf52aaefabf 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/ModelPredictorProducer.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/ModelPredictorProducer.java
@@ -20,11 +20,18 @@ import java.io.IOException;
import ai.djl.Application;
import ai.djl.MalformedModelException;
+import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.output.CategoryMask;
+import ai.djl.modality.cv.output.DetectedObjects;
+import ai.djl.modality.cv.output.Joints;
+import ai.djl.ndarray.NDArray;
import ai.djl.repository.zoo.ModelNotFoundException;
import org.apache.camel.RuntimeCamelException;
-import
org.apache.camel.component.djl.model.audio.ZooAudioClassificationPredictor;
-import
org.apache.camel.component.djl.model.cv.CustomImageClassificationPredictor;
-import org.apache.camel.component.djl.model.cv.CustomObjectDetectionPredictor;
+import org.apache.camel.component.djl.model.audio.CustomAudioPredictor;
+import org.apache.camel.component.djl.model.audio.ZooAudioPredictor;
+import org.apache.camel.component.djl.model.cv.CustomCvPredictor;
+import org.apache.camel.component.djl.model.cv.CustomImageGenerationPredictor;
import org.apache.camel.component.djl.model.cv.ZooActionRecognitionPredictor;
import org.apache.camel.component.djl.model.cv.ZooImageClassificationPredictor;
import org.apache.camel.component.djl.model.cv.ZooImageEnhancementPredictor;
@@ -34,6 +41,9 @@ import
org.apache.camel.component.djl.model.cv.ZooObjectDetectionPredictor;
import org.apache.camel.component.djl.model.cv.ZooPoseEstimationPredictor;
import
org.apache.camel.component.djl.model.cv.ZooSemanticSegmentationPredictor;
import org.apache.camel.component.djl.model.cv.ZooWordRecognitionPredictor;
+import org.apache.camel.component.djl.model.nlp.CustomNlpPredictor;
+import org.apache.camel.component.djl.model.nlp.CustomQuestionAnswerPredictor;
+import org.apache.camel.component.djl.model.nlp.CustomWordEmbeddingPredictor;
import org.apache.camel.component.djl.model.nlp.ZooFillMaskPredictor;
import org.apache.camel.component.djl.model.nlp.ZooMachineTranslationPredictor;
import org.apache.camel.component.djl.model.nlp.ZooMultipleChoicePredictor;
@@ -44,8 +54,10 @@ import
org.apache.camel.component.djl.model.nlp.ZooTextEmbeddingPredictor;
import org.apache.camel.component.djl.model.nlp.ZooTextGenerationPredictor;
import
org.apache.camel.component.djl.model.nlp.ZooTokenClassificationPredictor;
import org.apache.camel.component.djl.model.nlp.ZooWordEmbeddingPredictor;
+import org.apache.camel.component.djl.model.tabular.CustomTabularPredictor;
import
org.apache.camel.component.djl.model.tabular.ZooLinearRegressionPredictor;
import
org.apache.camel.component.djl.model.tabular.ZooSoftmaxRegressionPredictor;
+import
org.apache.camel.component.djl.model.timeseries.CustomForecastingPredictor;
import org.apache.camel.component.djl.model.timeseries.ZooForecastingPredictor;
import static ai.djl.Application.CV.ACTION_RECOGNITION;
@@ -132,7 +144,7 @@ public final class ModelPredictorProducer {
// Audio
if (Application.Audio.ANY.getPath().equals(applicationPath)) {
- return new ZooAudioClassificationPredictor(artifactId);
+ return new ZooAudioPredictor(artifactId);
}
// Time Series
@@ -144,12 +156,67 @@ public final class ModelPredictorProducer {
}
public static AbstractPredictor getCustomPredictor(String applicationPath,
String model, String translator) {
+ // CV
if (applicationPath.equals(IMAGE_CLASSIFICATION.getPath())) {
- return new CustomImageClassificationPredictor(model, translator);
+ return new CustomCvPredictor<Classifications>(model, translator);
} else if (applicationPath.equals(OBJECT_DETECTION.getPath())) {
- return new CustomObjectDetectionPredictor(model, translator);
- } else {
- throw new RuntimeCamelException("Application not supported ");
+ return new CustomCvPredictor<DetectedObjects>(model, translator);
+ } else if (SEMANTIC_SEGMENTATION.getPath().equals(applicationPath)) {
+ return new CustomCvPredictor<CategoryMask>(model, translator);
+ } else if (INSTANCE_SEGMENTATION.getPath().equals(applicationPath)) {
+ return new CustomCvPredictor<DetectedObjects>(model, translator);
+ } else if (POSE_ESTIMATION.getPath().equals(applicationPath)) {
+ return new CustomCvPredictor<Joints>(model, translator);
+ } else if (ACTION_RECOGNITION.getPath().equals(applicationPath)) {
+ return new CustomCvPredictor<Classifications>(model, translator);
+ } else if (WORD_RECOGNITION.getPath().equals(applicationPath)) {
+ return new CustomCvPredictor<String>(model, translator);
+ } else if (IMAGE_GENERATION.getPath().equals(applicationPath)) {
+ return new CustomImageGenerationPredictor(model, translator);
+ } else if (IMAGE_ENHANCEMENT.getPath().equals(applicationPath)) {
+ return new CustomCvPredictor<Image>(model, translator);
+ }
+
+ // NLP
+ if (FILL_MASK.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<String[]>(model, translator);
+ } else if (QUESTION_ANSWER.getPath().equals(applicationPath)) {
+ return new CustomQuestionAnswerPredictor(model, translator);
+ } else if (TEXT_CLASSIFICATION.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<Classifications>(model, translator);
+ } else if (SENTIMENT_ANALYSIS.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<Classifications>(model, translator);
+ } else if (TOKEN_CLASSIFICATION.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<Classifications>(model, translator);
+ } else if (WORD_EMBEDDING.getPath().equals(applicationPath)) {
+ return new CustomWordEmbeddingPredictor(model, translator);
+ } else if (TEXT_GENERATION.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<String>(model, translator);
+ } else if (MACHINE_TRANSLATION.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<String>(model, translator);
+ } else if (MULTIPLE_CHOICE.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<String>(model, translator);
+ } else if (TEXT_EMBEDDING.getPath().equals(applicationPath)) {
+ return new CustomNlpPredictor<NDArray>(model, translator);
+ }
+
+ // Tabular
+ if (LINEAR_REGRESSION.getPath().equals(applicationPath)) {
+ return new CustomTabularPredictor(model, translator);
+ } else if (SOFTMAX_REGRESSION.getPath().equals(applicationPath)) {
+ return new CustomTabularPredictor(model, translator);
+ }
+
+ // Audio
+ if (Application.Audio.ANY.getPath().equals(applicationPath)) {
+ return new CustomAudioPredictor(model, translator);
}
+
+ // Time Series
+ if (FORECASTING.getPath().equals(applicationPath)) {
+ return new CustomForecastingPredictor(model, translator);
+ }
+
+ throw new RuntimeCamelException("Application not supported: " +
applicationPath);
}
}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageClassificationPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/CustomAudioPredictor.java
similarity index 53%
copy from
components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageClassificationPredictor.java
copy to
components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/CustomAudioPredictor.java
index 3e98db60853..a2f390e8815 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageClassificationPredictor.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/CustomAudioPredictor.java
@@ -14,81 +14,89 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.camel.component.djl.model.cv;
+package org.apache.camel.component.djl.model.audio;
-import java.io.*;
+import java.io.ByteArrayInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
import ai.djl.Model;
import ai.djl.inference.Predictor;
-import ai.djl.modality.Classifications;
-import ai.djl.modality.cv.Image;
-import ai.djl.modality.cv.ImageFactory;
+import ai.djl.modality.audio.Audio;
+import ai.djl.modality.audio.AudioFactory;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import org.apache.camel.Exchange;
import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
import org.apache.camel.component.djl.model.AbstractPredictor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class CustomImageClassificationPredictor extends AbstractPredictor {
- private static final Logger LOG =
LoggerFactory.getLogger(CustomImageClassificationPredictor.class);
+public class CustomAudioPredictor extends AbstractPredictor {
- private final String modelName;
- private final String translatorName;
+ private static final Logger LOG =
LoggerFactory.getLogger(CustomAudioPredictor.class);
- public CustomImageClassificationPredictor(String modelName, String
translatorName) {
+ protected final String modelName;
+ protected final String translatorName;
+
+ public CustomAudioPredictor(String modelName, String translatorName) {
this.modelName = modelName;
this.translatorName = translatorName;
}
@Override
public void process(Exchange exchange) throws Exception {
- Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
- @SuppressWarnings("unchecked")
- Translator<Image, Classifications> translator
- =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
-
- if (exchange.getIn().getBody() instanceof byte[]) {
+ Object body = exchange.getIn().getBody();
+ String result;
+ if (body instanceof Audio) {
+ result = predict(exchange, exchange.getIn().getBody(Audio.class));
+ } else if (body instanceof byte[]) {
byte[] bytes = exchange.getIn().getBody(byte[].class);
- Classifications result = classify(model, translator, new
ByteArrayInputStream(bytes));
- exchange.getIn().setBody(result);
- } else if (exchange.getIn().getBody() instanceof File) {
- Classifications result = classify(model, translator,
exchange.getIn().getBody(File.class));
- exchange.getIn().setBody(result);
- } else if (exchange.getIn().getBody() instanceof InputStream) {
- Classifications result = classify(model, translator,
exchange.getIn().getBody(InputStream.class));
- exchange.getIn().setBody(result);
+ result = predict(exchange, new ByteArrayInputStream(bytes));
+ } else if (body instanceof File) {
+ result = predict(exchange, exchange.getIn().getBody(File.class));
+ } else if (body instanceof InputStream) {
+ result = predict(exchange,
exchange.getIn().getBody(InputStream.class));
} else {
- throw new RuntimeCamelException("Data type is not supported. Body
should be byte[], InputStream or File");
+ throw new RuntimeCamelException(
+ "Data type is not supported. Body should be
ai.djl.modality.audio.Audio, byte[], InputStream or File");
}
+ exchange.getIn().setBody(result);
}
- private Classifications classify(Model model, Translator<Image,
Classifications> translator, File input) {
+ protected String predict(Exchange exchange, File input) {
try (InputStream fileInputStream = new FileInputStream(input)) {
- Image image =
ImageFactory.getInstance().fromInputStream(fileInputStream);
- return classify(model, translator, image);
+ Audio audio =
AudioFactory.newInstance().fromInputStream(fileInputStream);
+ return predict(exchange, audio);
} catch (IOException e) {
LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
}
}
- private Classifications classify(Model model, Translator<Image,
Classifications> translator, InputStream input) {
+ protected String predict(Exchange exchange, InputStream input) {
try {
- Image image = ImageFactory.getInstance().fromInputStream(input);
- return classify(model, translator, image);
+ Audio audio = AudioFactory.newInstance().fromInputStream(input);
+ return predict(exchange, audio);
} catch (IOException e) {
LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
}
}
- private Classifications classify(Model model, Translator<Image,
Classifications> translator, Image image) {
- try (Predictor<Image, Classifications> predictor =
model.newPredictor(translator)) {
- return predictor.predict(image);
+ protected String predict(Exchange exchange, Audio audio) {
+ Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
+ @SuppressWarnings("unchecked")
+ Translator<Audio, String> translator
+ =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
+
+ exchange.getIn().setHeader(DJLConstants.INPUT, audio);
+ try (Predictor<Audio, String> predictor =
model.newPredictor(translator)) {
+ return predictor.predict(audio);
} catch (TranslateException e) {
- LOG.error("Could not process input or output", e);
throw new RuntimeCamelException("Could not process input or
output", e);
}
}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioPredictor.java
new file mode 100644
index 00000000000..3bd18a049e6
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioPredictor.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.camel.component.djl.model.audio;
+
+import java.io.ByteArrayInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+import ai.djl.Application;
+import ai.djl.MalformedModelException;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.audio.Audio;
+import ai.djl.modality.audio.AudioFactory;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ModelNotFoundException;
+import ai.djl.repository.zoo.ModelZoo;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.training.util.ProgressBar;
+import ai.djl.translate.TranslateException;
+import org.apache.camel.Exchange;
+import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
+import org.apache.camel.component.djl.model.AbstractPredictor;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ZooAudioPredictor extends AbstractPredictor {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(ZooAudioPredictor.class);
+
+ private final ZooModel<Audio, String> model;
+
+ public ZooAudioPredictor(String artifactId) throws ModelNotFoundException,
MalformedModelException, IOException {
+ Criteria<Audio, String> criteria = Criteria.builder()
+ .optApplication(Application.Audio.ANY)
+ .setTypes(Audio.class, String.class)
+ .optArtifactId(artifactId)
+ .optProgress(new ProgressBar())
+ .build();
+ this.model = ModelZoo.loadModel(criteria);
+ }
+
+ @Override
+ public void process(Exchange exchange) throws Exception {
+ Object body = exchange.getIn().getBody();
+ String result;
+ if (body instanceof Audio) {
+ result = predict(exchange, exchange.getIn().getBody(Audio.class));
+ } else if (body instanceof byte[]) {
+ byte[] bytes = exchange.getIn().getBody(byte[].class);
+ result = predict(exchange, new ByteArrayInputStream(bytes));
+ } else if (body instanceof File) {
+ result = predict(exchange, exchange.getIn().getBody(File.class));
+ } else if (body instanceof InputStream) {
+ result = predict(exchange,
exchange.getIn().getBody(InputStream.class));
+ } else {
+ throw new RuntimeCamelException(
+ "Data type is not supported. Body should be
ai.djl.modality.audio.Audio, byte[], InputStream or File");
+ }
+ exchange.getIn().setBody(result);
+ }
+
+ protected String predict(Exchange exchange, File input) {
+ try (InputStream fileInputStream = new FileInputStream(input)) {
+ Audio audio =
AudioFactory.newInstance().fromInputStream(fileInputStream);
+ return predict(exchange, audio);
+ } catch (IOException e) {
+ LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
+ throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
+ }
+ }
+
+ protected String predict(Exchange exchange, InputStream input) {
+ try {
+ Audio audio = AudioFactory.newInstance().fromInputStream(input);
+ return predict(exchange, audio);
+ } catch (IOException e) {
+ LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
+ throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
+ }
+ }
+
+ protected String predict(Exchange exchange, Audio audio) {
+ exchange.getIn().setHeader(DJLConstants.INPUT, audio);
+ try (Predictor<Audio, String> predictor = model.newPredictor()) {
+ return predictor.predict(audio);
+ } catch (TranslateException e) {
+ throw new RuntimeCamelException("Could not process input or
output", e);
+ }
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageClassificationPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomCvPredictor.java
similarity index 64%
rename from
components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageClassificationPredictor.java
rename to
components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomCvPredictor.java
index 3e98db60853..56c3d9a54b0 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageClassificationPredictor.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomCvPredictor.java
@@ -16,76 +16,82 @@
*/
package org.apache.camel.component.djl.model.cv;
-import java.io.*;
+import java.io.ByteArrayInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
import ai.djl.Model;
import ai.djl.inference.Predictor;
-import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import org.apache.camel.Exchange;
import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
import org.apache.camel.component.djl.model.AbstractPredictor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class CustomImageClassificationPredictor extends AbstractPredictor {
- private static final Logger LOG =
LoggerFactory.getLogger(CustomImageClassificationPredictor.class);
+public class CustomCvPredictor<T> extends AbstractPredictor {
- private final String modelName;
- private final String translatorName;
+ private static final Logger LOG =
LoggerFactory.getLogger(CustomCvPredictor.class);
- public CustomImageClassificationPredictor(String modelName, String
translatorName) {
+ protected final String modelName;
+ protected final String translatorName;
+
+ public CustomCvPredictor(String modelName, String translatorName) {
this.modelName = modelName;
this.translatorName = translatorName;
}
@Override
public void process(Exchange exchange) throws Exception {
- Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
- @SuppressWarnings("unchecked")
- Translator<Image, Classifications> translator
- =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
-
- if (exchange.getIn().getBody() instanceof byte[]) {
+ Object body = exchange.getIn().getBody();
+ T result;
+ if (body instanceof byte[]) {
byte[] bytes = exchange.getIn().getBody(byte[].class);
- Classifications result = classify(model, translator, new
ByteArrayInputStream(bytes));
- exchange.getIn().setBody(result);
- } else if (exchange.getIn().getBody() instanceof File) {
- Classifications result = classify(model, translator,
exchange.getIn().getBody(File.class));
- exchange.getIn().setBody(result);
- } else if (exchange.getIn().getBody() instanceof InputStream) {
- Classifications result = classify(model, translator,
exchange.getIn().getBody(InputStream.class));
- exchange.getIn().setBody(result);
+ result = predict(exchange, new ByteArrayInputStream(bytes));
+ } else if (body instanceof File) {
+ result = predict(exchange, exchange.getIn().getBody(File.class));
+ } else if (body instanceof InputStream) {
+ result = predict(exchange,
exchange.getIn().getBody(InputStream.class));
} else {
throw new RuntimeCamelException("Data type is not supported. Body
should be byte[], InputStream or File");
}
+ exchange.getIn().setBody(result);
}
- private Classifications classify(Model model, Translator<Image,
Classifications> translator, File input) {
+ protected T predict(Exchange exchange, File input) {
try (InputStream fileInputStream = new FileInputStream(input)) {
Image image =
ImageFactory.getInstance().fromInputStream(fileInputStream);
- return classify(model, translator, image);
+ return predict(exchange, image);
} catch (IOException e) {
LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
}
}
- private Classifications classify(Model model, Translator<Image,
Classifications> translator, InputStream input) {
+ protected T predict(Exchange exchange, InputStream input) {
try {
Image image = ImageFactory.getInstance().fromInputStream(input);
- return classify(model, translator, image);
+ return predict(exchange, image);
} catch (IOException e) {
LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
}
}
- private Classifications classify(Model model, Translator<Image,
Classifications> translator, Image image) {
- try (Predictor<Image, Classifications> predictor =
model.newPredictor(translator)) {
+ protected T predict(Exchange exchange, Image image) {
+ Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
+ @SuppressWarnings("unchecked")
+ Translator<Image, T> translator
+ =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
+
+ exchange.getIn().setHeader(DJLConstants.INPUT, image);
+ try (Predictor<Image, T> predictor = model.newPredictor(translator)) {
return predictor.predict(image);
} catch (TranslateException e) {
LOG.error("Could not process input or output", e);
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageGenerationPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageGenerationPredictor.java
new file mode 100644
index 00000000000..8bfa1dbcd5e
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomImageGenerationPredictor.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.camel.component.djl.model.cv;
+
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import org.apache.camel.Exchange;
+import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
+import org.apache.camel.component.djl.model.AbstractPredictor;
+
+public class CustomImageGenerationPredictor extends AbstractPredictor {
+
+ private final String modelName;
+ private final String translatorName;
+
+ public CustomImageGenerationPredictor(String modelName, String
translatorName) {
+ this.modelName = modelName;
+ this.translatorName = translatorName;
+ }
+
+ @Override
+ public void process(Exchange exchange) {
+ if (exchange.getIn().getBody() instanceof int[]) {
+ int[] seed = exchange.getIn().getBody(int[].class);
+ Image[] result = predict(exchange, seed);
+ exchange.getIn().setBody(result);
+ } else {
+ throw new RuntimeCamelException("Data type is not supported. Body
should be int[]");
+ }
+ }
+
+ protected Image[] predict(Exchange exchange, int[] seed) {
+ Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
+ @SuppressWarnings("unchecked")
+ Translator<int[], Image[]> translator
+ =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
+
+ exchange.getIn().setHeader(DJLConstants.INPUT, seed);
+ try (Predictor<int[], Image[]> predictor =
model.newPredictor(translator)) {
+ return predictor.predict(seed);
+ } catch (TranslateException e) {
+ throw new RuntimeCamelException("Could not process input or
output", e);
+ }
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomObjectDetectionPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomObjectDetectionPredictor.java
deleted file mode 100644
index f253dd83a1e..00000000000
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomObjectDetectionPredictor.java
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.camel.component.djl.model.cv;
-
-import java.io.*;
-
-import ai.djl.Model;
-import ai.djl.inference.Predictor;
-import ai.djl.modality.cv.Image;
-import ai.djl.modality.cv.ImageFactory;
-import ai.djl.modality.cv.output.DetectedObjects;
-import ai.djl.translate.TranslateException;
-import ai.djl.translate.Translator;
-import org.apache.camel.Exchange;
-import org.apache.camel.RuntimeCamelException;
-import org.apache.camel.component.djl.model.AbstractPredictor;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class CustomObjectDetectionPredictor extends AbstractPredictor {
-
- private static final Logger LOG =
LoggerFactory.getLogger(CustomObjectDetectionPredictor.class);
-
- private final String modelName;
- private final String translatorName;
-
- public CustomObjectDetectionPredictor(String modelName, String
translatorName) {
- this.modelName = modelName;
- this.translatorName = translatorName;
- }
-
- @Override
- public void process(Exchange exchange) {
- Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
- Translator translator =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
-
- if (exchange.getIn().getBody() instanceof byte[]) {
- byte[] bytes = exchange.getIn().getBody(byte[].class);
- DetectedObjects result = classify(model, translator, new
ByteArrayInputStream(bytes));
- exchange.getIn().setBody(result);
- } else if (exchange.getIn().getBody() instanceof File) {
- DetectedObjects result = classify(model, translator,
exchange.getIn().getBody(File.class));
- exchange.getIn().setBody(result);
- } else if (exchange.getIn().getBody() instanceof InputStream) {
- DetectedObjects result = classify(model, translator,
exchange.getIn().getBody(InputStream.class));
- exchange.getIn().setBody(result);
- } else {
- throw new RuntimeCamelException("Data type is not supported. Body
should be byte[], InputStream or File");
- }
- }
-
- public DetectedObjects classify(Model model, Translator translator, Image
image) {
- try (Predictor<Image, DetectedObjects> predictor =
model.newPredictor(translator)) {
- return predictor.predict(image);
- } catch (TranslateException e) {
- LOG.error("Could not process input or output", e);
- throw new RuntimeCamelException("Could not process input or
output", e);
- }
- }
-
- public DetectedObjects classify(Model model, Translator translator, File
input) {
- try (InputStream fileInputStream = new FileInputStream(input)) {
- Image image =
ImageFactory.getInstance().fromInputStream(fileInputStream);
- return classify(model, translator, image);
- } catch (IOException e) {
- LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
- throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
- }
- }
-
- public DetectedObjects classify(Model model, Translator translator,
InputStream input) {
- try {
- Image image = ImageFactory.getInstance().fromInputStream(input);
- return classify(model, translator, image);
- } catch (IOException e) {
- LOG.error(FAILED_TO_TRANSFORM_MESSAGE);
- throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e);
- }
- }
-}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/ZooImageEnhancementPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/ZooImageEnhancementPredictor.java
index a9056497c18..8b4fdc98abe 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/ZooImageEnhancementPredictor.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/ZooImageEnhancementPredictor.java
@@ -31,7 +31,7 @@ public class ZooImageEnhancementPredictor extends
AbstractCvZooPredictor<Image>
public ZooImageEnhancementPredictor(String artifactId) throws
ModelNotFoundException, MalformedModelException,
IOException {
Criteria<Image, Image> criteria = Criteria.builder()
- .optApplication(Application.CV.SEMANTIC_SEGMENTATION)
+ .optApplication(Application.CV.IMAGE_ENHANCEMENT)
.setTypes(Image.class, Image.class)
.optArtifactId(artifactId)
.optProgress(new ProgressBar())
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomNlpPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomNlpPredictor.java
new file mode 100644
index 00000000000..54157ab53b4
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomNlpPredictor.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.camel.component.djl.model.nlp;
+
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import org.apache.camel.Exchange;
+import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
+import org.apache.camel.component.djl.model.AbstractPredictor;
+
+public class CustomNlpPredictor<T> extends AbstractPredictor {
+
+ protected final String modelName;
+ protected final String translatorName;
+
+ public CustomNlpPredictor(String modelName, String translatorName) {
+ this.modelName = modelName;
+ this.translatorName = translatorName;
+ }
+
+ @Override
+ public void process(Exchange exchange) {
+ if (exchange.getIn().getBody() instanceof String) {
+ String input = exchange.getIn().getBody(String.class);
+ T result = predict(exchange, input);
+ exchange.getIn().setBody(result);
+ } else {
+ throw new RuntimeCamelException("Data type is not supported. Body
should be String");
+ }
+ }
+
+ protected T predict(Exchange exchange, String input) {
+ Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
+ @SuppressWarnings("unchecked")
+ Translator<String, T> translator
+ =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
+
+ exchange.getIn().setHeader(DJLConstants.INPUT, input);
+ try (Predictor<String, T> predictor = model.newPredictor(translator)) {
+ return predictor.predict(input);
+ } catch (TranslateException e) {
+ throw new RuntimeCamelException("Could not process input or
output", e);
+ }
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomQuestionAnswerPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomQuestionAnswerPredictor.java
new file mode 100644
index 00000000000..ab162c69567
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomQuestionAnswerPredictor.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.camel.component.djl.model.nlp;
+
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.nlp.qa.QAInput;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import org.apache.camel.Exchange;
+import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
+import org.apache.camel.component.djl.model.AbstractPredictor;
+
+public class CustomQuestionAnswerPredictor extends AbstractPredictor {
+
+ private final String modelName;
+ private final String translatorName;
+
+ public CustomQuestionAnswerPredictor(String modelName, String
translatorName) {
+ this.modelName = modelName;
+ this.translatorName = translatorName;
+ }
+
+ @Override
+ public void process(Exchange exchange) throws Exception {
+ Object body = exchange.getIn().getBody();
+ String result;
+ if (body instanceof QAInput) {
+ QAInput input = exchange.getIn().getBody(QAInput.class);
+ result = predict(exchange, input);
+ } else if (body instanceof String[]) {
+ String[] strs = exchange.getIn().getBody(String[].class);
+ if (strs.length < 2) {
+ throw new RuntimeCamelException("Input String[] should have
two elements");
+ }
+ QAInput input = new QAInput(strs[0], strs[1]);
+ result = predict(exchange, input);
+ } else {
+ throw new RuntimeCamelException("Data type is not supported. Body
should be String[] or QAInput");
+ }
+ exchange.getIn().setBody(result);
+ }
+
+ protected String predict(Exchange exchange, QAInput input) {
+ Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
+ @SuppressWarnings("unchecked")
+ Translator<QAInput, String> translator
+ =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
+
+ exchange.getIn().setHeader(DJLConstants.INPUT, input);
+ try (Predictor<QAInput, String> predictor =
model.newPredictor(translator)) {
+ return predictor.predict(input);
+ } catch (TranslateException e) {
+ throw new RuntimeCamelException("Could not process input or
output", e);
+ }
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioClassificationPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomWordEmbeddingPredictor.java
similarity index 61%
rename from
components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioClassificationPredictor.java
rename to
components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomWordEmbeddingPredictor.java
index bf210081d76..a9df3b4e062 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioClassificationPredictor.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/CustomWordEmbeddingPredictor.java
@@ -14,18 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.camel.component.djl.model.audio;
+package org.apache.camel.component.djl.model.nlp;
+import ai.djl.ndarray.NDList;
import org.apache.camel.Exchange;
-import org.apache.camel.component.djl.model.AbstractPredictor;
-public class ZooAudioClassificationPredictor extends AbstractPredictor {
- public ZooAudioClassificationPredictor(String artifactId) {
- super();
+public class CustomWordEmbeddingPredictor extends CustomNlpPredictor<NDList> {
+
+ public CustomWordEmbeddingPredictor(String modelName, String
translatorName) {
+ super(modelName, translatorName);
}
@Override
- public void process(Exchange exchange) throws Exception {
-
+ public void process(Exchange exchange) {
+ super.process(exchange);
+ // DJL NDList should not be exposed outside the endpoint
+ NDList result = exchange.getIn().getBody(NDList.class);
+ exchange.getIn().setBody(result.encode());
}
}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/ZooMachineTranslationPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/ZooMachineTranslationPredictor.java
index 02bac9975ec..e6df6fb109e 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/ZooMachineTranslationPredictor.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/nlp/ZooMachineTranslationPredictor.java
@@ -30,7 +30,7 @@ public class ZooMachineTranslationPredictor extends
AbstractNlpZooPredictor<Stri
public ZooMachineTranslationPredictor(String artifactId) throws
ModelNotFoundException, MalformedModelException,
IOException {
Criteria<String, String> criteria = Criteria.builder()
- .optApplication(Application.NLP.TEXT_GENERATION)
+ .optApplication(Application.NLP.MACHINE_TRANSLATION)
.setTypes(String.class, String.class)
.optArtifactId(artifactId)
.optProgress(new ProgressBar())
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/CustomTabularPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/CustomTabularPredictor.java
new file mode 100644
index 00000000000..c958a715663
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/CustomTabularPredictor.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.camel.component.djl.model.tabular;
+
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import org.apache.camel.Exchange;
+import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
+import org.apache.camel.component.djl.model.AbstractPredictor;
+
+public class CustomTabularPredictor extends AbstractPredictor {
+
+ protected final String modelName;
+ protected final String translatorName;
+
+ public CustomTabularPredictor(String modelName, String translatorName) {
+ this.modelName = modelName;
+ this.translatorName = translatorName;
+ }
+
+ @Override
+ public void process(Exchange exchange) throws Exception {
+ Object input = exchange.getIn().getBody();
+ Object result = predict(exchange, input);
+ exchange.getIn().setBody(result);
+ }
+
+ @SuppressWarnings({ "unchecked", "rawtypes" })
+ protected Object predict(Exchange exchange, Object input) {
+ Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
+ Translator translator
+ =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
+
+ exchange.getIn().setHeader(DJLConstants.INPUT, input);
+ try (Predictor predictor = model.newPredictor(translator)) {
+ return predictor.predict(input);
+ } catch (TranslateException e) {
+ throw new RuntimeCamelException("Could not process input or
output", e);
+ }
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooLinearRegressionPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooLinearRegressionPredictor.java
index 3d221f95127..7b7f023ddb6 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooLinearRegressionPredictor.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooLinearRegressionPredictor.java
@@ -26,6 +26,6 @@ public class ZooLinearRegressionPredictor extends
AbstractPredictor {
@Override
public void process(Exchange exchange) throws Exception {
-
+ // TODO: impl
}
}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooSoftmaxRegressionPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooSoftmaxRegressionPredictor.java
index 31dba3a7f77..1e51de5409e 100644
---
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooSoftmaxRegressionPredictor.java
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/tabular/ZooSoftmaxRegressionPredictor.java
@@ -26,6 +26,6 @@ public class ZooSoftmaxRegressionPredictor extends
AbstractPredictor {
@Override
public void process(Exchange exchange) throws Exception {
-
+ // TODO: impl
}
}
diff --git
a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/timeseries/CustomForecastingPredictor.java
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/timeseries/CustomForecastingPredictor.java
new file mode 100644
index 00000000000..d54f7e2a74f
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/timeseries/CustomForecastingPredictor.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.camel.component.djl.model.timeseries;
+
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.timeseries.Forecast;
+import ai.djl.timeseries.TimeSeriesData;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import org.apache.camel.Exchange;
+import org.apache.camel.RuntimeCamelException;
+import org.apache.camel.component.djl.DJLConstants;
+import org.apache.camel.component.djl.model.AbstractPredictor;
+
+public class CustomForecastingPredictor extends AbstractPredictor {
+
+ protected final String modelName;
+ protected final String translatorName;
+
+ public CustomForecastingPredictor(String modelName, String translatorName)
{
+ this.modelName = modelName;
+ this.translatorName = translatorName;
+ }
+
+ @Override
+ public void process(Exchange exchange) throws Exception {
+ if (exchange.getIn().getBody() instanceof TimeSeriesData) {
+ TimeSeriesData input =
exchange.getIn().getBody(TimeSeriesData.class);
+ Forecast result = predict(exchange, input);
+ exchange.getIn().setBody(result);
+ } else {
+ throw new RuntimeCamelException("Data type is not supported. Body
should be TimeSeriesData");
+ }
+ }
+
+ protected Forecast predict(Exchange exchange, TimeSeriesData input) {
+ Model model =
exchange.getContext().getRegistry().lookupByNameAndType(modelName, Model.class);
+ @SuppressWarnings("unchecked")
+ Translator<TimeSeriesData, Forecast> translator
+ =
exchange.getContext().getRegistry().lookupByNameAndType(translatorName,
Translator.class);
+
+ exchange.getIn().setHeader(DJLConstants.INPUT, input);
+ try (Predictor<TimeSeriesData, Forecast> predictor =
model.newPredictor(translator)) {
+ return predictor.predict(input);
+ } catch (TranslateException e) {
+ throw new RuntimeCamelException("Could not process input or
output", e);
+ }
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/AudioLocalTest.java
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/AudioLocalTest.java
new file mode 100644
index 00000000000..797b60e0b8e
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/AudioLocalTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.camel.component.djl;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.file.Files;
+
+import ai.djl.MalformedModelException;
+import ai.djl.Model;
+import ai.djl.modality.audio.AudioFactory;
+import ai.djl.modality.audio.translator.SpeechRecognitionTranslator;
+import ai.djl.util.ZipUtils;
+import org.apache.camel.builder.RouteBuilder;
+import org.apache.camel.test.junit5.CamelTestSupport;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class AudioLocalTest extends CamelTestSupport {
+ private static final Logger LOG =
LoggerFactory.getLogger(AudioLocalTest.class);
+
+ private static final String MODEL_URL =
"https://resources.djl.ai/test-models/pytorch/wav2vec2.zip";
+ private static final String MODEL_NAME = "wav2vec2.ptl";
+
+ @BeforeAll
+ public static void setupDefaultEngine() {
+ // A PyTorch model
+ System.setProperty("ai.djl.default_engine", "PyTorch");
+ }
+
+ @Test
+ void testDJL() throws Exception {
+ LOG.info("Read and load local model");
+ loadLocalModel();
+
+ LOG.info("Starting route to infer");
+
context.createProducerTemplate().sendBody("controlbus:route?routeId=audio&action=start",
null);
+ var mock = getMockEndpoint("mock:result");
+ mock.expectedMinimumMessageCount(1);
+ mock.await();
+ }
+
+ @Override
+ protected RouteBuilder createRouteBuilder() {
+ return new RouteBuilder() {
+ public void configure() {
+ from("timer:testDJL?repeatCount=1")
+ .routeId("audio").autoStartup(false)
+ .process(exchange -> {
+ var wave =
"https://resources.djl.ai/audios/speech.wav";
+ var audio =
AudioFactory.newInstance().fromUrl(wave);
+ exchange.getIn().setBody(audio);
+ })
+ .to("djl:audio?model=MyModel&translator=MyTranslator")
+ .log("Result: ${body}")
+ .to("mock:result");
+ }
+ };
+ }
+
+ private void loadLocalModel() throws IOException, MalformedModelException,
URISyntaxException {
+ // Load a model
+ var model = Model.newInstance(MODEL_NAME);
+ // TfModel doesn't allow direct loading from remote input stream yet
+ // https://github.com/deepjavalibrary/djl/issues/3303
+ var modelDir = Files.createTempDirectory(MODEL_NAME);
+ ZipUtils.unzip(new URI(MODEL_URL).toURL().openStream(), modelDir);
+ model.load(modelDir);
+
+ // Bind model beans
+ context.getRegistry().bind("MyModel", model);
+ context.getRegistry().bind("MyTranslator", new
SpeechRecognitionTranslator());
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageEnhancementLocalTest.java
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageEnhancementLocalTest.java
new file mode 100644
index 00000000000..ed04d1183bc
--- /dev/null
+++
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageEnhancementLocalTest.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.camel.component.djl;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.file.Files;
+
+import ai.djl.MalformedModelException;
+import ai.djl.Model;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.DataType;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorContext;
+import ai.djl.util.TarUtils;
+import org.apache.camel.builder.RouteBuilder;
+import org.apache.camel.test.junit5.CamelTestSupport;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class CvImageEnhancementLocalTest extends CamelTestSupport {
+ private static final Logger LOG =
LoggerFactory.getLogger(CvImageEnhancementLocalTest.class);
+
+ private static final String MODEL_URL =
"https://storage.googleapis.com/tfhub-modules/captain-pool/esrgan-tf2/1.tar.gz";
+ private static final String MODEL_NAME = "esrgan-tf2";
+
+ @BeforeAll
+ public static void setupDefaultEngine() {
+ // ESRGAN-TF2 is a TensorFlow model
+ System.setProperty("ai.djl.default_engine", "TensorFlow");
+ }
+
+ @Test
+ void testDJL() throws Exception {
+ LOG.info("Read and load local model");
+ loadLocalModel();
+
+ LOG.info("Starting route to infer");
+
context.createProducerTemplate().sendBody("controlbus:route?routeId=image_enhancement&action=start",
null);
+ var mock = getMockEndpoint("mock:result");
+ mock.expectedMinimumMessageCount(1);
+ mock.await();
+ }
+
+ @Override
+ protected RouteBuilder createRouteBuilder() {
+ return new RouteBuilder() {
+ public void configure() {
+
from("file:src/test/resources/data/enhance?recursive=true&noop=true")
+ .routeId("image_enhancement").autoStartup(false)
+ .convertBodyTo(byte[].class)
+
.to("djl:cv/image_enhancement?model=MyModel&translator=MyTranslator")
+ .log("${header.CamelFileName} = ${body}")
+ .process(exchange -> {
+ var image = exchange.getIn().getBody(Image.class);
+ var os = new ByteArrayOutputStream();
+ image.save(os, "png");
+ exchange.getIn().setBody(os.toByteArray());
+ })
+
.to("file:target/output?fileName=CvImageEnhancementLocalTest-${date:now:ssSSS}.png")
+ .to("mock:result");
+ }
+ };
+ }
+
+ private void loadLocalModel() throws IOException, MalformedModelException,
URISyntaxException {
+ // Load a model
+ var model = Model.newInstance(MODEL_NAME);
+ // TfModel doesn't allow direct loading from remote input stream yet
+ // https://github.com/deepjavalibrary/djl/issues/3303
+ var modelDir = Files.createTempDirectory(MODEL_NAME);
+ TarUtils.untar(new URI(MODEL_URL).toURL().openStream(), modelDir,
true);
+ model.load(modelDir);
+
+ // Bind model beans
+ context.getRegistry().bind("MyModel", model);
+ context.getRegistry().bind("MyTranslator", new MyTranslator());
+ }
+
+ private static class MyTranslator implements Translator<Image, Image> {
+ @Override
+ public NDList processInput(TranslatorContext ctx, Image input) {
+ NDManager manager = ctx.getNDManager();
+ return new
NDList(input.toNDArray(manager).toType(DataType.FLOAT32, false));
+ }
+
+ @Override
+ public Image processOutput(TranslatorContext ctx, NDList list) {
+ NDArray output = list.get(0).clip(0, 255);
+ return ImageFactory.getInstance().fromNDArray(output.squeeze());
+ }
+ }
+}
diff --git
a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageGenerationTest.java
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageGenerationTest.java
index d866e22518e..5ff283d3b72 100644
---
a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageGenerationTest.java
+++
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageGenerationTest.java
@@ -54,7 +54,7 @@ public class CvImageGenerationTest extends CamelTestSupport {
image.save(os, "png");
exchange.getIn().setBody(os.toByteArray());
})
-
.to("file:target/output?fileName=ImageGenerationTest-${date:now:ssSSS}.png")
+
.to("file:target/output?fileName=CvImageGenerationTest-${date:now:ssSSS}.png")
.to("mock:result");
}
};
diff --git
a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/model/ModelPredictorProducerTest.java
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/model/ModelPredictorProducerTest.java
index 3dd441c0159..136a66ba5f0 100644
---
a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/model/ModelPredictorProducerTest.java
+++
b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/model/ModelPredictorProducerTest.java
@@ -20,6 +20,9 @@ import java.io.IOException;
import ai.djl.MalformedModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
+import org.apache.camel.component.djl.model.audio.CustomAudioPredictor;
+import org.apache.camel.component.djl.model.cv.CustomCvPredictor;
+import org.apache.camel.component.djl.model.cv.CustomImageGenerationPredictor;
import org.apache.camel.component.djl.model.cv.ZooActionRecognitionPredictor;
import org.apache.camel.component.djl.model.cv.ZooImageClassificationPredictor;
import org.apache.camel.component.djl.model.cv.ZooImageGenerationPredictor;
@@ -27,13 +30,19 @@ import
org.apache.camel.component.djl.model.cv.ZooInstanceSegmentationPredictor;
import org.apache.camel.component.djl.model.cv.ZooObjectDetectionPredictor;
import org.apache.camel.component.djl.model.cv.ZooPoseEstimationPredictor;
import
org.apache.camel.component.djl.model.cv.ZooSemanticSegmentationPredictor;
+import org.apache.camel.component.djl.model.nlp.CustomNlpPredictor;
+import org.apache.camel.component.djl.model.nlp.CustomQuestionAnswerPredictor;
+import org.apache.camel.component.djl.model.nlp.CustomWordEmbeddingPredictor;
import org.apache.camel.component.djl.model.nlp.ZooQuestionAnswerPredictor;
import org.apache.camel.component.djl.model.nlp.ZooSentimentAnalysisPredictor;
import org.apache.camel.component.djl.model.nlp.ZooWordEmbeddingPredictor;
+import org.apache.camel.component.djl.model.tabular.CustomTabularPredictor;
+import
org.apache.camel.component.djl.model.timeseries.CustomForecastingPredictor;
import org.apache.camel.component.djl.model.timeseries.ZooForecastingPredictor;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
+import static
org.apache.camel.component.djl.model.ModelPredictorProducer.getCustomPredictor;
import static
org.apache.camel.component.djl.model.ModelPredictorProducer.getZooPredictor;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
@@ -91,4 +100,66 @@ class ModelPredictorProducerTest {
assertInstanceOf(ZooForecastingPredictor.class,
getZooPredictor("timeseries/forecasting",
"ai.djl.pytorch:deepar:0.0.1"));
}
+
+ @Test
+ void testGetCustomPredictor() {
+ var modelName = "MyModel";
+ var translatorName = "MyTranslator";
+
+ // CV
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/image_classification", modelName,
translatorName));
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/object_detection", modelName,
translatorName));
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/semantic_segmentation", modelName,
translatorName));
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/instance_segmentation", modelName,
translatorName));
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/pose_estimation", modelName,
translatorName));
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/action_recognition", modelName,
translatorName));
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/word_recognition", modelName,
translatorName));
+ assertInstanceOf(CustomImageGenerationPredictor.class,
+ getCustomPredictor("cv/image_generation", modelName,
translatorName));
+ assertInstanceOf(CustomCvPredictor.class,
+ getCustomPredictor("cv/image_enhancement", modelName,
translatorName));
+
+ // NLP
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/fill_mask", modelName,
translatorName));
+ assertInstanceOf(CustomQuestionAnswerPredictor.class,
+ getCustomPredictor("nlp/question_answer", modelName,
translatorName));
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/text_classification", modelName,
translatorName));
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/sentiment_analysis", modelName,
translatorName));
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/token_classification", modelName,
translatorName));
+ assertInstanceOf(CustomWordEmbeddingPredictor.class,
+ getCustomPredictor("nlp/word_embedding", modelName,
translatorName));
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/text_generation", modelName,
translatorName));
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/machine_translation", modelName,
translatorName));
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/multiple_choice", modelName,
translatorName));
+ assertInstanceOf(CustomNlpPredictor.class,
+ getCustomPredictor("nlp/text_embedding", modelName,
translatorName));
+
+ // Tabular
+ assertInstanceOf(CustomTabularPredictor.class,
+ getCustomPredictor("tabular/linear_regression", modelName,
translatorName));
+ assertInstanceOf(CustomTabularPredictor.class,
+ getCustomPredictor("tabular/softmax_regression", modelName,
translatorName));
+
+ // Audio
+ assertInstanceOf(CustomAudioPredictor.class,
+ getCustomPredictor("audio", modelName, translatorName));
+
+ // Time Series
+ assertInstanceOf(CustomForecastingPredictor.class,
+ getCustomPredictor("timeseries/forecasting", modelName,
translatorName));
+ }
}
diff --git
a/components/camel-ai/camel-djl/src/test/resources/data/enhance/fox.png
b/components/camel-ai/camel-djl/src/test/resources/data/enhance/fox.png
new file mode 100644
index 00000000000..16459ad7fd9
Binary files /dev/null and
b/components/camel-ai/camel-djl/src/test/resources/data/enhance/fox.png differ