Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
adamdebreceni committed Jan 13, 2025
1 parent 8612676 commit 1df84a1
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 118 deletions.
2 changes: 1 addition & 1 deletion extensions/llamacpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ target_include_directories(minifi-llamacpp PUBLIC "${CMAKE_SOURCE_DIR}/extension

target_link_libraries(minifi-llamacpp ${LIBMINIFI} llamacpp)

register_extension(minifi-llamacpp "AI PROCESSORS" AI-PROCESSORS "Provides AI processors")
register_extension(minifi-llamacpp "LLAMACPP EXTENSION" LLAMACPP-EXTENSION "Provides LlamaCpp support" "extensions/llamacpp/tests")

register_extension_linter(minifi-llamacpp-linter)
148 changes: 148 additions & 0 deletions extensions/llamacpp/processors/LlamaContext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "LlamaContext.h"
#include "Exception.h"
#include "fmt/format.h"
#pragma push_macro("DEPRECATED")
#include "llama.h"
#pragma pop_macro("DEPRECATED")

namespace org::apache::nifi::minifi::processors::llamacpp {

static std::function<std::unique_ptr<LlamaContext>(const std::filesystem::path&, float)> test_provider;

void LlamaContext::testSetProvider(std::function<std::unique_ptr<LlamaContext>(const std::filesystem::path&, float)> provider) {
test_provider = provider;
}

class DefaultLlamaContext : public LlamaContext {
public:
DefaultLlamaContext(const std::filesystem::path& model_path, float temperature) {
llama_backend_init();

llama_model_params model_params = llama_model_default_params();
llama_model_ = llama_load_model_from_file(model_path.c_str(), model_params);
if (!llama_model_) {
throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION, fmt::format("Failed to load model from '{}'", model_path.c_str()));
}

llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 0;
llama_ctx_ = llama_new_context_with_model(llama_model_, ctx_params);

auto sparams = llama_sampler_chain_default_params();
llama_sampler_ = llama_sampler_chain_init(sparams);

llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_k(50));
llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_p(0.9, 1));
llama_sampler_chain_add(llama_sampler_, llama_sampler_init_temp(temperature));
llama_sampler_chain_add(llama_sampler_, llama_sampler_init_dist(1234));
}

std::string applyTemplate(const std::vector<LlamaChatMessage>& messages) override {
std::vector<llama_chat_message> msgs;
for (auto& msg : messages) {
msgs.push_back(llama_chat_message{.role = msg.role.c_str(), .content = msg.content.c_str()});
}
std::string text;
int32_t res_size = llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size());
if (res_size > gsl::narrow<int32_t>(text.size())) {
text.resize(res_size);
llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size());
}
text.resize(res_size);

// utils::string::replaceAll(text, "<NEWLINE_CHAR>", "\n");

return text;
}

void generate(const std::string& input, std::function<bool(std::string_view/*token*/)> cb) override {
std::vector<llama_token> enc_input = [&] {
int32_t n_tokens = input.length() + 2;
std::vector<llama_token> enc_input(n_tokens);
n_tokens = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true);
if (n_tokens < 0) {
enc_input.resize(-n_tokens);
int check = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true);
gsl_Assert(check == -n_tokens);
} else {
enc_input.resize(n_tokens);
}
return enc_input;
}();


llama_batch batch = llama_batch_get_one(enc_input.data(), enc_input.size());

llama_token new_token_id;

bool terminate = false;

while (!terminate) {
if (int32_t res = llama_decode(llama_ctx_, batch); res < 0) {
throw std::logic_error("failed to execute decode");
}

new_token_id = llama_sampler_sample(llama_sampler_, llama_ctx_, -1);

if (llama_token_is_eog(llama_model_, new_token_id)) {
break;
}

llama_sampler_accept(llama_sampler_, new_token_id);

std::array<char, 128> buf;
int32_t len = llama_token_to_piece(llama_model_, new_token_id, buf.data(), buf.size(), 0, true);
if (len < 0) {
throw std::logic_error("failed to convert to text");
}
gsl_Assert(len < 128);

std::string_view token_str{buf.data(), gsl::narrow<std::string_view::size_type>(len)};

batch = llama_batch_get_one(&new_token_id, 1);

terminate = cb(token_str);
}
}

