Skip to content

Commit

Permalink
[Enhancement] Add optional softmax in LinearClsHead (#1858)
Browse files Browse the repository at this point in the history
* add softmax in cls postprocess

* minor
  • Loading branch information
lzhangzz authored Mar 9, 2023
1 parent f69c636 commit bcb93ea
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions csrc/mmdeploy/codebase/mmcls/linear_cls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification {
public:
explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) {
if (cfg.contains("params")) {
softmax_ = cfg["params"].value("softmax", false);
topk_ = cfg["params"].value("topk", 1);
if (topk_ <= 0) {
MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_);
Expand Down Expand Up @@ -54,17 +55,32 @@ class LinearClsHead : public MMClassification {
iota(begin(idx), end(idx), 0);
partial_sort(begin(idx), begin(idx) + topk, end(idx),
[&](int i, int j) { return scores_data[i] > scores_data[j]; });

auto sum_exp = 0.f;
std::vector<float> exp_scores;
if (softmax_) {
exp_scores.reserve(class_num);
auto max_val = scores_data[idx[0]];
for (int i = 0; i < class_num; ++i) {
sum_exp += exp_scores.emplace_back(std::exp(scores_data[i] - max_val));
}
}
for (int i = 0; i < topk; ++i) {
auto label = Label{idx[i], scores_data[idx[i]]};
MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score);
output.push_back(label);
float score = 0.f;
if (softmax_) {
score = exp_scores[idx[i]] / sum_exp;
} else {
score = scores_data[idx[i]];
}
output.push_back({idx[i], score});
}
return to_value(std::move(output));
}

private:
static constexpr const auto kHost = Device{0};

bool softmax_{false};
int topk_{1};
};

Expand Down

0 comments on commit bcb93ea

Please sign in to comment.