MNN/source/backend/cpu/CPURandomUniform.cpp

84 lines
2.8 KiB
C++
Raw Normal View History

2020-11-05 16:41:56 +08:00
//
// CPURandomUniform.cpp
// MNN
//
// Created by MNN on 2020/8/14.
// Copyright © 2018, Alibaba Group Holding Limited
//
2021-04-28 18:02:10 +08:00
#include <random>
2020-11-05 16:41:56 +08:00
#include "backend/cpu/CPURandomUniform.hpp"
#include "core/Macro.h"
#include "backend/cpu/CPUBackend.hpp"
namespace MNN {
ErrorCode CPURandomUniform::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return NO_ERROR;
}
ErrorCode CPURandomUniform::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
MNN_ASSERT(outputs.size() == 1);
auto output = outputs[0];
int size = output->elementSize();
auto parameter = mOp->main_as_RandomUniform();
2021-04-28 18:02:10 +08:00
auto outputPtr = output->host<float>();
std::uniform_real_distribution<float> distribution(parameter->low(),parameter->high());
2020-11-05 16:41:56 +08:00
int seed = parameter->seed();
int seed1 = parameter->seed2();
if (seed || seed1) {
2021-04-28 18:02:10 +08:00
std::mt19937 generator(seed || seed1);
for (int i = 0; i < size; i++) {
outputPtr[i] = distribution(generator);
}
2020-11-05 16:41:56 +08:00
} else {
2021-04-28 18:02:10 +08:00
std::default_random_engine generator;
for (int i = 0; i < size; i++) {
outputPtr[i] = distribution(generator);
}
2020-11-05 16:41:56 +08:00
}
return NO_ERROR;
}
2022-01-04 10:50:40 +08:00
ErrorCode CPURandomNormal::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return NO_ERROR;
}
ErrorCode CPURandomNormal::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
MNN_ASSERT(outputs.size() == 1);
auto output = outputs[0];
int size = output->elementSize();
auto parameter = mOp->main_as_RandomUniform();
auto outputPtr = output->host<float>();
// RandomUniform and RandomNormal use same param table. low -> mean, high -> scale
std::normal_distribution<float> distribution(parameter->low(),parameter->high());
int seed = parameter->seed();
int seed1 = parameter->seed2();
if (seed || seed1) {
std::mt19937 generator(seed || seed1);
for (int i = 0; i < size; i++) {
outputPtr[i] = distribution(generator);
}
} else {
std::default_random_engine generator;
for (int i = 0; i < size; i++) {
outputPtr[i] = distribution(generator);
}
}
return NO_ERROR;
}
class CPURandomCreator : public CPUBackend::Creator {
2020-11-05 16:41:56 +08:00
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
2022-01-04 10:50:40 +08:00
if (op->type() == OpType_RandomUniform) {
return new CPURandomUniform(backend, op);
} else {
return new CPURandomNormal(backend, op);
}
2020-11-05 16:41:56 +08:00
}
};
2022-01-04 10:50:40 +08:00
REGISTER_CPU_OP_CREATOR(CPURandomCreator, OpType_RandomUniform);
REGISTER_CPU_OP_CREATOR(CPURandomCreator, OpType_RandomNormal);
2020-11-05 16:41:56 +08:00
} // namespace MNN