~DefaultLlamaContext() override {
llama_sampler_free(llama_sampler_);
llama_sampler_ = nullptr;
llama_free(llama_ctx_);
llama_ctx_ = nullptr;
llama_free_model(llama_model_);
llama_model_ = nullptr;
llama_backend_free();
}

private:
llama_model* llama_model_{nullptr};
llama_context* llama_ctx_{nullptr};
llama_sampler* llama_sampler_{nullptr};
};

std::unique_ptr<LlamaContext> LlamaContext::create(const std::filesystem::path& model_path, float temperature) {
if (test_provider) {
return test_provider(model_path, temperature);
}
return std::make_unique<DefaultLlamaContext>(model_path, temperature);
}

} // namespace org::apache::nifi::minifi::processors::llamacpp
42 changes: 42 additions & 0 deletions extensions/llamacpp/processors/LlamaContext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <memory>
#include <filesystem>
#include <vector>
#include <string_view>
#include <string>
#include <functional>

namespace org::apache::nifi::minifi::processors::llamacpp {

struct LlamaChatMessage {
std::string role;
std::string content;
};

class LlamaContext {
public:
static void testSetProvider(std::function<std::unique_ptr<LlamaContext>(const std::filesystem::path&, float)> provider);
static std::unique_ptr<LlamaContext> create(const std::filesystem::path& model_path, float temperature);
virtual std::string applyTemplate(const std::vector<LlamaChatMessage>& messages) = 0;
virtual void generate(const std::string& input, std::function<bool(std::string_view/*token*/)> cb) = 0;
virtual ~LlamaContext() = default;
};

} // namespace org::apache::nifi::minifi::processors::llamacpp
124 changes: 13 additions & 111 deletions extensions/llamacpp/processors/LlamaCppProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,10 @@

#include "rapidjson/document.h"
#include "rapidjson/error/en.h"
#include "LlamaContext.h"

namespace org::apache::nifi::minifi::processors {

namespace {

struct LlamaChatMessage {
std::string role;
std::string content;

operator llama_chat_message() const {
return llama_chat_message{
.role = role.c_str(),
.content = content.c_str()
};
}
};

//constexpr const char* relationship_prompt = R"(You are a helpful assistant helping to analyze the user's description of a data transformation and routing algorithm.
//The data consists of attributes and a content encapsulated in what is called a flowfile.
//The routing targets are called relationships.
//You have to extract the comma separated list of all possible relationships one can route to based on the user's description.
//Output only the list and nothing else.
//)";

} // namespace

void LlamaCppProcessor::initialize() {
setSupportedProperties(Properties);
setSupportedRelationships(Relationships);
Expand Down Expand Up @@ -129,25 +107,7 @@ void LlamaCppProcessor::onSchedule(core::ProcessContext& context, core::ProcessS
examples_.push_back(LLMExample{.input = std::move(input), .output = std::move(output)});
}

llama_backend_init();

llama_model_params model_params = llama_model_default_params();
llama_model_ = llama_load_model_from_file(model_name_.c_str(), model_params);
if (!llama_model_) {
throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION, fmt::format("Failed to load model from '{}'", model_name_));
}

llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 0;
llama_ctx_ = llama_new_context_with_model(llama_model_, ctx_params);

auto sparams = llama_sampler_chain_default_params();
llama_sampler_ = llama_sampler_chain_init(sparams);

llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_k(50));
llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_p(0.9, 1));
llama_sampler_chain_add(llama_sampler_, llama_sampler_init_temp(gsl::narrow_cast<float>(temperature_)));
llama_sampler_chain_add(llama_sampler_, llama_sampler_init_dist(1234));
llama_ctx_ = llamacpp::LlamaContext::create(model_name_, gsl::narrow_cast<float>(temperature_));
}

