mirror of https://github.com/alibaba/MNN.git
340 lines
9.6 KiB
C++
340 lines
9.6 KiB
C++
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <utility>
|
|
#include "tokenizer.hpp"
|
|
#include "core/Macro.h"
|
|
#include <sstream>
|
|
#include <functional>
|
|
#include <codecvt>
|
|
#include <regex>
|
|
#include <set>
|
|
#include "rapidjson/document.h"
|
|
#include "rapidjson/filereadstream.h"
|
|
#include "rapidjson/error/en.h"
|
|
namespace diffusion {
|
|
|
|
bool BertTokenizer::load(const std::string& dictPath) {
|
|
std::ifstream dictFile(dictPath + "/vocab.txt");
|
|
if(dictFile.is_open()) {
|
|
MNN_ERROR("tokenize load error, vocab.txt not found in %s\n", dictPath.c_str());
|
|
return false;
|
|
}
|
|
int index = 0;
|
|
std::string word;
|
|
while (dictFile >> word) {
|
|
mVocabs.insert(std::make_pair<std::string, int>(std::move(word), index++));
|
|
}
|
|
mStartIdx = this->word_piece("[CLS]")[0];
|
|
mEndIdx = this->word_piece("[SEP]")[0];
|
|
return true;
|
|
}
|
|
|
|
std::vector<int> BertTokenizer::word_piece(const std::string& token) {
|
|
auto it = mVocabs.find(token);
|
|
if (it != mVocabs.end()) {
|
|
return {it->second};
|
|
}
|
|
std::vector<int> ids;
|
|
std::string current = token;
|
|
while (!current.empty()) {
|
|
int match_id = -1;
|
|
size_t match_pos = 0;
|
|
for (int len = current.size(); len > 0; --len) {
|
|
std::string candidate = current.substr(0, len);
|
|
if (!ids.empty()) {
|
|
candidate = "##" + candidate;
|
|
}
|
|
auto it = mVocabs.find(candidate);
|
|
if (it != mVocabs.end()) {
|
|
match_id = it->second;
|
|
match_pos = len;
|
|
break;
|
|
}
|
|
}
|
|
// [UNK]
|
|
if (match_id == -1) {
|
|
ids.push_back(100);
|
|
break;
|
|
}
|
|
ids.push_back(match_id);
|
|
// not first word, adding ## prefix
|
|
current = current.substr(match_pos);
|
|
}
|
|
return ids;
|
|
}
|
|
|
|
|
|
std::vector<int> BertTokenizer::encode(const std::string& str, int maxlen) {
|
|
std::vector<int> ids(maxlen * 2, 0);
|
|
// uncond
|
|
ids[0] = mStartIdx;
|
|
ids[1] = mEndIdx;
|
|
// ids
|
|
int idx = maxlen;
|
|
ids[idx++] = mStartIdx;
|
|
|
|
std::vector<std::string> tokens;
|
|
std::string current_token;
|
|
size_t i = 0;
|
|
while (i < str.size()) {
|
|
current_token.clear();
|
|
unsigned char c = static_cast<unsigned char>(str[i]);
|
|
// handle multi-byte UTF-8 characters
|
|
if ((c & 0x80) != 0) {
|
|
unsigned char mask = 0xE0; // 1110 0000 for 3-byte char
|
|
if ((c & mask) == mask) {
|
|
current_token = str.substr(i, 3);
|
|
i += 3;
|
|
} else {
|
|
++i;
|
|
continue;
|
|
}
|
|
}
|
|
// handle continuous sequence of letters and digits
|
|
else if (std::isalnum(c)) {
|
|
while (i < str.size() && std::isalnum(static_cast<unsigned char>(str[i]))) {
|
|
current_token += std::tolower(str[i]);
|
|
++i;
|
|
}
|
|
}
|
|
// handle punctuation and symbols
|
|
else if (std::ispunct(c)) {
|
|
current_token = str[i];
|
|
++i;
|
|
}
|
|
// handle space, tab, enter
|
|
else if (std::isspace(c)) {
|
|
++i;
|
|
continue;
|
|
}
|
|
// handle any other single-byte characters
|
|
else {
|
|
current_token = str[i];
|
|
++i;
|
|
}
|
|
if (!current_token.empty()) {
|
|
tokens.push_back(current_token);
|
|
}
|
|
}
|
|
|
|
for (auto token : tokens) {
|
|
for (auto id : word_piece(token)) {
|
|
ids[idx++] = id;
|
|
}
|
|
}
|
|
|
|
ids[idx++] = mEndIdx;
|
|
return ids;
|
|
}
|
|
|
|
bool CLIPTokenizer::load(const std::string& filePath) {
|
|
bool res_0 = loadVocab(filePath + "/vocab.json");
|
|
bool res_1 = loadMerges(filePath + "/merges.txt");
|
|
return res_0 && res_1;
|
|
}
|
|
|
|
std::wstring utf8_to_wstring(const std::string& str) {
|
|
std::wstring_convert<std::codecvt_utf8<wchar_t>> myconv;
|
|
return myconv.from_bytes(str);
|
|
}
|
|
|
|
std::string wstring_to_utf8(const std::wstring& str) {
|
|
std::wstring_convert<std::codecvt_utf8<wchar_t>> myconv;
|
|
return myconv.to_bytes(str);
|
|
}
|
|
|
|
bool CLIPTokenizer::loadVocab(const std::string& vocabFilePath) {
|
|
FILE* fp = fopen(vocabFilePath.c_str(), "rb");
|
|
if (!fp) {
|
|
MNN_ERROR("File %s open failed.\n", vocabFilePath.c_str());
|
|
return false;
|
|
}
|
|
|
|
char readBuffer[65536];
|
|
rapidjson::FileReadStream is(fp, readBuffer, sizeof(readBuffer));
|
|
|
|
rapidjson::Document doc;
|
|
doc.ParseStream(is);
|
|
fclose(fp);
|
|
|
|
if (doc.HasParseError()) {
|
|
MNN_ERROR("JSON parse error: %s\n", rapidjson::GetParseError_En(doc.GetParseError()));
|
|
return false;
|
|
}
|
|
|
|
if (!doc.IsObject()) {
|
|
MNN_ERROR("JSON is not an object.\n");
|
|
return false;
|
|
}
|
|
|
|
for (rapidjson::Value::ConstMemberIterator itr = doc.MemberBegin(); itr != doc.MemberEnd(); ++itr) {
|
|
const char* key = itr->name.GetString();
|
|
|
|
if (itr->value.IsInt()) {
|
|
int intValue = itr->value.GetInt();
|
|
mVocabs[std::string(key)] = intValue;
|
|
// std::cout << key << ": " << intValue << std::endl;
|
|
} else {
|
|
MNN_ERROR("Value for key: %s is not an integer.\n", key);
|
|
}
|
|
}
|
|
|
|
auto _insert_range = [=](int start, int end) {
|
|
for (int c = start; c <= end; c++) {
|
|
b2u_.insert({uint8_t(c), wchar_t(c)});
|
|
}
|
|
};
|
|
|
|
b2u_.clear();
|
|
_insert_range(L'!', L'~');
|
|
_insert_range(L'¡', L'¬');
|
|
_insert_range(L'®', L'ÿ');
|
|
|
|
int n = 0;
|
|
for (int b = 0; b < 256; b++) {
|
|
if (b2u_.find(uint8_t(b)) == b2u_.end()) {
|
|
b2u_.insert({uint8_t(b), wchar_t(256 + n)});
|
|
n++;
|
|
}
|
|
}
|
|
for (auto e : b2u_) {
|
|
u2b_.insert({e.second, e.first});
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool CLIPTokenizer::loadMerges(const std::string& mergesFilePath) {
|
|
std::ifstream file(mergesFilePath);
|
|
std::string line;
|
|
|
|
if (!file.is_open()) {
|
|
MNN_ERROR("Failed to open merges file: %s\n", mergesFilePath.c_str());
|
|
return false;
|
|
}
|
|
int count = 0;
|
|
while (getline(file, line)) {
|
|
if (line.empty() || line[0] == '#') {
|
|
continue;
|
|
}
|
|
|
|
size_t colonPos = line.find(' ');
|
|
if (colonPos != std::string::npos) {
|
|
std::string token_0 = line.substr(0, colonPos);
|
|
std::string token_1 = line.substr(colonPos + 1);
|
|
bpe_ranks_[std::make_pair(utf8_to_wstring(token_0), utf8_to_wstring(token_1))] = count++;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void get_pairs(const std::wstring& word, std::vector<std::pair<std::wstring, std::wstring>>* pairs) {
|
|
pairs->clear();
|
|
|
|
if (word.size() < 2) return;
|
|
|
|
wchar_t previous = word[0];
|
|
for (int i = 1; i < word.size(); i++) {
|
|
pairs->push_back({std::wstring(1, previous), std::wstring(1, word[i])});
|
|
previous = word[i];
|
|
}
|
|
}
|
|
|
|
void CLIPTokenizer::bpe(const std::wstring& token, const BPERanks& bpe_ranks, std::vector<std::wstring>* result) {
|
|
std::set<int> merged; // records indices in pairs that were merged.
|
|
auto _left = [](int i, std::set<int>& merged) {
|
|
for (int j = i - 1; j >= -1; j--) {
|
|
if (merged.find(j) == merged.end()) return j;
|
|
}
|
|
return -1;
|
|
};
|
|
auto _right = [](int i, int cap, std::set<int>& merged) {
|
|
for (int j = i + 1; j < cap; j++) {
|
|
if (merged.find(j) == merged.end()) return j;
|
|
}
|
|
return cap;
|
|
};
|
|
|
|
std::vector<std::pair<std::wstring, std::wstring>> pairs;
|
|
get_pairs(token, &pairs);
|
|
|
|
while (true) {
|
|
int min_score = INT_MAX;
|
|
int to_merge = -1; // indices into pairs.
|
|
|
|
for (int i = 0; i < pairs.size(); ++i) {
|
|
if (merged.find(i) == merged.end()) { // pair i is not merged.
|
|
auto iter = bpe_ranks.find(pairs[i]);
|
|
int score = iter != bpe_ranks.end() ? iter->second : INT_MAX;
|
|
if (score < min_score) {
|
|
min_score = score;
|
|
to_merge = i;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (to_merge == -1) break;
|
|
|
|
merged.insert(to_merge);
|
|
std::wstring merge_into = pairs[to_merge].first + pairs[to_merge].second;
|
|
|
|
int l = _left(to_merge, merged);
|
|
if (l >= 0) pairs[l].second = merge_into;
|
|
int r = _right(to_merge, pairs.size(), merged);
|
|
if (r < pairs.size()) pairs[r].first = merge_into;
|
|
} // end while (true)
|
|
|
|
if (merged.size() == pairs.size()) {
|
|
result->push_back(token);
|
|
|
|
} else {
|
|
for (int i = 0; i < pairs.size(); ++i) {
|
|
if (merged.find(i) == merged.end()) {
|
|
if (_left(i, merged) < 0) result->push_back(pairs[i].first);
|
|
result->push_back(pairs[i].second);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<int> CLIPTokenizer::encode(const std::string& text, int maxlen) {
|
|
|
|
std::regex re("('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\\s\\w]+|\\s+)");
|
|
std::string input = text;
|
|
std::vector<std::string> result;
|
|
std::string token;
|
|
std::smatch match;
|
|
while (std::regex_search(input, match, re)) {
|
|
token = match.str(0);
|
|
input = match.suffix().str();
|
|
|
|
std::wstring wtoken;
|
|
for (char c : token) {
|
|
wtoken.push_back(b2u_.at(uint8_t(c)));
|
|
}
|
|
std::vector<std::wstring> bpe_tokens;
|
|
bpe(wtoken, bpe_ranks_, &bpe_tokens);
|
|
|
|
for (auto ws : bpe_tokens) {
|
|
result.push_back(wstring_to_utf8(ws));
|
|
}
|
|
}
|
|
|
|
std::vector<int> ids(maxlen * 2, 0);
|
|
// uncond
|
|
ids[0] = mStartIdx;
|
|
ids[1] = mEndIdx;
|
|
// ids
|
|
int idx = maxlen;
|
|
ids[idx++] = mStartIdx;
|
|
for (auto s : result) {
|
|
ids[idx++] = mVocabs.at(s);
|
|
}
|
|
ids[idx++] = mEndIdx;
|
|
|
|
return ids;
|
|
}
|
|
|
|
} // diffusion
|