mirror of https://github.com/alibaba/MNN.git
38 lines
961 B
C++
38 lines
961 B
C++
#include <map>
|
|
#include <vector>
|
|
#include <MNN/Interpreter.hpp>
|
|
#include <MNN/expr/Expr.hpp>
|
|
#include <MNN/expr/MathOp.hpp>
|
|
#include <MNN/expr/NeuralNetWorkOp.hpp>
|
|
#include <MNN/expr/Module.hpp>
|
|
|
|
using namespace MNN;
|
|
using namespace MNN::Express;
|
|
|
|
namespace diffusion {
|
|
|
|
class Pipeline {
|
|
public:
|
|
Pipeline(std::string modelPath);
|
|
~Pipeline() = default;
|
|
bool run(const std::string& sentence, const std::string& img_name);
|
|
private:
|
|
bool load_modules(std::string modelPath);
|
|
VARP step_plms(VARP sample, VARP model_output, int index);
|
|
VARP text_encoder(const std::vector<int>& ids);
|
|
VARP unet(VARP text_embeddings);
|
|
VARP vae_decoder(VARP latent);
|
|
private:
|
|
std::shared_ptr<Executor::RuntimeManager> runtime_manager_;
|
|
std::vector<std::shared_ptr<Module>> mModules;
|
|
|
|
std::string mModelPath;
|
|
// step_plms
|
|
std::vector<int> mTimeSteps;
|
|
std::vector<float> mAlphas;
|
|
std::vector<VARP> mEts;
|
|
VARP mSample;
|
|
};
|
|
|
|
}
|