From 1df84a11ccbc6e2a99fecf115546964cae8534ad Mon Sep 17 00:00:00 2001 From: Adam Debreceni Date: Mon, 13 Jan 2025 10:39:13 +0100 Subject: [PATCH] Refactor --- extensions/llamacpp/CMakeLists.txt | 2 +- .../llamacpp/processors/LlamaContext.cpp | 148 ++++++++++++++++++ extensions/llamacpp/processors/LlamaContext.h | 42 +++++ .../llamacpp/processors/LlamaCppProcessor.cpp | 124 ++------------- .../llamacpp/processors/LlamaCppProcessor.h | 8 +- extensions/llamacpp/tests/CMakeLists.txt | 37 +++++ extensions/llamacpp/tests/LlamaCppTests.cpp | 65 ++++++++ 7 files changed, 308 insertions(+), 118 deletions(-) create mode 100644 extensions/llamacpp/processors/LlamaContext.cpp create mode 100644 extensions/llamacpp/processors/LlamaContext.h create mode 100644 extensions/llamacpp/tests/CMakeLists.txt create mode 100644 extensions/llamacpp/tests/LlamaCppTests.cpp diff --git a/extensions/llamacpp/CMakeLists.txt b/extensions/llamacpp/CMakeLists.txt index e995c8d5de..8935d67297 100644 --- a/extensions/llamacpp/CMakeLists.txt +++ b/extensions/llamacpp/CMakeLists.txt @@ -33,6 +33,6 @@ target_include_directories(minifi-llamacpp PUBLIC "${CMAKE_SOURCE_DIR}/extension target_link_libraries(minifi-llamacpp ${LIBMINIFI} llamacpp) -register_extension(minifi-llamacpp "AI PROCESSORS" AI-PROCESSORS "Provides AI processors") +register_extension(minifi-llamacpp "LLAMACPP EXTENSION" LLAMACPP-EXTENSION "Provides LlamaCpp support" "extensions/llamacpp/tests") register_extension_linter(minifi-llamacpp-linter) \ No newline at end of file diff --git a/extensions/llamacpp/processors/LlamaContext.cpp b/extensions/llamacpp/processors/LlamaContext.cpp new file mode 100644 index 0000000000..531747cdf1 --- /dev/null +++ b/extensions/llamacpp/processors/LlamaContext.cpp @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LlamaContext.h" +#include "Exception.h" +#include "fmt/format.h" +#pragma push_macro("DEPRECATED") +#include "llama.h" +#pragma pop_macro("DEPRECATED") + +namespace org::apache::nifi::minifi::processors::llamacpp { + +static std::function(const std::filesystem::path&, float)> test_provider; + +void LlamaContext::testSetProvider(std::function(const std::filesystem::path&, float)> provider) { + test_provider = provider; +} + +class DefaultLlamaContext : public LlamaContext { + public: + DefaultLlamaContext(const std::filesystem::path& model_path, float temperature) { + llama_backend_init(); + + llama_model_params model_params = llama_model_default_params(); + llama_model_ = llama_load_model_from_file(model_path.c_str(), model_params); + if (!llama_model_) { + throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION, fmt::format("Failed to load model from '{}'", model_path.c_str())); + } + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = 0; + llama_ctx_ = llama_new_context_with_model(llama_model_, ctx_params); + + auto sparams = llama_sampler_chain_default_params(); + llama_sampler_ = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_k(50)); + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_p(0.9, 1)); + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_temp(temperature)); + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_dist(1234)); + } + + std::string applyTemplate(const std::vector& messages) override { + std::vector msgs; + for (auto& msg : messages) { + msgs.push_back(llama_chat_message{.role = msg.role.c_str(), .content = msg.content.c_str()}); + } + std::string text; + int32_t res_size = llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size()); + if (res_size > gsl::narrow(text.size())) { + text.resize(res_size); + llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size()); + } + text.resize(res_size); + +// utils::string::replaceAll(text, "", "\n"); + + return text; + } + + void generate(const std::string& input, std::function cb) override { + std::vector enc_input = [&] { + int32_t n_tokens = input.length() + 2; + std::vector enc_input(n_tokens); + n_tokens = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true); + if (n_tokens < 0) { + enc_input.resize(-n_tokens); + int check = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true); + gsl_Assert(check == -n_tokens); + } else { + enc_input.resize(n_tokens); + } + return enc_input; + }(); + + + llama_batch batch = llama_batch_get_one(enc_input.data(), enc_input.size()); + + llama_token new_token_id; + + bool terminate = false; + + while (!terminate) { + if (int32_t res = llama_decode(llama_ctx_, batch); res < 0) { + throw std::logic_error("failed to execute decode"); + } + + new_token_id = llama_sampler_sample(llama_sampler_, llama_ctx_, -1); + + if (llama_token_is_eog(llama_model_, new_token_id)) { + break; + } + + llama_sampler_accept(llama_sampler_, new_token_id); + + std::array buf; + int32_t len = llama_token_to_piece(llama_model_, new_token_id, buf.data(), buf.size(), 0, true); + if (len < 0) { + throw std::logic_error("failed to convert to text"); + } + gsl_Assert(len < 128); + + std::string_view token_str{buf.data(), gsl::narrow(len)}; + + batch = llama_batch_get_one(&new_token_id, 1); + + terminate = cb(token_str); + } + } + + ~DefaultLlamaContext() override { + llama_sampler_free(llama_sampler_); + llama_sampler_ = nullptr; + llama_free(llama_ctx_); + llama_ctx_ = nullptr; + llama_free_model(llama_model_); + llama_model_ = nullptr; + llama_backend_free(); + } + + private: + llama_model* llama_model_{nullptr}; + llama_context* llama_ctx_{nullptr}; + llama_sampler* llama_sampler_{nullptr}; +}; + +std::unique_ptr LlamaContext::create(const std::filesystem::path& model_path, float temperature) { + if (test_provider) { + return test_provider(model_path, temperature); + } + return std::make_unique(model_path, temperature); +} + +} // namespace org::apache::nifi::minifi::processors::llamacpp diff --git a/extensions/llamacpp/processors/LlamaContext.h b/extensions/llamacpp/processors/LlamaContext.h new file mode 100644 index 0000000000..1d625445a1 --- /dev/null +++ b/extensions/llamacpp/processors/LlamaContext.h @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace org::apache::nifi::minifi::processors::llamacpp { + +struct LlamaChatMessage { + std::string role; + std::string content; +}; + +class LlamaContext { + public: + static void testSetProvider(std::function(const std::filesystem::path&, float)> provider); + static std::unique_ptr create(const std::filesystem::path& model_path, float temperature); + virtual std::string applyTemplate(const std::vector& messages) = 0; + virtual void generate(const std::string& input, std::function cb) = 0; + virtual ~LlamaContext() = default; +}; + +} // namespace org::apache::nifi::minifi::processors::llamacpp diff --git a/extensions/llamacpp/processors/LlamaCppProcessor.cpp b/extensions/llamacpp/processors/LlamaCppProcessor.cpp index 14786b8548..b2316500e0 100644 --- a/extensions/llamacpp/processors/LlamaCppProcessor.cpp +++ b/extensions/llamacpp/processors/LlamaCppProcessor.cpp @@ -23,32 +23,10 @@ #include "rapidjson/document.h" #include "rapidjson/error/en.h" +#include "LlamaContext.h" namespace org::apache::nifi::minifi::processors { -namespace { - -struct LlamaChatMessage { - std::string role; - std::string content; - - operator llama_chat_message() const { - return llama_chat_message{ - .role = role.c_str(), - .content = content.c_str() - }; - } -}; - -//constexpr const char* relationship_prompt = R"(You are a helpful assistant helping to analyze the user's description of a data transformation and routing algorithm. -//The data consists of attributes and a content encapsulated in what is called a flowfile. -//The routing targets are called relationships. -//You have to extract the comma separated list of all possible relationships one can route to based on the user's description. -//Output only the list and nothing else. -//)"; - -} // namespace - void LlamaCppProcessor::initialize() { setSupportedProperties(Properties); setSupportedRelationships(Relationships); @@ -129,25 +107,7 @@ void LlamaCppProcessor::onSchedule(core::ProcessContext& context, core::ProcessS examples_.push_back(LLMExample{.input = std::move(input), .output = std::move(output)}); } - llama_backend_init(); - - llama_model_params model_params = llama_model_default_params(); - llama_model_ = llama_load_model_from_file(model_name_.c_str(), model_params); - if (!llama_model_) { - throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION, fmt::format("Failed to load model from '{}'", model_name_)); - } - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = 0; - llama_ctx_ = llama_new_context_with_model(llama_model_, ctx_params); - - auto sparams = llama_sampler_chain_default_params(); - llama_sampler_ = llama_sampler_chain_init(sparams); - - llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_k(50)); - llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_p(0.9, 1)); - llama_sampler_chain_add(llama_sampler_, llama_sampler_init_temp(gsl::narrow_cast(temperature_))); - llama_sampler_chain_add(llama_sampler_, llama_sampler_init_dist(1234)); + llama_ctx_ = llamacpp::LlamaContext::create(model_name_, gsl::narrow_cast(temperature_)); } void LlamaCppProcessor::onTrigger(core::ProcessContext& context, core::ProcessSession& session) { @@ -172,76 +132,24 @@ void LlamaCppProcessor::onTrigger(core::ProcessContext& context, core::ProcessSe std::string input = [&] { - std::vector msgs; - msgs.push_back(llama_chat_message{.role = "system", .content = full_prompt_.c_str()}); + std::vector msgs; + msgs.push_back({.role = "system", .content = full_prompt_.c_str()}); for (auto& ex : examples_) { - msgs.push_back(llama_chat_message{.role = "user", .content = ex.input.c_str()}); - msgs.push_back(llama_chat_message{.role = "assistant", .content = ex.output.c_str()}); - } - msgs.push_back(llama_chat_message{.role = "user", .content = msg.c_str()}); - - std::string text; - int32_t res_size = llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size()); - if (res_size > gsl::narrow(text.size())) { - text.resize(res_size); - llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size()); + msgs.push_back({.role = "user", .content = ex.input.c_str()}); + msgs.push_back({.role = "assistant", .content = ex.output.c_str()}); } - text.resize(res_size); + msgs.push_back({.role = "user", .content = msg.c_str()}); -// utils::string::replaceAll(text, "", "\n"); - - return text; + return llama_ctx_->applyTemplate(msgs); }(); logger_->log_debug("AI model input: {}", input); - std::vector enc_input = [&] { - int32_t n_tokens = input.length() + 2; - std::vector enc_input(n_tokens); - n_tokens = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true); - if (n_tokens < 0) { - enc_input.resize(-n_tokens); - int check = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true); - gsl_Assert(check == -n_tokens); - } else { - enc_input.resize(n_tokens); - } - return enc_input; - }(); - - - llama_batch batch = llama_batch_get_one(enc_input.data(), enc_input.size()); - - llama_token new_token_id; - std::string text; - - while (true) { - if (int32_t res = llama_decode(llama_ctx_, batch); res < 0) { - throw std::logic_error("failed to execute decode"); - } - - new_token_id = llama_sampler_sample(llama_sampler_, llama_ctx_, -1); - - if (llama_token_is_eog(llama_model_, new_token_id)) { - break; - } - - llama_sampler_accept(llama_sampler_, new_token_id); - - std::array buf; - int32_t len = llama_token_to_piece(llama_model_, new_token_id, buf.data(), buf.size(), 0, true); - if (len < 0) { - throw std::logic_error("failed to convert to text"); - } - gsl_Assert(len < 128); - - std::string_view token_str{buf.data(), gsl::narrow(len)}; - std::cout << token_str << std::flush; - text += token_str; - - batch = llama_batch_get_one(&new_token_id, 1); - } + llama_ctx_->generate(input, [&] (std::string_view token) { + text += token; + return true; + }); logger_->log_debug("AI model output: {}", text); @@ -316,13 +224,7 @@ void LlamaCppProcessor::onTrigger(core::ProcessContext& context, core::ProcessSe } void LlamaCppProcessor::notifyStop() { - llama_sampler_free(llama_sampler_); - llama_sampler_ = nullptr; - llama_free(llama_ctx_); - llama_ctx_ = nullptr; - llama_free_model(llama_model_); - llama_model_ = nullptr; - llama_backend_free(); + llama_ctx_.reset(); } REGISTER_RESOURCE(LlamaCppProcessor, Processor); diff --git a/extensions/llamacpp/processors/LlamaCppProcessor.h b/extensions/llamacpp/processors/LlamaCppProcessor.h index c22d17ae00..c14782c257 100644 --- a/extensions/llamacpp/processors/LlamaCppProcessor.h +++ b/extensions/llamacpp/processors/LlamaCppProcessor.h @@ -20,9 +20,7 @@ #include "core/Processor.h" #include "core/logging/LoggerFactory.h" #include "core/PropertyDefinitionBuilder.h" -#pragma push_macro("DEPRECATED") -#include "llama.h" -#pragma pop_macro("DEPRECATED") +#include "LlamaContext.h" namespace org::apache::nifi::minifi::processors { @@ -112,9 +110,7 @@ What now follows is a description of how the user would like you to transform/ro std::string full_prompt_; std::vector examples_; - llama_sampler* llama_sampler_{nullptr}; - llama_model* llama_model_{nullptr}; - llama_context* llama_ctx_{nullptr}; + std::unique_ptr llama_ctx_; }; } // namespace org::apache::nifi::minifi::processors diff --git a/extensions/llamacpp/tests/CMakeLists.txt b/extensions/llamacpp/tests/CMakeLists.txt new file mode 100644 index 0000000000..d1cd79e337 --- /dev/null +++ b/extensions/llamacpp/tests/CMakeLists.txt @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +file(GLOB LLAMACPP_TESTS "*.cpp") + +SET(EXTENSIONS_TEST_COUNT 0) +FOREACH(testfile ${LLAMACPP_TESTS}) + get_filename_component(testfilename "${testfile}" NAME_WE) + add_minifi_executable(${testfilename} "${testfile}") + target_include_directories(${testfilename} BEFORE PRIVATE "${CMAKE_SOURCE_DIR}/libminifi/include") + target_include_directories(${testfilename} BEFORE PRIVATE "${CMAKE_SOURCE_DIR}/extensions/llamacpp/processors") + createTests(${testfilename}) + target_link_libraries(${testfilename} Catch2WithMain) + target_link_libraries(${testfilename} minifi-llamacpp) + target_link_libraries(${testfilename} minifi-standard-processors) + target_compile_definitions("${testfilename}" PRIVATE TZ_DATA_DIR="${CMAKE_BINARY_DIR}/tzdata") + + MATH(EXPR EXTENSIONS_TEST_COUNT "${EXTENSIONS_TEST_COUNT}+1") + add_test(NAME ${testfilename} COMMAND ${testfilename} WORKING_DIRECTORY ${TEST_DIR}) +ENDFOREACH() +message("-- Finished building ${EXTENSIONS_TEST_COUNT} llama.cpp related test file(s)...") diff --git a/extensions/llamacpp/tests/LlamaCppTests.cpp b/extensions/llamacpp/tests/LlamaCppTests.cpp new file mode 100644 index 0000000000..7d030455e4 --- /dev/null +++ b/extensions/llamacpp/tests/LlamaCppTests.cpp @@ -0,0 +1,65 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "unit/TestBase.h" +#include "unit/Catch.h" +#include "LlamaCppProcessor.h" +#include "unit/SingleProcessorTestController.h" +#include "core/FlowFile.h" + +namespace minifi = org::apache::nifi::minifi; + +class MockLlamaContext : public minifi::processors::llamacpp::LlamaContext { + public: + std::string applyTemplate(const std::vector& messages) override { + return "Test Message"; + } + void generate(const std::string& input, std::function cb) override { + cb( + "attributes:\n" + " a: 1\n" + " b: 2\n" + "content:\n" + " Test content\n" + "relationship:\n" + " banana\n" + "attributes:\n" + "messed up result\n" + ); + } + ~MockLlamaContext() override = default; +}; + +TEST_CASE("Output is correctly parsed and routed") { + minifi::processors::llamacpp::LlamaContext::testSetProvider([] (const std::filesystem::path&, float) {return std::make_unique();}); + minifi::test::SingleProcessorTestController controller(std::make_unique("LlamaCppProcessor")); + controller.addDynamicRelationship("banana"); + controller.getProcessor()->setProperty(minifi::processors::LlamaCppProcessor::ModelName, "Dummy model"); + controller.getProcessor()->setProperty(minifi::processors::LlamaCppProcessor::Prompt, "Do whatever"); + controller.getProcessor()->setProperty(minifi::processors::LlamaCppProcessor::Examples, "[]"); + + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "some data", .attributes = {}}); + CHECK(results.size() == 2); + CHECK(results[core::Relationship{"malformed", ""}].size() == 1); + auto outputs = results[core::Relationship{"banana", ""}]; + REQUIRE(outputs.size() == 1); + CHECK(controller.plan->getContent(outputs[0]) == "Test content"); + CHECK(outputs[0]->getAttribute("a") == "1"); + CHECK(outputs[0]->getAttribute("b") == "2"); +}