mirror of https://github.com/alibaba/MNN.git
265 lines
9.8 KiB
C++
265 lines
9.8 KiB
C++
//
|
|
// Created by ruoyi.sjd on 2024/12/25.
|
|
// Copyright (c) 2024 Alibaba Group Holding Limited All rights reserved.
|
|
//
|
|
|
|
#include "mls_server.hpp"
|
|
#include <iostream>
|
|
|
|
|
|
|
|
|
|
namespace mls {
|
|
|
|
std::string GetCurrentTimeAsString() {
|
|
// Get the current time since epoch
|
|
auto now = std::chrono::system_clock::now();
|
|
auto duration = now.time_since_epoch();
|
|
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(duration).count();
|
|
|
|
// Convert to string
|
|
return std::to_string(seconds);
|
|
}
|
|
|
|
bool FromJson(const json& j, PromptItem& item) {
|
|
if (!j.is_object()) {
|
|
return false;
|
|
}
|
|
if (!j.contains("role") || !j["role"].is_string()) {
|
|
return false;
|
|
}
|
|
if (!j.contains("content") || !j["content"].is_string()) {
|
|
return false;
|
|
}
|
|
|
|
item.first = j["role"].get<std::string>(); // Role
|
|
item.second = j["content"].get<std::string>(); // Content
|
|
return true;
|
|
}
|
|
|
|
std::string trimLeadingWhitespace(const std::string& str) {
|
|
auto it = std::find_if(str.begin(), str.end(), [](unsigned char ch) {
|
|
return !std::isspace(ch); // Find the first non-whitespace character
|
|
});
|
|
return std::string(it, str.end()); // Create a substring from the first non-whitespace character
|
|
}
|
|
|
|
const std::string getR1AssistantString(std::string assistant_content) {
|
|
std::size_t pos = assistant_content.find("</think>");
|
|
if (pos != std::string::npos) {
|
|
assistant_content.erase(0, pos + std::string("</think>").length());
|
|
}
|
|
return trimLeadingWhitespace(assistant_content) + "<|end_of_sentence|>";
|
|
}
|
|
|
|
std::string GetR1UserString(std::string user_content, bool last) {
|
|
return "<|User|>" + std::string(user_content) + "<|Assistant|>";
|
|
}
|
|
|
|
std::vector<PromptItem> ConvertToR1(std::vector<PromptItem> chat_prompts) {
|
|
std::vector<PromptItem> result_prompts = {};
|
|
std::string prompt_result = "";
|
|
result_prompts.emplace_back("system", "<|begin_of_sentence|>You are a helpful assistant.");
|
|
auto iter = chat_prompts.begin();
|
|
for (; iter != chat_prompts.end() - 1; ++iter) {
|
|
if (iter->first == "system") {
|
|
continue;
|
|
} else if (iter->first == "assistant") {
|
|
result_prompts.emplace_back("assistant", getR1AssistantString(iter->second));
|
|
} else if (iter->first == "user") {
|
|
result_prompts.emplace_back("user", GetR1UserString(iter->second, false));
|
|
}
|
|
}
|
|
if (iter->first == "user") {
|
|
result_prompts.emplace_back("user", GetR1UserString(iter->second, true));
|
|
} else {
|
|
result_prompts.emplace_back("assistant", getR1AssistantString(iter->second));
|
|
}
|
|
return result_prompts;
|
|
}
|
|
|
|
void MlsServer::Answer(MNN::Transformer::Llm* llm, const json &messages, std::function<void(const std::string&)> on_result) {
|
|
std::vector<PromptItem> prompts{};
|
|
if (messages.is_array()) {
|
|
for (const auto& item_json : messages) {
|
|
PromptItem item;
|
|
if (!FromJson(item_json, item)) {
|
|
std::cerr << "Error converting JSON object to PromptItem." << std::endl;
|
|
break;
|
|
}
|
|
prompts.push_back(item);
|
|
}
|
|
}
|
|
std::stringstream response_buffer;
|
|
Utf8StreamProcessor processor([&response_buffer, on_result](const std::string& utf8Char) {
|
|
bool is_eop = utf8Char.find("<eop>") != std::string::npos;
|
|
if (!is_eop) {
|
|
response_buffer << utf8Char;
|
|
} else {
|
|
std::string response_result = response_buffer.str();
|
|
on_result(response_result);
|
|
}
|
|
}
|
|
);
|
|
LlmStreamBuffer stream_buffer{[&processor](const char* str, size_t len){
|
|
processor.processStream(str, len);
|
|
}};
|
|
std::ostream output_ostream(&stream_buffer);std::lock_guard<std::mutex> lock(llm_mutex_);
|
|
llm->response(this->is_r1_ ? ConvertToR1(prompts) : prompts, &output_ostream, "<eop>");
|
|
}
|
|
|
|
void MlsServer::AnswerStreaming(MNN::Transformer::Llm* llm,
|
|
const json& messages,
|
|
std::function<void(const std::string&, bool end)> on_partial) {
|
|
std::vector<PromptItem> prompts;
|
|
if (messages.is_array()) {
|
|
for (const auto& item_json : messages) {
|
|
PromptItem item;
|
|
if (!FromJson(item_json, item)) {
|
|
std::cerr << "Error converting JSON object to PromptItem." << std::endl;
|
|
return;
|
|
}
|
|
prompts.push_back(item);
|
|
}
|
|
}
|
|
std::string answer = "";
|
|
Utf8StreamProcessor processor([&on_partial, &answer](const std::string &utf8Char) {
|
|
bool is_eop = (utf8Char.find("<eop>") != std::string::npos);
|
|
if (is_eop) {
|
|
std::string response_result = answer;
|
|
std::cout<<"response result: "<<response_result<<std::endl;
|
|
on_partial("", true);
|
|
} else {
|
|
answer += utf8Char;
|
|
on_partial(utf8Char, false);
|
|
}
|
|
});
|
|
|
|
// LlmStreamBuffer calls our lambda as new bytes arrive from the LLM
|
|
LlmStreamBuffer stream_buffer([&processor](const char* str, size_t len) {
|
|
processor.processStream(str, len);
|
|
});
|
|
std::ostream output_ostream(&stream_buffer);
|
|
std::lock_guard<std::mutex> lock(llm_mutex_);
|
|
llm->response(this->is_r1_ ? ConvertToR1(prompts) : prompts, &output_ostream, "<eop>");
|
|
}
|
|
|
|
|
|
|
|
void AllowCors(httplib::Response& res) {
|
|
res.set_header("Access-Control-Allow-Origin", "*");
|
|
res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
|
|
res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
|
|
}
|
|
|
|
void MlsServer::Start(MNN::Transformer::Llm* llm, bool is_r1) {
|
|
this->is_r1_ = is_r1;
|
|
// Create a server instance
|
|
httplib::Server server;
|
|
|
|
// Define a route for the GET request on "/"
|
|
server.Get("/", [this](const httplib::Request& req, httplib::Response& res) {
|
|
AllowCors(res);
|
|
res.set_content(html_content, "text/html");
|
|
});
|
|
server.Post("/reset", [&](const httplib::Request &req, httplib::Response &res) {
|
|
printf("POST /reset\n");
|
|
AllowCors(res);
|
|
llm->reset();
|
|
res.set_content("{\"status\": \"ok\"}", "application/json");
|
|
});
|
|
server.Options("/chat/completions", [](const httplib::Request& /*req*/, httplib::Response& res) {
|
|
AllowCors(res);
|
|
res.status = 200;
|
|
});
|
|
server.Post("/chat/completions", [&](const httplib::Request &req, httplib::Response &res) {
|
|
std::cout << "POST /chat/completions, handled by thread: "
|
|
<< std::this_thread::get_id() << std::endl;
|
|
AllowCors(res);
|
|
if (!json::accept(req.body)) {
|
|
json err;
|
|
err["error"] = "Invalid JSON in request body.";
|
|
res.status = 400;
|
|
res.set_content(err.dump(), "application/json");
|
|
return;
|
|
}
|
|
json request_json = json::parse(req.body, nullptr, false);
|
|
json messages = request_json["messages"];
|
|
std::cout<<"received messages:"<<messages.dump(0)<<std::endl;
|
|
std::string model = request_json.value("model", "undefined-model");
|
|
bool stream = request_json.value("stream", false);
|
|
if (!stream) {
|
|
Answer(llm, messages, [&res, model](const std::string& answer) {
|
|
json response_json = {
|
|
{"id", "chatcmpl" + GetCurrentTimeAsString()},
|
|
{"object", "chat.completion"},
|
|
{"created", static_cast<int>(time(nullptr))},
|
|
{"model", model},
|
|
{
|
|
"choices", json::array({
|
|
{
|
|
{"index", 0},
|
|
{
|
|
"message", {
|
|
{"role", "assistant"},
|
|
{"content", answer}
|
|
}
|
|
},
|
|
{"finish_reason", "stop"}
|
|
}
|
|
})
|
|
},
|
|
{
|
|
"usage", {
|
|
{"prompt_tokens", 10},
|
|
{"completion_tokens", 7},
|
|
{"total_tokens", 17}
|
|
}
|
|
}
|
|
};
|
|
res.set_content(response_json.dump(), "application/json");
|
|
});
|
|
return;
|
|
}
|
|
res.set_header("Content-Type", "text/event-stream");
|
|
res.set_header("Cache-Control", "no-cache");
|
|
res.set_header("Connection", "keep-alive");
|
|
res.set_chunked_content_provider(
|
|
"text/event-stream",
|
|
[llm, messages, model, this](size_t /*offset*/, httplib::DataSink &sink) {
|
|
auto sse_callback = [&, this](const std::string &partial_text, bool end) {
|
|
std::string finish_reason = end ? "stop" : "";
|
|
json sse_json = {
|
|
{"id", "chatcmpl-" + GetCurrentTimeAsString()},
|
|
{"object", "chat.completion.chunk"},
|
|
{"created", static_cast<int>(std::time(nullptr))},
|
|
{"model", model},
|
|
{"choices", json::array({
|
|
{
|
|
{"delta", {{"content", partial_text}}},
|
|
{"index", 0},
|
|
{"finish_reason", finish_reason}
|
|
}
|
|
})}
|
|
};
|
|
std::string chunk_str = "data: " + sse_json.dump() + "\n\n";
|
|
sink.os.write(chunk_str.c_str(), chunk_str.size());
|
|
sink.os.flush();
|
|
};
|
|
AnswerStreaming(llm, messages, sse_callback);
|
|
std::string done_str = "data: [DONE]\n\n";
|
|
sink.os.write(done_str.c_str(), done_str.size());
|
|
sink.os.flush();
|
|
sink.done();
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
return false;
|
|
}
|
|
);
|
|
});
|
|
// Start the server on port
|
|
std::cout << "Starting server on http://localhost:9090\n";
|
|
if (!server.listen("0.0.0.0", 9090)) {
|
|
std::cerr << "Error: Could not start server.\n";
|
|
}
|
|
}
|
|
} |