mirror of https://github.com/alibaba/MNN.git
325 lines
10 KiB
C++
325 lines
10 KiB
C++
#include <random>
|
|
#include <fstream>
|
|
#include <chrono>
|
|
#include "pipeline.hpp"
|
|
#include "tokenizer.hpp"
|
|
#define MNN_OPEN_TIME_TRACE
|
|
#include <MNN/AutoTime.hpp>
|
|
#include <cv/cv.hpp>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
#include <MNN/expr/ExecutorScope.hpp>
|
|
|
|
#if defined(_MSC_VER)
|
|
#include <Windows.h>
|
|
#undef min
|
|
#undef max
|
|
#else
|
|
#include <sys/time.h>
|
|
#endif
|
|
|
|
//#define TEXT_MAX_LEN 512 // Taiyi_SD
|
|
#define TEXT_MAX_LEN 77 // SD_1.5
|
|
|
|
//#define MNN_DUMP_DATA
|
|
|
|
using namespace CV;
|
|
|
|
namespace diffusion {
|
|
|
|
static inline int64_t getTime() {
|
|
uint64_t time;
|
|
#if defined(_MSC_VER)
|
|
LARGE_INTEGER now, freq;
|
|
QueryPerformanceCounter(&now);
|
|
QueryPerformanceFrequency(&freq);
|
|
uint64_t sec = now.QuadPart / freq.QuadPart;
|
|
uint64_t usec = (now.QuadPart % freq.QuadPart) * 1000000 / freq.QuadPart;
|
|
time = sec * 1000000 + usec;
|
|
#else
|
|
struct timeval tv;
|
|
gettimeofday(&tv, nullptr);
|
|
time = static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
|
|
#endif
|
|
return time;
|
|
}
|
|
|
|
void display_progress(int cur, int total){
|
|
putchar('\r');
|
|
printf("[");
|
|
for (int i = 0; i < cur; i++) putchar('#');
|
|
for (int i = 0; i < total - cur; i++) putchar('-');
|
|
printf("]");
|
|
fprintf(stdout, " [%3d%%]", cur * 100 / total);
|
|
if (cur == total) putchar('\n');
|
|
fflush(stdout);
|
|
}
|
|
|
|
Pipeline::Pipeline(std::string modelPath) : mModelPath(modelPath) {
|
|
std::ifstream alphaFile(modelPath + "/alphas.txt");
|
|
int index = 0;
|
|
float alpha;
|
|
while (alphaFile >> alpha) {
|
|
mAlphas.push_back(alpha);
|
|
}
|
|
#if 0 // Different steps setting
|
|
mTimeSteps = { // 50 steps
|
|
981, 961, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741,
|
|
721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461,
|
|
441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181,
|
|
161, 141, 121, 101, 81, 61, 41, 21, 1
|
|
};
|
|
mTimeSteps = { // 20 steps
|
|
951, 901, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451,
|
|
401, 351, 301, 251, 201, 151, 101, 51, 1
|
|
};
|
|
#endif
|
|
mTimeSteps = {
|
|
801, 601, 601,
|
|
401, 201, 1
|
|
};
|
|
}
|
|
|
|
bool Pipeline::load_modules(std::string modelPath) {
|
|
AUTOTIME;
|
|
|
|
ScheduleConfig config;
|
|
BackendConfig backendConfig;
|
|
// config.type = MNN_FORWARD_CPU;
|
|
config.type = MNN_FORWARD_OPENCL;
|
|
config.numThread = 65;
|
|
backendConfig.precision = BackendConfig::Precision_Normal;
|
|
backendConfig.memory = BackendConfig::Memory_Low;
|
|
config.backendConfig = &backendConfig;
|
|
|
|
|
|
Module::Config module_config;
|
|
module_config.shapeMutable = false;
|
|
// module_config.rearrange = true;
|
|
runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config));
|
|
|
|
// if (config.type == MNN_FORWARD_OPENCL) {
|
|
// const char* cacheFileName = ".tempcache";
|
|
// runtime_manager_->setCache(cacheFileName);
|
|
// }
|
|
|
|
mModules.resize(3);
|
|
// load text_encoder model
|
|
{
|
|
std::string model_path = modelPath + "/text_encoder.mnn";
|
|
mModules[0].reset(Module::load(
|
|
{"input_ids"}, {"last_hidden_state", "pooler_output"}, model_path.c_str(), runtime_manager_, &module_config));
|
|
|
|
}
|
|
// load unet model
|
|
{
|
|
std::string model_path = modelPath + "/unet.mnn";
|
|
mModules[1].reset(Module::load(
|
|
{"sample", "timestep", "encoder_hidden_states"}, {"out_sample"}, model_path.c_str(), runtime_manager_, &module_config));
|
|
}
|
|
// load vae_decoder model
|
|
{
|
|
std::string model_path = modelPath + "/vae_decoder.mnn";
|
|
mModules[2].reset(Module::load(
|
|
{"latent_sample"}, {"sample"}, model_path.c_str(), runtime_manager_, &module_config));
|
|
}
|
|
|
|
auto exe = ExecutorScope::Current();
|
|
exe->lazyEval = false;
|
|
exe->setGlobalExecutorConfig(config.type, backendConfig, config.numThread);
|
|
|
|
return true;
|
|
}
|
|
|
|
VARP Pipeline::text_encoder(const std::vector<int>& ids) {
|
|
auto inputs_ids_ = _Const(ids.data(), {2, TEXT_MAX_LEN}, NCHW, halide_type_of<int>());
|
|
|
|
auto outputs = mModules[0]->onForward({inputs_ids_});
|
|
auto output = _Convert(outputs[0], NCHW);
|
|
output.fix(VARP::CONSTANT);
|
|
|
|
#ifdef MNN_DUMP_DATA
|
|
auto xx = output->readMap<float>();
|
|
for(int i=0; i<10; i+=2) {
|
|
printf("%f %f ", xx[i], xx[i+TEXT_MAX_LEN*768]);
|
|
}
|
|
printf("\n\n");
|
|
#endif
|
|
return output;
|
|
}
|
|
|
|
VARP Pipeline::step_plms(VARP sample, VARP model_output, int index) {
|
|
int timestep = mTimeSteps[index];
|
|
int prev_timestep = 0;
|
|
if (index + 1 < mTimeSteps.size()) {
|
|
prev_timestep = mTimeSteps[index + 1];
|
|
}
|
|
if (index != 1) {
|
|
if (mEts.size() >= 4) {
|
|
mEts[mEts.size() - 4] = nullptr;
|
|
}
|
|
mEts.push_back(model_output);
|
|
} else {
|
|
timestep = mTimeSteps[0];
|
|
prev_timestep = mTimeSteps[1];
|
|
}
|
|
int ets = mEts.size() - 1;
|
|
if (index == 0) {
|
|
mSample = sample;
|
|
} else if (index == 1) {
|
|
model_output = (model_output + mEts[ets]) * _Const(0.5);
|
|
sample = mSample;
|
|
} else if (ets == 1) {
|
|
model_output = (_Const(3.0) * mEts[ets] - mEts[ets-1]) * _Const(0.5);
|
|
} else if (ets == 2) {
|
|
model_output = (_Const(23.0) * mEts[ets] - _Const(16.0) * mEts[ets-1] + _Const(5.0) * mEts[ets-2]) * _Const(1.0 / 12.0);
|
|
} else if (ets >= 3) {
|
|
model_output = _Const(1. / 24.) * (_Const(55.0) * mEts[ets] - _Const(59.0) * mEts[ets-1] + _Const(37.0) * mEts[ets-2] - _Const(9.0) * mEts[ets-3]);
|
|
}
|
|
auto alpha_prod_t = mAlphas[timestep];
|
|
auto alpha_prod_t_prev = mAlphas[prev_timestep];
|
|
auto beta_prod_t = 1 - alpha_prod_t;
|
|
auto beta_prod_t_prev = 1 - alpha_prod_t_prev;
|
|
auto sample_coeff = std::sqrt(alpha_prod_t_prev / alpha_prod_t);
|
|
auto model_output_denom_coeff = alpha_prod_t * std::sqrt(beta_prod_t_prev) + std::sqrt(alpha_prod_t * beta_prod_t * alpha_prod_t_prev);
|
|
auto prev_sample = _Scalar(sample_coeff) * sample - _Scalar((alpha_prod_t_prev - alpha_prod_t)/model_output_denom_coeff) * model_output;
|
|
return prev_sample;
|
|
}
|
|
|
|
VARP Pipeline::unet(VARP text_embeddings) {
|
|
mModules[0].reset();
|
|
std::mt19937 rng;
|
|
rng.seed(std::random_device()());
|
|
std::normal_distribution<float> normal(0, 1);
|
|
std::vector<float> initVal(16384);
|
|
|
|
#ifdef MNN_DUMP_DATA
|
|
std::ostringstream fileName;
|
|
fileName << "random.txt";
|
|
std::ifstream input(fileName.str().c_str());
|
|
for (int i = 0; i < 16384; ++i) {
|
|
input >> initVal[i];
|
|
}
|
|
#else
|
|
for (int i = 0; i < 16384; i++) {
|
|
initVal[i] = normal(rng);
|
|
}
|
|
#endif
|
|
|
|
// VARP latentvar = _Const(initVal.data(), {1, 4, 64, 64}, NCHW);
|
|
VARP latentvar = _Input({1, 4, 64, 64}, NCHW, halide_type_of<float>());
|
|
memcpy((void *)latentvar->writeMap<int8_t>(), initVal.data(), 16384*sizeof(float));
|
|
|
|
VARP scalevar = _Input({1}, NCHW, halide_type_of<float>());
|
|
auto scaleptr = scalevar->writeMap<float>();
|
|
scaleptr[0] = 7.5;
|
|
|
|
|
|
VARP timestepvar = _Input({1}, NCHW, halide_type_of<int>());
|
|
auto floatVar = _Input({1}, NCHW, halide_type_of<float>());
|
|
auto ptr = floatVar->writeMap<float>();
|
|
|
|
for (int i = 0; i < mTimeSteps.size(); i++) {
|
|
AUTOTIME;
|
|
display_progress(i, mTimeSteps.size());
|
|
|
|
auto t0 = getTime();
|
|
|
|
int timestep = mTimeSteps[i];
|
|
|
|
ptr[0] = timestep;
|
|
auto temp = _Cast(floatVar, halide_type_of<int>());
|
|
timestepvar->input(temp);
|
|
|
|
VARP samplevar = _Concat({latentvar, latentvar}, 0);
|
|
auto outputs = mModules[1]->onForward({samplevar, timestepvar, text_embeddings});
|
|
|
|
auto output = _Convert(outputs[0], NCHW);
|
|
auto t1 = getTime();
|
|
|
|
auto noise_pred = output;
|
|
|
|
auto splitvar = _Split(noise_pred, {2}, 0);
|
|
auto noise_pred_uncond = splitvar[0];
|
|
auto noise_pred_text = splitvar[1];
|
|
|
|
noise_pred = scalevar * (noise_pred_text - noise_pred_uncond) + noise_pred_uncond;
|
|
auto t2 = getTime();
|
|
|
|
latentvar = step_plms(latentvar, noise_pred, i);
|
|
auto t3 = getTime();
|
|
MNN_PRINT("Times: %f %f %f ms\n", (t1-t0)/ 1000.0f, (t2-t1)/ 1000.0f, (t3-t2)/ 1000.0f);
|
|
// latentvar.fix(VARP::CONSTANT);
|
|
#ifdef MNN_DUMP_DATA
|
|
auto xx = output->readMap<float>();
|
|
auto yy = samplevar->readMap<float>();
|
|
auto zz = text_embeddings->readMap<float>();
|
|
|
|
for(int i=0; i<6; i+=2) {
|
|
printf("%f %f %f ", xx[i], yy[i], zz[i]);
|
|
}
|
|
for(int i=0; i<6; i+=2) {
|
|
printf("%f %f %f ", xx[16384+i], yy[16384+i], zz[TEXT_MAX_LEN*768+i]);
|
|
}
|
|
printf("\n\n");
|
|
#endif
|
|
}
|
|
latentvar.fix(VARP::CONSTANT);
|
|
|
|
#ifdef MNN_DUMP_DATA
|
|
auto xx = latentvar->readMap<float>();
|
|
for(int i=0; i<10; i+=2) {
|
|
printf("%f ", xx[i]);
|
|
}
|
|
printf("\n\n");
|
|
#endif
|
|
return latentvar;
|
|
}
|
|
|
|
VARP Pipeline::vae_decoder(VARP latent) {
|
|
mModules[1].reset();
|
|
latent = latent * _Const(1 / 0.18215);
|
|
|
|
auto outputs = mModules[2]->onForward({latent});
|
|
auto output = _Convert(outputs[0], NCHW);
|
|
|
|
#ifdef MNN_DUMP_DATA
|
|
auto xx = output->readMap<float>();
|
|
for(int i=0; i<320; i+=32) {
|
|
printf("%f ", xx[i]);
|
|
}
|
|
printf("\n\n");
|
|
#endif
|
|
|
|
auto image = output;
|
|
image = _Relu6(image * _Const(0.5) + _Const(0.5), 0, 1);
|
|
image = _Squeeze(_Transpose(image, {0, 2, 3, 1}));
|
|
image = _Cast(_Round(image * _Const(255.0)), halide_type_of<uint8_t>());
|
|
image = cvtColor(image, COLOR_BGR2RGB);
|
|
image.fix(VARP::CONSTANT);
|
|
return image;
|
|
}
|
|
|
|
bool Pipeline::run(const std::string& sentence, const std::string& img_name) {
|
|
std::unique_ptr<diffusion::Tokenizer> tok;
|
|
// tok.reset(new diffusion::BertTokenizer);
|
|
tok.reset(new diffusion::CLIPTokenizer);
|
|
tok->load(mModelPath);
|
|
load_modules(mModelPath);
|
|
|
|
AUTOTIME;
|
|
auto ids = tok->encode(sentence, TEXT_MAX_LEN);
|
|
auto text_embeddings = text_encoder(ids);
|
|
|
|
auto latent = unet(text_embeddings);
|
|
auto image = vae_decoder(latent);
|
|
bool res = imwrite(img_name, image);
|
|
if (res) {
|
|
printf("SUCCESS! write to %s\n", img_name.c_str());
|
|
}
|
|
// runtime_manager_->updateCache();
|
|
return res;
|
|
}
|
|
|
|
}
|