diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f95d8ba..17f9afd9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,6 +44,10 @@ else() endif() endif() +add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp) +add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp) +add_executable(store_demo ${CMAKE_CURRENT_LIST_DIR}/demo/store_demo.cpp) + if (BUILD_FOR_ANDROID) add_library(MNN SHARED IMPORTED) add_library(MNN_Express SHARED IMPORTED) @@ -57,22 +61,20 @@ if (BUILD_FOR_ANDROID) PROPERTIES IMPORTED_LOCATION ${CMAKE_CURRENT_LIST_DIR}/libs/libMNN_Express.so ) - # just cli demo - add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp) target_link_libraries(cli_demo llm log) + target_link_libraries(embedding_demo llm log) + target_link_libraries(store_demo llm log) else() - # cli demo - add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp) + # web demo add_executable(web_demo ${CMAKE_CURRENT_LIST_DIR}/demo/web_demo.cpp) - add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp) + target_link_libraries(cli_demo llm) + target_link_libraries(embedding_demo llm) + target_link_libraries(store_demo llm) if (MSVC) - target_link_libraries(cli_demo llm) target_link_libraries(web_demo llm pthreadVC2) # copy all lib to target dir file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/libs/ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/Debug/) else() - target_link_libraries(cli_demo llm) - target_link_libraries(embedding_demo llm) target_link_libraries(web_demo llm pthread) endif() endif() diff --git a/demo/store_demo.cpp b/demo/store_demo.cpp new file mode 100644 index 00000000..e6a42aac --- /dev/null +++ b/demo/store_demo.cpp @@ -0,0 +1,50 @@ +// +// store_demo.cpp +// +// Created by MNN on 2024/01/10. +// ZhaodeWang +// + +#include "llm.hpp" +#include +#include + +static void dumpVARP(VARP var) { + auto size = var->getInfo()->size; + auto ptr = var->readMap(); + printf("[ "); + for (int i = 0; i < 5; i++) { + printf("%f, ", ptr[i]); + } + printf("... "); + for (int i = size - 5; i < size; i++) { + printf("%f, ", ptr[i]); + } + printf(" ]\n"); +} + +int main(int argc, const char* argv[]) { + if (argc < 2) { + std::cout << "Usage: " << argv[0] << " model.mnn" << std::endl; + return 0; + } + std::string model_dir = argv[1]; + std::cout << "model path is " << model_dir << std::endl; + std::unique_ptr store(new TextVectorStore); + store->bench(); + std::shared_ptr embedding(Embedding::createEmbedding(model_dir)); + embedding->load(model_dir); + std::vector texts = { + "在春暖花开的季节,走在樱花缤纷的道路上,人们纷纷拿出手机拍照留念。樱花树下,情侣手牵手享受着这绝美的春光。孩子们在树下追逐嬉戏,脸上洋溢着纯真的笑容。春天的气息在空气中弥漫,一切都显得那么生机勃勃,充满希望。", + "春天到了,樱花树悄然绽放,吸引了众多游客前来观赏。小朋友们在花瓣飘落的树下玩耍,而恋人们则在这浪漫的景色中尽情享受二人世界。每个人的脸上都挂着幸福的笑容,仿佛整个世界都被春天温暖的阳光和满树的樱花渲染得更加美好。", + "在炎热的夏日里,沙滩上的游客们穿着泳装享受着海水的清凉。孩子们在海边堆沙堡,大人们则在太阳伞下品尝冷饮,享受悠闲的时光。远处,冲浪者们挑战着波涛,体验着与海浪争斗的刺激。夏天的海滩,总是充满了活力和热情。" + }; + store->set_embedding(embedding); + store->add_texts(texts); + std::string text = "春风轻拂过,公园里的花朵竞相开放,五彩斑斓地装点着大自然。游人如织,他们带着相机记录下这些美丽的瞬间。孩童们在花海中欢笑玩耍,无忧无虑地享受着春日的温暖。情侣们依偎在一起,沉醉于这迷人的季节。春天带来了新生与希望的讯息,让人心情愉悦,充满了对未来的美好憧憬。"; + auto similar_texts = store->search_similar_texts(text, 1); + for (const auto& text : similar_texts) { + std::cout << text << std::endl; + } + return 0; +} diff --git a/include/llm.hpp b/include/llm.hpp index ff8e73e9..ba76ed48 100644 --- a/include/llm.hpp +++ b/include/llm.hpp @@ -235,6 +235,7 @@ class Embedding { void load(const std::string& model_dir); VARP embedding(const std::string& txt); void print_speed(); + int dim() { return hidden_size_; } public: // time int64_t embedding_us_ = 0; @@ -276,4 +277,28 @@ class Bge : public Embedding { // Embedding end +// TextVectorStore strat +class TextVectorStore { +public: + TextVectorStore() {} + ~TextVectorStore() {} + static TextVectorStore* load(const std::string& path); + void set_embedding(std::shared_ptr embedding) { + embedding_ = embedding; + } + void save(const std::string& path); + void add_text(const std::string& text); + void add_texts(const std::vector& texts); + std::vector search_similar_texts(const std::string& txt, int topk = 1); + void bench(); +protected: + inline VARP text2vector(const std::string& text); +private: + std::shared_ptr embedding_; + VARP vectors_; + std::vector texts_; + int dim_ = 1024; +}; +// TextVectorStore end + #endif // LLM_hpp diff --git a/src/llm.cpp b/src/llm.cpp index a89ddfb9..79637832 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -685,7 +685,7 @@ bool Llama2_7b::is_stop(int token_id) { // Embedding start float Embedding::dist(VARP var0, VARP var1) { - auto distVar = _ReduceSum(_Square(var0 - var1)); + auto distVar = _Sqrt(_ReduceSum(_Square(var0 - var1))); auto dist = distVar->readMap()[0]; return dist; } @@ -788,4 +788,81 @@ VARP Bge::gen_position_ids(int seq_len) { } return position_ids; } -// Embedding end \ No newline at end of file +// Embedding end + +// TextVectorStore strat + +TextVectorStore* TextVectorStore::load(const std::string& path) { + auto vars = Variable::load(path.c_str()); + return nullptr; + // TODO +} + +void TextVectorStore::save(const std::string& path) { + std::vector vars; + Variable::save(vars, path.c_str()); + // TODO +} + +void TextVectorStore::add_text(const std::string& text) { + auto vector = text2vector(text); + texts_.push_back(text); + if (vectors_ == nullptr) { + vectors_ = vector; + } else { + vectors_ = _Concat({vectors_, vector}, 0); + } +} + +void TextVectorStore::add_texts(const std::vector& texts) { + for (const auto& text : texts) { + add_text(text); + } +} + +std::vector TextVectorStore::search_similar_texts(const std::string& text, int topk) { + auto vector = text2vector(text); + auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vector), {-1})); + auto indices = _Sort(dist, 0, true); + auto ptr = dist->readMap(); + auto iptr = indices->readMap(); + auto idx_ptr = indices->readMap(); + std::vector res; + for (int i = 0; i < topk; i++) { + int pos = idx_ptr[i]; + if (pos >= 0 && pos < texts_.size()) { + res.push_back(texts_[pos]); + } + } + return res; +} + +void TextVectorStore::bench() { + const int n = 50000; + const int d = 1024; + std::vector shape0_v = {n, d}; + std::vector shape1_v = {1, d}; + auto shape0 = _Const(shape0_v.data(), {2}); + auto shape1 = _Const(shape1_v.data(), {2}); + vectors_ = _RandomUnifom(shape0, halide_type_of()); + auto vec = _RandomUnifom(shape1, halide_type_of()); + auto start = std::chrono::high_resolution_clock::now(); + auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vec), {-1})); + auto ptr = dist->readMap(); + auto indices = _Sort(dist, 0, true); + auto iptr = indices->readMap(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + std::cout << "search took " << duration.count() << " milliseconds." << std::endl; + for (int i = 0; i < 5; i++) { + printf("index: %d, distance: %f\n", iptr[i], ptr[iptr[i]]); + } + vectors_ = nullptr; +} + +VARP TextVectorStore::text2vector(const std::string& text) { + auto vector = embedding_->embedding(text); + return vector; +} + +// TextVectorStore end \ No newline at end of file