2019-04-17 10:49:11 +08:00
|
|
|
//
|
|
|
|
// MetalQuantizedSoftmax.mm
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2019/01/30.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2019-12-27 22:16:57 +08:00
|
|
|
#import "backend/metal/MetalQuantizedSoftmax.hpp"
|
|
|
|
#import "backend/cpu/CPUQuantizationUtils.hpp"
|
|
|
|
#import "backend/metal/MNNMetalContext.h"
|
|
|
|
#import "backend/metal/MetalBackend.hpp"
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
#if MNN_METAL_ENABLED
|
|
|
|
namespace MNN {
|
|
|
|
|
|
|
|
const int kScaledDiffIntegerBits = 5;
|
|
|
|
|
|
|
|
MetalQuantizedSoftmax::MetalQuantizedSoftmax(Backend *backend, float beta, float scale) : Execution(backend) {
|
|
|
|
PreprocessSoftmaxScaling(beta, scale, kScaledDiffIntegerBits, &mInputMultiplier, &mInputLeftShift);
|
|
|
|
mDiffMin = -1.0 * CalculateInputRadius(kScaledDiffIntegerBits, mInputLeftShift);
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode MetalQuantizedSoftmax::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
|
|
auto backend = static_cast<MetalBackend *>(this->backend());
|
|
|
|
auto context = (__bridge MNNMetalContext *)backend->context();
|
|
|
|
auto input = inputs[0];
|
|
|
|
|
|
|
|
int outer = 0, inner = 0;
|
|
|
|
if (input->dimensions() == 4) {
|
|
|
|
outer = input->length(0) * input->length(1) * input->length(2);
|
|
|
|
inner = input->length(3);
|
|
|
|
} else {
|
|
|
|
outer = input->length(0);
|
|
|
|
inner = input->length(1);
|
|
|
|
}
|
|
|
|
|
|
|
|
mConst = [context newDeviceBuffer:5 * sizeof(int) access:CPUWriteOnly];
|
|
|
|
((int *)mConst.contents)[0] = outer;
|
|
|
|
((int *)mConst.contents)[1] = inner;
|
|
|
|
((int *)mConst.contents)[2] = mDiffMin;
|
|
|
|
((int *)mConst.contents)[3] = mInputMultiplier;
|
|
|
|
((int *)mConst.contents)[4] = mInputLeftShift;
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode MetalQuantizedSoftmax::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
|
|
auto backend = static_cast<MetalBackend *>(this->backend());
|
|
|
|
auto context = (__bridge MNNMetalContext *)backend->context();
|
|
|
|
auto input = inputs[0], output = outputs[0];
|
|
|
|
|
|
|
|
auto encoder = [context encoder];
|
|
|
|
auto bandwidth = [context load:@"quantized_softmax" encoder:encoder];
|
|
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0];
|
|
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:1];
|
|
|
|
[encoder setBuffer:mConst offset:0 atIndex:2];
|
|
|
|
[context dispatchEncoder:encoder threads:{ (NSUInteger) output->size(), 1, 1 } bandwidth:bandwidth];
|
|
|
|
[encoder endEncoding];
|
|
|
|
MNN_PRINT_ENCODER(context, encoder);
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
class MetalQuantizedSoftmaxCreator : public MetalBackend::Creator {
|
|
|
|
public:
|
|
|
|
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend) const {
|
|
|
|
auto qs = op->main_as_QuantizedSoftmax();
|
|
|
|
return new MetalQuantizedSoftmax(backend, qs->beta(), qs->inputScale());
|
|
|
|
}
|
|
|
|
};
|
2019-05-08 15:44:57 +08:00
|
|
|
REGISTER_METAL_OP_CREATOR(MetalQuantizedSoftmaxCreator, OpType_QuantizedSoftmax);
|
2019-04-17 10:49:11 +08:00
|
|
|
} // namespace MNN
|
|
|
|
#endif /* MNN_METAL_ENABLED */
|