mirror of https://github.com/alibaba/MNN.git
242 lines
7.3 KiB
Plaintext
242 lines
7.3 KiB
Plaintext
//
|
|
// MetalArgMax.mm
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2023/12/29.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#import "core/Macro.h"
|
|
#import "MetalCast.hpp"
|
|
#import "MetalBackend.hpp"
|
|
#import "MNNMetalContext.h"
|
|
#include "MNN_generated.h"
|
|
|
|
#if MNN_METAL_ENABLED
|
|
namespace MNN {
|
|
static const char* gArgMaxTemplate = R"metal(
|
|
#include <metal_stdlib>
|
|
#include <simd/simd.h>
|
|
using namespace metal;
|
|
struct constBuffer
|
|
{
|
|
int4 size;
|
|
};
|
|
struct sourceBuffer
|
|
{
|
|
T data[1];
|
|
};
|
|
struct destbuffer
|
|
{
|
|
int data[1];
|
|
};
|
|
|
|
kernel void main0(device destbuffer& uOutput [[buffer(0)]], const device sourceBuffer& uInput [[buffer(1)]], constant constBuffer& uConst [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
|
|
{
|
|
threadgroup T local_buffer[256];
|
|
threadgroup int local_index[256];
|
|
int index = int(gl_GlobalInvocationID.x) / uConst.size.w;
|
|
int lidIndex = int(gl_LocalInvocationID.x);
|
|
int lidx = lidIndex / uConst.size.w;
|
|
int lid = lidIndex % uConst.size.w;
|
|
int x = index % uConst.size.x;
|
|
int y = index / uConst.size.x;
|
|
int W = uConst.size.x;
|
|
int H = uConst.size.y;
|
|
int C = uConst.size.z;
|
|
bool _68 = y < uConst.size.z;
|
|
bool _75;
|
|
if (_68)
|
|
{
|
|
_75 = lid < uConst.size.y;
|
|
}
|
|
else
|
|
{
|
|
_75 = _68;
|
|
}
|
|
if (_75)
|
|
{
|
|
int offset = ((y * H) * W) + x;
|
|
T maxValue = uInput.data[offset + (lid * W)];
|
|
int maxIndex = lid;
|
|
int _107 = lid + uConst.size.w;
|
|
for (int i = _107; i < uConst.size.y; i += uConst.size.w)
|
|
{
|
|
T value = uInput.data[offset + (i * W)];
|
|
#ifdef ARGMIN
|
|
if (value < maxValue)
|
|
{
|
|
maxValue = value;
|
|
maxIndex = i;
|
|
}
|
|
#else
|
|
if (value > maxValue)
|
|
{
|
|
maxValue = value;
|
|
maxIndex = i;
|
|
}
|
|
#endif
|
|
}
|
|
local_buffer[lid + (lidx * uConst.size.w)] = maxValue;
|
|
local_index[lid + (lidx * uConst.size.w)] = maxIndex;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if ((y < uConst.size.z) && (lid == 0))
|
|
{
|
|
T maxValue_1 = local_buffer[lidx * uConst.size.w];
|
|
int maxIndex_1 = local_index[lidx * uConst.size.w];
|
|
int t = 1;
|
|
for (;;)
|
|
{
|
|
bool _195 = t < uConst.size.w;
|
|
bool _202;
|
|
if (_195)
|
|
{
|
|
_202 = t < uConst.size.y;
|
|
}
|
|
else
|
|
{
|
|
_202 = _195;
|
|
}
|
|
if (_202)
|
|
{
|
|
T next = local_buffer[t + (lidx * uConst.size.w)];
|
|
#ifdef ARGMIN
|
|
if (next < maxValue_1)
|
|
{
|
|
maxValue_1 = next;
|
|
maxIndex_1 = local_index[t + (lidx * uConst.size.w)];
|
|
}
|
|
#else
|
|
if (next > maxValue_1)
|
|
{
|
|
maxValue_1 = next;
|
|
maxIndex_1 = local_index[t + (lidx * uConst.size.w)];
|
|
}
|
|
#endif
|
|
t++;
|
|
continue;
|
|
}
|
|
else
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
uOutput.data[index] = maxIndex_1;
|
|
}
|
|
}
|
|
)metal";
|
|
|
|
class MetalArgMax : public MetalExecution {
|
|
private:
|
|
int mAxis;
|
|
int mGroupNumber;
|
|
id<MTLBuffer> mParam;
|
|
id<MTLComputePipelineState> mPipeline;
|
|
public:
|
|
MetalArgMax(id<MTLComputePipelineState> pipeline, int axis, Backend *bn) : MetalExecution(bn) {
|
|
mPipeline = pipeline;
|
|
mAxis = axis;
|
|
auto mtbn = static_cast<MetalBackend *>(bn);
|
|
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
|
mParam = [context newDeviceBuffer:sizeof(int) * 4 access:CPUWriteOnly];
|
|
};
|
|
virtual ErrorCode onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override {
|
|
auto input = inputs[0];
|
|
auto output = outputs[0];
|
|
|
|
auto inputFormat = TensorUtils::getDescribe(input)->dimensionFormat;
|
|
auto axis = mAxis;
|
|
if (axis < 0) {
|
|
axis = input->dimensions() + axis;
|
|
}
|
|
int inside = 1;
|
|
int outside = 1;
|
|
int mid = input->length(axis);
|
|
for (int i=0; i<axis; ++i) {
|
|
outside *= input->length(i);
|
|
}
|
|
for (int i=axis+1; i<input->dimensions(); ++i) {
|
|
inside *= input->length(i);
|
|
}
|
|
auto total = outside * inside;
|
|
int outsideParallel = 1;
|
|
int reduceAxis = 1;
|
|
if (total >= 256) {
|
|
reduceAxis = 1;
|
|
outsideParallel = 256;
|
|
} else if (total < 16) {
|
|
reduceAxis = 256;
|
|
outsideParallel = 1;
|
|
} else {
|
|
reduceAxis = 16;
|
|
outsideParallel = 16;
|
|
}
|
|
|
|
// gpu param
|
|
{
|
|
auto Argmax = reinterpret_cast<int*>([mParam contents]);
|
|
Argmax[0] = inside;
|
|
Argmax[1] = mid;
|
|
Argmax[2] = outside;
|
|
Argmax[3] = reduceAxis;
|
|
}
|
|
mGroupNumber = UP_DIV(total, outsideParallel);
|
|
return NO_ERROR;
|
|
}
|
|
virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs,
|
|
id<MTLComputeCommandEncoder> encoder) override {
|
|
[encoder setComputePipelineState:mPipeline];
|
|
MetalBackend::setTensor(outputs[0], encoder, 0);
|
|
MetalBackend::setTensor(inputs[0], encoder, 1);
|
|
[encoder setBuffer:mParam offset:0 atIndex:2];
|
|
[encoder dispatchThreadgroups:MTLSizeMake(mGroupNumber, 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
|
|
}
|
|
|
|
};
|
|
|
|
class MetalArgMaxCreator : public MetalBackend::Creator {
|
|
public:
|
|
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *bn, const std::vector<Tensor *> &outputs) const {
|
|
if (TensorUtils::getDescribe(inputs[0])->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
// Don't support legency version
|
|
return nullptr;
|
|
}
|
|
auto axis = op->main_as_ArgMax()->axis();
|
|
auto mtbn = static_cast<MetalBackend *>(bn);
|
|
auto type = MetalCast::getScalarType(inputs[0]->getType(), mtbn->useFp16InsteadFp32());
|
|
std::vector<std::string> keys {
|
|
"argmax",
|
|
std::string([type UTF8String])
|
|
};
|
|
if (op->type() == OpType_ArgMin) {
|
|
keys.emplace_back("argmin");
|
|
}
|
|
auto pipeline = mtbn->runtime()->findPipeline(keys);
|
|
if (nil == pipeline) {
|
|
MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init];
|
|
if (op->type() != OpType_ArgMin) {
|
|
compileOptions.preprocessorMacros = @{
|
|
@"T" : type,
|
|
};
|
|
} else {
|
|
compileOptions.preprocessorMacros = @{
|
|
@"T" : type,
|
|
@"ARGMIN": @"1"
|
|
};
|
|
}
|
|
pipeline = mtbn->makeComputePipelineWithSourceOption(gArgMaxTemplate, "main0", compileOptions);
|
|
mtbn->runtime()->insertPipeline(keys, pipeline);
|
|
}
|
|
if (nil == pipeline) {
|
|
MNN_ERROR("Create ArgMax pipeline error\n");
|
|
return nullptr;
|
|
}
|
|
return new MetalArgMax(pipeline, axis, bn);
|
|
}
|
|
};
|
|
REGISTER_METAL_OP_CREATOR(MetalArgMaxCreator, OpType_ArgMax);
|
|
REGISTER_METAL_OP_CREATOR(MetalArgMaxCreator, OpType_ArgMin);
|
|
};
|
|
#endif
|