add perplexity and llm dataset processing, supports wikitext and shareGPT

This commit is contained in:
hzx 2024-11-12 13:58:59 +08:00
parent a16fc02c0b
commit 028f09a7c9
25 changed files with 969 additions and 313 deletions

3
.gitignore vendored
View File

@ -365,4 +365,5 @@ MNN_compression_pb2.py
model/
# datasets
datasets/
datasets/*
!datasets/*.sh

View File

@ -1,135 +0,0 @@
#include <algorithm>
#include <vector>
#include <cmath>
#include <evaluation/perplexity.hpp>
#include <llm/llm.hpp>
#include <iostream>
#include <iomanip>
#include <sampler/sampler.hpp>
namespace MNN{
namespace Transformer{
void PPLMeasurer::init(Llm* llm, std::vector<std::vector<int>> prompts, int max_len, int stride) {
if (stride == 0) {
// default stride for sliding window.
stride = max_len / 2;
}
mLlm = llm;
mMaxLen = max_len;
mStride = stride;
mPrompts = prompts;
}
PPLMeasurer::PPLMeasurer(Llm* llm, std::vector<std::vector<int>> prompts, int max_len, int stride) {
init(llm, manager, prompts, max_len, stride);
}
PPLMeasurer::PPLMeasurer(Llm* llm, std::vector<std::string> prompts, int max_len, int stride) {
std::vector<std::vector<int>> tokens(prompts.size());
for (int p = 0; p < prompts.size(); ++p) tokens[p] = llm->encode(prompts[p]);
init(llm, manager, tokens, max_len, stride);
}
/* Implemented based on https://huggingface.co/docs/transformers/perplexity
******************** HuggingFace Python Version ************************
import torch
from tqdm import tqdm
max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
# to the left by 1.
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
******************** HuggingFace Python Version ************************
*/
float PPLMeasurer::perplexity_one(const std::vector<int>& prompt) {
int seq_len = prompt.size();
std::vector<float> nlls;
float ppl = 0.f;
// start calculation
int prev_end_loc = 1; // the first token start from id=1, do not count the first one.
for (int begin_loc = 0; begin_loc < seq_len; begin_loc += mStride) {
mStateCacheManager->setCurrentReference(mCandidate);
int end_loc = std::min(begin_loc + mMaxLen, seq_len);
// first token
std::vector<int> tokens(prev_end_loc - begin_loc);
for (int it = begin_loc; it < prev_end_loc; ++it) tokens[it - begin_loc] = prompt[it];
auto logits = mLlm->forward(tokens, true);
logits = MNN::Express::_Softmax(logits);
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]));
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[prev_end_loc]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]) << std::endl;
// decode following tokens
for (int it = prev_end_loc+1; it < end_loc; ++it) {
auto logits = mLlm->forward({prompt[it-1]}, false);
logits = MNN::Express::_Softmax(logits);
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[it]]));
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[it]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[it]]) << std::endl;
}
// clean up once
reset();
mLlm->reset();
prev_end_loc = end_loc;
if (end_loc == seq_len) break;
}
// calculate ppl
for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
ppl /= nlls.size();
ppl = std::exp(ppl);
// print
std::cout << "PPL: " << std::setprecision(9) << ppl << std::endl;
return ppl;
}
std::vector<float> PPLMeasurer::perplexity() {
std::vector<float> ppls;
for (auto prompt : mPrompts) {
ppls.push_back(perplexity_one(prompt));
reset();
mLlm->reset();
}
return ppls;
}
void PPLMeasurer::reset() {
// in the future, only reset its own.
}
void PPLMeasurer::reset(int max_len, int stride) {
mMaxLen = max_len;
mStride = stride;
reset();
}
} // Transformer
} // MNN

View File

@ -1,50 +0,0 @@
#ifndef PERPLEXITY_hpp
#define PERPLEXITY_hpp
#include <vector>
#include <memory>
#include <string>
#include <fstream>
#include <sstream>
#include <iostream>
#include <streambuf>
#include <functional>
#include <unordered_map>
#include <utility>
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/Module.hpp>
#include <MNN/expr/MathOp.hpp>
#include <MNN/expr/NeuralNetWorkOp.hpp>
#include <MNN/StateCacheManager.hpp>
namespace MNN {
namespace Transformer {
class Llm;
class MNN_PUBLIC PPLMeasurer {
protected:
Llm* mLlm;
StateCacheManager* mStateCacheManager;
std::vector<std::vector<int>> mPrompts;
std::shared_ptr<StateCacheReference> mCandidate;
int mStride, mMaxLen;
void init(Llm* llm, StateCacheManager* manager, std::vector<std::vector<int>> prompts, int max_len, int stride);
public:
PPLMeasurer(Llm* llm, StateCacheManager* manager, std::vector<std::vector<int>> prompts, int max_len=2048, int stride=0);
PPLMeasurer(Llm* llm, StateCacheManager* manager, std::vector<std::string> prompts, int max_len=2048, int stride=0);
float perplexity_one(const std::vector<int>& prompt);
std::vector<float> perplexity();
// prepare for another round of sampling
// in the future, only reset its own.
void reset();
void reset(int max_len, int stride);
};
} // Transformer
} // MNN
#endif // SAMPLER_hpp

View File

@ -248,7 +248,7 @@ pc端直接推理
```bash
# 利用adb push将链接库push到手机上
adb shell mkdir /data/local/tmp/llm
adb push chat_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
```
#### GPTQ权重加载

