lordgamez commented on code in PR #2107:
URL: https://github.com/apache/nifi-minifi-cpp/pull/2107#discussion_r3258654056


##########
extensions/llamacpp/processors/DefaultLlamaContext.cpp:
##########
@@ -73,9 +78,27 @@ DefaultLlamaContext::DefaultLlamaContext(const 
std::filesystem::path& model_path
     llama_sampler_chain_add(llama_sampler_, 
llama_sampler_init_temp(*llama_sampler_params.temperature));
   }
   llama_sampler_chain_add(llama_sampler_, 
llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
+
+  if (!multimodal_model_path) {
+    logger->log_info("No multimodal model path provided");
+    return;
+  }
+
+  mtmd_context_params mparams = mtmd_context_params_default();
+  mparams.use_gpu = false;
+  mparams.flash_attn_type  = LLAMA_FLASH_ATTN_TYPE_DISABLED;
+
+  multimodal_ctx_ = 
mtmd_init_from_file(multimodal_model_path->string().c_str(), llama_model_, 
mparams);
+  if (!multimodal_ctx_) {
+    throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION, 
fmt::format("Failed to load multimodal model from '{}'", 
multimodal_model_path->string()));
+  }
+
+  logger->log_info("Successfully loaded multimodal model from '{}'", 
multimodal_model_path->string());

Review Comment:
   I would extract this to a separate function and have something like
   ```
   if (multimodal_model_path) {
     initializeMultimodalContext();
   }
   ```



