首页 > 程序开发 > 综合编程 > 其他综合 >

深度学习开源库tiny-dnn的使用(MNIST)

2016-12-05

tiny-dnn是一个基于DNN的深度学习开源库,它的License是BSD 3-Clause。之前名字是tiny-cnn是基于CNN的,tiny-dnn与tiny-cnn相关又增加了些新层。

tiny-dnn是一个基于DNN的深度学习开源库,它的License是BSD 3-Clause。之前名字是tiny-cnn是基于CNN的,tiny-dnn与tiny-cnn相关又增加了些新层。此开源库很活跃,几乎每天都有新的提交,因此下面详细介绍下tiny-dnn在windows7 64bit vs2013的编译及使用。

1.从https://github.com/tiny-dnn/tiny-dnn下载源码

$ git clone https://github.com/tiny-dnn/tiny-dnn.git版本号为6281c1b,更新日期2016.12.03

2.源文件中已经包含了vs2013工程,vc/vc12/tiny-dnn.sln,默认是win32的,这里新建一个x64的控制台工程tiny-dnn;

3.仿照源工程,将相应.h文件加入到新控制台工程中,新加一个test_tiny-dnn.cpp文件;

4.仿照examples/mnist中test.cpp和train.cpp文件中的代码添加测试代码;

#include "funset.hpp"
#include 
#include 
#include "tiny_dnn/tiny_dnn.h"

static void construct_net(tiny_dnn::network& nn)
{
	// connection table [Y.Lecun, 1998 Table.1]
#define O true
#define X false
	static const bool tbl[] = {
		O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
		O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
		O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
		X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
		X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
		X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
	};
#undef O
#undef X

	// by default will use backend_t::tiny_dnn unless you compiled
	// with -DUSE_AVX=ON and your device supports AVX intrinsics
	tiny_dnn::core::backend_t backend_type = tiny_dnn::core::default_engine();

	// construct nets: C: convolution; S: sub-sampling; F: fully connected
	nn << tiny_dnn::convolutional_layer(32, 32, 5, 1, 6,  // C1, 1@32x32-in, 6@28x28-out
		tiny_dnn::padding::valid, true, 1, 1, backend_type)
		<< tiny_dnn::average_pooling_layer(28, 28, 6, 2)   // S2, 6@28x28-in, 6@14x14-out
		<< tiny_dnn::convolutional_layer(14, 14, 5, 6, 16, // C3, 6@14x14-in, 16@10x10-out
		connection_table(tbl, 6, 16),
		tiny_dnn::padding::valid, true, 1, 1, backend_type)
		<< tiny_dnn::average_pooling_layer(10, 10, 16, 2)  // S4, 16@10x10-in, 16@5x5-out
		<< tiny_dnn::convolutional_layer(5, 5, 5, 16, 120, // C5, 16@5x5-in, 120@1x1-out
		tiny_dnn::padding::valid, true, 1, 1, backend_type)
		<< tiny_dnn::fully_connected_layer(120, 10,        // F6, 120-in, 10-out
		true, backend_type);
}

static void train_lenet(const std::string& data_dir_path)
{
	// specify loss-function and learning strategy
	tiny_dnn::network nn;
	tiny_dnn::adagrad optimizer;

	construct_net(nn);

	std::cout << "load models..." << std::endl;

	// load MNIST dataset
	std::vector train_labels, test_labels;
	std::vector train_images, test_images;

	tiny_dnn::parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte", &train_labels);
	tiny_dnn::parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte", &train_images, -1.0, 1.0, 2, 2);
	tiny_dnn::parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte", &test_labels);
	tiny_dnn::parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte", &test_images, -1.0, 1.0, 2, 2);

	std::cout << "start training" << std::endl;

	tiny_dnn::progress_display disp(static_cast(train_images.size()));
	tiny_dnn::timer t;
	int minibatch_size = 10;
	int num_epochs = 30;

	optimizer.alpha *= static_cast(std::sqrt(minibatch_size));

	// create callback
	auto on_enumerate_epoch = [&](){
		std::cout << t.elapsed() << "s elapsed." << std::endl;
		tiny_dnn::result res = nn.test(test_images, test_labels);
		std::cout << res.num_success << "/" << res.num_total << std::endl;

		disp.restart(static_cast(train_images.size()));
		t.restart();
	};

	auto on_enumerate_minibatch = [&](){
		disp += minibatch_size;
	};

	// training
	nn.train(optimizer, train_images, train_labels, minibatch_size, num_epochs, on_enumerate_minibatch, on_enumerate_epoch);

	std::cout << "end training." << std::endl;

	// test and show results
	nn.test(test_images, test_labels).print_detail(std::cout);

	// save network model & trained weights
	nn.save(data_dir_path + "/LeNet-model");
}

// rescale output to 0-100
template 
static double rescale(double x)
{
	Activation a;
	return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);
}

static void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, tiny_dnn::vec_t& data)
{
	tiny_dnn::image<> img(imagefilename, tiny_dnn::image_type::grayscale);
	tiny_dnn::image<> resized = resize_image(img, w, h);

	// mnist dataset is "white on black", so negate required
	std::transform(resized.begin(), resized.end(), std::back_inserter(data),
		[=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });
}

int test_dnn_mnist_train()
{
	std::string data_dir_path = "E:/GitCode/NN_Test/data";
	train_lenet(data_dir_path);

	return 0;
}

int test_dnn_mnist_predict()
{
	std::string model { "E:/GitCode/NN_Test/data/LeNet-model" };
	std::string image_path { "E:/GitCode/NN_Test/data/images/"};
	int target[10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };

	tiny_dnn::network nn;
	nn.load(model);

	for (int i = 0; i < 10; i++) {
		std::string str = std::to_string(i);
		str += ".png";
		str = image_path + str;

		// convert imagefile to vec_t
		tiny_dnn::vec_t data;
		convert_image(str, -1.0, 1.0, 32, 32, data);

		// recognize
		auto res = nn.predict(data);
		std::vector > scores;

		// sort & print top-3
		for (int j = 0; j < 10; j++)
			scores.emplace_back(rescale(res[j]), j);

		std::sort(scores.begin(), scores.end(), std::greater>());

		for (int j = 0; j < 3; j++)
			fprintf(stdout, "%d: %f;  ", scores[j].second, scores[j].first);
		fprintf(stderr, "\n");

		// save outputs of each layer
		for (size_t j = 0; j < nn.depth(); j++) {
			auto out_img = nn[j]->output_to_image();
			auto filename = image_path + std::to_string(i) + "_layer_" + std::to_string(j) + ".png";
			out_img.save(filename);
		}

		// save filter shape of first convolutional layer
		auto weight = nn.at>(0).weight_to_image();
		auto filename = image_path + std::to_string(i) + "_weights.png";
		weight.save(filename);

		fprintf(stdout, "the actual digit is: %d, correct digit is: %d \n\n", scores[0].second, target[i]);
	}

	return 0;
}

5.运行程序,train时,运行结果如下图所示,准确率达到99%以上:

6. 对生成的model进行测试,通过画图工具,每个数字生成一张图像,共10幅,如下图:

7. 通过导入train时生成的model,对这10张图像进行识别,识别结果如下图,其中0,8,9被误识别为2,2,1.

GitHub:https://github.com/fengbingchun/NN_Test

相关文章
最新文章
热点推荐