Skip to content

Commit

Permalink
add config
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed May 9, 2024
1 parent 0f25187 commit 0a6c344
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 190 deletions.
2 changes: 1 addition & 1 deletion demo/cli_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ int main(int argc, const char* argv[]) {
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
llm->load();
if (argc < 3) {
llm->chat();
}
Expand Down
2 changes: 1 addition & 1 deletion demo/memory_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int main(int argc, const char* argv[]) {
if (argc == 4) {
auto llm_dir = argv[3];
std::shared_ptr<Llm> llm(Llm::createLLM(llm_dir));
llm->load(llm_dir);
llm->load();
chat_memory->summarize(llm);
chat_memory->save(memory_dir);
}
Expand Down
6 changes: 4 additions & 2 deletions demo/tokenizer_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ int main(int argc, const char* argv[]) {
std::unique_ptr<Tokenizer> tokenizer_(new Tiktoken);
tokenizer_->load(tokenizer_path);
const std::string system_str = "Youare a helpful assistant.";
const std::string user_str = "<|endoftext|>";
const std::string user_str = "Hello";
// const std::string query = "\n<|im_start|>system\n" + system_str + "<|im_end|>\n<|im_start|>\n" + user_str + "<|im_end|>\n<|im_start|>assistant\n";
const std::string query = system_str + "\n" + user_str;
const std::string query = "\n<|im_start|>user\n" + user_str + "<|im_end|>\n<|im_start|>assistant\n";
// const std::string query = system_str + "\n" + user_str;
auto tokens = tokenizer_->encode(query);

std::string decode_str;
printf("encode tokens = [ ");
for (auto token : tokens) {
printf("%d, ", token);
decode_str += tokenizer_->decode(token);
}
printf("]\n");
Expand Down
2 changes: 1 addition & 1 deletion demo/web_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ int main(int argc, const char* argv[]) {
std::string web_dir = argv[2];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
llm->load();

std::stringstream ss;
httplib::Server svr;
Expand Down
90 changes: 82 additions & 8 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <memory>
#include <string>
#include <iostream>
#include <fstream>
#include <streambuf>
#include <functional>
#include <unordered_map>
Expand Down Expand Up @@ -63,6 +64,79 @@ struct Prompt {
std::vector<int> tokens;
};

class LlmConfig {
public:
LlmConfig() {}
LlmConfig(const std::string& dir) {
base_dir_ = dir + "/";
std::ifstream config_file(dir + "/config.json");
if (config_file.is_open()) {
config_ = json::parse(config_file);
} else {
std::cerr << "Unable to open config file: " << dir << std::endl;
}
}

std::string model_type() const {
return config_.value("model_type", "unknow");
}

std::string tokenizer_type() const {
return config_.value("tokenizer_type", "tiktoken");
}

std::string llm_model() const {
return base_dir_ + config_.value("llm_model", "llm.mnn");
}

std::string llm_weight() const {
return base_dir_ + config_.value("llm_weight", "llm.mnn.weight");
}

std::string embedding_file() const {
return base_dir_ + config_.value("embedding_file", "embeddings_bf16.bin");
}

std::string tokenizer_file() const {
return base_dir_ + config_.value("tokenizer_file", "tokenizer.txt");
}

int hidden_size() const {
return config_.value("hidden_size", 4096);
}

std::vector<int> key_value_shape() const {
return config_.value("key_value_shape", std::vector<int>{});
}

std::vector<int> stop_ids() const {
return config_.value("stop_ids", std::vector<int>{});
}

std::string prompt_template() const {
return config_.value("prompt_template", "");
}

std::string backend_type() const {
return config_.value("backend_type", "cpu");
}

int thread_num() const {
return config_.value("thread_num", 4);
}

std::string precision() const {
return config_.value("precision", "low");
}

std::string memory() const {
return config_.value("memory", "low");
}
private:
std::string base_dir_;
json config_;
};

class Llm {
public:
Llm() {
Expand All @@ -75,7 +149,7 @@ class Llm {
runtime_manager_.reset();
}
static Llm* createLLM(const std::string& path, std::string model_type = "auto");
void load(const std::string& model_dir);
void load();
void chat();
void warmup();
std::string response(const std::string& input_str, std::ostream* os = &std::cout, const char* end_with = nullptr);
Expand Down Expand Up @@ -104,6 +178,7 @@ class Llm {
// time
int64_t prefill_us_ = 0;
int64_t decode_us_ = 0;
LlmConfig config_;
protected:
VARP embedding(const std::vector<int>& input_ids);
VARP txt_embedding(const std::vector<int>& input_ids);
Expand All @@ -112,24 +187,23 @@ class Llm {
protected:
VARP inputs_embeds_, attention_mask_, position_ids_;
// model configs
bool is_single_ = false;
bool is_disk_embedding_ = false;
bool is_single_ = true;
bool is_disk_embedding_ = true;
bool is_visual_ = false;
int layer_nums_ = 0;
int hidden_size_ = 4096;
std::vector<int> key_value_shape_ = {};
std::string disk_embedding_file_ = "";
// gen info
float load_progress_ = 0.f;
// tokenizer
std::unique_ptr<Tokenizer> tokenizer_;
std::shared_ptr<Module> visual_module_;
private:
virtual VARP visual_embedding(const std::vector<int>& input_ids) { return nullptr; }
virtual std::vector<int> tokenizer(const std::string& query) = 0;
virtual VARP gen_attention_mask(int seq_len) = 0;
virtual VARP gen_position_ids(int seq_len) = 0;
virtual bool is_stop(int token_id) = 0;
virtual std::vector<int> tokenizer(const std::string& query);
virtual VARP gen_attention_mask(int seq_len);
virtual VARP gen_position_ids(int seq_len);
virtual bool is_stop(int token_id);
private:
// MNN Modules
std::shared_ptr<Executor::RuntimeManager> runtime_manager_;
Expand Down
1 change: 1 addition & 0 deletions include/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Tokenizer {
public:
Tokenizer() = default;
virtual ~Tokenizer() = default;
static Tokenizer* createTokenizer(const std::string& type);
virtual bool load(const std::string& filename) = 0;
virtual std::vector<int> encode(const std::string& str) = 0;
virtual std::string decode(int id) = 0;
Expand Down
Loading

0 comments on commit 0a6c344

Please sign in to comment.