7
transformers/llm/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
datasets/*
!datasets/*.sh
!datasets/visualization/
datasets/visualization/data
datasets/visualization/pic

View File

@ -0,0 +1,2 @@
git lfs install
git clone https://huggingface.co/datasets/shareAI/ShareGPT-Chinese-English-90k

View File

@ -0,0 +1,2 @@
wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
unzip wikitext-2-raw-v1.zip

View File

@ -0,0 +1,116 @@
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from matplotlib import cbook
from matplotlib.axes import Axes
import pandas as pd
import numpy as np
import argparse
import os
vis_root = "pic"
def remove_blanks(df: pd.DataFrame) -> pd.DataFrame:
# Removing unnamed columns using drop function
df.drop(df.columns[df.columns.str.contains(
'unnamed', case=False)], axis=1, inplace=True)
return df
def add_turns(df: pd.DataFrame) -> pd.DataFrame:
df["turns"] = (1-df.isnull()).sum(axis=1) // 2
return df
def get_max_turn(df: pd.DataFrame) -> int:
keys = list(df.keys())
return max([int(key.replace("decode", "")) for key in keys if "decode" in key]) + 1
def add_pd_ratio(df: pd.DataFrame) -> pd.DataFrame:
max_turns = get_max_turn(df)
for i in range(max_turns):
df["pd_ratio{}".format(i)] = df["prefill{}".format(i)] / df["decode{}".format(i)]
return df
def preprocess(file_path: str) -> pd.DataFrame:
table = pd.read_csv(file_path)
table = remove_blanks(table)
table = add_turns(table)
table = add_pd_ratio(table)
print(table)
return table
def draw_distribution(df: pd.DataFrame, file_path: str):
turns_bin = df.value_counts(subset=["turns"], sort=False)
print(turns_bin)
plt.close()
plt.rcParams['font.size'] = 10
_, ax = plt.subplots()
# N is the count in each bin, bins is the lower-limit of the bin
N, bins, patches = ax.hist(df["turns"], bins=get_max_turn(df), density=True, align="left", label=True)
# We'll color code by height, but you could use any scalar
fracs = N / N.max()
# we need to normalize the data to 0..1 for the full range of the colormap
norm = colors.Normalize(fracs.min(), fracs.max())
# Now, we'll loop through our objects and set the color of each accordingly
for thisfrac, thispatch in zip(fracs, patches):
color = plt.cm.viridis(norm(thisfrac))
thispatch.set_facecolor(color)
# Now we format the y-axis to display percentage
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1))
ax.set_xlim((0.5, get_max_turn(df)-0.5))
ax.set_xticks(np.arange(1,get_max_turn(df)+1),np.arange(1,get_max_turn(df)+1),rotation=60, fontsize=9)
ax.set_ylabel("frequency", fontsize=14)
ax.set_xlabel("num of turns", fontsize=14)
plt.savefig(file_path, dpi=600)
plt.close()
def draw_prefill(df: pd.DataFrame, ax: Axes):
stats = [cbook.boxplot_stats(df[df["prefill{}".format(i)].notna()]["prefill{}".format(i)], labels=[i+1])[0]
for i in range(get_max_turn(df))]
print(stats)
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
ax.set_ylim(0,600)
ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
ax.set_ylabel("prefill", fontsize=12, rotation=90)
return
def draw_decode(df: pd.DataFrame, ax: Axes):
stats = [cbook.boxplot_stats(df[df["decode{}".format(i)].notna()]["decode{}".format(i)], labels=[i+1])[0]
for i in range(get_max_turn(df))]
print(stats)
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
ax.set_ylim(0,600)
ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
ax.set_ylabel("decode", fontsize=12, rotation=90)
return
def draw_pd_ratio(df: pd.DataFrame, ax: Axes):
stats = [cbook.boxplot_stats(df[df["pd_ratio{}".format(i)].notna()]["pd_ratio{}".format(i)], labels=[i+1])[0]
for i in range(get_max_turn(df))]
print(stats)
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
ax.plot(np.arange(0,get_max_turn(df)+2), np.ones_like(np.arange(0,get_max_turn(df)+2),dtype=float))
ax.set_xlim(0, get_max_turn(df)+1)
ax.set_ylim(0, 2.)
ax.set_xticks(np.arange(1,get_max_turn(df)), np.arange(1,get_max_turn(df)), rotation=60, fontsize=9)
ax.set_yticks([0,0.5,1,2], [0,0.5,1,2], fontsize=9)
ax.set_xlabel("round", fontsize=12)
ax.set_ylabel("prefill/decode", fontsize=12, rotation=90)
return
def draw_reuse_kv(df: pd.DataFrame, file_path: str):
plt.close()
_, axs = plt.subplots(3,1,sharex="col")
draw_prefill(df, axs[0])
draw_decode(df, axs[1])
draw_pd_ratio(df, axs[2])
plt.savefig(file_path, dpi=1200)
plt.close()
return
def draw_no_reuse_kv():
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="./data")
parser.add_argument("--name", type=str, default="shareGPT_dialog_stats_common_en.csv")
args = parser.parse_args()
file_path = os.path.join(args.root, args.name)
dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_dist.png")
pd_dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_pd_dist.png")
table = preprocess(file_path)
draw_distribution(table, dist_path)
draw_reuse_kv(table, pd_dist_path)

View File

@ -26,8 +26,11 @@ else()
endif()
add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/app/llm_demo.cpp)
add_executable(ppl_demo ${CMAKE_CURRENT_LIST_DIR}/app/ppl_demo.cpp)
IF (NOT MNN_SEP_BUILD)
target_link_libraries(llm_demo ${MNN_DEPS})
target_link_libraries(ppl_demo ${MNN_DEPS})
ELSE ()
target_link_libraries(llm_demo ${MNN_DEPS} llm)
target_link_libraries(ppl_demo ${MNN_DEPS} llm)
ENDIF ()

View File

@ -6,6 +6,7 @@
//
#include "llm/llm.hpp"
#include "evaluation/dataset.hpp"
#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>
#include <MNN/expr/ExecutorScope.hpp>
@ -23,49 +24,6 @@ static void trace_prepare(Llm* llm) {
llm->reset();
}
std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>& lines) {
std::vector<std::vector<std::string>> csv_data;
std::string line;
std::vector<std::string> row;
std::string cell;
bool insideQuotes = false;
bool startCollecting = false;
// content to stream
std::string content = "";
for (auto line : lines) {
content = content + line + "\n";
}
std::istringstream stream(content);
while (stream.peek() != EOF) {
char c = stream.get();
if (c == '"') {
if (insideQuotes && stream.peek() == '"') { // quote
cell += '"';
stream.get(); // skip quote
} else {
insideQuotes = !insideQuotes; // start or end text in quote
}
startCollecting = true;
} else if (c == ',' && !insideQuotes) { // end element, start new element
row.push_back(cell);
cell.clear();
startCollecting = false;
} else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line
row.push_back(cell);
csv_data.push_back(row);
cell.clear();
row.clear();
startCollecting = false;
} else {
cell += c;
startCollecting = true;
}
}
return csv_data;
}
static int benchmark(Llm* llm, const std::vector<std::string>& prompts) {
for (int i = 0; i < prompts.size(); i++) {
const auto& prompt = prompts[i];

View File

@ -0,0 +1,61 @@
//
// ppl_demo.cpp
//
// Created by MNN on 2023/03/24.
// ZhaodeWang
//
#include "llm/llm.hpp"
#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
#include <stdlib.h>
#include <MNN/expr/Executor.hpp>
#include <MNN/expr/ExecutorScope.hpp>
using namespace MNN::Transformer;
static void trace_prepare(Llm* llm) {
MNN_PRINT("Prepare for resize opt Begin\n");
llm->trace(true);
std::ostringstream cacheOs;
llm->response("Hello", &cacheOs);
MNN_PRINT("Prepare for resize opt End\n");
llm->trace(false);
llm->reset();
}
// parse json
static int ppl_eval(Llm* llm, std::string prompt_file, std::ofstream* perfOS) {
std::cout << "prompt file is " << prompt_file << std::endl;
// ppl evaluation
std::vector<float> ppls = llm->perplexity(prompt_file, perfOS);
float mean_ppl = 0.f;
for (int j = 0; j < ppls.size(); ++j) mean_ppl += ppls[j];
mean_ppl /= ppls.size();
std::cout << mean_ppl << std::endl;
return 0;
}
int main(int argc, const char* argv[]) {
if (argc < 3) {
std::cout << "Usage: " << argv[0] << " config.json ppl-prompt.txt [perf.txt]" << std::endl;
return 0;
}
std::string config_path = argv[1];
std::cout << "config path is " << config_path << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(config_path));
{
AUTOTIME;
llm->load();
}
{
AUTOTIME;
trace_prepare(llm.get());
}
std::string prompt_file = argv[2];
std::unique_ptr<std::ofstream> perfOS(nullptr);
if (argc == 4) { perfOS.reset(new std::ofstream(argv[3])); }
return ppl_eval(llm.get(), prompt_file, perfOS.get());
}

View File

@ -0,0 +1,33 @@
#ifndef LLM_DATASET_hpp
#define LLM_DATASET_hpp
#include <vector>
#include <string>
#include <iostream>
#include <sstream>
#include <fstream>
#include <rapidjson/document.h>
#include <rapidjson/writer.h>
#include <rapidjson/stringbuffer.h>
#include "llm/llm.hpp"
#include <MNN/MNNDefine.h>
namespace MNN {
namespace Transformer {
// parse csv
MNN_PUBLIC std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>& lines);
void parse_jsonl(std::string prompt_file, std::vector<std::vector<std::vector<PromptItem>>>& dialogs);
std::string getPPLType(std::string dataset_name);
std::vector<std::string> rowsplit(std::string prompt_file);
std::vector<std::string> plaintext(std::string prompt_file);
std::vector<std::string> wikitext(std::string prompt_file);
std::vector<std::vector<std::vector<PromptItem>>> shareGPT(std::string prompt_file, int sample_size=-1); // -1: no sampling
} // Transformer
} // MNN
#endif // LLM_DATASET_hpp

View File

@ -41,6 +41,8 @@ struct TimePerformance {
std::vector<int> prompt_record_;
};
void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv);
struct PrefillMemPerformance {
size_t prefill_prev_token_ = 0;
size_t prefill_token_ = 0;

View File

@ -63,6 +63,10 @@ public:
};
class MNN_PUBLIC Llm {
public:
std::shared_ptr<Sampler> mSampler;
std::shared_ptr<PromptLib> mPromptLib;
std::vector<LlmSessionInfo> mLlmSessionInfos; // Llm conversation session information. Currently, only mLlmSessionInfos[0] is allowed!
public:
Llm(std::shared_ptr<LlmConfig> config) : config_(config) {}
virtual ~Llm();
@ -81,6 +85,7 @@ public:
std::string generateTrace(const std::vector<int>& input_ids, std::ostream* os, const char* end_with);
void print_speed();
void print_speed(std::ostream* os);
std::vector<float> perplexity(std::string prompt_file, std::ostream* statsOS = nullptr);
// config function
std::string dump_config();
bool set_config(const std::string& content);
@ -91,7 +96,6 @@ public:
bool select_module(size_t index);
friend class Pipeline;
public:
std::vector<LlmSessionInfo> mLlmSessionInfos; // currently, only mLlmSessionInfos[0] is allowed!
bool is_single_ = true;
bool attention_fused_ = true;
virtual std::vector<int> tokenizer(const std::string& query);
@ -109,8 +113,6 @@ public:
protected:
std::shared_ptr<LlmConfig> config_;
std::shared_ptr<Tokenizer> tokenizer_;
std::shared_ptr<Sampler> mSampler;
std::shared_ptr<PromptLib> mPromptLib;
std::vector<int> key_value_shape_ = {};
std::vector<MNN::Express::VARP> past_key_values_;
MNN::Express::VARP inputs_embeds_, attention_mask_, position_ids_;

View File

@ -0,0 +1,83 @@
#include "llm/llm.hpp"
namespace MNN {
namespace Transformer {
// LlmSessionInfo starts
void LlmSessionInfo::resetSamplerFields() {
all_seq_len_ = 0;
gen_seq_len_ = 0;
tokens.clear();
}
void LlmSessionInfo::resetPromptFields() {
mHistory.clear();
mInputs.clear();
}
void LlmSessionInfo::resetPerformanceFields() {
clearPerformance(&mTimePerformance);
}
float LlmSessionInfo::average_total_speed() {
return (getTotalPromptLen()+getTotalDecodeLen())/(getTotalPrefillTime()+getTotalDecodeTime());
}
float LlmSessionInfo::average_prefill_speed() {
// prefill response rate
return getTotalPromptLen()/getTotalPrefillTime();
}
float LlmSessionInfo::average_decode_speed() {
return getTotalDecodeLen()/getTotalDecodeTime();
}
float LlmSessionInfo::getTotalPrefillTime() {
float sum = 0.f;
for (auto record : mTimePerformance.prefill_record_) {
sum += ((float)record.prefill_us_)*MICRO_TO_SEC;
}
return sum;
}
float LlmSessionInfo::getTotalDecodeTime() {
float sum = 0.0f;
for (auto record : mTimePerformance.decode_record_) {
sum += ((float)record.decode_us_)*MICRO_TO_SEC;
}
return sum;
}
int LlmSessionInfo::getTotalPromptLen() {
int prompt_len = 0;
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
for (auto record : mTimePerformance.prefill_record_) {
prompt_len += record.prefill_token_;
}
} else {
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
prompt_len += mTimePerformance.prompt_record_[r];
}
}
return prompt_len;
}
int LlmSessionInfo::getTotalDecodeLen() {
return mTimePerformance.decode_record_.size();
}
void LlmSessionInfo::print_speed(std::ostream* os) {
(*os) << "prefill " << mTimePerformance.prefill_record_.size() << std::endl;
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
(*os) << "prev_token input_token speed(token/s)" << std::endl;
for (auto record : mTimePerformance.prefill_record_) {
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
}
} else {
(*os) << "prev_token input_token prompt_token response_speed(token/s)" << std::endl;
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
auto record = mTimePerformance.prefill_record_[r];
auto prompt_len = mTimePerformance.prompt_record_[r];
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << prompt_len << " " << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
}
}
(*os) << "decode " << mTimePerformance.decode_record_.size() << std::endl;
(*os) << "prev_token speed(token/s)" << std::endl;
for (auto record : mTimePerformance.decode_record_) {
(*os) << record.decode_prev_token_ << " " << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl;
}
}
} // Transformer
} // MNN

View File

@ -0,0 +1,221 @@
#include <algorithm>
#include <vector>
#include <cmath>
#include <llm/llm.hpp>
#include <iostream>
#include <fstream>
#include <iomanip>
#include <string>
#include <iterator>
#include <random>
#include "evaluation/dataset.hpp"
#include <rapidjson/document.h>
#include <rapidjson/writer.h>
#include <rapidjson/stringbuffer.h>
namespace MNN {
namespace Transformer {
// parse file
// csv json
// parse csv
std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>& lines) {
std::vector<std::vector<std::string>> csv_data;
std::string line;
std::vector<std::string> row;
std::string cell;
bool insideQuotes = false;
bool startCollecting = false;
// content to stream
std::string content = "";
for (auto line : lines) {
content = content + line + "\n";
}
std::istringstream stream(content);
while (stream.peek() != EOF) {
char c = stream.get();
if (c == '"') {
if (insideQuotes && stream.peek() == '"') { // quote
cell += '"';
stream.get(); // skip quote
} else {
insideQuotes = !insideQuotes; // start or end text in quote
}
startCollecting = true;
} else if (c == ',' && !insideQuotes) { // end element, start new element
row.push_back(cell);
cell.clear();
startCollecting = false;
} else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line
row.push_back(cell);
csv_data.push_back(row);
cell.clear();
row.clear();
startCollecting = false;
} else {
cell += c;
startCollecting = true;
}
}
return csv_data;
}
// dialog, turn,
void parse_jsonl(std::string prompt_file, std::vector<std::vector<std::vector<PromptItem>>>& dialogs) {
std::ifstream prompt_fs(prompt_file);
std::string prompt;
while(std::getline(prompt_fs, prompt)) {
rapidjson::Document document;
document.Parse(prompt.c_str());
std::vector<std::vector<PromptItem>> cnv;
if(document.HasMember("conversation")) {
auto& value = document["conversation"];
if (value.IsArray()) {
for (auto& v : value.GetArray()) {
if (v.IsObject()) {
std::vector<PromptItem> result;
for (auto itr = v.MemberBegin(); itr != v.MemberEnd(); ++itr) {
// {"human"/"user": , "assistant": }
result.push_back(std::make_pair(itr->name.GetString(), itr->value.GetString()));
}
cnv.push_back(result);
}
}
}
}
dialogs.push_back(cnv);
}
}
void write_jsonl(std::string prompt_file, const std::vector<std::vector<std::vector<PromptItem>>>& dialogs) {
std::ofstream prompt_fs(prompt_file);
for(auto& dialog : dialogs) {
rapidjson::Document document;
document.SetObject();
rapidjson::Value conversation(rapidjson::kArrayType);
conversation.SetArray();
for (auto& turn : dialog) {
rapidjson::Value sentence(rapidjson::kObjectType);
sentence.SetObject();
for (auto& role : turn) {
sentence.AddMember(rapidjson::Value(role.first.c_str(), document.GetAllocator()),
rapidjson::Value(role.second.c_str(), document.GetAllocator()), document.GetAllocator());
}
conversation.PushBack(sentence, document.GetAllocator());
}
document.AddMember("conversation", conversation, document.GetAllocator());
// write to file
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
document.Accept(writer);
prompt_fs << buffer.GetString() << std::endl;
}
}
// dataset
// wikitext, ShareGPT
std::string getPPLType(std::string dataset_name) {
if (dataset_name == "wikitext"
|| dataset_name == "plaintext"
|| dataset_name == "rowsplit") {
return "text";
} else if (dataset_name == "shareGPT") {
return "chat";
} else {
// default chat
return "chat";
}
}
std::vector<std::string> plaintext(std::string prompt_file) {
// split by line
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
prompts.push_back("");
while (std::getline(prompt_fs, prompt)) {
if (prompt.back() == '\r' || prompt.back() == '\n') {
prompt.pop_back();
}
// concatenate.
prompts.back() += prompt + "\n";
}
return prompts;
}
std::vector<std::string> rowsplit(std::string prompt_file) {
// split by line
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
if (prompt.back() == '\r' || prompt.back() == '\n') {
prompt.pop_back();
}
prompts.push_back(prompt);
}
return prompts;
}
// wikitext
void removeSubstrs(std::string& s, std::string p) {
std::string::size_type n = p.length();
for (std::string::size_type i = s.find(p); i != std::string::npos; i = s.find(p))
s.erase(i, n);
}
std::vector<std::string> wikitext(std::string prompt_file) {
// split wiki text into " = " first-level column.
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
if (prompt.back() == '\r' || prompt.back() == '\n') {
prompt.pop_back();
}
if (prompt.size() < 4) continue;
removeSubstrs(prompt, "@-@");
if ((prompts.size() == 0) \
|| (prompt.size() >= 4 \
&& prompt.at(0) == ' ' \
&& prompt.at(1) == '=' \
&& prompt.at(2) == ' ' \
&& prompt.at(3) != '=')) {
// first-level column.
prompts.push_back(prompt);
} else {
// concatenate.
prompts.back() += "\n" + prompt;
}
}
return prompts;
}
std::string genSampleName(std::string oriName, int sample_size) {
const size_t last_slash_idx = oriName.rfind('.');
auto stem = oriName.substr(0, last_slash_idx);
return stem + "_sample" + std::to_string(sample_size) + ".jsonl";
}
std::vector<std::vector<std::vector<PromptItem>>> shareGPT(std::string prompt_file, int sample_size) {
std::vector<std::vector<std::vector<PromptItem>>> dialogs, dataset;
parse_jsonl(prompt_file, dialogs);
// randomly sample a subset
if (sample_size > 0 && sample_size < dialogs.size()){
std::sample(dialogs.begin(), dialogs.end(), std::back_inserter(dataset),
sample_size, std::mt19937 {std::random_device{}()});
dialogs = dataset;
// store dialogs to file
write_jsonl(genSampleName(prompt_file, sample_size), dialogs);
}
return dialogs;
}
} // Transformer
} // MNN

View File

@ -12,6 +12,19 @@ void clearPerformance(struct TimePerformance* perf) {
perf->decode_record_.clear();
perf->prompt_record_.clear();
}
void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv) {
if (reuse_kv) {
perf->prompt_record_.push_back(input_len);
} else {
// not reuse kv
if (!perf->decode_record_.empty()) {
perf->prompt_record_.push_back(input_len - (perf->decode_record_.back().decode_prev_token_+1));
} else {
// first prefill
perf->prompt_record_.push_back(input_len);
}
}
}
} // Transformer
} // MNN

View File

@ -33,82 +33,6 @@ using namespace MNN::Express;
namespace MNN {
namespace Transformer {
// LlmSessionInfo starts
void LlmSessionInfo::resetSamplerFields() {
all_seq_len_ = 0;
gen_seq_len_ = 0;
tokens.clear();
}
void LlmSessionInfo::resetPromptFields() {
mHistory.clear();
mInputs.clear();
}
void LlmSessionInfo::resetPerformanceFields() {
clearPerformance(&mTimePerformance);
}
float LlmSessionInfo::average_total_speed() {
return (getTotalPromptLen()+getTotalDecodeLen())/(getTotalPrefillTime()+getTotalDecodeTime());
}
float LlmSessionInfo::average_prefill_speed() {
// prefill response rate
return getTotalPromptLen()/getTotalPrefillTime();
}
float LlmSessionInfo::average_decode_speed() {
return getTotalDecodeLen()/getTotalDecodeTime();
}
float LlmSessionInfo::getTotalPrefillTime() {
float sum = 0.f;
for (auto record : mTimePerformance.prefill_record_) {
sum += ((float)record.prefill_us_)*MICRO_TO_SEC;
}
return sum;
}
float LlmSessionInfo::getTotalDecodeTime() {
float sum = 0.0f;
for (auto record : mTimePerformance.decode_record_) {
sum += ((float)record.decode_us_)*MICRO_TO_SEC;
}
return sum;
}
int LlmSessionInfo::getTotalPromptLen() {
int prompt_len = 0;
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
for (auto record : mTimePerformance.prefill_record_) {
prompt_len += record.prefill_token_;
}
} else {
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
prompt_len += mTimePerformance.prompt_record_[r];
}
}
return prompt_len;
}
int LlmSessionInfo::getTotalDecodeLen() {
return mTimePerformance.decode_record_.size();
}
void LlmSessionInfo::print_speed(std::ostream* os) {
(*os) << "prefill " << mTimePerformance.prefill_record_.size() << std::endl;
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
(*os) << "prev_token input_token speed(token/s)" << std::endl;
for (auto record : mTimePerformance.prefill_record_) {
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
}
} else {
(*os) << "prev_token input_token prompt_token response_speed(token/s)" << std::endl;
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
auto record = mTimePerformance.prefill_record_[r];
auto prompt_len = mTimePerformance.prompt_record_[r];
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << prompt_len << " " << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
}
}
(*os) << "decode " << mTimePerformance.decode_record_.size() << std::endl;
(*os) << "prev_token speed(token/s)" << std::endl;
for (auto record : mTimePerformance.decode_record_) {
(*os) << record.decode_prev_token_ << " " << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl;
}
}
// Lvlm starts
class Lvlm : public Llm {
@ -408,7 +332,6 @@ void Llm::chat(bool session_by_line, bool from_file,
}
std::string Llm::response(const std::string& user_str, std::ostream* os, const char* end_with) {
mLlmSessionInfos[0].mTimePerformance.prompt_record_.push_back(tokenizer(user_str).size());
mPromptLib->appendUserPrompt(user_str);
auto assistant_str = generate(mPromptLib->getLLMInput(), os, end_with);
mPromptLib->appendLLMOutput(assistant_str);
@ -487,6 +410,13 @@ std::vector<int> Llm::tokenizer(const std::string& prompt) {
// generate >
// < evaluation
std::vector<float> Llm::perplexity(std::string prompt_file, std::ostream* perfOS) {
return mSampler->perplexity(prompt_file, perfOS);
}
// evaluation >
Llm::~Llm() {
#if DEBUG_MODE==1
if (nullptr != gTimeTraceInfo) {

View File

@ -12,7 +12,7 @@
#include <iostream>
#include <sstream>
#include <fstream>
#include "rapidjson/document.h"
#include <rapidjson/document.h>
#include <rapidjson/writer.h>
#include <rapidjson/stringbuffer.h>
@ -394,6 +394,18 @@ public:
return config_.value("system_prompt", "You are a helpful assistant!\n");
}
// app config end >
// < evaulation config start
int ppl_stride() const {
return config_.value("ppl_stride", 0);
}
std::string dataset() const {
return config_.value("dataset", "wikitext");
}
int dataset_sample_size() const {
return config_.value("dataset_sample_size", -1); // -1 stands for no sampling, use all.
}
// evaulation config end
};
} // Transformer
} // MNN

View File

@ -0,0 +1,321 @@
#include <algorithm>
#include <vector>
#include <cmath>
#include <llm/llm.hpp>
#include <iostream>
#include <iomanip>
#include "sampler.hpp"
#include "perplexity.hpp"
#include "llmconfig.hpp"
#include "prompt.hpp"
namespace MNN{
namespace Transformer{
/* -----------TextPPLMeasurer---------- */
TextPPLMeasurer::TextPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> llmConfig) {
mLlm = llm;
mConfig.max_all_tokens = llmConfig->max_all_tokens();
mConfig.max_new_tokens = llmConfig->max_new_tokens();
mDatasetType = llmConfig->dataset();
mStride = llmConfig->ppl_stride();
if (mStride == 0) {
// default stride for sliding window.
mStride = mConfig.max_all_tokens / 2;
}
}
/* Implemented based on https://huggingface.co/docs/transformers/perplexity
******************** HuggingFace Python Version ************************
import torch
from tqdm import tqdm
max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
# to the left by 1.
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
******************** HuggingFace Python Version ************************
*/
float TextPPLMeasurer::perplexity_one(const std::vector<int>& prompt) {
int seq_len = prompt.size();
std::vector<float> nlls;
float ppl = 0.f;
// start calculation
int prev_end_loc = 1; // the first token start from id=1, do not count the first one.
for (int begin_loc = 0; begin_loc < seq_len; begin_loc += mStride) {
int end_loc = std::min(begin_loc + mConfig.max_all_tokens, seq_len);
// first token
std::vector<int> tokens(prev_end_loc - begin_loc);
for (int it = begin_loc; it < prev_end_loc; ++it) tokens[it - begin_loc] = prompt[it];
mLlm->mLlmSessionInfos[0].all_seq_len_ = tokens.size();
mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_;
auto logits = mLlm->forward(tokens, mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, true);
logits = MNN::Express::_Softmax(logits);
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]));
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[prev_end_loc]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]) << std::endl;
std::cout << -std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]) << std::endl;
// decode following tokens
for (int it = prev_end_loc+1; it < end_loc; ++it) {
mLlm->mLlmSessionInfos[0].all_seq_len_ += 1;
mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_;
auto logits = mLlm->forward({prompt[it-1]},mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, false);
logits = MNN::Express::_Softmax(logits);
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[it]]));
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[it]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[it]]) << std::endl;
std::cout << -std::log(((float*)(logits->readMap<float>()))[prompt[it]]) << std::endl;
}
// clean up once
mLlm->reset();
prev_end_loc = end_loc;
if (end_loc == seq_len) break;
}
// calculate ppl
for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
ppl /= nlls.size();
ppl = std::exp(ppl);
// print
std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl;
return ppl;
}
std::vector<float> TextPPLMeasurer::perplexity(std::vector<std::vector<int>> prompts) {
std::vector<float> ppls;
for (auto prompt : prompts) {
ppls.push_back(perplexity_one(prompt));
mLlm->reset();
}
return ppls;
}
std::vector<float> TextPPLMeasurer::perplexity(std::vector<std::string> prompts) {
std::vector<std::vector<int>> tokens(prompts.size());
for (int p = 0; p < prompts.size(); ++p) tokens[p] = mLlm->tokenizer(prompts[p]);
return perplexity(tokens);
}
std::vector<float> TextPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) {
// No performance will be printed!
std::vector<std::string> prompts;
if (mDatasetType == "wikitext") {
prompts = wikitext(prompt_file);
}
else if (mDatasetType == "plaintext") {
prompts = plaintext(prompt_file);
}
else if (mDatasetType == "rowsplit") {
prompts = rowsplit(prompt_file);
}
else {
MNN_ERROR("Dataset not suppoted");
exit(1);
}
std::cout << "prompt file loaded!" << std::endl;
return perplexity(prompts);
}
/* -----------ChatPPLMeasurer---------- */
ChatPPLMeasurer::ChatPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> llmConfig) {
mLlm = llm;
mConfig.max_all_tokens = llmConfig->max_all_tokens();
mConfig.max_new_tokens = llmConfig->max_new_tokens();
mDatasetType = llmConfig->dataset();
mDatasetSampleSize = llmConfig->dataset_sample_size();
}
void ChatPPLMeasurer::handleToken(int token) {
// CommonPrefix and Candidates managements
mLlm->mLlmSessionInfos[0].tokens.push_back(token);
mLlm->mLlmSessionInfos[0].all_seq_len_++;
mLlm->mLlmSessionInfos[0].gen_seq_len_++;
}
std::vector<float> ChatPPLMeasurer::sample(const std::vector<int>& input_ids, const std::vector<int>& prompt, struct TimePerformance* time_perf) {
std::vector<float> nlls;
// initialization for time performance
PrefillTimePerformance prefill_time;
prefill_time.prefill_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
prefill_time.prefill_token_ = input_ids.size();
appendNewPromptRecord(time_perf, input_ids.size(), mLlm->reuse_kv());
// initialization
mLlm->mLlmSessionInfos[0].tokens.insert(mLlm->mLlmSessionInfos[0].tokens.end(), input_ids.begin(), input_ids.end());
// all_seq_len_ in sampler functions as kv_seq_len_, prev_seq_len_ = all_seq_len_ - seq_len
mLlm->mLlmSessionInfos[0].all_seq_len_ = mLlm->mLlmSessionInfos[0].tokens.size();
mLlm->mLlmSessionInfos[0].gen_seq_len_ = 0;
// prefill
auto st = std::chrono::system_clock::now();
auto logits = mLlm->forward(input_ids, mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, true);
logits = MNN::Express::_Softmax(logits);
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]]));
// record time
auto et = std::chrono::system_clock::now();
prefill_time.prefill_us_ = std::chrono::duration_cast<std::chrono::microseconds>(et - st).count();
time_perf->prefill_record_.push_back(prefill_time);
// handle the new token
handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]);
// decode
while (mLlm->mLlmSessionInfos[0].gen_seq_len_ < prompt.size()) {
DecodeTimePerformance decode_time;
decode_time.decode_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
st = std::chrono::system_clock::now();
// next token
logits = mLlm->forward({mLlm->mLlmSessionInfos[0].tokens.back()}, mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, false);
logits = MNN::Express::_Softmax(logits);
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]]));
et = std::chrono::system_clock::now();
decode_time.decode_us_ = std::chrono::duration_cast<std::chrono::microseconds>(et - st).count();
time_perf->decode_record_.push_back(decode_time);
handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]);
}
// return nlls
return nlls;
}
float ChatPPLMeasurer::perplexity_one(const std::vector<std::vector<PromptItem>>& prompt, std::ostream* perfOS) {
// (turns, roles)
std::vector<float> nlls;
float ppl = 0.f;
// < simulated chat
mLlm->reset();
for (auto& turn : prompt) {
mLlm->mPromptLib->appendUserPrompt(turn[0].second);
std::vector<int> input_ids = mLlm->tokenizer(mLlm->mPromptLib->getLLMInput());
mLlm->generate_init();
auto turn_nlls = sample(input_ids, mLlm->tokenizer(turn[1].second), &(mLlm->mLlmSessionInfos[0].mTimePerformance));
nlls.insert(nlls.end(), turn_nlls.begin(), turn_nlls.end());
mLlm->mPromptLib->appendLLMOutput(turn[1].second);
}
// record time performance to file
if (perfOS != nullptr) {
(*perfOS) << "<chat>" << std::endl;
mLlm->mLlmSessionInfos[0].print_speed(perfOS);
}
mLlm->reset();
// simulated chat >
// calculate ppl
for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
ppl /= nlls.size();
ppl = std::exp(ppl);
// print
std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl;
return ppl;
}
std::vector<float> ChatPPLMeasurer::perplexity(const std::vector<std::vector<std::vector<PromptItem>>>& prompts, std::ostream* perfOS) {
std::vector<float> ppls;
for (auto& prompt : prompts) {
ppls.push_back(perplexity_one(prompt, perfOS));
mLlm->reset();
}
return ppls;
}
void ChatPPLMeasurer::getStats(const std::vector<std::vector<std::vector<PromptItem>>>& prompts) {
std::ofstream total_stats("total_stats.csv");
std::ofstream dialog_stats("dialog_stats.csv");
float average_turns=0, average_prefill=0, average_decode=0, average_total_tokens=0;
int max_turns=0;
std::vector<std::vector<std::vector<int>>> stats; // (dialog, turn, (prefill, decode))
std::cout << prompts.size() << std::endl;
int counter = 0;
for (auto& dialog : prompts) {
std::vector<std::vector<int>> dialog_stats;
if ((counter++) % std::max((int)prompts.size()/200, 1) == 0) std::cout << "*" << std::flush;
float prefill_len_turn = 0;
float decode_len_turn = 0;
for (auto& turn : dialog) {
// turn: prefill, decode
int prefill_len = mLlm->tokenizer(turn[0].second).size();
int decode_len = mLlm->tokenizer(turn[1].second).size();
prefill_len_turn += prefill_len;
decode_len_turn += decode_len;
average_total_tokens += prefill_len + decode_len;
dialog_stats.push_back({prefill_len, decode_len});
}
stats.push_back(dialog_stats);
average_prefill += prefill_len_turn / dialog.size(); // average over turns
average_decode += decode_len_turn / dialog.size(); // average over turns
average_turns += dialog.size();
max_turns = std::max(max_turns, (int)dialog.size());
}
average_turns /= prompts.size();
average_prefill /= prompts.size();
average_decode /= prompts.size();
average_total_tokens /= prompts.size();
total_stats << "total_dialogs," << "max_turns," << "avg_turns," \
<< "avg_prefill_tokens/turn," << "avg_decode_tokens/turn," \
<< "avg_total_tokens/dialog" << std::endl;
total_stats << prompts.size() << "," << max_turns << "," << average_turns << "," \
<< average_prefill << "," << average_decode << "," \
<< average_total_tokens << std::endl;
for (int i=0; i<max_turns; ++i) dialog_stats << "prefill" << i << "," << "decode" << i << ","; // this creates an extra blank column at the end.
dialog_stats << std::endl;
for (auto& dialog : stats) {
for (auto& turn : dialog){
dialog_stats << turn[0] << "," << turn[1] << ",";
}
for (int i=dialog.size(); i<max_turns; ++i) {
dialog_stats << ",,";
}
dialog_stats << std::endl;
}
}
std::vector<float> ChatPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) {
// No performance will be printed!
std::vector<std::vector<std::vector<PromptItem>>> prompts;
if (mDatasetType == "shareGPT") {
prompts = shareGPT(prompt_file, mDatasetSampleSize);
}
else {
MNN_ERROR("Dataset not suppoted");
exit(1);
}
std::cout << "prompt file loaded!" << std::endl;
getStats(prompts);
std::cout << "\nshareGPT statistics counted!" << std::endl;
return perplexity(prompts, perfOS);
}
} // Transformer
} // MNN