##########
extensions/llamacpp/processors/DefaultLlamaContext.cpp:
##########
@@ -85,47 +108,96 @@ DefaultLlamaContext::~DefaultLlamaContext() {
 }
 
 std::optional<std::string> DefaultLlamaContext::applyTemplate(const 
std::vector<LlamaChatMessage>& messages) {
-  std::vector<llama_chat_message> llama_messages;
-  llama_messages.reserve(messages.size());
-  std::transform(messages.begin(), messages.end(), 
std::back_inserter(llama_messages),
-                 [](const LlamaChatMessage& msg) { return 
llama_chat_message{.role = msg.role.c_str(), .content = msg.content.c_str()}; 
});
-  std::string text;
-  text.resize(DEFAULT_BUFFER_SIZE);
-  const char * chat_template = llama_model_chat_template(llama_model_, 
nullptr);
-  int32_t res_size = llama_chat_apply_template(chat_template, 
llama_messages.data(), llama_messages.size(), true, text.data(), 
gsl::narrow<int32_t>(text.size()));
-  if (res_size < 0) {
+  if (!chat_template_) {
     return std::nullopt;
   }
-  if (res_size > gsl::narrow<int32_t>(text.size())) {
-    text.resize(res_size);
-    res_size = llama_chat_apply_template(chat_template, llama_messages.data(), 
llama_messages.size(), true, text.data(), gsl::narrow<int32_t>(text.size()));
-    if (res_size < 0) {
-      return std::nullopt;
-    }
+  common_chat_templates_inputs inputs;
+  for (auto& msg : messages) {
+    common_chat_msg chat_msg;
+    chat_msg.role = msg.role;
+    chat_msg.content = msg.content;
+    inputs.messages.push_back(std::move(chat_msg));
   }
-  text.resize(res_size);
+  inputs.enable_thinking = false;  // TODO(adebreceni): MINIFICPP-2800 
common_chat_templates_support_enable_thinking(chat_template_.get());
 
-  return text;
+  return common_chat_templates_apply(chat_template_.get(), inputs).prompt;
 }
 
-std::expected<GenerationResult, std::string> 
DefaultLlamaContext::generate(const std::string& input, 
std::function<void(std::string_view/*token*/)> token_handler) {
+namespace {
+
+struct mtmd_bitmap_deleter {
+  void operator()(mtmd_bitmap* val) { mtmd_bitmap_free(val); }
+};
+using unique_bitmap_ptr = std::unique_ptr<mtmd_bitmap, mtmd_bitmap_deleter>;
+
+struct mtmd_input_chunks_deleter {
+  void operator()(mtmd_input_chunks* val) { mtmd_input_chunks_free(val); }
+};
+using unique_mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, 
mtmd_input_chunks_deleter>;
+
+}  // namespace
+
+std::expected<GenerationResult, std::string> 
DefaultLlamaContext::generate(const std::string& prompt, const 
std::vector<std::vector<std::byte>>& files,
+      std::function<void(std::string_view/*token*/)> token_handler) {
   GenerationResult result{};
   auto start_time = std::chrono::steady_clock::now();
+  llama_memory_seq_rm(llama_get_memory(llama_ctx_), 0, -1, -1);
   const llama_vocab * vocab = llama_model_get_vocab(llama_model_);
-  std::vector<llama_token> tokenized_input = tokenizeInput(vocab, input);
-  result.num_tokens_in = gsl::narrow<uint64_t>(tokenized_input.size());
+  llama_pos n_past = 0;
+  std::vector<llama_token> tokenized_input;
+  llama_batch batch = llama_batch_init(1, 0, 1);
+  auto batch_deleter = gsl::finally([&] {llama_batch_free(batch);});
+  batch.n_tokens = 1;
+  batch.n_seq_id[0] = 1;
+  batch.seq_id[0][0] = 0;
+  batch.logits[0] = true;

Review Comment:
   This can be moved before the `while (decode_status == 0) {` line as it is 
only used in the loop. Also it might be better to use a wrapper object for 
automatic initialization and destruction.



##########
extensions/llamacpp/processors/DefaultLlamaContext.cpp:
##########
@@ -85,47 +108,96 @@ DefaultLlamaContext::~DefaultLlamaContext() {
 }
 
 std::optional<std::string> DefaultLlamaContext::applyTemplate(const 
std::vector<LlamaChatMessage>& messages) {
-  std::vector<llama_chat_message> llama_messages;
-  llama_messages.reserve(messages.size());
-  std::transform(messages.begin(), messages.end(), 
std::back_inserter(llama_messages),
-                 [](const LlamaChatMessage& msg) { return 
llama_chat_message{.role = msg.role.c_str(), .content = msg.content.c_str()}; 
});
-  std::string text;
-  text.resize(DEFAULT_BUFFER_SIZE);
-  const char * chat_template = llama_model_chat_template(llama_model_, 
nullptr);
-  int32_t res_size = llama_chat_apply_template(chat_template, 
llama_messages.data(), llama_messages.size(), true, text.data(), 
gsl::narrow<int32_t>(text.size()));
-  if (res_size < 0) {
+  if (!chat_template_) {
     return std::nullopt;
   }
-  if (res_size > gsl::narrow<int32_t>(text.size())) {
-    text.resize(res_size);
-    res_size = llama_chat_apply_template(chat_template, llama_messages.data(), 
llama_messages.size(), true, text.data(), gsl::narrow<int32_t>(text.size()));
-    if (res_size < 0) {
-      return std::nullopt;
-    }
+  common_chat_templates_inputs inputs;
+  for (auto& msg : messages) {
+    common_chat_msg chat_msg;
+    chat_msg.role = msg.role;
+    chat_msg.content = msg.content;
+    inputs.messages.push_back(std::move(chat_msg));
   }
-  text.resize(res_size);
+  inputs.enable_thinking = false;  // TODO(adebreceni): MINIFICPP-2800 
common_chat_templates_support_enable_thinking(chat_template_.get());
 
-  return text;
+  return common_chat_templates_apply(chat_template_.get(), inputs).prompt;
 }
 
-std::expected<GenerationResult, std::string> 
DefaultLlamaContext::generate(const std::string& input, 
std::function<void(std::string_view/*token*/)> token_handler) {
+namespace {
+
+struct mtmd_bitmap_deleter {
+  void operator()(mtmd_bitmap* val) { mtmd_bitmap_free(val); }
+};
+using unique_bitmap_ptr = std::unique_ptr<mtmd_bitmap, mtmd_bitmap_deleter>;
+
+struct mtmd_input_chunks_deleter {
+  void operator()(mtmd_input_chunks* val) { mtmd_input_chunks_free(val); }
+};
+using unique_mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, 
mtmd_input_chunks_deleter>;
+
+}  // namespace
+
+std::expected<GenerationResult, std::string> 
DefaultLlamaContext::generate(const std::string& prompt, const 
std::vector<std::vector<std::byte>>& files,
+      std::function<void(std::string_view/*token*/)> token_handler) {
   GenerationResult result{};
   auto start_time = std::chrono::steady_clock::now();
+  llama_memory_seq_rm(llama_get_memory(llama_ctx_), 0, -1, -1);
   const llama_vocab * vocab = llama_model_get_vocab(llama_model_);
-  std::vector<llama_token> tokenized_input = tokenizeInput(vocab, input);
-  result.num_tokens_in = gsl::narrow<uint64_t>(tokenized_input.size());
+  llama_pos n_past = 0;
+  std::vector<llama_token> tokenized_input;
+  llama_batch batch = llama_batch_init(1, 0, 1);
+  auto batch_deleter = gsl::finally([&] {llama_batch_free(batch);});
+  batch.n_tokens = 1;
+  batch.n_seq_id[0] = 1;
+  batch.seq_id[0][0] = 0;
+  batch.logits[0] = true;
+  int32_t decode_status = 0;
+  if (multimodal_ctx_) {
+    if (files.empty()) {
+      return std::unexpected{"Multimodal input requires at least one file"};
+    }
+    std::vector<unique_bitmap_ptr> bitmaps;
+    for (auto& file : files) {
+      unique_bitmap_ptr 
bitmap{mtmd_helper_bitmap_init_from_buf(multimodal_ctx_, reinterpret_cast<const 
unsigned char*>(file.data()), file.size())};
+      if (!bitmap) {
+        throw Exception(PROCESSOR_EXCEPTION, "Failed to create multimodal 
bitmap from buffer");
+      }
+      bitmaps.push_back(std::move(bitmap));
+    }
+    mtmd_input_text inp_txt = {
+      .text = prompt.c_str(),
+      .add_special = true,
+      .parse_special = true,
+    };
+    unique_mtmd_input_chunks_ptr chunks{mtmd_input_chunks_init()};
+    auto bitmap_c_ptrs = bitmaps | ranges::views::transform([] (auto& ptr) 
{return static_cast<const mtmd_bitmap*>(ptr.get());}) | 
ranges::to<std::vector>();
+    auto tokenized = mtmd_tokenize(multimodal_ctx_, chunks.get(), &inp_txt, 
bitmap_c_ptrs.data(), bitmap_c_ptrs.size());
+    if (tokenized != 0) {
+      throw Exception(PROCESSOR_EXCEPTION, fmt::format("Failed to tokenize 
multimodal prompt, error: {}", tokenized));
+    }
+    auto status = mtmd_helper_eval_chunks(multimodal_ctx_, llama_ctx_, 
chunks.get(), 0, 0, 1, true, &n_past);
+    if (status != 0) {
+      throw Exception(PROCESSOR_EXCEPTION, fmt::format("Failed to eval 
multimodal chunks, error: {}", status));
+    }

Review Comment:
   I would extract this to a separate function. Additionally why is 
llama_decode run in case of the of string tokenization, but not in the 
multimodal use case?



##########
extensions/llamacpp/tests/RunLlamaCppInferenceTests.cpp:
##########
@@ -37,10 +37,16 @@ class MockLlamaContext : public processors::LlamaContext {
     return "Test input";
   }
 
-  std::expected<processors::GenerationResult, std::string> generate(const 
std::string& input, std::function<void(std::string_view/*token*/)> 
token_handler) override {
+  std::expected<processors::GenerationResult, std::string> generate(const 
std::string& input, const std::vector<std::vector<std::byte>>& files,
+        std::function<void(std::string_view/*token*/)> token_handler) override 
{
     if (fail_generation_) {
       return std::unexpected{"Generation failed"};
     }
+    if (multimodal_) {
+      if (files.empty()) {

Review Comment:
   This could be merged to `if (multimodal_ && files.empty())`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to