2025-05-08 12:39:44 +08:00
|
|
|
//
|
|
|
|
// embedding.cpp
|
|
|
|
//
|
|
|
|
// Created by MNN on 2025/04/08.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
|
|
|
#include "llm/llm.hpp"
|
|
|
|
#include "llmconfig.hpp"
|
|
|
|
#include "prompt.hpp"
|
|
|
|
#include "tokenizer.hpp"
|
|
|
|
#include "diskembedding.hpp"
|
|
|
|
|
|
|
|
namespace MNN {
|
|
|
|
using namespace Express;
|
|
|
|
namespace Transformer {
|
|
|
|
|
|
|
|
float Embedding::dist(VARP var0, VARP var1) {
|
|
|
|
auto distVar = _Sqrt(_ReduceSum(_Square(var0 - var1)));
|
|
|
|
auto dist = distVar->readMap<float>()[0];
|
|
|
|
return dist;
|
|
|
|
}
|
|
|
|
|
2025-07-29 13:38:28 +08:00
|
|
|
float Embedding::cos_sim(VARP var0, VARP var1) {
|
|
|
|
auto innerProd = _ReduceSum(_Multiply(var0, var1))->readMap<float>()[0];
|
|
|
|
auto len0 = _Sqrt(_ReduceSum(_Square(var0)))->readMap<float>()[0];
|
|
|
|
auto len1 = _Sqrt(_ReduceSum(_Square(var1)))->readMap<float>()[0];
|
|
|
|
auto sim = innerProd / (len0 * len1);
|
|
|
|
return sim;
|
|
|
|
}
|
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
Embedding* Embedding::createEmbedding(const std::string& config_path, bool load) {
|
|
|
|
std::shared_ptr<LlmConfig> config(new LlmConfig(config_path));
|
|
|
|
Embedding* embedding = new Embedding(config);
|
|
|
|
if (load) {
|
|
|
|
embedding->load();
|
|
|
|
}
|
|
|
|
return embedding;
|
|
|
|
}
|
|
|
|
|
|
|
|
Embedding::Embedding(std::shared_ptr<LlmConfig> config) : Llm(config) {
|
|
|
|
}
|
|
|
|
|
|
|
|
int Embedding::dim() const {
|
|
|
|
return mConfig->hidden_size();
|
|
|
|
}
|
|
|
|
|
2025-09-05 17:37:34 +08:00
|
|
|
bool Embedding::load() {
|
2025-05-08 12:39:44 +08:00
|
|
|
initRuntime();
|
|
|
|
printf("load tokenizer\n");
|
|
|
|
std::cout << mConfig->tokenizer_file() << std::endl;
|
|
|
|
// 1. load vocab
|
|
|
|
mTokenizer.reset(Tokenizer::createTokenizer(mConfig->tokenizer_file()));
|
|
|
|
printf("load tokenizer Done\n");
|
|
|
|
mDiskEmbedding.reset(new DiskEmbedding(mConfig));
|
|
|
|
// 2. load model
|
|
|
|
Module::Config module_config;
|
|
|
|
module_config.shapeMutable = true;
|
|
|
|
module_config.rearrange = true;
|
|
|
|
auto model_path = mConfig->llm_model();
|
|
|
|
MNN_PRINT("load %s ... ", model_path.c_str());
|
2025-09-05 17:37:34 +08:00
|
|
|
mModule.reset(Module::load({"input_ids", "attention_mask", "position_ids"}, {"sentence_embeddings"},
|
2025-05-08 12:39:44 +08:00
|
|
|
model_path.c_str(), mRuntimeManager, &module_config));
|
2025-09-05 17:37:34 +08:00
|
|
|
if (nullptr == mModule.get()) {
|
|
|
|
return false;
|
|
|
|
}
|
2025-05-08 12:39:44 +08:00
|
|
|
MNN_PRINT("Done!\n");
|
2025-09-05 17:37:34 +08:00
|
|
|
return true;
|
2025-05-08 12:39:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
VARP Embedding::ids_embedding(const std::vector<int>& ids) {
|
|
|
|
int prompt_len = ids.size();
|
|
|
|
auto inputs_ids = embedding(ids);
|
|
|
|
auto attention_mask = gen_attention_mask(prompt_len);
|
|
|
|
auto position_ids = gen_position_ids(prompt_len);
|
2025-09-05 17:37:34 +08:00
|
|
|
auto outputs = mModule->onForward({inputs_ids, attention_mask, position_ids});
|
2025-05-08 12:39:44 +08:00
|
|
|
auto sentence_embeddings = outputs[0];
|
|
|
|
return sentence_embeddings;
|
|
|
|
}
|
|
|
|
|
2025-08-08 12:24:23 +08:00
|
|
|
VARP Embedding::txt_embedding(const std::string& txt) {
|
|
|
|
return ids_embedding(tokenizer_encode(txt));
|
2025-05-08 12:39:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
VARP Embedding::gen_attention_mask(int seq_len) {
|
2025-07-23 14:10:58 +08:00
|
|
|
auto attention_mask = _Input({1, 1, seq_len, seq_len}, NCHW, halide_type_of<float>());
|
|
|
|
auto ptr = attention_mask->writeMap<float>();
|
|
|
|
if (mConfig->attention_mask() == "float") {
|
|
|
|
for (int i = 0; i < seq_len; i++) {
|
|
|
|
for (int j = 0; j < seq_len; j++) {
|
|
|
|
ptr[seq_len * i + j] = (j > i) * std::numeric_limits<float>::lowest();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for (int i = 0; i < seq_len; i++) {
|
|
|
|
for (int j = 0; j < seq_len; j++) {
|
|
|
|
ptr[seq_len * i + j] = 1.0;
|
|
|
|
}
|
|
|
|
}
|
2025-05-08 12:39:44 +08:00
|
|
|
}
|
|
|
|
return attention_mask;
|
|
|
|
}
|
|
|
|
|
|
|
|
VARP Embedding::gen_position_ids(int seq_len) {
|
|
|
|
auto position_ids = _Input({1, seq_len}, NCHW, halide_type_of<int>());
|
|
|
|
auto ptr = position_ids->writeMap<int>();
|
|
|
|
for (int i = 0; i < seq_len; i++) {
|
|
|
|
ptr[i] = i;
|
|
|
|
}
|
|
|
|
return position_ids;
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2025-08-08 12:24:23 +08:00
|
|
|
}
|