Skip to content

Commit

Permalink
support TextVectorStore.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jan 10, 2024
1 parent b399e55 commit 4cb5c44
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 10 deletions.
18 changes: 10 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ else()
endif()
endif()

add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp)
add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp)
add_executable(store_demo ${CMAKE_CURRENT_LIST_DIR}/demo/store_demo.cpp)

if (BUILD_FOR_ANDROID)
add_library(MNN SHARED IMPORTED)
add_library(MNN_Express SHARED IMPORTED)
Expand All @@ -57,22 +61,20 @@ if (BUILD_FOR_ANDROID)
PROPERTIES IMPORTED_LOCATION
${CMAKE_CURRENT_LIST_DIR}/libs/libMNN_Express.so
)
# just cli demo
add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp)
target_link_libraries(cli_demo llm log)
target_link_libraries(embedding_demo llm log)
target_link_libraries(store_demo llm log)
else()
# cli demo
add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp)
# web demo
add_executable(web_demo ${CMAKE_CURRENT_LIST_DIR}/demo/web_demo.cpp)
add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp)
target_link_libraries(cli_demo llm)
target_link_libraries(embedding_demo llm)
target_link_libraries(store_demo llm)
if (MSVC)
target_link_libraries(cli_demo llm)
target_link_libraries(web_demo llm pthreadVC2)
# copy all lib to target dir
file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/libs/ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/Debug/)
else()
target_link_libraries(cli_demo llm)
target_link_libraries(embedding_demo llm)
target_link_libraries(web_demo llm pthread)
endif()
endif()
50 changes: 50 additions & 0 deletions demo/store_demo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//
// store_demo.cpp
//
// Created by MNN on 2024/01/10.
// ZhaodeWang
//

#include "llm.hpp"
#include <fstream>
#include <stdlib.h>

static void dumpVARP(VARP var) {
auto size = var->getInfo()->size;
auto ptr = var->readMap<float>();
printf("[ ");
for (int i = 0; i < 5; i++) {
printf("%f, ", ptr[i]);
}
printf("... ");
for (int i = size - 5; i < size; i++) {
printf("%f, ", ptr[i]);
}
printf(" ]\n");
}

