mirror of https://github.com/alibaba/MNN.git
132 lines
3.5 KiB
C++
132 lines
3.5 KiB
C++
//
|
|
// Execution.hpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2018/07/06.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#ifndef Execution_hpp
|
|
#define Execution_hpp
|
|
|
|
#include <MNN/MNNForwardType.h>
|
|
#include <MNN/ErrorCode.hpp>
|
|
#include <MNN/Tensor.hpp>
|
|
#include <memory>
|
|
#include <string>
|
|
#include "NonCopyable.hpp"
|
|
|
|
namespace MNN {
|
|
class Backend;
|
|
struct Op;
|
|
|
|
/** abstract execution */
|
|
class Execution : public NonCopyable {
|
|
public:
|
|
/**
|
|
* @brief initializer.
|
|
* @param backend backend that exection will running on.
|
|
*/
|
|
Execution() = delete;
|
|
Execution(Backend *backend) : mBackEnd(backend) {
|
|
// nothing to do
|
|
}
|
|
/**
|
|
* @brief deinitializer.
|
|
*/
|
|
virtual ~Execution() = default;
|
|
|
|
/**
|
|
* @brief response shape change of input or output tensors.
|
|
* @param inputs input tensors
|
|
* @param outputs output tensors
|
|
* @return resize result
|
|
*/
|
|
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
return NO_ERROR;
|
|
}
|
|
|
|
/**
|
|
* @brief perform execution.
|
|
* @param inputs input tensors
|
|
* @param outputs output tensors
|
|
* @return execution result
|
|
*/
|
|
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) = 0;
|
|
|
|
/**
|
|
* @brief clone execution, new execution will share weight from this execution
|
|
* @param bn the cloned' execution's backend
|
|
* @param dst if dst = nullptr, just return whether execution can clone, otherwise clone the execution into dst
|
|
* @return execution result
|
|
*/
|
|
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) {
|
|
return false;
|
|
}
|
|
public:
|
|
/**
|
|
* @brief designed for plugin system. not ready yet.
|
|
*/
|
|
class Creator : public NonCopyable {
|
|
public:
|
|
/**
|
|
* @brief deinitializer.
|
|
*/
|
|
virtual ~Creator() = default;
|
|
/**
|
|
* @brief create execution for given op on given backend.
|
|
* @param backend given backend.
|
|
* @param op given op.
|
|
* @return execution.
|
|
*/
|
|
virtual Execution *onCreate(Backend *backend, const Op *op) const = 0;
|
|
};
|
|
|
|
// Search for extra creator, if not found, return nullptr
|
|
MNN_PUBLIC static const Creator *searchExtraCreator(const std::string &key, MNNForwardType type);
|
|
|
|
/**
|
|
* @brief register creator for given key and backend type.
|
|
* @param creator registering creator.
|
|
* @param key given key.
|
|
* @param type given backend type.
|
|
* @return false if registered creator for same key and type exists, true otherwise.
|
|
*/
|
|
MNN_PUBLIC static bool insertExtraCreator(std::shared_ptr<Creator> creator, const std::string &key,
|
|
MNNForwardType type);
|
|
|
|
/**
|
|
* @brief unregister creator for given key and backend type.
|
|
* @param key given key.
|
|
* @param type given backend type.
|
|
* @return true if registered creator for given key and type exists, false otherwise.
|
|
*/
|
|
MNN_PUBLIC static bool removeExtraCreator(const std::string &key, MNNForwardType type);
|
|
|
|
public:
|
|
/**
|
|
* @brief check if execution is valid.
|
|
* @return valid or not.
|
|
*/
|
|
inline bool valid() const {
|
|
return mValid;
|
|
}
|
|
/**
|
|
* @brief get backend.
|
|
* @return backend.
|
|
*/
|
|
Backend *backend() const {
|
|
return mBackEnd;
|
|
}
|
|
|
|
protected:
|
|
bool mValid = true;
|
|
|
|
private:
|
|
Backend *mBackEnd;
|
|
};
|
|
|
|
} // namespace MNN
|
|
|
|
#endif /* Execution_hpp */
|