MNN/transformers/llm/engine/app/mls.cpp

446 lines
12 KiB
C++

//
// mls.cpp
//
// Created by MNN on 2023/03/24.
// Jinde.Song
// LLM command line tool, based on llm_demo.cpp
//
#include "llm/llm.hpp"
#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>
#include <fstream>
#include <cstdlib>
#include "file_utils.hpp"
#include "remote_model_downloader.hpp"
#include "llm_benchmark.hpp"
#include "mls_server.hpp"
using namespace MNN::Transformer;
static void tuning_prepare(Llm *llm) {
MNN_PRINT("Prepare for tuning opt Begin\n");
llm->tuning(OP_ENCODER_NUMBER, {1, 5, 10, 20, 30, 50, 100});
MNN_PRINT("Prepare for tuning opt End\n");
}
static std::unique_ptr<Llm> create_and_prepare_llm(const char *config_path, bool use_template) {
std::unique_ptr<Llm> llm(Llm::createLLM(config_path));
if (use_template) {
llm->set_config("{\"tmp_path\":\"tmp\"}");
} else {
llm->set_config("{\"tmp_path\":\"tmp\",\"use_template\":false}");
}
{
AUTOTIME;
llm->load();
}
if (true)
{
AUTOTIME;
tuning_prepare(llm.get());
}
return llm;
}
int list_local_models(const std::string &directory_path, std::vector<std::string> &model_names, bool sort = true) {
std::error_code ec;
if (!fs::exists(directory_path, ec)) {
return 1;
}
if (!fs::is_directory(directory_path, ec)) {
return 1;
}
for (const auto &entry : fs::directory_iterator(directory_path, ec)) {
if (ec) {
return 1;
}
if (fs::is_symlink(entry, ec)) {
if (ec) {
return 1;
}
std::string file_name = entry.path().filename().string();
model_names.emplace_back(file_name);
}
}
if (sort) {
std::sort(model_names.begin(), model_names.end());
}
return 0;
}
static int eval_prompts(Llm *llm, const std::vector<std::string> &prompts) {
int prompt_len = 0;
int decode_len = 0;
int64_t vision_time = 0;
int64_t audio_time = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
// llm->warmup();
for (int i = 0; i < prompts.size(); i++)
{
const auto &prompt = prompts[i];
// prompt start with '#' will be ignored
if (prompt.substr(0, 1) == "#")
{
continue;
}
llm->response(prompt);
auto context = llm->getContext();
prompt_len += context->prompt_len;
decode_len += context->gen_seq_len;
prefill_time += context->prefill_us;
decode_time += context->decode_us;
}
float vision_s = vision_time / 1e6;
float audio_s = audio_time / 1e6;
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
printf("\n#################################\n");
printf("prompt tokens num = %d\n", prompt_len);
printf("decode tokens num = %d\n", decode_len);
printf(" vision time = %.2f s\n", vision_s);
printf(" audio time = %.2f s\n", audio_s);
printf("prefill time = %.2f s\n", prefill_s);
printf(" decode time = %.2f s\n", decode_s);
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
printf("##################################\n");
return 0;
}
static int eval_file(Llm *llm, std::string prompt_file) {
std::cout << "prompt file is " << prompt_file << std::endl;
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt))
{
if (prompt.back() == '\r')
{
prompt.pop_back();
}
prompts.push_back(prompt);
}
prompt_fs.close();
if (prompts.empty())
{
return 1;
}
return eval_prompts(llm, prompts);
}
static int print_usage() {
std::cout << "Available Commands:" << std::endl;
std::cout << " mls list: list all local model names" << std::endl;
std::cout << " mls search keyword: search all available remote models by key" << std::endl;
std::cout << " mls download model_name : download the model" << std::endl;
std::cout << " mls run model_name : download the model" << std::endl;
std::cout << " mls benchmark: model_name test benchmark of a model" << std::endl;
std::cout << " mls serve: serve with openai compatible api" << std::endl;
std::cout << " mls delete model_name: remove the download model" << std::endl;
return 0;
}
// list files int the directory of ~/.cache/modelscope/hub/MNN/Qwen-7B-Chat-MNN/
static int list_models(int argc, const char *argv[]) {
std::vector<std::string> model_names;
list_local_models(mls::FileUtils::GetBaseCacheDir(), model_names);
if (!model_names.empty()) {
for (auto &name : model_names)
{
printf("%s\n", name.c_str());
}
} else {
printf("no local models; use \'mls search\' to search remote models and download\n");
}
return 0;
}
static bool IsR1(const std::string& path) {
std::string lowerModelName = path;
std::transform(lowerModelName.begin(), lowerModelName.end(), lowerModelName.begin(), ::tolower);
return lowerModelName.find("deepseek-r1") != std::string::npos;
}
static int serve(int argc, const char *argv[]) {
bool invalid_param{false};
std::string config_path{};
std::string arg{};
if (argc < 3) {
print_usage();
return 1;
}
arg = argv[2];
if (arg.find('-') != 0) {
config_path = (fs::path(mls::FileUtils::GetBaseCacheDir()) / arg / "config.json").string();
}
for (int i = 2; i < argc; i++) {
arg = argv[i];
if (arg == "-c") {
if (++i >= argc) {
invalid_param = true;
break;
}
config_path = mls::FileUtils::ExpandTilde(argv[i]);
}
}
mls::MlsServer server;
bool is_r1 = IsR1(config_path);
auto llm = create_and_prepare_llm(config_path.c_str(), !is_r1);
server.Start(llm.get(), is_r1);
return 0;
}
static int benchmark(int argc, const char *argv[]) {
std::string arg{};
bool invalid_param{false};
std::string config_path{};
if (argc < 3)
{
print_usage();
return 1;
}
arg = argv[2];
if (arg.find('-') != 0)
{
config_path = mls::FileUtils::GetConfigPath(arg);
}
for (int i = 2; i < argc; i++)
{
arg = argv[i];
if (arg == "-c")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
config_path = argv[i];
}
}
if (invalid_param)
{
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
print_usage();
exit(1);
}
if (config_path.empty())
{
fprintf(stderr, "error: config path is empty\n");
print_usage();
exit(1);
}
auto llm = create_and_prepare_llm(config_path.c_str(), true);
mls::LLMBenchmark benchmark;
benchmark.Start(llm.get(), {});
return 0;
}
void chat(Llm* llm) {
ChatMessages messages;
messages.emplace_back("system", "You are a helpful assistant.");
auto context = llm->getContext();
while (true) {
std::cout << "\nUser: ";
std::string user_str;
std::getline(std::cin, user_str);
if (user_str == "/exit") {
return;
}
if (user_str == "/reset") {
llm->reset();
std::cout << "\nA: reset done." << std::endl;
continue;
}
messages.emplace_back("user", user_str);
std::cout << "\nA: " << std::flush;
llm->response(messages);
auto assistant_str = context->generate_str;
messages.emplace_back("assistant", assistant_str);
}
}
static int run(int argc, const char *argv[]) {
std::cout << "Start run..." << std::endl;
std::string arg{};
std::string config_path{};
std::string prompt;
std::string prompt_file;
bool invalid_param = false;
if (argc < 3)
{
print_usage();
return 1;
}
arg = argv[2];
if (arg.find('-') != 0)
{
config_path = mls::FileUtils::GetConfigPath(arg);
}
for (int i = 2; i < argc; i++)
{
arg = argv[i];
if (arg == "-c")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
config_path = mls::FileUtils::ExpandTilde(argv[i]);
}
else if (arg == "-p")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
prompt = argv[i];
}
else if (arg == "-pf")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
prompt_file = argv[i];
}
}
if (invalid_param)
{
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
print_usage();
exit(1);
}
if (config_path.empty())
{
fprintf(stderr, "error: config path is empty\n");
print_usage();
exit(1);
}
std::cout << "config path is " << config_path << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(config_path));
llm->set_config("{\"tmp_path\":\"tmp\"}");
{
AUTOTIME;
llm->load();
}
if (true)
{
AUTOTIME;
tuning_prepare(llm.get());
}
if (prompt.empty() && prompt_file.empty())
{
chat(llm.get());
}
else if (!prompt.empty())
{
eval_prompts(llm.get(), {prompt});
}
else
{
eval_file(llm.get(), prompt_file);
}
return 0;
}
int download(int argc, const char *argv[]) {
if (argc < 3)
{
print_usage();
return 1;
}
std::string repo_name = argv[2];
std::cout << "download repo: " << repo_name << std::endl;
mls::HfApiClient api_client;
std::string error_info;
if (repo_name.find("taobao-mnn/") != 0)
{
repo_name = "taobao-mnn/" + repo_name;
}
const auto repo_info = api_client.GetRepoInfo(repo_name, "main", error_info);
if (!error_info.empty())
{
std::cout << "get repo info error: " << error_info << std::endl;
return 1;
}
api_client.DownloadRepo(repo_info);
return 0;
}
int search(int argc, const char *argv[]) {
if (argc < 3)
{
print_usage();
return 1;
}
const std::string key = argv[2];
mls::HfApiClient client = mls::HfApiClient();
auto repos = std::move(client.SearchRepos(key));
for (auto &repo : repos)
{
auto pos = repo.model_id.rfind('/');
if (pos != std::string::npos)
{
printf("%s\n", repo.model_id.substr(pos + 1).c_str());
}
}
return 0;
}
int delete_model(int argc, const char *argv[]) {
if (argc < 3)
{
print_usage();
return 1;
}
std::string model_name = argv[2];
std::string linker_path = mls::FileUtils::GetFolderLinkerPath(model_name);
mls::FileUtils::RemoveFileIfExists(linker_path);
if (model_name.find("taobao-mnn") != 0)
{
model_name = "taobao-mnn/" + model_name;
}
std::string storage_path = mls::FileUtils::GetStorageFolderPath(model_name);
mls::FileUtils::RemoveFileIfExists(storage_path);
return 0;
}
int main(int argc, const char *argv[]) {
if (argc < 2) {
print_usage();
return 0;
}
std::string cmd = argv[1];
if (cmd == "list") {
list_models(argc, argv);
}
else if (cmd == "serve") {
serve(argc, argv);
}
else if (cmd == "run") {
run(argc, argv);
}
else if (cmd == "benchmark") {
benchmark(argc, argv);
}
else if (cmd == "download") {
download(argc, argv);
}
else if (cmd == "search") {
search(argc, argv);
}
else if (cmd == "delete") {
delete_model(argc, argv);
}
else {
print_usage();
}
return 0;
}