View File

@ -0,0 +1,65 @@
#ifndef PERPLEXITY_hpp
#define PERPLEXITY_hpp
#include <vector>
#include <memory>
#include <string>
#include <fstream>
#include <sstream>
#include <iostream>
#include <streambuf>
#include <functional>
#include <unordered_map>
#include <utility>
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/Module.hpp>
#include <MNN/expr/MathOp.hpp>
#include <MNN/expr/NeuralNetWorkOp.hpp>
#include "sampler.hpp"
#include "evaluation/dataset.hpp"
namespace MNN {
namespace Transformer {
class Llm;
class MNN_PUBLIC TextPPLMeasurer : public Sampler {
protected:
Llm* mLlm;
int mStride;
std::string mDatasetType;
LlmSamplerConfig mConfig;
public:
TextPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> config);
float perplexity_one(const std::vector<int>& prompt);
std::vector<float> perplexity(std::vector<std::vector<int>> prompts);
std::vector<float> perplexity(std::vector<std::string> prompts);
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; }
virtual std::vector<float> perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override;
};
class MNN_PUBLIC ChatPPLMeasurer : public Sampler {
protected:
Llm* mLlm;
std::string mDatasetType;
int mDatasetSampleSize;
LlmSamplerConfig mConfig;
void handleToken(int token);
std::vector<float> sample(const std::vector<int>& input_ids, const std::vector<int>& prompt, struct TimePerformance* time_perf);
public:
ChatPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> config);
void getStats(const std::vector<std::vector<std::vector<PromptItem>>>& prompts);
float perplexity_one(const std::vector<std::vector<PromptItem>>& prompt, std::ostream* perfOS);
std::vector<float> perplexity(const std::vector<std::vector<std::vector<PromptItem>>>& prompts, std::ostream* perfOS);
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; }
virtual std::vector<float> perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override;
};
} // Transformer
} // MNN
#endif // SAMPLER_hpp

