mirror of https://github.com/alibaba/MNN.git
add perplexity and llm dataset processing, supports wikitext and shareGPT
This commit is contained in:
parent
a16fc02c0b
commit
028f09a7c9
|
@ -365,4 +365,5 @@ MNN_compression_pb2.py
|
|||
model/
|
||||
|
||||
# datasets
|
||||
datasets/
|
||||
datasets/*
|
||||
!datasets/*.sh
|
|
@ -1,135 +0,0 @@
|
|||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <evaluation/perplexity.hpp>
|
||||
#include <llm/llm.hpp>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <sampler/sampler.hpp>
|
||||
|
||||
namespace MNN{
|
||||
namespace Transformer{
|
||||
|
||||
void PPLMeasurer::init(Llm* llm, std::vector<std::vector<int>> prompts, int max_len, int stride) {
|
||||
if (stride == 0) {
|
||||
// default stride for sliding window.
|
||||
stride = max_len / 2;
|
||||
}
|
||||
mLlm = llm;
|
||||
mMaxLen = max_len;
|
||||
mStride = stride;
|
||||
mPrompts = prompts;
|
||||
}
|
||||
|
||||
PPLMeasurer::PPLMeasurer(Llm* llm, std::vector<std::vector<int>> prompts, int max_len, int stride) {
|
||||
init(llm, manager, prompts, max_len, stride);
|
||||
}
|
||||
|
||||
PPLMeasurer::PPLMeasurer(Llm* llm, std::vector<std::string> prompts, int max_len, int stride) {
|
||||
std::vector<std::vector<int>> tokens(prompts.size());
|
||||
for (int p = 0; p < prompts.size(); ++p) tokens[p] = llm->encode(prompts[p]);
|
||||
init(llm, manager, tokens, max_len, stride);
|
||||
}
|
||||
|
||||
/* Implemented based on https://huggingface.co/docs/transformers/perplexity
|
||||
|
||||
******************** HuggingFace Python Version ************************
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
max_length = model.config.n_positions
|
||||
stride = 512
|
||||
seq_len = encodings.input_ids.size(1)
|
||||
|
||||
nlls = []
|
||||
prev_end_loc = 0
|
||||
for begin_loc in tqdm(range(0, seq_len, stride)):
|
||||
end_loc = min(begin_loc + max_length, seq_len)
|
||||
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
||||
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, labels=target_ids)
|
||||
|
||||
# loss is calculated using CrossEntropyLoss which averages over valid labels
|
||||
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
|
||||
# to the left by 1.
|
||||
neg_log_likelihood = outputs.loss
|
||||
|
||||
nlls.append(neg_log_likelihood)
|
||||
|
||||
prev_end_loc = end_loc
|
||||
if end_loc == seq_len:
|
||||
break
|
||||
|
||||
ppl = torch.exp(torch.stack(nlls).mean())
|
||||
|
||||
******************** HuggingFace Python Version ************************
|
||||
*/
|
||||
|
||||
float PPLMeasurer::perplexity_one(const std::vector<int>& prompt) {
|
||||
int seq_len = prompt.size();
|
||||
std::vector<float> nlls;
|
||||
float ppl = 0.f;
|
||||
|
||||
// start calculation
|
||||
int prev_end_loc = 1; // the first token start from id=1, do not count the first one.
|
||||
for (int begin_loc = 0; begin_loc < seq_len; begin_loc += mStride) {
|
||||
mStateCacheManager->setCurrentReference(mCandidate);
|
||||
int end_loc = std::min(begin_loc + mMaxLen, seq_len);
|
||||
// first token
|
||||
std::vector<int> tokens(prev_end_loc - begin_loc);
|
||||
for (int it = begin_loc; it < prev_end_loc; ++it) tokens[it - begin_loc] = prompt[it];
|
||||
auto logits = mLlm->forward(tokens, true);
|
||||
logits = MNN::Express::_Softmax(logits);
|
||||
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]));
|
||||
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[prev_end_loc]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]) << std::endl;
|
||||
// decode following tokens
|
||||
for (int it = prev_end_loc+1; it < end_loc; ++it) {
|
||||
auto logits = mLlm->forward({prompt[it-1]}, false);
|
||||
logits = MNN::Express::_Softmax(logits);
|
||||
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[it]]));
|
||||
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[it]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[it]]) << std::endl;
|
||||
}
|
||||
// clean up once
|
||||
reset();
|
||||
mLlm->reset();
|
||||
prev_end_loc = end_loc;
|
||||
if (end_loc == seq_len) break;
|
||||
}
|
||||
|
||||
// calculate ppl
|
||||
for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
|
||||
ppl /= nlls.size();
|
||||
ppl = std::exp(ppl);
|
||||
|
||||
// print
|
||||
std::cout << "PPL: " << std::setprecision(9) << ppl << std::endl;
|
||||
return ppl;
|
||||
}
|
||||
|
||||
std::vector<float> PPLMeasurer::perplexity() {
|
||||
std::vector<float> ppls;
|
||||
for (auto prompt : mPrompts) {
|
||||
ppls.push_back(perplexity_one(prompt));
|
||||
reset();
|
||||
mLlm->reset();
|
||||
}
|
||||
return ppls;
|
||||
}
|
||||
|
||||
void PPLMeasurer::reset() {
|
||||
// in the future, only reset its own.
|
||||
}
|
||||
|
||||
void PPLMeasurer::reset(int max_len, int stride) {
|
||||
mMaxLen = max_len;
|
||||
mStride = stride;
|
||||
reset();
|
||||
}
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
|
@ -1,50 +0,0 @@
|
|||
#ifndef PERPLEXITY_hpp
|
||||
#define PERPLEXITY_hpp
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <streambuf>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include <MNN/expr/Expr.hpp>
|
||||
#include <MNN/expr/Module.hpp>
|
||||
#include <MNN/expr/MathOp.hpp>
|
||||
#include <MNN/expr/NeuralNetWorkOp.hpp>
|
||||
#include <MNN/StateCacheManager.hpp>
|
||||
|
||||
namespace MNN {
|
||||
namespace Transformer {
|
||||
class Llm;
|
||||
|
||||
class MNN_PUBLIC PPLMeasurer {
|
||||
protected:
|
||||
Llm* mLlm;
|
||||
StateCacheManager* mStateCacheManager;
|
||||
std::vector<std::vector<int>> mPrompts;
|
||||
std::shared_ptr<StateCacheReference> mCandidate;
|
||||
int mStride, mMaxLen;
|
||||
void init(Llm* llm, StateCacheManager* manager, std::vector<std::vector<int>> prompts, int max_len, int stride);
|
||||
public:
|
||||
PPLMeasurer(Llm* llm, StateCacheManager* manager, std::vector<std::vector<int>> prompts, int max_len=2048, int stride=0);
|
||||
PPLMeasurer(Llm* llm, StateCacheManager* manager, std::vector<std::string> prompts, int max_len=2048, int stride=0);
|
||||
float perplexity_one(const std::vector<int>& prompt);
|
||||
std::vector<float> perplexity();
|
||||
// prepare for another round of sampling
|
||||
// in the future, only reset its own.
|
||||
void reset();
|
||||
void reset(int max_len, int stride);
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
||||
|
||||
|
||||
#endif // SAMPLER_hpp
|
|
@ -248,7 +248,7 @@ pc端直接推理
|
|||
```bash
|
||||
# 利用adb push将链接库push到手机上
|
||||
adb shell mkdir /data/local/tmp/llm
|
||||
adb push chat_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
|
||||
adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
|
||||
```
|
||||
|
||||
#### GPTQ权重加载
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
datasets/*
|
||||
!datasets/*.sh
|
||||
|
||||
|
||||
!datasets/visualization/
|
||||
datasets/visualization/data
|
||||
datasets/visualization/pic
|
|
@ -0,0 +1,2 @@
|
|||
git lfs install
|
||||
git clone https://huggingface.co/datasets/shareAI/ShareGPT-Chinese-English-90k
|
|
@ -0,0 +1,2 @@
|
|||
wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
unzip wikitext-2-raw-v1.zip
|
|
@ -0,0 +1,116 @@
|
|||
import matplotlib.pyplot as plt
|
||||
from matplotlib import colors
|
||||
from matplotlib.ticker import PercentFormatter
|
||||
from matplotlib import cbook
|
||||
from matplotlib.axes import Axes
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
|
||||
vis_root = "pic"
|
||||
|
||||
def remove_blanks(df: pd.DataFrame) -> pd.DataFrame:
|
||||
# Removing unnamed columns using drop function
|
||||
df.drop(df.columns[df.columns.str.contains(
|
||||
'unnamed', case=False)], axis=1, inplace=True)
|
||||
return df
|
||||
def add_turns(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df["turns"] = (1-df.isnull()).sum(axis=1) // 2
|
||||
return df
|
||||
def get_max_turn(df: pd.DataFrame) -> int:
|
||||
keys = list(df.keys())
|
||||
return max([int(key.replace("decode", "")) for key in keys if "decode" in key]) + 1
|
||||
def add_pd_ratio(df: pd.DataFrame) -> pd.DataFrame:
|
||||
max_turns = get_max_turn(df)
|
||||
for i in range(max_turns):
|
||||
df["pd_ratio{}".format(i)] = df["prefill{}".format(i)] / df["decode{}".format(i)]
|
||||
return df
|
||||
def preprocess(file_path: str) -> pd.DataFrame:
|
||||
table = pd.read_csv(file_path)
|
||||
table = remove_blanks(table)
|
||||
table = add_turns(table)
|
||||
table = add_pd_ratio(table)
|
||||
print(table)
|
||||
return table
|
||||
|
||||
def draw_distribution(df: pd.DataFrame, file_path: str):
|
||||
turns_bin = df.value_counts(subset=["turns"], sort=False)
|
||||
print(turns_bin)
|
||||
plt.close()
|
||||
plt.rcParams['font.size'] = 10
|
||||
_, ax = plt.subplots()
|
||||
# N is the count in each bin, bins is the lower-limit of the bin
|
||||
N, bins, patches = ax.hist(df["turns"], bins=get_max_turn(df), density=True, align="left", label=True)
|
||||
# We'll color code by height, but you could use any scalar
|
||||
fracs = N / N.max()
|
||||
# we need to normalize the data to 0..1 for the full range of the colormap
|
||||
norm = colors.Normalize(fracs.min(), fracs.max())
|
||||
# Now, we'll loop through our objects and set the color of each accordingly
|
||||
for thisfrac, thispatch in zip(fracs, patches):
|
||||
color = plt.cm.viridis(norm(thisfrac))
|
||||
thispatch.set_facecolor(color)
|
||||
# Now we format the y-axis to display percentage
|
||||
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1))
|
||||
ax.set_xlim((0.5, get_max_turn(df)-0.5))
|
||||
ax.set_xticks(np.arange(1,get_max_turn(df)+1),np.arange(1,get_max_turn(df)+1),rotation=60, fontsize=9)
|
||||
ax.set_ylabel("frequency", fontsize=14)
|
||||
ax.set_xlabel("num of turns", fontsize=14)
|
||||
plt.savefig(file_path, dpi=600)
|
||||
plt.close()
|
||||
|
||||
def draw_prefill(df: pd.DataFrame, ax: Axes):
|
||||
stats = [cbook.boxplot_stats(df[df["prefill{}".format(i)].notna()]["prefill{}".format(i)], labels=[i+1])[0]
|
||||
for i in range(get_max_turn(df))]
|
||||
print(stats)
|
||||
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
|
||||
ax.set_ylim(0,600)
|
||||
ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
|
||||
ax.set_ylabel("prefill", fontsize=12, rotation=90)
|
||||
return
|
||||
def draw_decode(df: pd.DataFrame, ax: Axes):
|
||||
stats = [cbook.boxplot_stats(df[df["decode{}".format(i)].notna()]["decode{}".format(i)], labels=[i+1])[0]
|
||||
for i in range(get_max_turn(df))]
|
||||
print(stats)
|
||||
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
|
||||
ax.set_ylim(0,600)
|
||||
ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
|
||||
ax.set_ylabel("decode", fontsize=12, rotation=90)
|
||||
return
|
||||
def draw_pd_ratio(df: pd.DataFrame, ax: Axes):
|
||||
stats = [cbook.boxplot_stats(df[df["pd_ratio{}".format(i)].notna()]["pd_ratio{}".format(i)], labels=[i+1])[0]
|
||||
for i in range(get_max_turn(df))]
|
||||
print(stats)
|
||||
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
|
||||
ax.plot(np.arange(0,get_max_turn(df)+2), np.ones_like(np.arange(0,get_max_turn(df)+2),dtype=float))
|
||||
ax.set_xlim(0, get_max_turn(df)+1)
|
||||
ax.set_ylim(0, 2.)
|
||||
ax.set_xticks(np.arange(1,get_max_turn(df)), np.arange(1,get_max_turn(df)), rotation=60, fontsize=9)
|
||||
ax.set_yticks([0,0.5,1,2], [0,0.5,1,2], fontsize=9)
|
||||
ax.set_xlabel("round", fontsize=12)
|
||||
ax.set_ylabel("prefill/decode", fontsize=12, rotation=90)
|
||||
return
|
||||
def draw_reuse_kv(df: pd.DataFrame, file_path: str):
|
||||
plt.close()
|
||||
_, axs = plt.subplots(3,1,sharex="col")
|
||||
draw_prefill(df, axs[0])
|
||||
draw_decode(df, axs[1])
|
||||
draw_pd_ratio(df, axs[2])
|
||||
plt.savefig(file_path, dpi=1200)
|
||||
plt.close()
|
||||
return
|
||||
def draw_no_reuse_kv():
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--root", type=str, default="./data")
|
||||
parser.add_argument("--name", type=str, default="shareGPT_dialog_stats_common_en.csv")
|
||||
args = parser.parse_args()
|
||||
|
||||
file_path = os.path.join(args.root, args.name)
|
||||
dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_dist.png")
|
||||
pd_dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_pd_dist.png")
|
||||
table = preprocess(file_path)
|
||||
draw_distribution(table, dist_path)
|
||||
draw_reuse_kv(table, pd_dist_path)
|
|
@ -26,8 +26,11 @@ else()
|
|||
endif()
|
||||
|
||||
add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/app/llm_demo.cpp)
|
||||
add_executable(ppl_demo ${CMAKE_CURRENT_LIST_DIR}/app/ppl_demo.cpp)
|
||||
IF (NOT MNN_SEP_BUILD)
|
||||
target_link_libraries(llm_demo ${MNN_DEPS})
|
||||
target_link_libraries(ppl_demo ${MNN_DEPS})
|
||||
ELSE ()
|
||||
target_link_libraries(llm_demo ${MNN_DEPS} llm)
|
||||
target_link_libraries(ppl_demo ${MNN_DEPS} llm)
|
||||
ENDIF ()
|
|
@ -6,6 +6,7 @@
|
|||
//
|
||||
|
||||
#include "llm/llm.hpp"
|
||||
#include "evaluation/dataset.hpp"
|
||||
#define MNN_OPEN_TIME_TRACE
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include <MNN/expr/ExecutorScope.hpp>
|
||||
|
@ -23,49 +24,6 @@ static void trace_prepare(Llm* llm) {
|
|||
llm->reset();
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>& lines) {
|
||||
std::vector<std::vector<std::string>> csv_data;
|
||||
std::string line;
|
||||
std::vector<std::string> row;
|
||||
std::string cell;
|
||||
bool insideQuotes = false;
|
||||
bool startCollecting = false;
|
||||
|
||||
// content to stream
|
||||
std::string content = "";
|
||||
for (auto line : lines) {
|
||||
content = content + line + "\n";
|
||||
}
|
||||
std::istringstream stream(content);
|
||||
|
||||
while (stream.peek() != EOF) {
|
||||
char c = stream.get();
|
||||
if (c == '"') {
|
||||
if (insideQuotes && stream.peek() == '"') { // quote
|
||||
cell += '"';
|
||||
stream.get(); // skip quote
|
||||
} else {
|
||||
insideQuotes = !insideQuotes; // start or end text in quote
|
||||
}
|
||||
startCollecting = true;
|
||||
} else if (c == ',' && !insideQuotes) { // end element, start new element
|
||||
row.push_back(cell);
|
||||
cell.clear();
|
||||
startCollecting = false;
|
||||
} else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line
|
||||
row.push_back(cell);
|
||||
csv_data.push_back(row);
|
||||
cell.clear();
|
||||
row.clear();
|
||||
startCollecting = false;
|
||||
} else {
|
||||
cell += c;
|
||||
startCollecting = true;
|
||||
}
|
||||
}
|
||||
return csv_data;
|
||||
}
|
||||
|
||||
static int benchmark(Llm* llm, const std::vector<std::string>& prompts) {
|
||||
for (int i = 0; i < prompts.size(); i++) {
|
||||
const auto& prompt = prompts[i];
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
//
|
||||
// ppl_demo.cpp
|
||||
//
|
||||
// Created by MNN on 2023/03/24.
|
||||
// ZhaodeWang
|
||||
//
|
||||
|
||||
#include "llm/llm.hpp"
|
||||
#define MNN_OPEN_TIME_TRACE
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <stdlib.h>
|
||||
#include <MNN/expr/Executor.hpp>
|
||||
#include <MNN/expr/ExecutorScope.hpp>
|
||||
using namespace MNN::Transformer;
|
||||
static void trace_prepare(Llm* llm) {
|
||||
MNN_PRINT("Prepare for resize opt Begin\n");
|
||||
llm->trace(true);
|
||||
std::ostringstream cacheOs;
|
||||
llm->response("Hello", &cacheOs);
|
||||
MNN_PRINT("Prepare for resize opt End\n");
|
||||
llm->trace(false);
|
||||
llm->reset();
|
||||
}
|
||||
|
||||
// parse json
|
||||
|
||||
static int ppl_eval(Llm* llm, std::string prompt_file, std::ofstream* perfOS) {
|
||||
std::cout << "prompt file is " << prompt_file << std::endl;
|
||||
// ppl evaluation
|
||||
std::vector<float> ppls = llm->perplexity(prompt_file, perfOS);
|
||||
float mean_ppl = 0.f;
|
||||
for (int j = 0; j < ppls.size(); ++j) mean_ppl += ppls[j];
|
||||
mean_ppl /= ppls.size();
|
||||
std::cout << mean_ppl << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: " << argv[0] << " config.json ppl-prompt.txt [perf.txt]" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
std::string config_path = argv[1];
|
||||
std::cout << "config path is " << config_path << std::endl;
|
||||
std::unique_ptr<Llm> llm(Llm::createLLM(config_path));
|
||||
{
|
||||
AUTOTIME;
|
||||
llm->load();
|
||||
}
|
||||
{
|
||||
AUTOTIME;
|
||||
trace_prepare(llm.get());
|
||||
}
|
||||
std::string prompt_file = argv[2];
|
||||
std::unique_ptr<std::ofstream> perfOS(nullptr);
|
||||
if (argc == 4) { perfOS.reset(new std::ofstream(argv[3])); }
|
||||
return ppl_eval(llm.get(), prompt_file, perfOS.get());
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
#ifndef LLM_DATASET_hpp
|
||||
#define LLM_DATASET_hpp
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <rapidjson/document.h>
|
||||
#include <rapidjson/writer.h>
|
||||
#include <rapidjson/stringbuffer.h>
|
||||
#include "llm/llm.hpp"
|
||||
|
||||
#include <MNN/MNNDefine.h>
|
||||
|
||||
namespace MNN {
|
||||
namespace Transformer {
|
||||
|
||||
|
||||
// parse csv
|
||||
MNN_PUBLIC std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>& lines);
|
||||
void parse_jsonl(std::string prompt_file, std::vector<std::vector<std::vector<PromptItem>>>& dialogs);
|
||||
|
||||
std::string getPPLType(std::string dataset_name);
|
||||
std::vector<std::string> rowsplit(std::string prompt_file);
|
||||
std::vector<std::string> plaintext(std::string prompt_file);
|
||||
std::vector<std::string> wikitext(std::string prompt_file);
|
||||
std::vector<std::vector<std::vector<PromptItem>>> shareGPT(std::string prompt_file, int sample_size=-1); // -1: no sampling
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
||||
|
||||
#endif // LLM_DATASET_hpp
|
|
@ -41,6 +41,8 @@ struct TimePerformance {
|
|||
std::vector<int> prompt_record_;
|
||||
};
|
||||
|
||||
void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv);
|
||||
|
||||
struct PrefillMemPerformance {
|
||||
size_t prefill_prev_token_ = 0;
|
||||
size_t prefill_token_ = 0;
|
||||
|
|
|
@ -63,6 +63,10 @@ public:
|
|||
};
|
||||
|
||||
class MNN_PUBLIC Llm {
|
||||
public:
|
||||
std::shared_ptr<Sampler> mSampler;
|
||||
std::shared_ptr<PromptLib> mPromptLib;
|
||||
std::vector<LlmSessionInfo> mLlmSessionInfos; // Llm conversation session information. Currently, only mLlmSessionInfos[0] is allowed!
|
||||
public:
|
||||
Llm(std::shared_ptr<LlmConfig> config) : config_(config) {}
|
||||
virtual ~Llm();
|
||||
|
@ -81,6 +85,7 @@ public:
|
|||
std::string generateTrace(const std::vector<int>& input_ids, std::ostream* os, const char* end_with);
|
||||
void print_speed();
|
||||
void print_speed(std::ostream* os);
|
||||
std::vector<float> perplexity(std::string prompt_file, std::ostream* statsOS = nullptr);
|
||||
// config function
|
||||
std::string dump_config();
|
||||
bool set_config(const std::string& content);
|
||||
|
@ -91,7 +96,6 @@ public:
|
|||
bool select_module(size_t index);
|
||||
friend class Pipeline;
|
||||
public:
|
||||
std::vector<LlmSessionInfo> mLlmSessionInfos; // currently, only mLlmSessionInfos[0] is allowed!
|
||||
bool is_single_ = true;
|
||||
bool attention_fused_ = true;
|
||||
virtual std::vector<int> tokenizer(const std::string& query);
|
||||
|
@ -109,8 +113,6 @@ public:
|
|||
protected:
|
||||
std::shared_ptr<LlmConfig> config_;
|
||||
std::shared_ptr<Tokenizer> tokenizer_;
|
||||
std::shared_ptr<Sampler> mSampler;
|
||||
std::shared_ptr<PromptLib> mPromptLib;
|
||||
std::vector<int> key_value_shape_ = {};
|
||||
std::vector<MNN::Express::VARP> past_key_values_;
|
||||
MNN::Express::VARP inputs_embeds_, attention_mask_, position_ids_;
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
|
||||
#include "llm/llm.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace Transformer {
|
||||
|
||||
// LlmSessionInfo starts
|
||||
void LlmSessionInfo::resetSamplerFields() {
|
||||
all_seq_len_ = 0;
|
||||
gen_seq_len_ = 0;
|
||||
tokens.clear();
|
||||
}
|
||||
void LlmSessionInfo::resetPromptFields() {
|
||||
mHistory.clear();
|
||||
mInputs.clear();
|
||||
}
|
||||
void LlmSessionInfo::resetPerformanceFields() {
|
||||
clearPerformance(&mTimePerformance);
|
||||
}
|
||||
float LlmSessionInfo::average_total_speed() {
|
||||
return (getTotalPromptLen()+getTotalDecodeLen())/(getTotalPrefillTime()+getTotalDecodeTime());
|
||||
}
|
||||
float LlmSessionInfo::average_prefill_speed() {
|
||||
// prefill response rate
|
||||
return getTotalPromptLen()/getTotalPrefillTime();
|
||||
}
|
||||
float LlmSessionInfo::average_decode_speed() {
|
||||
return getTotalDecodeLen()/getTotalDecodeTime();
|
||||
}
|
||||
float LlmSessionInfo::getTotalPrefillTime() {
|
||||
float sum = 0.f;
|
||||
for (auto record : mTimePerformance.prefill_record_) {
|
||||
sum += ((float)record.prefill_us_)*MICRO_TO_SEC;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
float LlmSessionInfo::getTotalDecodeTime() {
|
||||
float sum = 0.0f;
|
||||
for (auto record : mTimePerformance.decode_record_) {
|
||||
sum += ((float)record.decode_us_)*MICRO_TO_SEC;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
int LlmSessionInfo::getTotalPromptLen() {
|
||||
int prompt_len = 0;
|
||||
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
|
||||
for (auto record : mTimePerformance.prefill_record_) {
|
||||
prompt_len += record.prefill_token_;
|
||||
}
|
||||
} else {
|
||||
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
|
||||
prompt_len += mTimePerformance.prompt_record_[r];
|
||||
}
|
||||
}
|
||||
return prompt_len;
|
||||
}
|
||||
int LlmSessionInfo::getTotalDecodeLen() {
|
||||
return mTimePerformance.decode_record_.size();
|
||||
}
|
||||
void LlmSessionInfo::print_speed(std::ostream* os) {
|
||||
(*os) << "prefill " << mTimePerformance.prefill_record_.size() << std::endl;
|
||||
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
|
||||
(*os) << "prev_token input_token speed(token/s)" << std::endl;
|
||||
for (auto record : mTimePerformance.prefill_record_) {
|
||||
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
|
||||
}
|
||||
} else {
|
||||
(*os) << "prev_token input_token prompt_token response_speed(token/s)" << std::endl;
|
||||
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
|
||||
auto record = mTimePerformance.prefill_record_[r];
|
||||
auto prompt_len = mTimePerformance.prompt_record_[r];
|
||||
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << prompt_len << " " << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
|
||||
}
|
||||
}
|
||||
(*os) << "decode " << mTimePerformance.decode_record_.size() << std::endl;
|
||||
(*os) << "prev_token speed(token/s)" << std::endl;
|
||||
for (auto record : mTimePerformance.decode_record_) {
|
||||
(*os) << record.decode_prev_token_ << " " << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
|
@ -0,0 +1,221 @@
|
|||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <llm/llm.hpp>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <iterator>
|
||||
#include <random>
|
||||
#include "evaluation/dataset.hpp"
|
||||
#include <rapidjson/document.h>
|
||||
#include <rapidjson/writer.h>
|
||||
#include <rapidjson/stringbuffer.h>
|
||||
|
||||
namespace MNN {
|
||||
namespace Transformer {
|
||||
|
||||
|
||||
// parse file
|
||||
// csv json
|
||||
|
||||
// parse csv
|
||||
std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>& lines) {
|
||||
std::vector<std::vector<std::string>> csv_data;
|
||||
std::string line;
|
||||
std::vector<std::string> row;
|
||||
std::string cell;
|
||||
bool insideQuotes = false;
|
||||
bool startCollecting = false;
|
||||
|
||||
// content to stream
|
||||
std::string content = "";
|
||||
for (auto line : lines) {
|
||||
content = content + line + "\n";
|
||||
}
|
||||
std::istringstream stream(content);
|
||||
|
||||
while (stream.peek() != EOF) {
|
||||
char c = stream.get();
|
||||
if (c == '"') {
|
||||
if (insideQuotes && stream.peek() == '"') { // quote
|
||||
cell += '"';
|
||||
stream.get(); // skip quote
|
||||
} else {
|
||||
insideQuotes = !insideQuotes; // start or end text in quote
|
||||
}
|
||||
startCollecting = true;
|
||||
} else if (c == ',' && !insideQuotes) { // end element, start new element
|
||||
row.push_back(cell);
|
||||
cell.clear();
|
||||
startCollecting = false;
|
||||
} else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line
|
||||
row.push_back(cell);
|
||||
csv_data.push_back(row);
|
||||
cell.clear();
|
||||
row.clear();
|
||||
startCollecting = false;
|
||||
} else {
|
||||
cell += c;
|
||||
startCollecting = true;
|
||||
}
|
||||
}
|
||||
return csv_data;
|
||||
}
|
||||
|
||||
// dialog, turn,
|
||||
void parse_jsonl(std::string prompt_file, std::vector<std::vector<std::vector<PromptItem>>>& dialogs) {
|
||||
std::ifstream prompt_fs(prompt_file);
|
||||
std::string prompt;
|
||||
while(std::getline(prompt_fs, prompt)) {
|
||||
rapidjson::Document document;
|
||||
document.Parse(prompt.c_str());
|
||||
std::vector<std::vector<PromptItem>> cnv;
|
||||
if(document.HasMember("conversation")) {
|
||||
auto& value = document["conversation"];
|
||||
if (value.IsArray()) {
|
||||
for (auto& v : value.GetArray()) {
|
||||
if (v.IsObject()) {
|
||||
std::vector<PromptItem> result;
|
||||
for (auto itr = v.MemberBegin(); itr != v.MemberEnd(); ++itr) {
|
||||
// {"human"/"user": , "assistant": }
|
||||
result.push_back(std::make_pair(itr->name.GetString(), itr->value.GetString()));
|
||||
}
|
||||
cnv.push_back(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dialogs.push_back(cnv);
|
||||
}
|
||||
}
|
||||
|
||||
void write_jsonl(std::string prompt_file, const std::vector<std::vector<std::vector<PromptItem>>>& dialogs) {
|
||||
std::ofstream prompt_fs(prompt_file);
|
||||
for(auto& dialog : dialogs) {
|
||||
rapidjson::Document document;
|
||||
document.SetObject();
|
||||
rapidjson::Value conversation(rapidjson::kArrayType);
|
||||
conversation.SetArray();
|
||||
for (auto& turn : dialog) {
|
||||
rapidjson::Value sentence(rapidjson::kObjectType);
|
||||
sentence.SetObject();
|
||||
for (auto& role : turn) {
|
||||
sentence.AddMember(rapidjson::Value(role.first.c_str(), document.GetAllocator()),
|
||||
rapidjson::Value(role.second.c_str(), document.GetAllocator()), document.GetAllocator());
|
||||
}
|
||||
conversation.PushBack(sentence, document.GetAllocator());
|
||||
}
|
||||
document.AddMember("conversation", conversation, document.GetAllocator());
|
||||
// write to file
|
||||
rapidjson::StringBuffer buffer;
|
||||
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
|
||||
document.Accept(writer);
|
||||
prompt_fs << buffer.GetString() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// dataset
|
||||
// wikitext, ShareGPT
|
||||
|
||||
std::string getPPLType(std::string dataset_name) {
|
||||
if (dataset_name == "wikitext"
|
||||
|| dataset_name == "plaintext"
|
||||
|| dataset_name == "rowsplit") {
|
||||
return "text";
|
||||
} else if (dataset_name == "shareGPT") {
|
||||
return "chat";
|
||||
} else {
|
||||
// default chat
|
||||
return "chat";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> plaintext(std::string prompt_file) {
|
||||
// split by line
|
||||
std::ifstream prompt_fs(prompt_file);
|
||||
std::vector<std::string> prompts;
|
||||
std::string prompt;
|
||||
prompts.push_back("");
|
||||
while (std::getline(prompt_fs, prompt)) {
|
||||
if (prompt.back() == '\r' || prompt.back() == '\n') {
|
||||
prompt.pop_back();
|
||||
}
|
||||
// concatenate.
|
||||
prompts.back() += prompt + "\n";
|
||||
}
|
||||
return prompts;
|
||||
}
|
||||
|
||||
std::vector<std::string> rowsplit(std::string prompt_file) {
|
||||
// split by line
|
||||
std::ifstream prompt_fs(prompt_file);
|
||||
std::vector<std::string> prompts;
|
||||
std::string prompt;
|
||||
while (std::getline(prompt_fs, prompt)) {
|
||||
if (prompt.back() == '\r' || prompt.back() == '\n') {
|
||||
prompt.pop_back();
|
||||
}
|
||||
prompts.push_back(prompt);
|
||||
}
|
||||
return prompts;
|
||||
}
|
||||
|
||||
// wikitext
|
||||
void removeSubstrs(std::string& s, std::string p) {
|
||||
std::string::size_type n = p.length();
|
||||
for (std::string::size_type i = s.find(p); i != std::string::npos; i = s.find(p))
|
||||
s.erase(i, n);
|
||||
}
|
||||
std::vector<std::string> wikitext(std::string prompt_file) {
|
||||
// split wiki text into " = " first-level column.
|
||||
std::ifstream prompt_fs(prompt_file);
|
||||
std::vector<std::string> prompts;
|
||||
std::string prompt;
|
||||
while (std::getline(prompt_fs, prompt)) {
|
||||
if (prompt.back() == '\r' || prompt.back() == '\n') {
|
||||
prompt.pop_back();
|
||||
}
|
||||
if (prompt.size() < 4) continue;
|
||||
removeSubstrs(prompt, "@-@");
|
||||
if ((prompts.size() == 0) \
|
||||
|| (prompt.size() >= 4 \
|
||||
&& prompt.at(0) == ' ' \
|
||||
&& prompt.at(1) == '=' \
|
||||
&& prompt.at(2) == ' ' \
|
||||
&& prompt.at(3) != '=')) {
|
||||
// first-level column.
|
||||
prompts.push_back(prompt);
|
||||
} else {
|
||||
// concatenate.
|
||||
prompts.back() += "\n" + prompt;
|
||||
}
|
||||
}
|
||||
return prompts;
|
||||
}
|
||||
|
||||
std::string genSampleName(std::string oriName, int sample_size) {
|
||||
const size_t last_slash_idx = oriName.rfind('.');
|
||||
auto stem = oriName.substr(0, last_slash_idx);
|
||||
return stem + "_sample" + std::to_string(sample_size) + ".jsonl";
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::vector<PromptItem>>> shareGPT(std::string prompt_file, int sample_size) {
|
||||
std::vector<std::vector<std::vector<PromptItem>>> dialogs, dataset;
|
||||
parse_jsonl(prompt_file, dialogs);
|
||||
// randomly sample a subset
|
||||
if (sample_size > 0 && sample_size < dialogs.size()){
|
||||
std::sample(dialogs.begin(), dialogs.end(), std::back_inserter(dataset),
|
||||
sample_size, std::mt19937 {std::random_device{}()});
|
||||
dialogs = dataset;
|
||||
// store dialogs to file
|
||||
write_jsonl(genSampleName(prompt_file, sample_size), dialogs);
|
||||
}
|
||||
return dialogs;
|
||||
}
|
||||
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
|
@ -12,6 +12,19 @@ void clearPerformance(struct TimePerformance* perf) {
|
|||
perf->decode_record_.clear();
|
||||
perf->prompt_record_.clear();
|
||||
}
|
||||
void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv) {
|
||||
if (reuse_kv) {
|
||||
perf->prompt_record_.push_back(input_len);
|
||||
} else {
|
||||
// not reuse kv
|
||||
if (!perf->decode_record_.empty()) {
|
||||
perf->prompt_record_.push_back(input_len - (perf->decode_record_.back().decode_prev_token_+1));
|
||||
} else {
|
||||
// first prefill
|
||||
perf->prompt_record_.push_back(input_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
|
@ -33,82 +33,6 @@ using namespace MNN::Express;
|
|||
namespace MNN {
|
||||
namespace Transformer {
|
||||
|
||||
// LlmSessionInfo starts
|
||||
void LlmSessionInfo::resetSamplerFields() {
|
||||
all_seq_len_ = 0;
|
||||
gen_seq_len_ = 0;
|
||||
tokens.clear();
|
||||
}
|
||||
void LlmSessionInfo::resetPromptFields() {
|
||||
mHistory.clear();
|
||||
mInputs.clear();
|
||||
}
|
||||
void LlmSessionInfo::resetPerformanceFields() {
|
||||
clearPerformance(&mTimePerformance);
|
||||
}
|
||||
float LlmSessionInfo::average_total_speed() {
|
||||
return (getTotalPromptLen()+getTotalDecodeLen())/(getTotalPrefillTime()+getTotalDecodeTime());
|
||||
}
|
||||
float LlmSessionInfo::average_prefill_speed() {
|
||||
// prefill response rate
|
||||
return getTotalPromptLen()/getTotalPrefillTime();
|
||||
}
|
||||
float LlmSessionInfo::average_decode_speed() {
|
||||
return getTotalDecodeLen()/getTotalDecodeTime();
|
||||
}
|
||||
float LlmSessionInfo::getTotalPrefillTime() {
|
||||
float sum = 0.f;
|
||||
for (auto record : mTimePerformance.prefill_record_) {
|
||||
sum += ((float)record.prefill_us_)*MICRO_TO_SEC;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
float LlmSessionInfo::getTotalDecodeTime() {
|
||||
float sum = 0.0f;
|
||||
for (auto record : mTimePerformance.decode_record_) {
|
||||
sum += ((float)record.decode_us_)*MICRO_TO_SEC;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
int LlmSessionInfo::getTotalPromptLen() {
|
||||
int prompt_len = 0;
|
||||
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
|
||||
for (auto record : mTimePerformance.prefill_record_) {
|
||||
prompt_len += record.prefill_token_;
|
||||
}
|
||||
} else {
|
||||
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
|
||||
prompt_len += mTimePerformance.prompt_record_[r];
|
||||
}
|
||||
}
|
||||
return prompt_len;
|
||||
}
|
||||
int LlmSessionInfo::getTotalDecodeLen() {
|
||||
return mTimePerformance.decode_record_.size();
|
||||
}
|
||||
void LlmSessionInfo::print_speed(std::ostream* os) {
|
||||
(*os) << "prefill " << mTimePerformance.prefill_record_.size() << std::endl;
|
||||
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
|
||||
(*os) << "prev_token input_token speed(token/s)" << std::endl;
|
||||
for (auto record : mTimePerformance.prefill_record_) {
|
||||
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
|
||||
}
|
||||
} else {
|
||||
(*os) << "prev_token input_token prompt_token response_speed(token/s)" << std::endl;
|
||||
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
|
||||
auto record = mTimePerformance.prefill_record_[r];
|
||||
auto prompt_len = mTimePerformance.prompt_record_[r];
|
||||
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << prompt_len << " " << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
|
||||
}
|
||||
}
|
||||
(*os) << "decode " << mTimePerformance.decode_record_.size() << std::endl;
|
||||
(*os) << "prev_token speed(token/s)" << std::endl;
|
||||
for (auto record : mTimePerformance.decode_record_) {
|
||||
(*os) << record.decode_prev_token_ << " " << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Lvlm starts
|
||||
class Lvlm : public Llm {
|
||||
|
@ -408,7 +332,6 @@ void Llm::chat(bool session_by_line, bool from_file,
|
|||
}
|
||||
|
||||
std::string Llm::response(const std::string& user_str, std::ostream* os, const char* end_with) {
|
||||
mLlmSessionInfos[0].mTimePerformance.prompt_record_.push_back(tokenizer(user_str).size());
|
||||
mPromptLib->appendUserPrompt(user_str);
|
||||
auto assistant_str = generate(mPromptLib->getLLMInput(), os, end_with);
|
||||
mPromptLib->appendLLMOutput(assistant_str);
|
||||
|
@ -487,6 +410,13 @@ std::vector<int> Llm::tokenizer(const std::string& prompt) {
|
|||
// generate >
|
||||
|
||||
|
||||
// < evaluation
|
||||
std::vector<float> Llm::perplexity(std::string prompt_file, std::ostream* perfOS) {
|
||||
return mSampler->perplexity(prompt_file, perfOS);
|
||||
}
|
||||
// evaluation >
|
||||
|
||||
|
||||
Llm::~Llm() {
|
||||
#if DEBUG_MODE==1
|
||||
if (nullptr != gTimeTraceInfo) {
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include "rapidjson/document.h"
|
||||
#include <rapidjson/document.h>
|
||||
#include <rapidjson/writer.h>
|
||||
#include <rapidjson/stringbuffer.h>
|
||||
|
||||
|
@ -394,6 +394,18 @@ public:
|
|||
return config_.value("system_prompt", "You are a helpful assistant!\n");
|
||||
}
|
||||
// app config end >
|
||||
|
||||
// < evaulation config start
|
||||
int ppl_stride() const {
|
||||
return config_.value("ppl_stride", 0);
|
||||
}
|
||||
std::string dataset() const {
|
||||
return config_.value("dataset", "wikitext");
|
||||
}
|
||||
int dataset_sample_size() const {
|
||||
return config_.value("dataset_sample_size", -1); // -1 stands for no sampling, use all.
|
||||
}
|
||||
// evaulation config end
|
||||
};
|
||||
} // Transformer
|
||||
} // MNN
|
||||
|
|
|
@ -0,0 +1,321 @@
|
|||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <llm/llm.hpp>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "sampler.hpp"
|
||||
#include "perplexity.hpp"
|
||||
#include "llmconfig.hpp"
|
||||
#include "prompt.hpp"
|
||||
|
||||
namespace MNN{
|
||||
namespace Transformer{
|
||||
|
||||
|
||||
/* -----------TextPPLMeasurer---------- */
|
||||
TextPPLMeasurer::TextPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> llmConfig) {
|
||||
mLlm = llm;
|
||||
mConfig.max_all_tokens = llmConfig->max_all_tokens();
|
||||
mConfig.max_new_tokens = llmConfig->max_new_tokens();
|
||||
mDatasetType = llmConfig->dataset();
|
||||
mStride = llmConfig->ppl_stride();
|
||||
if (mStride == 0) {
|
||||
// default stride for sliding window.
|
||||
mStride = mConfig.max_all_tokens / 2;
|
||||
}
|
||||
}
|
||||
|
||||
/* Implemented based on https://huggingface.co/docs/transformers/perplexity
|
||||
|
||||
******************** HuggingFace Python Version ************************
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
max_length = model.config.n_positions
|
||||
stride = 512
|
||||
seq_len = encodings.input_ids.size(1)
|
||||
|
||||
nlls = []
|
||||
prev_end_loc = 0
|
||||
for begin_loc in tqdm(range(0, seq_len, stride)):
|
||||
end_loc = min(begin_loc + max_length, seq_len)
|
||||
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
||||
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, labels=target_ids)
|
||||
|
||||
# loss is calculated using CrossEntropyLoss which averages over valid labels
|
||||
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
|
||||
# to the left by 1.
|
||||
neg_log_likelihood = outputs.loss
|
||||
|
||||
nlls.append(neg_log_likelihood)
|
||||
|
||||
prev_end_loc = end_loc
|
||||
if end_loc == seq_len:
|
||||
break
|
||||
|
||||
ppl = torch.exp(torch.stack(nlls).mean())
|
||||
|
||||
******************** HuggingFace Python Version ************************
|
||||
*/
|
||||
|
||||
float TextPPLMeasurer::perplexity_one(const std::vector<int>& prompt) {
|
||||
int seq_len = prompt.size();
|
||||
std::vector<float> nlls;
|
||||
float ppl = 0.f;
|
||||
|
||||
// start calculation
|
||||
int prev_end_loc = 1; // the first token start from id=1, do not count the first one.
|
||||
for (int begin_loc = 0; begin_loc < seq_len; begin_loc += mStride) {
|
||||
int end_loc = std::min(begin_loc + mConfig.max_all_tokens, seq_len);
|
||||
// first token
|
||||
std::vector<int> tokens(prev_end_loc - begin_loc);
|
||||
for (int it = begin_loc; it < prev_end_loc; ++it) tokens[it - begin_loc] = prompt[it];
|
||||
mLlm->mLlmSessionInfos[0].all_seq_len_ = tokens.size();
|
||||
mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_;
|
||||
auto logits = mLlm->forward(tokens, mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, true);
|
||||
logits = MNN::Express::_Softmax(logits);
|
||||
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]));
|
||||
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[prev_end_loc]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]) << std::endl;
|
||||
std::cout << -std::log(((float*)(logits->readMap<float>()))[prompt[prev_end_loc]]) << std::endl;
|
||||
// decode following tokens
|
||||
for (int it = prev_end_loc+1; it < end_loc; ++it) {
|
||||
mLlm->mLlmSessionInfos[0].all_seq_len_ += 1;
|
||||
mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_;
|
||||
auto logits = mLlm->forward({prompt[it-1]},mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, false);
|
||||
logits = MNN::Express::_Softmax(logits);
|
||||
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[it]]));
|
||||
// std::cout << mLlm->decode(argmax(logits)) << " " << mLlm->decode(prompt[it]) << " " << -std::log(((float*)(logits->readMap<float>()))[prompt[it]]) << std::endl;
|
||||
std::cout << -std::log(((float*)(logits->readMap<float>()))[prompt[it]]) << std::endl;
|
||||
}
|
||||
// clean up once
|
||||
mLlm->reset();
|
||||
prev_end_loc = end_loc;
|
||||
if (end_loc == seq_len) break;
|
||||
}
|
||||
|
||||
// calculate ppl
|
||||
for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
|
||||
ppl /= nlls.size();
|
||||
ppl = std::exp(ppl);
|
||||
|
||||
// print
|
||||
std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl;
|
||||
return ppl;
|
||||
}
|
||||
|
||||
std::vector<float> TextPPLMeasurer::perplexity(std::vector<std::vector<int>> prompts) {
|
||||
std::vector<float> ppls;
|
||||
for (auto prompt : prompts) {
|
||||
ppls.push_back(perplexity_one(prompt));
|
||||
mLlm->reset();
|
||||
}
|
||||
return ppls;
|
||||
}
|
||||
|
||||
std::vector<float> TextPPLMeasurer::perplexity(std::vector<std::string> prompts) {
|
||||
std::vector<std::vector<int>> tokens(prompts.size());
|
||||
for (int p = 0; p < prompts.size(); ++p) tokens[p] = mLlm->tokenizer(prompts[p]);
|
||||
return perplexity(tokens);
|
||||
}
|
||||
|
||||
std::vector<float> TextPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) {
|
||||
// No performance will be printed!
|
||||
std::vector<std::string> prompts;
|
||||
if (mDatasetType == "wikitext") {
|
||||
prompts = wikitext(prompt_file);
|
||||
}
|
||||
else if (mDatasetType == "plaintext") {
|
||||
prompts = plaintext(prompt_file);
|
||||
}
|
||||
else if (mDatasetType == "rowsplit") {
|
||||
prompts = rowsplit(prompt_file);
|
||||
}
|
||||
else {
|
||||
MNN_ERROR("Dataset not suppoted");
|
||||
exit(1);
|
||||
}
|
||||
std::cout << "prompt file loaded!" << std::endl;
|
||||
return perplexity(prompts);
|
||||
}
|
||||
|
||||
/* -----------ChatPPLMeasurer---------- */
|
||||
ChatPPLMeasurer::ChatPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> llmConfig) {
|
||||
mLlm = llm;
|
||||
mConfig.max_all_tokens = llmConfig->max_all_tokens();
|
||||
mConfig.max_new_tokens = llmConfig->max_new_tokens();
|
||||
mDatasetType = llmConfig->dataset();
|
||||
mDatasetSampleSize = llmConfig->dataset_sample_size();
|
||||
}
|
||||
|
||||
void ChatPPLMeasurer::handleToken(int token) {
|
||||
// CommonPrefix and Candidates managements
|
||||
mLlm->mLlmSessionInfos[0].tokens.push_back(token);
|
||||
mLlm->mLlmSessionInfos[0].all_seq_len_++;
|
||||
mLlm->mLlmSessionInfos[0].gen_seq_len_++;
|
||||
}
|
||||
|
||||
std::vector<float> ChatPPLMeasurer::sample(const std::vector<int>& input_ids, const std::vector<int>& prompt, struct TimePerformance* time_perf) {
|
||||
std::vector<float> nlls;
|
||||
// initialization for time performance
|
||||
PrefillTimePerformance prefill_time;
|
||||
prefill_time.prefill_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
|
||||
prefill_time.prefill_token_ = input_ids.size();
|
||||
appendNewPromptRecord(time_perf, input_ids.size(), mLlm->reuse_kv());
|
||||
// initialization
|
||||
mLlm->mLlmSessionInfos[0].tokens.insert(mLlm->mLlmSessionInfos[0].tokens.end(), input_ids.begin(), input_ids.end());
|
||||
// all_seq_len_ in sampler functions as kv_seq_len_, prev_seq_len_ = all_seq_len_ - seq_len
|
||||
mLlm->mLlmSessionInfos[0].all_seq_len_ = mLlm->mLlmSessionInfos[0].tokens.size();
|
||||
mLlm->mLlmSessionInfos[0].gen_seq_len_ = 0;
|
||||
// prefill
|
||||
auto st = std::chrono::system_clock::now();
|
||||
auto logits = mLlm->forward(input_ids, mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, true);
|
||||
logits = MNN::Express::_Softmax(logits);
|
||||
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]]));
|
||||
// record time
|
||||
auto et = std::chrono::system_clock::now();
|
||||
prefill_time.prefill_us_ = std::chrono::duration_cast<std::chrono::microseconds>(et - st).count();
|
||||
time_perf->prefill_record_.push_back(prefill_time);
|
||||
// handle the new token
|
||||
handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]);
|
||||
// decode
|
||||
while (mLlm->mLlmSessionInfos[0].gen_seq_len_ < prompt.size()) {
|
||||
DecodeTimePerformance decode_time;
|
||||
decode_time.decode_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
|
||||
st = std::chrono::system_clock::now();
|
||||
// next token
|
||||
logits = mLlm->forward({mLlm->mLlmSessionInfos[0].tokens.back()}, mLlm->mLlmSessionInfos[0].all_seq_len_, mLlm->mLlmSessionInfos[0].gen_seq_len_, false);
|
||||
logits = MNN::Express::_Softmax(logits);
|
||||
nlls.push_back(-std::log(((float*)(logits->readMap<float>()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]]));
|
||||
et = std::chrono::system_clock::now();
|
||||
decode_time.decode_us_ = std::chrono::duration_cast<std::chrono::microseconds>(et - st).count();
|
||||
time_perf->decode_record_.push_back(decode_time);
|
||||
handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]);
|
||||
}
|
||||
// return nlls
|
||||
return nlls;
|
||||
}
|
||||
|
||||
float ChatPPLMeasurer::perplexity_one(const std::vector<std::vector<PromptItem>>& prompt, std::ostream* perfOS) {
|
||||
// (turns, roles)
|
||||
std::vector<float> nlls;
|
||||
float ppl = 0.f;
|
||||
|
||||
// < simulated chat
|
||||
mLlm->reset();
|
||||
for (auto& turn : prompt) {
|
||||
mLlm->mPromptLib->appendUserPrompt(turn[0].second);
|
||||
std::vector<int> input_ids = mLlm->tokenizer(mLlm->mPromptLib->getLLMInput());
|
||||
mLlm->generate_init();
|
||||
auto turn_nlls = sample(input_ids, mLlm->tokenizer(turn[1].second), &(mLlm->mLlmSessionInfos[0].mTimePerformance));
|
||||
nlls.insert(nlls.end(), turn_nlls.begin(), turn_nlls.end());
|
||||
mLlm->mPromptLib->appendLLMOutput(turn[1].second);
|
||||
}
|
||||
|
||||
// record time performance to file
|
||||
if (perfOS != nullptr) {
|
||||
(*perfOS) << "<chat>" << std::endl;
|
||||
mLlm->mLlmSessionInfos[0].print_speed(perfOS);
|
||||
}
|
||||
|
||||
mLlm->reset();
|
||||
// simulated chat >
|
||||
|
||||
// calculate ppl
|
||||
for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
|
||||
ppl /= nlls.size();
|
||||
ppl = std::exp(ppl);
|
||||
|
||||
// print
|
||||
std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl;
|
||||
return ppl;
|
||||
}
|
||||
|
||||
|
||||
std::vector<float> ChatPPLMeasurer::perplexity(const std::vector<std::vector<std::vector<PromptItem>>>& prompts, std::ostream* perfOS) {
|
||||
std::vector<float> ppls;
|
||||
for (auto& prompt : prompts) {
|
||||
ppls.push_back(perplexity_one(prompt, perfOS));
|
||||
mLlm->reset();
|
||||
}
|
||||
return ppls;
|
||||
}
|
||||
|
||||
void ChatPPLMeasurer::getStats(const std::vector<std::vector<std::vector<PromptItem>>>& prompts) {
|
||||
std::ofstream total_stats("total_stats.csv");
|
||||
std::ofstream dialog_stats("dialog_stats.csv");
|
||||
float average_turns=0, average_prefill=0, average_decode=0, average_total_tokens=0;
|
||||
int max_turns=0;
|
||||
std::vector<std::vector<std::vector<int>>> stats; // (dialog, turn, (prefill, decode))
|
||||
std::cout << prompts.size() << std::endl;
|
||||
int counter = 0;
|
||||
for (auto& dialog : prompts) {
|
||||
std::vector<std::vector<int>> dialog_stats;
|
||||
if ((counter++) % std::max((int)prompts.size()/200, 1) == 0) std::cout << "*" << std::flush;
|
||||
float prefill_len_turn = 0;
|
||||
float decode_len_turn = 0;
|
||||
for (auto& turn : dialog) {
|
||||
// turn: prefill, decode
|
||||
int prefill_len = mLlm->tokenizer(turn[0].second).size();
|
||||
int decode_len = mLlm->tokenizer(turn[1].second).size();
|
||||
prefill_len_turn += prefill_len;
|
||||
decode_len_turn += decode_len;
|
||||
average_total_tokens += prefill_len + decode_len;
|
||||
dialog_stats.push_back({prefill_len, decode_len});
|
||||
}
|
||||
stats.push_back(dialog_stats);
|
||||
average_prefill += prefill_len_turn / dialog.size(); // average over turns
|
||||
average_decode += decode_len_turn / dialog.size(); // average over turns
|
||||
average_turns += dialog.size();
|
||||
max_turns = std::max(max_turns, (int)dialog.size());
|
||||
}
|
||||
average_turns /= prompts.size();
|
||||
average_prefill /= prompts.size();
|
||||
average_decode /= prompts.size();
|
||||
average_total_tokens /= prompts.size();
|
||||
total_stats << "total_dialogs," << "max_turns," << "avg_turns," \
|
||||
<< "avg_prefill_tokens/turn," << "avg_decode_tokens/turn," \
|
||||
<< "avg_total_tokens/dialog" << std::endl;
|
||||
total_stats << prompts.size() << "," << max_turns << "," << average_turns << "," \
|
||||
<< average_prefill << "," << average_decode << "," \
|
||||
<< average_total_tokens << std::endl;
|
||||
for (int i=0; i<max_turns; ++i) dialog_stats << "prefill" << i << "," << "decode" << i << ","; // this creates an extra blank column at the end.
|
||||
dialog_stats << std::endl;
|
||||
for (auto& dialog : stats) {
|
||||
for (auto& turn : dialog){
|
||||
dialog_stats << turn[0] << "," << turn[1] << ",";
|
||||
}
|
||||
for (int i=dialog.size(); i<max_turns; ++i) {
|
||||
dialog_stats << ",,";
|
||||
}
|
||||
dialog_stats << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<float> ChatPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) {
|
||||
// No performance will be printed!
|
||||
std::vector<std::vector<std::vector<PromptItem>>> prompts;
|
||||
if (mDatasetType == "shareGPT") {
|
||||
prompts = shareGPT(prompt_file, mDatasetSampleSize);
|
||||
}
|
||||
else {
|
||||
MNN_ERROR("Dataset not suppoted");
|
||||
exit(1);
|
||||
}
|
||||
std::cout << "prompt file loaded!" << std::endl;
|
||||
getStats(prompts);
|
||||
std::cout << "\nshareGPT statistics counted!" << std::endl;
|
||||
return perplexity(prompts, perfOS);
|
||||
}
|
||||
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
|
@ -0,0 +1,65 @@
|
|||
#ifndef PERPLEXITY_hpp
|
||||
#define PERPLEXITY_hpp
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <streambuf>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include <MNN/expr/Expr.hpp>
|
||||
#include <MNN/expr/Module.hpp>
|
||||
#include <MNN/expr/MathOp.hpp>
|
||||
#include <MNN/expr/NeuralNetWorkOp.hpp>
|
||||
|
||||
#include "sampler.hpp"
|
||||
#include "evaluation/dataset.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace Transformer {
|
||||
class Llm;
|
||||
|
||||
class MNN_PUBLIC TextPPLMeasurer : public Sampler {
|
||||
protected:
|
||||
Llm* mLlm;
|
||||
int mStride;
|
||||
std::string mDatasetType;
|
||||
LlmSamplerConfig mConfig;
|
||||
public:
|
||||
TextPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> config);
|
||||
float perplexity_one(const std::vector<int>& prompt);
|
||||
std::vector<float> perplexity(std::vector<std::vector<int>> prompts);
|
||||
std::vector<float> perplexity(std::vector<std::string> prompts);
|
||||
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; }
|
||||
virtual std::vector<float> perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override;
|
||||
};
|
||||
|
||||
class MNN_PUBLIC ChatPPLMeasurer : public Sampler {
|
||||
protected:
|
||||
Llm* mLlm;
|
||||
std::string mDatasetType;
|
||||
int mDatasetSampleSize;
|
||||
LlmSamplerConfig mConfig;
|
||||
void handleToken(int token);
|
||||
std::vector<float> sample(const std::vector<int>& input_ids, const std::vector<int>& prompt, struct TimePerformance* time_perf);
|
||||
public:
|
||||
ChatPPLMeasurer(Llm* llm, std::shared_ptr<LlmConfig> config);
|
||||
void getStats(const std::vector<std::vector<std::vector<PromptItem>>>& prompts);
|
||||
float perplexity_one(const std::vector<std::vector<PromptItem>>& prompt, std::ostream* perfOS);
|
||||
std::vector<float> perplexity(const std::vector<std::vector<std::vector<PromptItem>>>& prompts, std::ostream* perfOS);
|
||||
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; }
|
||||
virtual std::vector<float> perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // Transformer
|
||||
} // MNN
|
||||
|
||||
|
||||
#endif // SAMPLER_hpp
|
|
@ -8,7 +8,7 @@ PromptLib* PromptLib::createPromptLib(Llm* llm, const std::string& config_path)
|
|||
return createPromptLib(llm, std::shared_ptr<LlmConfig>(new LlmConfig(config_path)));
|
||||
}
|
||||
PromptLib* PromptLib::createPromptLib(Llm* llm, std::shared_ptr<LlmConfig> config) {
|
||||
if (config->app_type() == "chat") {
|
||||
if (config->app_type() == "chat" || config->app_type() == "perplexity") {
|
||||
return new BaseChatPromptLib(llm, config);
|
||||
} else {
|
||||
std::cout << "PromptLib not Implemented!\n" << std::endl;
|
||||
|
|
|
@ -6,8 +6,11 @@
|
|||
|
||||
#include <MNN/expr/Executor.hpp>
|
||||
#include <MNN/expr/ExecutorScope.hpp>
|
||||
|
||||
#include "llm/llm.hpp"
|
||||
#include "evaluation/dataset.hpp"
|
||||
#include "sampler.hpp"
|
||||
#include "perplexity.hpp"
|
||||
#include "llmconfig.hpp"
|
||||
|
||||
namespace MNN{
|
||||
|
@ -107,6 +110,10 @@ Sampler* Sampler::createSampler(Llm* llm, std::shared_ptr<LlmConfig> config) {
|
|||
|| sampler_type == "mixed"
|
||||
) {
|
||||
return new LocalSampler(llm, config);
|
||||
} else if (config->app_type() == "perplexity") {
|
||||
std::string ppl_type = getPPLType(config->dataset());
|
||||
if (ppl_type == "text") { return new TextPPLMeasurer(llm, config); }
|
||||
else if (ppl_type == "chat") { return new ChatPPLMeasurer(llm, config); }
|
||||
} else {
|
||||
std::cout << "Designated Sampler Not Supported yet!";
|
||||
exit(1);
|
||||
|
@ -489,6 +496,7 @@ std::string LocalSampler::sample(const std::vector<int>& input_ids, std::ostream
|
|||
PrefillTimePerformance prefill_time;
|
||||
prefill_time.prefill_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
|
||||
prefill_time.prefill_token_ = input_ids.size();
|
||||
appendNewPromptRecord(time_perf, input_ids.size(), mLlm->reuse_kv());
|
||||
// initialization
|
||||
std::string output_str;
|
||||
mLlm->mLlmSessionInfos[0].tokens.insert(mLlm->mLlmSessionInfos[0].tokens.end(), input_ids.begin(), input_ids.end());
|
||||
|
|
|
@ -73,6 +73,7 @@ public:
|
|||
static Sampler* createSampler(Llm* llm, const std::string& config_path);
|
||||
static Sampler* createSampler(Llm* llm, std::shared_ptr<LlmConfig> config);
|
||||
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) = 0;
|
||||
virtual std::vector<float> perplexity(std::string prompt_file, std::ostream* perfOS) { return std::vector<float>(); }
|
||||
// prepare for another round of sampling
|
||||
// in the future, only reset its own.
|
||||
virtual void reset(Llm* llm) { mLlm = llm; }
|
||||
|
|
|
@ -475,7 +475,7 @@ void Tiktoken::encode(const std::string& str, std::vector<int>& ids) {
|
|||
} else {
|
||||
// If no matching symbol is found, this typically means an error in the encoding
|
||||
// or the input text contains characters that the encoder doesn't know how to handle
|
||||
std::cerr << "Error: No encoding found for the sequence starting at position " << i << std::endl;
|
||||
std::cerr << "Error: No encoding found for the sequence starting at position " << i << " , symbol: " << str[i-2] << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue