Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Add accuracy func and test mnist accuracy #40

Merged
merged 2 commits into from
Jun 28, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/example/mnist/create_mnist_recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def create_mnist_recordio_files():
# Convert mnist to recordio file
# Convert mnist training set to recordio files
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
Expand All @@ -30,7 +30,21 @@ def create_mnist_recordio_files():
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'/tmp/mnist.recordio', reader, feeder)
'/tmp/mnist_train.recordio', reader, feeder)

# Convert mnist testing set to recordio files
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.test(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'/tmp/mnist_test.recordio', reader, feeder)


if __name__ == "__main__":
Expand Down
96 changes: 76 additions & 20 deletions src/example/mnist/test_mnist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
// limitations under the License.

#include <fstream>
#include <numeric>
#include <vector>

#include "gtest/gtest.h"
#include "src/function.h"

using paddle::tape::VariableHandle;
using paddle::tape::Linear;
using paddle::tape::SGD;
using paddle::tape::Adam;
using paddle::tape::accuracy;
using paddle::tape::mean;
using paddle::tape::softmax;
using paddle::tape::cross_entropy;
Expand All @@ -35,35 +39,41 @@ bool is_file_exist(const std::string& fileName) {
}

TEST(Mnist, TestCPU) {
std::string filename = "/tmp/mnist.recordio";
PADDLE_ENFORCE(is_file_exist(filename),
"file doesn't exist; have you run create_mnist_recordio.py");
auto reader = CreateRecordioFileReader(
filename, {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});

Linear linear1(784, 200, "relu");
Linear linear2(200, 200, "relu");
Linear linear3(200, 10, "relu");
std::string filename1 = "/tmp/mnist_train.recordio";
std::string filename2 = "/tmp/mnist_test.recordio";
PADDLE_ENFORCE(is_file_exist(filename1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put is_file_exist inside CreateRecordioFileReader

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"%s doesn't exist; have you run create_mnist_recordio.py",
filename1);
PADDLE_ENFORCE(is_file_exist(filename2),
"%s doesn't exist; have you run create_mnist_recordio.py",
filename2);
auto train_reader = CreateRecordioFileReader(
filename1, {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});
auto test_reader = CreateRecordioFileReader(
filename2, {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});

Linear linear1(784, 200, "tanh");
Linear linear2(200, 200, "tanh");
Linear linear3(200, 10, "softmax");
Adam adam(0.001);

auto forward = [&](VariableHandle input) -> VariableHandle {
return linear3(linear2(linear1(input)));
};

int total_steps = 10000;
int print_step = 100;
float avg_loss = 0.0;
float threshold = 0.90f;

for (int i = 0; i < 1000; ++i) {
for (int i = 0; i < total_steps; ++i) {
reset_global_tape();
auto data_label = ReadNext(reader);
auto data_label = ReadNext(train_reader, true);
auto data = data_label[0];
auto label = data_label[1];

auto predict = softmax(linear3(linear2(linear1(data))));
auto predict = forward(data);
auto loss = mean(cross_entropy(predict, label));

avg_loss +=
loss->Value().Get<paddle::framework::LoDTensor>().data<float>()[0];
if ((i + 1) % print_step == 0) {
LOG(INFO) << avg_loss / print_step;
avg_loss = 0;
}
auto precision = accuracy(predict, label);

get_global_tape().Backward(loss);

Expand All @@ -76,6 +86,52 @@ TEST(Mnist, TestCPU) {
for (auto w : linear3.Params()) {
adam.Update(w);
}

// Every time certain amount of batches have been processed,
// we test the average loss and accuracy on the test data set,
// we stop training when the accuracy hit some threshold
if ((i + 1) % print_step == 0) {
std::vector<float> losses;
std::vector<float> accuracies;

while (true) {
reset_global_tape();

auto data_label = ReadNext(test_reader, false);
if (data_label.empty()) {
break;
}

auto data = data_label[0];
auto label = data_label[1];

auto predict = forward(data);
auto loss = mean(cross_entropy(predict, label));
auto precision = accuracy(predict, label);

get_global_tape().Forward();

losses.push_back(
loss->Get<paddle::framework::LoDTensor>().data<float>()[0]);
accuracies.push_back(
precision->Get<paddle::framework::LoDTensor>().data<float>()[0]);
}

float avg_loss =
std::accumulate(losses.begin(), losses.end(), 0.0f) / losses.size();
float avg_accu =
std::accumulate(accuracies.begin(), accuracies.end(), 0.0f) /
accuracies.size();

LOG(INFO) << "Pass #" << (i + 1) / print_step
<< ", test set evaluation result: Avg loss is " << avg_loss
<< ", Avg accuracy is " << avg_accu;

if (avg_accu >= threshold) {
LOG(INFO) << "Meets target accuracy, stop training";
break;
}
}
}
}

Expand Down
38 changes: 34 additions & 4 deletions src/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void init_params(VariableHandle v,

class Linear {
public:
Linear(int in_dim, int out_dim, const std::string &act)
Linear(int in_dim, int out_dim, const std::string &act = "")
: w_(new Variable("LinearWeight")),
b_(new Variable("LinearBias")),
act_(act) {
Expand Down Expand Up @@ -110,6 +110,9 @@ class Linear {
{{"X", {pre_bias}}, {"Y", {b_}}},
{{"Out", {pre_act}}},
add_op_attrs);
if (act_.empty()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also make Conv support act_.empty()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for both conv and batchnorm

return pre_act;
}
VariableHandle post_act(new Variable("linear"));
get_global_tape().AddOp(
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {});
Expand Down Expand Up @@ -337,6 +340,29 @@ class BatchNorm {
std::string act_;
};

// Calculate the top k accuracy of the prediction against the label
VariableHandle accuracy(VariableHandle prediction,
VariableHandle label,
int k = 1) {
// Use top_k op to get top k prediction class labels
VariableHandle topk_values(new Variable("accuracy"));
VariableHandle topk_indices(new Variable("accuracy"));
get_global_tape().AddOp("top_k",
{{"X", {prediction}}},
{{"Out", {topk_values}}, {"Indices", {topk_indices}}},
{{"k", k}});

VariableHandle acc_out(new Variable("accuracy"));
VariableHandle correct(new Variable("accuracy"));
VariableHandle total(new Variable("accuracy"));
get_global_tape().AddOp(
"accuracy",
{{"Out", {topk_values}}, {"Indices", {topk_indices}}, {"Label", {label}}},
{{"Accuracy", {acc_out}}, {"Correct", {correct}}, {"Total", {total}}},
{});
return acc_out;
}

VariableHandle pool2d(VariableHandle x,
const framework::AttributeMap &attrs = {}) {
VariableHandle out(new Variable("pool2d"));
Expand Down Expand Up @@ -397,15 +423,19 @@ VariableHandle CreateRecordioFileReader(std::string filename,
return reader;
}

std::vector<VariableHandle> ReadNext(VariableHandle reader) {
std::vector<VariableHandle> ReadNext(VariableHandle reader, bool repeat) {
PADDLE_ENFORCE(reader->Var().IsType<framework::ReaderHolder>());

paddle::framework::LoDTensorArray data_holder;
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(&data_holder);
if (data_holder.empty()) {
reader->GetMutable<paddle::framework::ReaderHolder>()->ReInit();
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(
&data_holder);
if (repeat) {
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(
&data_holder);
} else {
return {};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this return {} hide the check PADDLE_ENFORCE(!data_holder.empty(), "Error reading file.");?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Revised.

}
}
PADDLE_ENFORCE(!data_holder.empty(), "Error reading file.");

Expand Down
1 change: 1 addition & 0 deletions src/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/operator.h" // framework::kGradVarSuffix
Expand Down