int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cout << "Usage: " << argv[0] << " model.mnn" << std::endl;
return 0;
}
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<TextVectorStore> store(new TextVectorStore);
store->bench();
std::shared_ptr<Embedding> embedding(Embedding::createEmbedding(model_dir));
embedding->load(model_dir);
std::vector<std::string> texts = {
"在春暖花开的季节,走在樱花缤纷的道路上,人们纷纷拿出手机拍照留念。樱花树下,情侣手牵手享受着这绝美的春光。孩子们在树下追逐嬉戏,脸上洋溢着纯真的笑容。春天的气息在空气中弥漫,一切都显得那么生机勃勃,充满希望。",
"春天到了,樱花树悄然绽放,吸引了众多游客前来观赏。小朋友们在花瓣飘落的树下玩耍,而恋人们则在这浪漫的景色中尽情享受二人世界。每个人的脸上都挂着幸福的笑容,仿佛整个世界都被春天温暖的阳光和满树的樱花渲染得更加美好。",
"在炎热的夏日里,沙滩上的游客们穿着泳装享受着海水的清凉。孩子们在海边堆沙堡,大人们则在太阳伞下品尝冷饮,享受悠闲的时光。远处,冲浪者们挑战着波涛,体验着与海浪争斗的刺激。夏天的海滩,总是充满了活力和热情。"
};
store->set_embedding(embedding);
store->add_texts(texts);
std::string text = "春风轻拂过,公园里的花朵竞相开放,五彩斑斓地装点着大自然。游人如织,他们带着相机记录下这些美丽的瞬间。孩童们在花海中欢笑玩耍,无忧无虑地享受着春日的温暖。情侣们依偎在一起,沉醉于这迷人的季节。春天带来了新生与希望的讯息,让人心情愉悦,充满了对未来的美好憧憬。";
auto similar_texts = store->search_similar_texts(text, 1);
for (const auto& text : similar_texts) {
std::cout << text << std::endl;
}
return 0;
}
25 changes: 25 additions & 0 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class Embedding {
void load(const std::string& model_dir);
VARP embedding(const std::string& txt);
void print_speed();
int dim() { return hidden_size_; }
public:
// time
int64_t embedding_us_ = 0;
Expand Down Expand Up @@ -276,4 +277,28 @@ class Bge : public Embedding {

// Embedding end

// TextVectorStore strat
class TextVectorStore {
public:
TextVectorStore() {}
~TextVectorStore() {}
static TextVectorStore* load(const std::string& path);
void set_embedding(std::shared_ptr<Embedding> embedding) {
embedding_ = embedding;
}
void save(const std::string& path);
void add_text(const std::string& text);
void add_texts(const std::vector<std::string>& texts);
std::vector<std::string> search_similar_texts(const std::string& txt, int topk = 1);
void bench();
protected:
inline VARP text2vector(const std::string& text);
private:
std::shared_ptr<Embedding> embedding_;
VARP vectors_;
std::vector<std::string> texts_;
int dim_ = 1024;
};
// TextVectorStore end

#endif // LLM_hpp
81 changes: 79 additions & 2 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ bool Llama2_7b::is_stop(int token_id) {

// Embedding start
float Embedding::dist(VARP var0, VARP var1) {
auto distVar = _ReduceSum(_Square(var0 - var1));
auto distVar = _Sqrt(_ReduceSum(_Square(var0 - var1)));
auto dist = distVar->readMap<float>()[0];
return dist;
}
Expand Down Expand Up @@ -788,4 +788,81 @@ VARP Bge::gen_position_ids(int seq_len) {
}
return position_ids;
}
// Embedding end
// Embedding end

// TextVectorStore strat

TextVectorStore* TextVectorStore::load(const std::string& path) {
auto vars = Variable::load(path.c_str());
return nullptr;
// TODO
}

void TextVectorStore::save(const std::string& path) {
std::vector<VARP> vars;
Variable::save(vars, path.c_str());
// TODO
}

void TextVectorStore::add_text(const std::string& text) {
auto vector = text2vector(text);
texts_.push_back(text);
if (vectors_ == nullptr) {
vectors_ = vector;
} else {
vectors_ = _Concat({vectors_, vector}, 0);
}
}

void TextVectorStore::add_texts(const std::vector<std::string>& texts) {
for (const auto& text : texts) {
add_text(text);
}
}

std::vector<std::string> TextVectorStore::search_similar_texts(const std::string& text, int topk) {
auto vector = text2vector(text);
auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vector), {-1}));
auto indices = _Sort(dist, 0, true);
auto ptr = dist->readMap<float>();
auto iptr = indices->readMap<int>();
auto idx_ptr = indices->readMap<int>();
std::vector<std::string> res;
for (int i = 0; i < topk; i++) {
int pos = idx_ptr[i];
if (pos >= 0 && pos < texts_.size()) {
res.push_back(texts_[pos]);
}
}
return res;
}

void TextVectorStore::bench() {
const int n = 50000;
const int d = 1024;
std::vector<int> shape0_v = {n, d};
std::vector<int> shape1_v = {1, d};
auto shape0 = _Const(shape0_v.data(), {2});
auto shape1 = _Const(shape1_v.data(), {2});
vectors_ = _RandomUnifom(shape0, halide_type_of<float>());
auto vec = _RandomUnifom(shape1, halide_type_of<float>());
auto start = std::chrono::high_resolution_clock::now();
auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vec), {-1}));
auto ptr = dist->readMap<float>();
auto indices = _Sort(dist, 0, true);
auto iptr = indices->readMap<int>();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cout << "search took " << duration.count() << " milliseconds." << std::endl;
for (int i = 0; i < 5; i++) {
printf("index: %d, distance: %f\n", iptr[i], ptr[iptr[i]]);
}
vectors_ = nullptr;
}

VARP TextVectorStore::text2vector(const std::string& text) {
auto vector = embedding_->embedding(text);
return vector;
}

// TextVectorStore end

0 comments on commit 4cb5c44

Please sign in to comment.