Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Add accuracy func and test mnist accuracy #40

Merged
merged 2 commits into from
Jun 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/example/mnist/create_mnist_recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def create_mnist_recordio_files():
# Convert mnist to recordio file
# Convert mnist training set to recordio files
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
Expand All @@ -30,7 +30,21 @@ def create_mnist_recordio_files():
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'/tmp/mnist.recordio', reader, feeder)
'/tmp/mnist_train.recordio', reader, feeder)

# Convert mnist testing set to recordio files
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.test(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'/tmp/mnist_test.recordio', reader, feeder)


if __name__ == "__main__":
Expand Down
96 changes: 70 additions & 26 deletions src/example/mnist/test_mnist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <fstream>
#include <numeric>
#include <vector>

#include "gtest/gtest.h"
#include "src/function.h"

using paddle::tape::VariableHandle;
using paddle::tape::Linear;
using paddle::tape::SGD;
using paddle::tape::Adam;
using paddle::tape::accuracy;
using paddle::tape::mean;
using paddle::tape::softmax;
using paddle::tape::cross_entropy;
Expand All @@ -29,41 +32,36 @@ using paddle::tape::get_global_tape;
using paddle::tape::CreateRecordioFileReader;
using paddle::tape::ReadNext;

bool is_file_exist(const std::string& fileName) {
std::ifstream infile(fileName);
return infile.good();
}

TEST(Mnist, TestCPU) {
std::string filename = "/tmp/mnist.recordio";
PADDLE_ENFORCE(is_file_exist(filename),
"file doesn't exist; have you run create_mnist_recordio.py");
auto reader = CreateRecordioFileReader(
filename, {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});

Linear linear1(784, 200, "relu");
Linear linear2(200, 200, "relu");
Linear linear3(200, 10, "relu");
std::string filename1 = "/tmp/mnist_train.recordio";
std::string filename2 = "/tmp/mnist_test.recordio";
auto train_reader = CreateRecordioFileReader(
filename1, {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});
auto test_reader = CreateRecordioFileReader(
filename2, {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});

Linear linear1(784, 200, "tanh");
Linear linear2(200, 200, "tanh");
Linear linear3(200, 10, "softmax");
Adam adam(0.001);

auto forward = [&](VariableHandle input) -> VariableHandle {
return linear3(linear2(linear1(input)));
};

int total_steps = 10000;
int print_step = 100;
float avg_loss = 0.0;
float threshold = 0.90f;

for (int i = 0; i < 1000; ++i) {
for (int i = 0; i < total_steps; ++i) {
reset_global_tape();
auto data_label = ReadNext(reader);
auto data_label = ReadNext(train_reader, true);
auto data = data_label[0];
auto label = data_label[1];

auto predict = softmax(linear3(linear2(linear1(data))));
auto predict = forward(data);
auto loss = mean(cross_entropy(predict, label));

avg_loss +=
loss->Value().Get<paddle::framework::LoDTensor>().data<float>()[0];
if ((i + 1) % print_step == 0) {
LOG(INFO) << avg_loss / print_step;
avg_loss = 0;
}
auto precision = accuracy(predict, label);

get_global_tape().Backward(loss);

Expand All @@ -76,6 +74,52 @@ TEST(Mnist, TestCPU) {
for (auto w : linear3.Params()) {
adam.Update(w);
}

// Every time certain amount of batches have been processed,
// we test the average loss and accuracy on the test data set,
// we stop training when the accuracy hit some threshold
if ((i + 1) % print_step == 0) {
std::vector<float> losses;
std::vector<float> accuracies;

while (true) {
reset_global_tape();

auto data_label = ReadNext(test_reader, false);
if (data_label.empty()) {
break;
}

auto data = data_label[0];
auto label = data_label[1];

auto predict = forward(data);
auto loss = mean(cross_entropy(predict, label));
auto precision = accuracy(predict, label);

get_global_tape().Forward();

losses.push_back(
loss->Get<paddle::framework::LoDTensor>().data<float>()[0]);
accuracies.push_back(
precision->Get<paddle::framework::LoDTensor>().data<float>()[0]);
}

float avg_loss =
std::accumulate(losses.begin(), losses.end(), 0.0f) / losses.size();
float avg_accu =
std::accumulate(accuracies.begin(), accuracies.end(), 0.0f) /
accuracies.size();

LOG(INFO) << "Pass #" << (i + 1) / print_step
<< ", test set evaluation result: Avg loss is " << avg_loss
<< ", Avg accuracy is " << avg_accu;

if (avg_accu >= threshold) {
LOG(INFO) << "Meets target accuracy, stop training";
break;
}
}
}
}

Expand Down
53 changes: 47 additions & 6 deletions src/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <cmath>
#include <fstream>
#include <string>
#include <tuple>
#include <vector>
Expand Down Expand Up @@ -76,7 +77,7 @@ void init_params(VariableHandle v,

class Linear {
public:
Linear(int in_dim, int out_dim, const std::string &act)
Linear(int in_dim, int out_dim, const std::string &act = "")
: w_(new Variable("LinearWeight")),
b_(new Variable("LinearBias")),
act_(act) {
Expand Down Expand Up @@ -110,6 +111,9 @@ class Linear {
{{"X", {pre_bias}}, {"Y", {b_}}},
{{"Out", {pre_act}}},
add_op_attrs);
if (act_.empty()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also make Conv support act_.empty()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for both conv and batchnorm

return pre_act;
}
VariableHandle post_act(new Variable("linear"));
get_global_tape().AddOp(
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {});
Expand All @@ -126,7 +130,7 @@ class Linear {

class Convolution2D {
public:
Convolution2D(int c_in, int c_out, int f, const std::string &act)
Convolution2D(int c_in, int c_out, int f, const std::string &act = "")
: w_(new Variable("ConvolutionWeight")),
b_(new Variable("ConvolutionBias")),
act_(act) {
Expand Down Expand Up @@ -162,6 +166,9 @@ class Convolution2D {
{{"X", {pre_bias}}, {"Y", {b_}}},
{{"Out", {pre_act}}},
add_op_attrs);
if (act_.empty()) {
return pre_act;
}
VariableHandle post_act(new Variable("conv"));
get_global_tape().AddOp(
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {});
Expand Down Expand Up @@ -282,7 +289,7 @@ class Adam {

class BatchNorm {
public:
BatchNorm(int channel_in, const std::string &act)
explicit BatchNorm(int channel_in, const std::string &act = "")
: scale_(new Variable("BatchNormScale")),
bias_(new Variable("BatchNormBias")),
mean_(new Variable("BatchNormMean")),
Expand Down Expand Up @@ -319,7 +326,9 @@ class BatchNorm {
{"SavedMean", {tmp_mean}},
{"SavedVariance", {tmp_var}}},
attrs);

if (act_.empty()) {
return pre_act;
}
VariableHandle post_act(new Variable("batch_norm"));
get_global_tape().AddOp(
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {});
Expand All @@ -337,6 +346,29 @@ class BatchNorm {
std::string act_;
};

// Calculate the top k accuracy of the prediction against the label
VariableHandle accuracy(VariableHandle prediction,
VariableHandle label,
int k = 1) {
// Use top_k op to get top k prediction class labels
VariableHandle topk_values(new Variable("accuracy"));
VariableHandle topk_indices(new Variable("accuracy"));
get_global_tape().AddOp("top_k",
{{"X", {prediction}}},
{{"Out", {topk_values}}, {"Indices", {topk_indices}}},
{{"k", k}});

VariableHandle acc_out(new Variable("accuracy"));
VariableHandle correct(new Variable("accuracy"));
VariableHandle total(new Variable("accuracy"));
get_global_tape().AddOp(
"accuracy",
{{"Out", {topk_values}}, {"Indices", {topk_indices}}, {"Label", {label}}},
{{"Accuracy", {acc_out}}, {"Correct", {correct}}, {"Total", {total}}},
{});
return acc_out;
}

VariableHandle pool2d(VariableHandle x,
const framework::AttributeMap &attrs = {}) {
VariableHandle out(new Variable("pool2d"));
Expand Down Expand Up @@ -382,6 +414,11 @@ VariableHandle CreateRecordioFileReader(std::string filename,
std::vector<int> shape_concat,
std::vector<int> ranks,
std::vector<int> lod_levels) {
std::ifstream infile(filename);
PADDLE_ENFORCE(infile.good(),
"%s doesn't exist; have you run create_mnist_recordio.py?",
filename);

VariableHandle reader(new paddle::tape::Variable("reader"));

framework::OpDesc op_desc = CreateOpDesc("create_recordio_file_reader",
Expand All @@ -397,7 +434,7 @@ VariableHandle CreateRecordioFileReader(std::string filename,
return reader;
}

std::vector<VariableHandle> ReadNext(VariableHandle reader) {
std::vector<VariableHandle> ReadNext(VariableHandle reader, bool repeat) {
PADDLE_ENFORCE(reader->Var().IsType<framework::ReaderHolder>());

paddle::framework::LoDTensorArray data_holder;
Expand All @@ -406,8 +443,12 @@ std::vector<VariableHandle> ReadNext(VariableHandle reader) {
reader->GetMutable<paddle::framework::ReaderHolder>()->ReInit();
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(
&data_holder);
PADDLE_ENFORCE(!data_holder.empty(), "Error reading file.");
if (!repeat) {
reader->GetMutable<paddle::framework::ReaderHolder>()->ReInit();
return {};
}
}
PADDLE_ENFORCE(!data_holder.empty(), "Error reading file.");

std::vector<VariableHandle> rval;
for (size_t i = 0; i < data_holder.size(); ++i) {
Expand Down
1 change: 1 addition & 0 deletions src/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/operator.h" // framework::kGradVarSuffix
Expand Down