View File

@ -8,7 +8,7 @@ PromptLib* PromptLib::createPromptLib(Llm* llm, const std::string& config_path)
return createPromptLib(llm, std::shared_ptr<LlmConfig>(new LlmConfig(config_path)));
}
PromptLib* PromptLib::createPromptLib(Llm* llm, std::shared_ptr<LlmConfig> config) {
if (config->app_type() == "chat") {
if (config->app_type() == "chat" || config->app_type() == "perplexity") {
return new BaseChatPromptLib(llm, config);
} else {
std::cout << "PromptLib not Implemented!\n" << std::endl;

View File

@ -6,8 +6,11 @@
#include <MNN/expr/Executor.hpp>
#include <MNN/expr/ExecutorScope.hpp>
#include "llm/llm.hpp"
#include "evaluation/dataset.hpp"
#include "sampler.hpp"
#include "perplexity.hpp"
#include "llmconfig.hpp"
namespace MNN{
@ -107,6 +110,10 @@ Sampler* Sampler::createSampler(Llm* llm, std::shared_ptr<LlmConfig> config) {
|| sampler_type == "mixed"
) {
return new LocalSampler(llm, config);
} else if (config->app_type() == "perplexity") {
std::string ppl_type = getPPLType(config->dataset());
if (ppl_type == "text") { return new TextPPLMeasurer(llm, config); }
else if (ppl_type == "chat") { return new ChatPPLMeasurer(llm, config); }
} else {
std::cout << "Designated Sampler Not Supported yet!";
exit(1);
@ -489,6 +496,7 @@ std::string LocalSampler::sample(const std::vector<int>& input_ids, std::ostream
PrefillTimePerformance prefill_time;
prefill_time.prefill_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
prefill_time.prefill_token_ = input_ids.size();
appendNewPromptRecord(time_perf, input_ids.size(), mLlm->reuse_kv());
// initialization
std::string output_str;
mLlm->mLlmSessionInfos[0].tokens.insert(mLlm->mLlmSessionInfos[0].tokens.end(), input_ids.begin(), input_ids.end());

View File

@ -73,6 +73,7 @@ public:
static Sampler* createSampler(Llm* llm, const std::string& config_path);
static Sampler* createSampler(Llm* llm, std::shared_ptr<LlmConfig> config);
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) = 0;
virtual std::vector<float> perplexity(std::string prompt_file, std::ostream* perfOS) { return std::vector<float>(); }
// prepare for another round of sampling
// in the future, only reset its own.
virtual void reset(Llm* llm) { mLlm = llm; }

View File

@ -475,7 +475,7 @@ void Tiktoken::encode(const std::string& str, std::vector<int>& ids) {
} else {
// If no matching symbol is found, this typically means an error in the encoding
// or the input text contains characters that the encoder doesn't know how to handle
std::cerr << "Error: No encoding found for the sequence starting at position " << i << std::endl;
std::cerr << "Error: No encoding found for the sequence starting at position " << i << " , symbol: " << str[i-2] << std::endl;
return;
}
}