Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] fix_slice #8470

Merged
merged 5 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,18 @@ kernel void buf_to_tex_c_n(const device float* input[[buffer(0)]],
kernel void buf_h_to_tex_h(const device half* input[[buffer(0)]],
texture2d_array<half, access::write> outTexture[[texture(0)]],
uint3 gid[[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height()) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
int gidz = gid.z * 4;
int output_size = outTexture.get_width() * outTexture.get_height();
int output_index = outTexture.get_width() * gid.y + gid.x;

half y0 = input[(gidz)*output_size + output_index];
half y1 = input[(gidz + 1) * output_size + output_index];
half y2 = input[(gidz + 2) * output_size + output_index];
half y3 = input[(gidz + 3) * output_size + output_index];

half y = input[outTexture.get_width() * gid.y + gid.x];
outTexture.write(half4(y, 0.0f, 0.0f, 0.0f), gid.xy, gid.z);
outTexture.write(half4(y0, y1, y2, y3), gid.xy, gid.z);
}
46 changes: 19 additions & 27 deletions lite/backends/metal/metal_kernel/texture/SliceKernel.metal
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,28 @@
using namespace metal;

struct MetalSliceParam {
short start0;
short start1;
short start2;
short start3;
short end0;
short end1;
short end2;
short end3;
int iC;
int oC;
int iW;
int iH;
int oW;
int oH;
int isize;
int osize;
int oarraysize;
int start[4];
int endC;
};

kernel void slice(texture2d_array<ftype, access::sample> inTexture[[texture(0)]],
texture2d_array<ftype, access::write> outTexture[[texture(1)]],
constant MetalSliceParam& param[[buffer(0)]],
kernel void slice(device ftype* input[[buffer(0)]],
device ftype* output[[buffer(1)]],
constant MetalSliceParam& param[[buffer(2)]],
uint3 gid[[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size())
if (gid.x >= param.oW || gid.y >= param.oH || gid.z >= param.oarraysize) {
return;
ftype4 output;
for (int i = 0; i < 4; ++i) {
int tmp = gid.z * 4 + i;
int output_c = tmp % param.oC;
int output_n = tmp / param.oC;
int c = output_c + param.start1;
tmp = output_n * param.iC + c;
int input_z = tmp / 4;
int input_c = tmp % 4;
const ftype4 input = inTexture.read(gid.xy, input_z);
output[i] = input[input_c % 4];
}
outTexture.write(output, gid.xy, gid.z);
for (int i = param.start[1], j = 0; i < param.endC; i++, j++) {
int in_idx =
i * param.isize + (param.start[2] + gid.y) * param.iW + (param.start[3] + gid.x);
int out_idx = j * param.osize + gid.y * param.oW + gid.x;
output[out_idx] = input[in_idx];
}
}
19 changes: 9 additions & 10 deletions lite/kernels/metal/image_op/metal_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,15 @@ struct ConvTransposeAddMetalParam {
};

struct SliceMetalParam {
uint16_t start0;
uint16_t start1;
uint16_t start2;
uint16_t start3;
uint16_t end0;
uint16_t end1;
uint16_t end2;
uint16_t end3;
int iC;
int oC;
int iW;
int iH;
int oW;
int oH;
int isize;
int osize;
int oarraysize;
int start[4];
int endC;
};

struct FeedMetalParam {
Expand Down
11 changes: 10 additions & 1 deletion lite/kernels/metal/image_op/slice_image_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,23 @@ class SliceImageCompute
private:
void run_without_mps();
void setup_without_mps();
void reset_data();
void run_tex_to_buf();
void run_buf_to_tex();

const MetalImage* input_buffer_;
MetalImage* output_buffer_{nullptr};
std::shared_ptr<MetalBuffer> params_buffer_;

id<MTLComputePipelineState> pipline_;
std::string function_name_;
MetalContext* metal_context_;

id<MTLBuffer> intermediate_input_;
id<MTLBuffer> intermediate_output_;

id<MTLComputePipelineState> pipline_;
id<MTLComputePipelineState> pipline_tex_to_buf;
id<MTLComputePipelineState> pipline_buf_to_tex;
};

} // namespace metal
Expand Down
139 changes: 96 additions & 43 deletions lite/kernels/metal/image_op/slice_image_compute.mm
Original file line number Diff line number Diff line change
Expand Up @@ -35,80 +35,133 @@
#else
input_buffer_ = param.X->data<MetalHalf, MetalImage>();
output_buffer_ = param.Out->mutable_data<MetalHalf, MetalImage>(metal_context_, output_dims);
#endif

#endif
setup_without_mps();
}

void SliceImageCompute::Run() {
@autoreleasepool {
reset_data();
run_tex_to_buf();
run_without_mps();
run_buf_to_tex();
}
}

void SliceImageCompute::reset_data() {
TargetWrapperMetal::MemsetSync(intermediate_input_.contents, 0, intermediate_input_.length);
TargetWrapperMetal::MemsetSync(intermediate_output_.contents, 0, intermediate_output_.length);
}

void SliceImageCompute::run_tex_to_buf() {
auto pipline = pipline_tex_to_buf;
auto outTexture = input_buffer_->image();
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
auto encoder = [backend commandEncoder];
[encoder setTexture:input_buffer_->image() atIndex:(0)];
[encoder setBuffer:intermediate_input_ offset:(0) atIndex:(0)];

[backend dispatchEncoder:encoder pipline:pipline outTexture:outTexture];
[backend commit];
}

void SliceImageCompute::run_buf_to_tex() {
auto pipline = pipline_buf_to_tex;
auto outTexture = output_buffer_->image();
auto backend = (__bridge MetalContextImp*)metal_context_->backend();

auto encoder = [backend commandEncoder];
[encoder setBuffer:intermediate_output_ offset:(0) atIndex:(0)];
[encoder setTexture:output_buffer_->image() atIndex:(0)];
[backend dispatchEncoder:encoder pipline:pipline outTexture:outTexture];
[backend commit];
}

void SliceImageCompute::run_without_mps() {
auto pipline = pipline_;
auto outTexture = output_buffer_->image();
auto backend = (__bridge MetalContextImp*)metal_context_->backend();

auto encoder = [backend commandEncoder];
[encoder setTexture:input_buffer_->image() atIndex:(0)];
[encoder setTexture:output_buffer_->image() atIndex:(1)];
[encoder setBuffer:(params_buffer_->buffer()) offset:(0) atIndex:(0)];
[encoder setBuffer:intermediate_input_ offset:(0) atIndex:(0)];
[encoder setBuffer:intermediate_output_ offset:(0) atIndex:(1)];
[encoder setBuffer:(params_buffer_->buffer()) offset:(0) atIndex:(2)];

[backend dispatchEncoder:encoder pipline:pipline outTexture:outTexture];
auto N = input_buffer_->pad_to_four_dim_[0];
auto C = input_buffer_->pad_to_four_dim_[1];
auto H = input_buffer_->pad_to_four_dim_[2];
auto W = input_buffer_->pad_to_four_dim_[3];

auto slices = (N + 3) / 4;

auto width = MIN(W, pipline.threadExecutionWidth);
auto height = MIN(H, pipline.maxTotalThreadsPerThreadgroup / width);
auto threadsPerGroup = MTLSizeMake(width, height, 1);

auto groupWidth = (W + width - 1) / width;
auto groupHeight = (H + height - 1) / height;
auto groups = MTLSizeMake(groupWidth, groupHeight, N ? N : slices);

[backend dispatchEncoder:encoder pipline:pipline threadsPerGroup:threadsPerGroup groups:groups];
[backend commit];
}

void SliceImageCompute::setup_without_mps() {
auto& context = ctx_->As<MTLContext>();
metal_context_ = (MetalContext*)context.context();
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
const auto& param = this->Param<param_t>();
const auto in_dims = input_buffer_->pad_to_four_dim_;
const auto out_dims = output_buffer_->pad_to_four_dim_;
const auto in_tensor_dims = input_buffer_->tensor_dim_;
const auto out_tensor_dims = output_buffer_->tensor_dim_;

std::vector<int> axes = {};
for (int i = 0; i < input_buffer_->tensor_dim_.size(); i++) {
if (input_buffer_->tensor_dim_[i] != output_buffer_->tensor_dim_[i]) {
axes.push_back(i);
}
}
// only support C channel slice
if (axes.size() == 1 && axes[0] == 1) {
} else {
LOG(FATAL) << "slice: only support channel axe";
}
auto axes = param.axes;
auto starts = param.starts;
auto ends = param.ends;
std::map<int, std::vector<uint16_t>> ranges = {};
for (int j = 0; j < axes.size(); j++) {
ranges[uint16_t(axes[j])] = {uint16_t(starts[j]), uint16_t(ends[j])};
}
//
int iC = (int)input_buffer_->tensor_dim_[1];
int oC = (int)output_buffer_->tensor_dim_[1];
uint16_t param_rangs[4][2] = {};
for (int k = 0; k < 4; k++) {
if (ranges.find(k) != ranges.end()) {
param_rangs[k][0] = (ranges[k])[0];
param_rangs[k][1] = (ranges[k])[1];
} else {
param_rangs[k][0] = 0;
param_rangs[k][1] = (uint16_t)(input_buffer_->tensor_dim_[k]);
std::vector<int> real_starts(in_tensor_dims.size(), 0);
std::vector<int> real_ends(in_tensor_dims.size(), 0);
for (int i = 0; i < axes.size(); i++) {
int dim_value = in_tensor_dims[axes[i]];
if (dim_value > 0) {
int start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
int end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
real_starts[axes[i]] = start;
real_ends[axes[i]] = end;
}
}
SliceMetalParam params{param_rangs[0][0],
param_rangs[1][0],
param_rangs[2][0],
param_rangs[3][0],
param_rangs[0][1],
param_rangs[1][1],
param_rangs[2][1],
param_rangs[3][1],
iC,
oC};
for (int i = out_tensor_dims.size(); i < 4; i++) {
real_starts.insert(real_starts.begin(), 0);
real_ends.insert(real_ends.begin(), 0);
}
real_ends[1] = real_ends[1] == 0 ? in_dims[1] : real_ends[1];
SliceMetalParam params{(int)in_dims[3],
(int)in_dims[2],
(int)out_dims[3],
(int)out_dims[2],
(int)(in_dims[3] * in_dims[2]),
(int)(out_dims[3] * out_dims[2]),
((int)out_dims[1] + 3) / 4,
{real_starts[0], real_starts[1], real_starts[2], real_starts[3]},
real_ends[1]};

params_buffer_ = std::make_shared<MetalBuffer>(metal_context_, sizeof(params), &params);

auto inputLength = input_buffer_->dim_.production() * sizeof(MetalHalf);
intermediate_input_ =
[backend newDeviceBuffer:inputLength access:METAL_ACCESS_FLAG::CPUWriteOnly];
auto outputLength = output_buffer_->dim_.production() * sizeof(MetalHalf);
intermediate_output_ =
[backend newDeviceBuffer:outputLength access:METAL_ACCESS_FLAG::CPUWriteOnly];

function_name_ = "slice";
// pipline
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
pipline_ = [backend pipline:function_name_];
pipline_tex_to_buf = [backend pipline:"tex2d_ary_to_buf"];
pipline_buf_to_tex = [backend pipline:"buf_h_to_tex_h"];
}

SliceImageCompute::~SliceImageCompute() {
Expand Down