fgerlits commented on code in PR #1903:
URL: https://github.com/apache/nifi-minifi-cpp/pull/1903#discussion_r2033337827
##########
PROCESSORS.md:
##########
@@ -1727,7 +1728,42 @@ In the list below, the names of required properties
appear in bold. Any other pr
| lastModifiedTime | success | The timestamp of when the file's content
changed in the filesystem as 'yyyy-MM-dd'T'HH:mm:ss'.
|
| creationTime | success | The timestamp of when the file was created
in the filesystem as 'yyyy-MM-dd'T'HH:mm:ss'.
|
| lastAccessTime | success | The timestamp of when the file was
accessed in the filesystem as 'yyyy-MM-dd'T'HH:mm:ss'.
|
-| size | success | The size of the file in bytes.
|
+| size | success | The size of the file in bytes.
Review Comment:
the | at the end of the line got deleted
##########
docker/test/integration/cluster/containers/MinifiContainer.py:
##########
@@ -47,6 +47,7 @@ def __init__(self):
self.enable_openssl_fips_mode = True
else:
self.enable_openssl_fips_mode = False
+ self.download_llama_model = True
Review Comment:
Should this default to False, like the other options (except fips)?
##########
extensions/llamacpp/tests/RunLlamaCppInferenceTests.cpp:
##########
@@ -0,0 +1,257 @@
+/**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "unit/TestBase.h"
+#include "unit/Catch.h"
+#include "RunLlamaCppInference.h"
+#include "unit/SingleProcessorTestController.h"
+#include "core/FlowFile.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::test {
+
+class MockLlamaContext : public processors::LlamaContext {
+ public:
+ std::string applyTemplate(const std::vector<processors::LlamaChatMessage>&
messages) override {
+ messages_ = messages;
+ return "Test input";
+ }
+
+ nonstd::expected<uint64_t, std::string> generate(const std::string& input,
std::function<void(std::string_view/*token*/)> token_handler) override {
+ if (fail_generation_) {
+ return nonstd::make_unexpected("Generation failed");
+ }
+ input_ = input;
+ token_handler("Test ");
+ token_handler("generated");
+ token_handler(" content");
+ return 3;
+ }
+
+ [[nodiscard]] const std::vector<processors::LlamaChatMessage>& getMessages()
const {
+ return messages_;
+ }
+
+ [[nodiscard]] const std::string& getInput() const {
+ return input_;
+ }
+
+ void setFailure() {
+ fail_generation_ = true;
+ }
+
+ private:
+ bool fail_generation_{false};
+ std::vector<processors::LlamaChatMessage> messages_;
+ std::string input_;
+};
+
+TEST_CASE("Prompt is generated correctly with default parameters") {
+ auto mock_llama_context = std::make_unique<MockLlamaContext>();
+ auto mock_llama_context_ptr = mock_llama_context.get();
+ std::filesystem::path test_model_path;
+ processors::LlamaSamplerParams test_sampler_params;
+ processors::LlamaContextParams test_context_params;
+ processors::LlamaContext::testSetProvider(
+ [&](const std::filesystem::path& model_path, const
processors::LlamaSamplerParams& sampler_params, const
processors::LlamaContextParams& context_params) {
+ test_model_path = model_path;
+ test_sampler_params = sampler_params;
+ test_context_params = context_params;
+ return std::move(mock_llama_context);
+ });
+ minifi::test::SingleProcessorTestController
controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference"));
+
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath,
"Dummy model");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt,
"Question: What is the answer to life, the universe and everything?");
+
+ auto results = controller.trigger(minifi::test::InputFlowFileData{.content =
"42", .attributes = {}});
+ CHECK(test_model_path == "Dummy model");
+ CHECK(test_sampler_params.temperature == 0.8F);
+ CHECK(test_sampler_params.top_k == 40);
+ CHECK(test_sampler_params.top_p == 0.9F);
+ CHECK(test_sampler_params.min_p == std::nullopt);
+ CHECK(test_sampler_params.min_keep == 0);
+ CHECK(test_context_params.n_ctx == 4096);
+ CHECK(test_context_params.n_batch == 2048);
+ CHECK(test_context_params.n_ubatch == 512);
+ CHECK(test_context_params.n_seq_max == 1);
+ CHECK(test_context_params.n_threads == 4);
+ CHECK(test_context_params.n_threads_batch == 4);
+
+ REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
+ auto& output_flow_file =
results.at(processors::RunLlamaCppInference::Success)[0];
+ CHECK(controller.plan->getContent(output_flow_file) == "Test generated
content");
+ CHECK(mock_llama_context_ptr->getInput() == "Test input");
+ REQUIRE(mock_llama_context_ptr->getMessages().size() == 2);
+ CHECK(mock_llama_context_ptr->getMessages()[0].role == "system");
+ CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful
assistant. You are given a question with some possible input data otherwise
called flow file content. "
+ "You are expected
to generate a response based on the question and the input data.");
+ CHECK(mock_llama_context_ptr->getMessages()[1].role == "user");
+ CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or
flow file content):\n42\n\nQuestion: What is the answer to life, the universe
and everything?");
+}
+
+TEST_CASE("Prompt is generated correctly with custom parameters") {
+ auto mock_llama_context = std::make_unique<MockLlamaContext>();
+ auto mock_llama_context_ptr = mock_llama_context.get();
+ std::filesystem::path test_model_path;
+ processors::LlamaSamplerParams test_sampler_params;
+ processors::LlamaContextParams test_context_params;
+ processors::LlamaContext::testSetProvider(
+ [&](const std::filesystem::path& model_path, const
processors::LlamaSamplerParams& sampler_params, const
processors::LlamaContextParams& context_params) {
+ test_model_path = model_path;
+ test_sampler_params = sampler_params;
+ test_context_params = context_params;
+ return std::move(mock_llama_context);
+ });
+ minifi::test::SingleProcessorTestController
controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference"));
+
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath,
"/path/to/model");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt,
"Question: What is the answer to life, the universe and everything?");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature,
"0.4");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK,
"20");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP,
"");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP,
"0.1");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinKeep,
"1");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TextContextSize,
"4096");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::LogicalMaximumBatchSize,
"1024");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::PhysicalMaximumBatchSize,
"796");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MaxNumberOfSequences,
"2");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForGeneration,
"12");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForBatchProcessing,
"8");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::SystemPrompt,
"Whatever");
+
+ auto results = controller.trigger(minifi::test::InputFlowFileData{.content =
"42", .attributes = {}});
+ CHECK(test_model_path == "/path/to/model");
+ CHECK(test_sampler_params.temperature == 0.4F);
+ CHECK(test_sampler_params.top_k == 20);
+ CHECK(test_sampler_params.top_p == std::nullopt);
+ CHECK(test_sampler_params.min_p == 0.1F);
+ CHECK(test_sampler_params.min_keep == 1);
+ CHECK(test_context_params.n_ctx == 4096);
+ CHECK(test_context_params.n_batch == 1024);
+ CHECK(test_context_params.n_ubatch == 796);
+ CHECK(test_context_params.n_seq_max == 2);
+ CHECK(test_context_params.n_threads == 12);
+ CHECK(test_context_params.n_threads_batch == 8);
+
+ REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
+ auto& output_flow_file =
results.at(processors::RunLlamaCppInference::Success)[0];
+ CHECK(controller.plan->getContent(output_flow_file) == "Test generated
content");
+ CHECK(mock_llama_context_ptr->getInput() == "Test input");
+ REQUIRE(mock_llama_context_ptr->getMessages().size() == 2);
+ CHECK(mock_llama_context_ptr->getMessages()[0].role == "system");
+ CHECK(mock_llama_context_ptr->getMessages()[0].content == "Whatever");
+ CHECK(mock_llama_context_ptr->getMessages()[1].role == "user");
+ CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or
flow file content):\n42\n\nQuestion: What is the answer to life, the universe
and everything?");
+}
+
+TEST_CASE("Empty flow file does not include input data in prompt") {
+ auto mock_llama_context = std::make_unique<MockLlamaContext>();
+ auto mock_llama_context_ptr = mock_llama_context.get();
+ processors::LlamaContext::testSetProvider(
+ [&](const std::filesystem::path&, const processors::LlamaSamplerParams&,
const processors::LlamaContextParams&) {
+ return std::move(mock_llama_context);
+ });
+ minifi::test::SingleProcessorTestController
controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference"));
+
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath,
"Dummy model");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt,
"Question: What is the answer to life, the universe and everything?");
+
+ auto results = controller.trigger(minifi::test::InputFlowFileData{.content =
"", .attributes = {}});
+
+ REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
+ auto& output_flow_file =
results.at(processors::RunLlamaCppInference::Success)[0];
+ CHECK(controller.plan->getContent(output_flow_file) == "Test generated
content");
+ CHECK(mock_llama_context_ptr->getInput() == "Test input");
+ REQUIRE(mock_llama_context_ptr->getMessages().size() == 2);
+ CHECK(mock_llama_context_ptr->getMessages()[0].role == "system");
+ CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful
assistant. You are given a question with some possible input data otherwise
called flow file content. "
+ "You are expected
to generate a response based on the question and the input data.");
+ CHECK(mock_llama_context_ptr->getMessages()[1].role == "user");
+ CHECK(mock_llama_context_ptr->getMessages()[1].content == "Question: What is
the answer to life, the universe and everything?");
+}
+
+TEST_CASE("Invalid values for optional double type properties throw
exception") {
+ processors::LlamaContext::testSetProvider(
+ [&](const std::filesystem::path&, const processors::LlamaSamplerParams&,
const processors::LlamaContextParams&) {
+ return std::make_unique<MockLlamaContext>();
+ });
+ minifi::test::SingleProcessorTestController
controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference"));
+
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath,
"Dummy model");
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt,
"Question: What is the answer to life, the universe and everything?");
+
+ std::string property_name;
+ SECTION("Invalid value for Temperature property") {
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature,
"invalid_value");
+ property_name = processors::RunLlamaCppInference::Temperature.name;
+ }
+ SECTION("Invalid value for Top P property") {
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP,
"invalid_value");
+ property_name = processors::RunLlamaCppInference::TopP.name;
+ }
+ SECTION("Invalid value for Min P property") {
+
controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP,
"invalid_value");
+ property_name = processors::RunLlamaCppInference::MinP.name;
+ }
+
+
REQUIRE_THROWS_WITH(controller.trigger(minifi::test::InputFlowFileData{.content
= "42", .attributes = {}}),
+ fmt::format("Process Schedule Operation: Property '{}'
has invalid value 'invalid_value'", property_name));
+}
+
+TEST_CASE("Top K property empty and invalid values are handled properly") {
+ std::optional<int32_t> test_top_k;
Review Comment:
the "empty value" test would be more convincing if `test_top_k` started out
with a non-null value here
##########
extensions/llamacpp/processors/RunLlamaCppInference.cpp:
##########
@@ -0,0 +1,166 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "RunLlamaCppInference.h"
+#include "core/ProcessContext.h"
+#include "core/ProcessSession.h"
+#include "core/Resource.h"
+#include "Exception.h"
+
+#include "rapidjson/document.h"
+#include "rapidjson/error/en.h"
+#include "LlamaContext.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
+
+namespace {
+
+std::optional<float> parseOptionalFloatProperty(const core::ProcessContext&
context, const core::PropertyReference& property) {
+ std::string str_value;
+ if (!context.getProperty(property, str_value) || str_value.empty()) {
+ return std::nullopt;
+ }
+ try {
+ return std::stof(str_value);
+ } catch(const std::exception&) {
+ throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("Property '{}' has
invalid value '{}'", property.name, str_value));
+ }
+}
+
+std::optional<int32_t> parseOptionalInt32Property(const core::ProcessContext&
context, const core::PropertyReference& property) {
+ std::string str_value;
+ if (!context.getProperty(property, str_value) || str_value.empty()) {
+ return std::nullopt;
+ }
+ try {
+ return gsl::narrow<int32_t>(std::stoi(str_value));
+ } catch(const std::exception&) {
+ throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("Property '{}' has
invalid value '{}'", property.name, str_value));
+ }
+}
+
+} // namespace
+
+void RunLlamaCppInference::initialize() {
+ setSupportedProperties(Properties);
+ setSupportedRelationships(Relationships);
+}
+
+void RunLlamaCppInference::onSchedule(core::ProcessContext& context,
core::ProcessSessionFactory&) {
+ model_path_.clear();
+ context.getProperty(ModelPath, model_path_);
+ context.getProperty(SystemPrompt, system_prompt_);
+
+ LlamaSamplerParams llama_sampler_params;
+ llama_sampler_params.temperature = parseOptionalFloatProperty(context,
Temperature);
+ llama_sampler_params.top_k = parseOptionalInt32Property(context, TopK);
+ llama_sampler_params.top_p = parseOptionalFloatProperty(context, TopP);
+ llama_sampler_params.min_p = parseOptionalFloatProperty(context, MinP);
Review Comment:
These, and the rest of the property parsing below, should be replaced with
the new parsing utilities in `ProcessorConfigUtils.h`, after a rebase.
##########
extensions/llamacpp/processors/RunLlamaCppInference.h:
##########
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "core/Processor.h"
+#include "core/logging/LoggerFactory.h"
+#include "core/PropertyDefinitionBuilder.h"
+#include "LlamaContext.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
+
+class RunLlamaCppInference : public core::ProcessorImpl {
+ struct LLMExample {
+ std::string input_role;
+ std::string input;
+ std::string output_role;
+ std::string output;
+ };
+
+ public:
+ explicit RunLlamaCppInference(std::string_view name, const
utils::Identifier& uuid = {})
+ : core::ProcessorImpl(name, uuid) {
+ }
+ ~RunLlamaCppInference() override = default;
+
+ EXTENSIONAPI static constexpr const char* Description = "LlamaCpp processor
to use llama.cpp library for running language model inference. "
+ "The final prompt used for the inference created using the System Prompt
and Prompt proprerty values and the content of the flowfile referred to as
input data or flow file content.";
+
+ EXTENSIONAPI static constexpr auto ModelPath =
core::PropertyDefinitionBuilder<>::createProperty("Model Path")
+ .withDescription("The filesystem path of the model file in gguf format.")
+ .isRequired(true)
+ .build();
+ EXTENSIONAPI static constexpr auto Temperature =
core::PropertyDefinitionBuilder<>::createProperty("Temperature")
+ .withDescription("The temperature to use for sampling.")
+ .withDefaultValue("0.8")
+ .build();
+ EXTENSIONAPI static constexpr auto TopK =
core::PropertyDefinitionBuilder<>::createProperty("Top K")
+ .withDescription("Limit the next token selection to the K most probable
tokens. Set <= 0 value to use vocab size.")
+ .withDefaultValue("40")
+ .build();
+ EXTENSIONAPI static constexpr auto TopP =
core::PropertyDefinitionBuilder<>::createProperty("Top P")
+ .withDescription("Limit the next token selection to a subset of tokens
with a cumulative probability above a threshold P. 1.0 = disabled.")
+ .withDefaultValue("0.9")
+ .build();
+ EXTENSIONAPI static constexpr auto MinP =
core::PropertyDefinitionBuilder<>::createProperty("Min P")
+ .withDescription("Sets a minimum base probability threshold for token
selection. 0.0 = disabled.")
+ .build();
+ EXTENSIONAPI static constexpr auto MinKeep =
core::PropertyDefinitionBuilder<>::createProperty("Min Keep")
+ .withDescription("If greater than 0, force samplers to return N possible
tokens at minimum.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("0")
+ .build();
+ EXTENSIONAPI static constexpr auto TextContextSize =
core::PropertyDefinitionBuilder<>::createProperty("Text Context Size")
+ .withDescription("Size of the text context, use 0 to use size set in
model.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("4096")
+ .build();
+ EXTENSIONAPI static constexpr auto LogicalMaximumBatchSize =
core::PropertyDefinitionBuilder<>::createProperty("Logical Maximum Batch Size")
+ .withDescription("Logical maximum batch size that can be submitted to
the llama.cpp decode function.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("2048")
+ .build();
+ EXTENSIONAPI static constexpr auto PhysicalMaximumBatchSize =
core::PropertyDefinitionBuilder<>::createProperty("Physical Maximum Batch Size")
+ .withDescription("Physical maximum batch size.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("512")
+ .build();
+ EXTENSIONAPI static constexpr auto MaxNumberOfSequences =
core::PropertyDefinitionBuilder<>::createProperty("Max Number Of Sequences")
+ .withDescription("Maximum number of sequences (i.e. distinct states for
recurrent models).")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("1")
+ .build();
+ EXTENSIONAPI static constexpr auto ThreadsForGeneration =
core::PropertyDefinitionBuilder<>::createProperty("Threads For Generation")
+ .withDescription("Number of threads to use for generation.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::INTEGER_TYPE)
+ .withDefaultValue("4")
+ .build();
+ EXTENSIONAPI static constexpr auto ThreadsForBatchProcessing =
core::PropertyDefinitionBuilder<>::createProperty("Threads For Batch
Processing")
+ .withDescription("Number of threads to use for batch processing.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::INTEGER_TYPE)
+ .withDefaultValue("4")
+ .build();
+ EXTENSIONAPI static constexpr auto Prompt =
core::PropertyDefinitionBuilder<>::createProperty("Prompt")
+ .withDescription("The user prompt for the inference.")
+ .supportsExpressionLanguage(true)
+ .isRequired(true)
+ .build();
+ EXTENSIONAPI static constexpr auto SystemPrompt =
core::PropertyDefinitionBuilder<>::createProperty("System Prompt")
+ .withDescription("The system prompt for the inference.")
+ .withDefaultValue("You are a helpful assistant. You are given a question
with some possible input data otherwise called flow file content. "
+ "You are expected to generate a response based on the
question and the input data.")
+ .isRequired(true)
+ .build();
+
+ EXTENSIONAPI static constexpr auto Properties =
std::to_array<core::PropertyReference>({
+ ModelPath,
+ Temperature,
+ TopK,
+ TopP,
+ MinP,
+ MinKeep,
+ TextContextSize,
+ LogicalMaximumBatchSize,
+ PhysicalMaximumBatchSize,
+ MaxNumberOfSequences,
+ ThreadsForGeneration,
+ ThreadsForBatchProcessing,
+ Prompt,
+ SystemPrompt
+ });
+
+
+ EXTENSIONAPI static constexpr auto Success =
core::RelationshipDefinition{"success", "Generated results from the model"};
+ EXTENSIONAPI static constexpr auto Failure =
core::RelationshipDefinition{"failure", "Generation failed"};
+ EXTENSIONAPI static constexpr auto Relationships = std::array{Success,
Failure};
+
+ EXTENSIONAPI static constexpr bool SupportsDynamicProperties = false;
+ EXTENSIONAPI static constexpr bool SupportsDynamicRelationships = true;
+ EXTENSIONAPI static constexpr core::annotation::Input InputRequirement =
core::annotation::Input::INPUT_REQUIRED;
+ EXTENSIONAPI static constexpr bool IsSingleThreaded = true;
+
+ ADD_COMMON_VIRTUAL_FUNCTIONS_FOR_PROCESSORS
+
+ void onSchedule(core::ProcessContext& context, core::ProcessSessionFactory&
session_factory) override;
+ void onTrigger(core::ProcessContext& context, core::ProcessSession& session)
override;
+ void initialize() override;
+ void notifyStop() override;
+
+ private:
+ std::shared_ptr<core::logging::Logger> logger_ =
core::logging::LoggerFactory<RunLlamaCppInference>::getLogger(uuid_);
+
+ std::string model_path_;
+ std::vector<LLMExample> examples_;
Review Comment:
I can't see `examples_`, or `LLMExample`, used anywhere.
##########
extensions/llamacpp/processors/RunLlamaCppInference.h:
##########
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "core/Processor.h"
+#include "core/logging/LoggerFactory.h"
+#include "core/PropertyDefinitionBuilder.h"
+#include "LlamaContext.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
+
+class RunLlamaCppInference : public core::ProcessorImpl {
+ struct LLMExample {
+ std::string input_role;
+ std::string input;
+ std::string output_role;
+ std::string output;
+ };
+
+ public:
+ explicit RunLlamaCppInference(std::string_view name, const
utils::Identifier& uuid = {})
+ : core::ProcessorImpl(name, uuid) {
+ }
+ ~RunLlamaCppInference() override = default;
+
+ EXTENSIONAPI static constexpr const char* Description = "LlamaCpp processor
to use llama.cpp library for running language model inference. "
+ "The final prompt used for the inference created using the System Prompt
and Prompt proprerty values and the content of the flowfile referred to as
input data or flow file content.";
+
+ EXTENSIONAPI static constexpr auto ModelPath =
core::PropertyDefinitionBuilder<>::createProperty("Model Path")
+ .withDescription("The filesystem path of the model file in gguf format.")
+ .isRequired(true)
+ .build();
+ EXTENSIONAPI static constexpr auto Temperature =
core::PropertyDefinitionBuilder<>::createProperty("Temperature")
+ .withDescription("The temperature to use for sampling.")
+ .withDefaultValue("0.8")
+ .build();
+ EXTENSIONAPI static constexpr auto TopK =
core::PropertyDefinitionBuilder<>::createProperty("Top K")
+ .withDescription("Limit the next token selection to the K most probable
tokens. Set <= 0 value to use vocab size.")
+ .withDefaultValue("40")
+ .build();
+ EXTENSIONAPI static constexpr auto TopP =
core::PropertyDefinitionBuilder<>::createProperty("Top P")
+ .withDescription("Limit the next token selection to a subset of tokens
with a cumulative probability above a threshold P. 1.0 = disabled.")
+ .withDefaultValue("0.9")
+ .build();
+ EXTENSIONAPI static constexpr auto MinP =
core::PropertyDefinitionBuilder<>::createProperty("Min P")
+ .withDescription("Sets a minimum base probability threshold for token
selection. 0.0 = disabled.")
+ .build();
+ EXTENSIONAPI static constexpr auto MinKeep =
core::PropertyDefinitionBuilder<>::createProperty("Min Keep")
+ .withDescription("If greater than 0, force samplers to return N possible
tokens at minimum.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("0")
+ .build();
+ EXTENSIONAPI static constexpr auto TextContextSize =
core::PropertyDefinitionBuilder<>::createProperty("Text Context Size")
+ .withDescription("Size of the text context, use 0 to use size set in
model.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("4096")
+ .build();
+ EXTENSIONAPI static constexpr auto LogicalMaximumBatchSize =
core::PropertyDefinitionBuilder<>::createProperty("Logical Maximum Batch Size")
+ .withDescription("Logical maximum batch size that can be submitted to
the llama.cpp decode function.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("2048")
+ .build();
+ EXTENSIONAPI static constexpr auto PhysicalMaximumBatchSize =
core::PropertyDefinitionBuilder<>::createProperty("Physical Maximum Batch Size")
+ .withDescription("Physical maximum batch size.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("512")
+ .build();
+ EXTENSIONAPI static constexpr auto MaxNumberOfSequences =
core::PropertyDefinitionBuilder<>::createProperty("Max Number Of Sequences")
+ .withDescription("Maximum number of sequences (i.e. distinct states for
recurrent models).")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::UNSIGNED_INT_TYPE)
+ .withDefaultValue("1")
+ .build();
+ EXTENSIONAPI static constexpr auto ThreadsForGeneration =
core::PropertyDefinitionBuilder<>::createProperty("Threads For Generation")
+ .withDescription("Number of threads to use for generation.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::INTEGER_TYPE)
+ .withDefaultValue("4")
+ .build();
+ EXTENSIONAPI static constexpr auto ThreadsForBatchProcessing =
core::PropertyDefinitionBuilder<>::createProperty("Threads For Batch
Processing")
+ .withDescription("Number of threads to use for batch processing.")
+ .isRequired(true)
+ .withPropertyType(core::StandardPropertyTypes::INTEGER_TYPE)
+ .withDefaultValue("4")
+ .build();
+ EXTENSIONAPI static constexpr auto Prompt =
core::PropertyDefinitionBuilder<>::createProperty("Prompt")
+ .withDescription("The user prompt for the inference.")
+ .supportsExpressionLanguage(true)
+ .isRequired(true)
+ .build();
+ EXTENSIONAPI static constexpr auto SystemPrompt =
core::PropertyDefinitionBuilder<>::createProperty("System Prompt")
+ .withDescription("The system prompt for the inference.")
+ .withDefaultValue("You are a helpful assistant. You are given a question
with some possible input data otherwise called flow file content. "
+ "You are expected to generate a response based on the
question and the input data.")
+ .isRequired(true)
+ .build();
+
+ EXTENSIONAPI static constexpr auto Properties =
std::to_array<core::PropertyReference>({
+ ModelPath,
+ Temperature,
+ TopK,
+ TopP,
+ MinP,
+ MinKeep,
+ TextContextSize,
+ LogicalMaximumBatchSize,
+ PhysicalMaximumBatchSize,
+ MaxNumberOfSequences,
+ ThreadsForGeneration,
+ ThreadsForBatchProcessing,
+ Prompt,
+ SystemPrompt
+ });
+
+
+ EXTENSIONAPI static constexpr auto Success =
core::RelationshipDefinition{"success", "Generated results from the model"};
+ EXTENSIONAPI static constexpr auto Failure =
core::RelationshipDefinition{"failure", "Generation failed"};
+ EXTENSIONAPI static constexpr auto Relationships = std::array{Success,
Failure};
+
+ EXTENSIONAPI static constexpr bool SupportsDynamicProperties = false;
+ EXTENSIONAPI static constexpr bool SupportsDynamicRelationships = true;
Review Comment:
this processor doesn't seem to support dynamic relationships
##########
cmake/LlamaCpp.cmake:
##########
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+include(FetchContent)
+
+set(BUILD_SHARED_LIBS "OFF" CACHE STRING "" FORCE)
+set(LLAMA_BUILD_TESTS "OFF" CACHE STRING "" FORCE)
+set(LLAMA_BUILD_EXAMPLES "OFF" CACHE STRING "" FORCE)
+set(LLAMA_BUILD_SERVER "OFF" CACHE STRING "" FORCE)
+set(GGML_OPENMP "OFF" CACHE STRING "" FORCE)
+
+set(PATCH_FILE_1 "${CMAKE_SOURCE_DIR}/thirdparty/llamacpp/metal.patch")
+set(PATCH_FILE_2
"${CMAKE_SOURCE_DIR}/thirdparty/llamacpp/lu8_macro_fix.patch") #
https://github.com/ggml-org/llama.cpp/issues/12740
Review Comment:
Can you add a link or explanation for `metal.patch`, as well, please?
##########
extensions/llamacpp/processors/DefaultLlamaContext.cpp:
##########
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DefaultLlamaContext.h"
+#include "Exception.h"
+#include "fmt/format.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
+
+namespace {
+std::vector<llama_token> tokenizeInput(const llama_vocab* vocab, const
std::string& input) {
+ int32_t number_of_tokens = gsl::narrow<int32_t>(input.length()) + 2;
+ std::vector<llama_token> tokenized_input(number_of_tokens);
+ number_of_tokens = llama_tokenize(vocab, input.data(),
gsl::narrow<int32_t>(input.length()), tokenized_input.data(),
gsl::narrow<int32_t>(tokenized_input.size()), true, true);
+ if (number_of_tokens < 0) {
+ tokenized_input.resize(-number_of_tokens);
+ [[maybe_unused]] int32_t check = llama_tokenize(vocab, input.data(),
gsl::narrow<int32_t>(input.length()), tokenized_input.data(),
gsl::narrow<int32_t>(tokenized_input.size()), true, true);
+ gsl_Assert(check == -number_of_tokens);
+ } else {
+ tokenized_input.resize(number_of_tokens);
+ }
+ return tokenized_input;
+}
+} // namespace
+
+
+DefaultLlamaContext::DefaultLlamaContext(const std::filesystem::path&
model_path, const LlamaSamplerParams& llama_sampler_params, const
LlamaContextParams& llama_ctx_params) {
+ llama_backend_init();
+
+ llama_model_ = llama_model_load_from_file(model_path.string().c_str(),
llama_model_default_params()); //
NOLINT(cppcoreguidelines-prefer-member-initializer)
+ if (!llama_model_) {
+ throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION,
fmt::format("Failed to load model from '{}'", model_path.string()));
+ }
+
+ llama_context_params ctx_params = llama_context_default_params();
+ ctx_params.n_ctx = llama_ctx_params.n_ctx;
+ ctx_params.n_batch = llama_ctx_params.n_batch;
+ ctx_params.n_ubatch = llama_ctx_params.n_ubatch;
+ ctx_params.n_seq_max = llama_ctx_params.n_seq_max;
+ ctx_params.n_threads = llama_ctx_params.n_threads;
+ ctx_params.n_threads_batch = llama_ctx_params.n_threads_batch;
+ ctx_params.flash_attn = false;
+ llama_ctx_ = llama_init_from_model(llama_model_, ctx_params);
+
+ auto sparams = llama_sampler_chain_default_params();
+ llama_sampler_ = llama_sampler_chain_init(sparams);
+
+ if (llama_sampler_params.min_p) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_min_p(*llama_sampler_params.min_p,
llama_sampler_params.min_keep));
+ }
+ if (llama_sampler_params.top_k) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_top_k(*llama_sampler_params.top_k));
+ }
+ if (llama_sampler_params.top_p) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_top_p(*llama_sampler_params.top_p,
llama_sampler_params.min_keep));
+ }
+ if (llama_sampler_params.temperature) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_temp(*llama_sampler_params.temperature));
+ }
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
+}
+
+DefaultLlamaContext::~DefaultLlamaContext() {
+ llama_sampler_free(llama_sampler_);
+ llama_sampler_ = nullptr;
+ llama_free(llama_ctx_);
+ llama_ctx_ = nullptr;
+ llama_model_free(llama_model_);
+ llama_model_ = nullptr;
+ llama_backend_free();
+}
+
+std::string DefaultLlamaContext::applyTemplate(const
std::vector<LlamaChatMessage>& messages) {
+ std::vector<llama_chat_message> llama_messages;
+ llama_messages.reserve(messages.size());
+ for (auto& msg : messages) {
+ llama_messages.push_back(llama_chat_message{.role = msg.role.c_str(),
.content = msg.content.c_str()});
+ }
+ std::string text;
Review Comment:
We could start out with some reasonably large buffer of null characters, so
we won't have to call `llama_chat_apply_template` twice every time.
##########
extensions/llamacpp/processors/RunLlamaCppInference.cpp:
##########
@@ -0,0 +1,166 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "RunLlamaCppInference.h"
+#include "core/ProcessContext.h"
+#include "core/ProcessSession.h"
+#include "core/Resource.h"
+#include "Exception.h"
+
+#include "rapidjson/document.h"
+#include "rapidjson/error/en.h"
+#include "LlamaContext.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
+
+namespace {
+
+std::optional<float> parseOptionalFloatProperty(const core::ProcessContext&
context, const core::PropertyReference& property) {
+ std::string str_value;
+ if (!context.getProperty(property, str_value) || str_value.empty()) {
+ return std::nullopt;
+ }
+ try {
+ return std::stof(str_value);
+ } catch(const std::exception&) {
+ throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("Property '{}' has
invalid value '{}'", property.name, str_value));
+ }
+}
+
+std::optional<int32_t> parseOptionalInt32Property(const core::ProcessContext&
context, const core::PropertyReference& property) {
+ std::string str_value;
+ if (!context.getProperty(property, str_value) || str_value.empty()) {
+ return std::nullopt;
+ }
+ try {
+ return gsl::narrow<int32_t>(std::stoi(str_value));
+ } catch(const std::exception&) {
+ throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("Property '{}' has
invalid value '{}'", property.name, str_value));
+ }
+}
+
+} // namespace
+
+void RunLlamaCppInference::initialize() {
+ setSupportedProperties(Properties);
+ setSupportedRelationships(Relationships);
+}
+
+void RunLlamaCppInference::onSchedule(core::ProcessContext& context,
core::ProcessSessionFactory&) {
+ model_path_.clear();
+ context.getProperty(ModelPath, model_path_);
+ context.getProperty(SystemPrompt, system_prompt_);
+
+ LlamaSamplerParams llama_sampler_params;
+ llama_sampler_params.temperature = parseOptionalFloatProperty(context,
Temperature);
+ llama_sampler_params.top_k = parseOptionalInt32Property(context, TopK);
+ llama_sampler_params.top_p = parseOptionalFloatProperty(context, TopP);
+ llama_sampler_params.min_p = parseOptionalFloatProperty(context, MinP);
+
+ uint64_t uint_value = 0;
+ if (context.getProperty(MinKeep, uint_value)) {
+ llama_sampler_params.min_keep = uint_value;
+ }
+
+ LlamaContextParams llama_ctx_params;
+ if (context.getProperty(TextContextSize, uint_value)) {
+ llama_ctx_params.n_ctx = gsl::narrow_cast<uint32_t>(uint_value);
+ }
+ if (context.getProperty(LogicalMaximumBatchSize, uint_value)) {
+ llama_ctx_params.n_batch = gsl::narrow_cast<uint32_t>(uint_value);
+ }
+ if (context.getProperty(PhysicalMaximumBatchSize, uint_value)) {
+ llama_ctx_params.n_ubatch = gsl::narrow_cast<uint32_t>(uint_value);
+ }
+ if (context.getProperty(MaxNumberOfSequences, uint_value)) {
+ llama_ctx_params.n_seq_max = gsl::narrow_cast<uint32_t>(uint_value);
+ }
+ int32_t int_value = 0;
+ if (context.getProperty(ThreadsForGeneration, int_value)) {
+ llama_ctx_params.n_threads = gsl::narrow_cast<int32_t>(int_value);
+ }
+ if (context.getProperty(ThreadsForBatchProcessing, int_value)) {
+ llama_ctx_params.n_threads_batch = gsl::narrow_cast<int32_t>(int_value);
+ }
+
+ llama_ctx_ = LlamaContext::create(model_path_, llama_sampler_params,
llama_ctx_params);
+}
+
+void RunLlamaCppInference::onTrigger(core::ProcessContext& context,
core::ProcessSession& session) {
+ auto input_ff = session.get();
+ if (!input_ff) {
+ context.yield();
+ return;
+ }
+
+ std::string prompt;
+ context.getProperty(Prompt, prompt, input_ff.get());
+
+ auto read_result = session.readBuffer(input_ff);
+ std::string input_data_and_prompt;
+ if (!read_result.buffer.empty()) {
+ input_data_and_prompt.append("Input data (or flow file content):\n");
+ input_data_and_prompt.append({reinterpret_cast<const
char*>(read_result.buffer.data()), read_result.buffer.size()});
+ input_data_and_prompt.append("\n\n");
+ }
+ input_data_and_prompt.append(prompt);
+
+ std::string input = [&] {
+ std::vector<LlamaChatMessage> messages;
+ messages.push_back({.role = "system", .content = system_prompt_});
+ messages.push_back({.role = "user", .content = input_data_and_prompt});
+
+ return llama_ctx_->applyTemplate(messages);
+ }();
+
+ logger_->log_debug("AI model input: {}", input);
+
+ auto start_time = std::chrono::steady_clock::now();
+
+ std::string text;
+ auto number_of_tokens_generated = llama_ctx_->generate(input, [&]
(std::string_view token) {
+ text += token;
+ });
+
+ auto elapsed_time =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now()
- start_time).count();
+
+ if (!number_of_tokens_generated) {
+ logger_->log_error("Inference failed with generation error: '{}'",
number_of_tokens_generated.error());
+ session.transfer(input_ff, Failure);
+ return;
+ }
+
+ auto ff_guard = gsl::finally([&] {
+ session.remove(input_ff);
+ });
+
+ logger_->log_debug("Number of tokens generated: {}",
*number_of_tokens_generated);
+ logger_->log_debug("AI model inference time: {} ms", elapsed_time);
+ logger_->log_debug("AI model output: {}", text);
+
+ auto result = session.create();
+ session.writeBuffer(result, text);
+ session.transfer(result, Success);
Review Comment:
Why do we remove the incoming flow file and create a new one? If we replaced
the content of the flow file, then the attributes would be preserved, which may
be useful.
##########
extensions/llamacpp/processors/DefaultLlamaContext.cpp:
##########
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DefaultLlamaContext.h"
+#include "Exception.h"
+#include "fmt/format.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
+
+namespace {
+std::vector<llama_token> tokenizeInput(const llama_vocab* vocab, const
std::string& input) {
+ int32_t number_of_tokens = gsl::narrow<int32_t>(input.length()) + 2;
+ std::vector<llama_token> tokenized_input(number_of_tokens);
+ number_of_tokens = llama_tokenize(vocab, input.data(),
gsl::narrow<int32_t>(input.length()), tokenized_input.data(),
gsl::narrow<int32_t>(tokenized_input.size()), true, true);
+ if (number_of_tokens < 0) {
+ tokenized_input.resize(-number_of_tokens);
+ [[maybe_unused]] int32_t check = llama_tokenize(vocab, input.data(),
gsl::narrow<int32_t>(input.length()), tokenized_input.data(),
gsl::narrow<int32_t>(tokenized_input.size()), true, true);
+ gsl_Assert(check == -number_of_tokens);
+ } else {
+ tokenized_input.resize(number_of_tokens);
+ }
+ return tokenized_input;
+}
+} // namespace
+
+
+DefaultLlamaContext::DefaultLlamaContext(const std::filesystem::path&
model_path, const LlamaSamplerParams& llama_sampler_params, const
LlamaContextParams& llama_ctx_params) {
+ llama_backend_init();
+
+ llama_model_ = llama_model_load_from_file(model_path.string().c_str(),
llama_model_default_params()); //
NOLINT(cppcoreguidelines-prefer-member-initializer)
+ if (!llama_model_) {
+ throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION,
fmt::format("Failed to load model from '{}'", model_path.string()));
+ }
+
+ llama_context_params ctx_params = llama_context_default_params();
+ ctx_params.n_ctx = llama_ctx_params.n_ctx;
+ ctx_params.n_batch = llama_ctx_params.n_batch;
+ ctx_params.n_ubatch = llama_ctx_params.n_ubatch;
+ ctx_params.n_seq_max = llama_ctx_params.n_seq_max;
+ ctx_params.n_threads = llama_ctx_params.n_threads;
+ ctx_params.n_threads_batch = llama_ctx_params.n_threads_batch;
+ ctx_params.flash_attn = false;
+ llama_ctx_ = llama_init_from_model(llama_model_, ctx_params);
+
+ auto sparams = llama_sampler_chain_default_params();
+ llama_sampler_ = llama_sampler_chain_init(sparams);
+
+ if (llama_sampler_params.min_p) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_min_p(*llama_sampler_params.min_p,
llama_sampler_params.min_keep));
+ }
+ if (llama_sampler_params.top_k) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_top_k(*llama_sampler_params.top_k));
+ }
+ if (llama_sampler_params.top_p) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_top_p(*llama_sampler_params.top_p,
llama_sampler_params.min_keep));
+ }
+ if (llama_sampler_params.temperature) {
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_temp(*llama_sampler_params.temperature));
+ }
+ llama_sampler_chain_add(llama_sampler_,
llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
+}
+
+DefaultLlamaContext::~DefaultLlamaContext() {
+ llama_sampler_free(llama_sampler_);
+ llama_sampler_ = nullptr;
+ llama_free(llama_ctx_);
+ llama_ctx_ = nullptr;
+ llama_model_free(llama_model_);
+ llama_model_ = nullptr;
+ llama_backend_free();
+}
+
+std::string DefaultLlamaContext::applyTemplate(const
std::vector<LlamaChatMessage>& messages) {
+ std::vector<llama_chat_message> llama_messages;
+ llama_messages.reserve(messages.size());
+ for (auto& msg : messages) {
+ llama_messages.push_back(llama_chat_message{.role = msg.role.c_str(),
.content = msg.content.c_str()});
+ }
+ std::string text;
+ const char * chat_template = llama_model_chat_template(llama_model_,
nullptr);
+ int32_t res_size = llama_chat_apply_template(chat_template,
llama_messages.data(), llama_messages.size(), true, text.data(),
gsl::narrow<int32_t>(text.size()));
+ if (res_size > gsl::narrow<int32_t>(text.size())) {
+ text.resize(res_size);
+ llama_chat_apply_template(chat_template, llama_messages.data(),
llama_messages.size(), true, text.data(), gsl::narrow<int32_t>(text.size()));
+ }
+ text.resize(res_size);
Review Comment:
`llama_chat_apply_template` can return a negative value to indicate an
error; we should handle that somehow.
##########
extensions/llamacpp/processors/RunLlamaCppInference.h:
##########
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "core/Processor.h"
+#include "core/logging/LoggerFactory.h"
+#include "core/PropertyDefinitionBuilder.h"
+#include "LlamaContext.h"
+
+namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
+
+class RunLlamaCppInference : public core::ProcessorImpl {
+ struct LLMExample {
+ std::string input_role;
+ std::string input;
+ std::string output_role;
+ std::string output;
+ };
+
+ public:
+ explicit RunLlamaCppInference(std::string_view name, const
utils::Identifier& uuid = {})
+ : core::ProcessorImpl(name, uuid) {
+ }
+ ~RunLlamaCppInference() override = default;
+
+ EXTENSIONAPI static constexpr const char* Description = "LlamaCpp processor
to use llama.cpp library for running language model inference. "
+ "The final prompt used for the inference created using the System Prompt
and Prompt proprerty values and the content of the flowfile referred to as
input data or flow file content.";
Review Comment:
This is not clear to me. I would write something like this:
```suggestion
"The inference will be based on the System Prompt and the Prompt
property values, together with the content of the incoming flow file. "
"In the Prompt, the content of the incoming flow file can be referred
to as 'the input data' or 'the flow file content'.";
```
##########
extensions/llamacpp/processors/LlamaContext.cpp:
##########
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LlamaContext.h"
+#include "Exception.h"
+#include "fmt/format.h"
+#include "llama.h"
Review Comment:
these are unused
```suggestion
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]