void LlamaCppProcessor::onTrigger(core::ProcessContext& context, core::ProcessSession& session) {
Expand All @@ -172,76 +132,24 @@ void LlamaCppProcessor::onTrigger(core::ProcessContext& context, core::ProcessSe


std::string input = [&] {
std::vector<llama_chat_message> msgs;
msgs.push_back(llama_chat_message{.role = "system", .content = full_prompt_.c_str()});
std::vector<llamacpp::LlamaChatMessage> msgs;
msgs.push_back({.role = "system", .content = full_prompt_.c_str()});
for (auto& ex : examples_) {
msgs.push_back(llama_chat_message{.role = "user", .content = ex.input.c_str()});
msgs.push_back(llama_chat_message{.role = "assistant", .content = ex.output.c_str()});
}
msgs.push_back(llama_chat_message{.role = "user", .content = msg.c_str()});

std::string text;
int32_t res_size = llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size());
if (res_size > gsl::narrow<int32_t>(text.size())) {
text.resize(res_size);
llama_chat_apply_template(llama_model_, nullptr, msgs.data(), msgs.size(), true, text.data(), text.size());
msgs.push_back({.role = "user", .content = ex.input.c_str()});
msgs.push_back({.role = "assistant", .content = ex.output.c_str()});
}
text.resize(res_size);
msgs.push_back({.role = "user", .content = msg.c_str()});

// utils::string::replaceAll(text, "<NEWLINE_CHAR>", "\n");

return text;
return llama_ctx_->applyTemplate(msgs);
}();

logger_->log_debug("AI model input: {}", input);

std::vector<llama_token> enc_input = [&] {
int32_t n_tokens = input.length() + 2;
std::vector<llama_token> enc_input(n_tokens);
n_tokens = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true);
if (n_tokens < 0) {
enc_input.resize(-n_tokens);
int check = llama_tokenize(llama_model_, input.data(), input.length(), enc_input.data(), enc_input.size(), true, true);
gsl_Assert(check == -n_tokens);
} else {
enc_input.resize(n_tokens);
}
return enc_input;
}();


llama_batch batch = llama_batch_get_one(enc_input.data(), enc_input.size());

llama_token new_token_id;

std::string text;

while (true) {
if (int32_t res = llama_decode(llama_ctx_, batch); res < 0) {
throw std::logic_error("failed to execute decode");
}

new_token_id = llama_sampler_sample(llama_sampler_, llama_ctx_, -1);

if (llama_token_is_eog(llama_model_, new_token_id)) {
break;
}

llama_sampler_accept(llama_sampler_, new_token_id);

std::array<char, 128> buf;
int32_t len = llama_token_to_piece(llama_model_, new_token_id, buf.data(), buf.size(), 0, true);
if (len < 0) {
throw std::logic_error("failed to convert to text");
}
gsl_Assert(len < 128);

std::string_view token_str{buf.data(), gsl::narrow<std::string_view::size_type>(len)};
std::cout << token_str << std::flush;
text += token_str;

batch = llama_batch_get_one(&new_token_id, 1);
}
llama_ctx_->generate(input, [&] (std::string_view token) {
text += token;
return true;
});

logger_->log_debug("AI model output: {}", text);

Expand Down Expand Up @@ -316,13 +224,7 @@ void LlamaCppProcessor::onTrigger(core::ProcessContext& context, core::ProcessSe
}

void LlamaCppProcessor::notifyStop() {
llama_sampler_free(llama_sampler_);
llama_sampler_ = nullptr;
llama_free(llama_ctx_);
llama_ctx_ = nullptr;
llama_free_model(llama_model_);
llama_model_ = nullptr;
llama_backend_free();
llama_ctx_.reset();
}

REGISTER_RESOURCE(LlamaCppProcessor, Processor);
Expand Down
8 changes: 2 additions & 6 deletions extensions/llamacpp/processors/LlamaCppProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
#include "core/Processor.h"
#include "core/logging/LoggerFactory.h"
#include "core/PropertyDefinitionBuilder.h"
#pragma push_macro("DEPRECATED")
#include "llama.h"
#pragma pop_macro("DEPRECATED")
#include "LlamaContext.h"

namespace org::apache::nifi::minifi::processors {

Expand Down Expand Up @@ -112,9 +110,7 @@ What now follows is a description of how the user would like you to transform/ro
std::string full_prompt_;
std::vector<LLMExample> examples_;

llama_sampler* llama_sampler_{nullptr};
llama_model* llama_model_{nullptr};
llama_context* llama_ctx_{nullptr};
std::unique_ptr<llamacpp::LlamaContext> llama_ctx_;
};

} // namespace org::apache::nifi::minifi::processors
Loading

0 comments on commit 1df84a1

Please sign in to comment.