This repository was archived by the owner on Jan 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
Add accuracy func and test mnist accuracy #40
Merged
kexinzhao
merged 2 commits into
PaddlePaddle:develop
from
kexinzhao:add_accuracy_utility
Jun 28, 2018
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,7 +76,7 @@ void init_params(VariableHandle v, | |
|
||
class Linear { | ||
public: | ||
Linear(int in_dim, int out_dim, const std::string &act) | ||
Linear(int in_dim, int out_dim, const std::string &act = "") | ||
: w_(new Variable("LinearWeight")), | ||
b_(new Variable("LinearBias")), | ||
act_(act) { | ||
|
@@ -110,6 +110,9 @@ class Linear { | |
{{"X", {pre_bias}}, {"Y", {b_}}}, | ||
{{"Out", {pre_act}}}, | ||
add_op_attrs); | ||
if (act_.empty()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe also make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done for both conv and batchnorm |
||
return pre_act; | ||
} | ||
VariableHandle post_act(new Variable("linear")); | ||
get_global_tape().AddOp( | ||
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {}); | ||
|
@@ -337,6 +340,29 @@ class BatchNorm { | |
std::string act_; | ||
}; | ||
|
||
// Calculate the top k accuracy of the prediction against the label | ||
VariableHandle accuracy(VariableHandle prediction, | ||
VariableHandle label, | ||
int k = 1) { | ||
// Use top_k op to get top k prediction class labels | ||
VariableHandle topk_values(new Variable("accuracy")); | ||
VariableHandle topk_indices(new Variable("accuracy")); | ||
get_global_tape().AddOp("top_k", | ||
{{"X", {prediction}}}, | ||
{{"Out", {topk_values}}, {"Indices", {topk_indices}}}, | ||
{{"k", k}}); | ||
|
||
VariableHandle acc_out(new Variable("accuracy")); | ||
VariableHandle correct(new Variable("accuracy")); | ||
VariableHandle total(new Variable("accuracy")); | ||
get_global_tape().AddOp( | ||
"accuracy", | ||
{{"Out", {topk_values}}, {"Indices", {topk_indices}}, {"Label", {label}}}, | ||
{{"Accuracy", {acc_out}}, {"Correct", {correct}}, {"Total", {total}}}, | ||
{}); | ||
return acc_out; | ||
} | ||
|
||
VariableHandle pool2d(VariableHandle x, | ||
const framework::AttributeMap &attrs = {}) { | ||
VariableHandle out(new Variable("pool2d")); | ||
|
@@ -397,15 +423,19 @@ VariableHandle CreateRecordioFileReader(std::string filename, | |
return reader; | ||
} | ||
|
||
std::vector<VariableHandle> ReadNext(VariableHandle reader) { | ||
std::vector<VariableHandle> ReadNext(VariableHandle reader, bool repeat) { | ||
PADDLE_ENFORCE(reader->Var().IsType<framework::ReaderHolder>()); | ||
|
||
paddle::framework::LoDTensorArray data_holder; | ||
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(&data_holder); | ||
if (data_holder.empty()) { | ||
reader->GetMutable<paddle::framework::ReaderHolder>()->ReInit(); | ||
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext( | ||
&data_holder); | ||
if (repeat) { | ||
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext( | ||
&data_holder); | ||
} else { | ||
return {}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Revised. |
||
} | ||
} | ||
PADDLE_ENFORCE(!data_holder.empty(), "Error reading file."); | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe put
is_file_exist
insideCreateRecordioFileReader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done