MATLAB R2017b: Deep Learning with CNN

今天读了这篇文章后发现MATLAB的Deep learning原来可以这么简单,有点像Keras,封装的比较好。想当初刚接触tensor flow的时候真的有点头大。

知乎那篇文章中只介绍了CPU的版本,正好手头有块老旧的GPU,拿来试试。

  1. 首先去这里下载数据
  2. 解压
    tar xzvf notMNIST_large.tar.gz
  3. 按照教程说的一步步做,最后trainingOption那里改成GPU就好
% Load data
ds = imageDatastore('notMNIST_large/','LabelSource','foldernames','IncludeSubfolders',true);

% Prepare data
[trainDigitData,valDigitData,testData]=ds.splitEachLabel(0.5,0.3,0.2,'Randomize');

% Define network layers
layers = [...
          imageInputLayer([28,28,1]);
          batchNormalizationLayer();
          convolution2dLayer(5,20);
          batchNormalizationLayer();
          reluLayer()
          maxPooling2dLayer(2,'Stride',2);
          fullyConnectedLayer(10);
          softmaxLayer();
          classificationLayer(),...
    ];

% Customize training option
options = trainingOptions('sgdm',...
                          'ValidationData',valDigitData,...
                          'Plots','training-progress',...
                          'ExecutionEnvironment','gpu');

% Train
net = trainNetwork(trainDigitData,layers,options);

% Test
testLabel = classify(net,testData);
precision = sum(testLabel==testData.Labels)/numel(testLabel)

第一次实验,发现速度非常快

《MATLAB R2017b: Deep Learning with CNN》 jianshunotMnist1.png

但是默认的validation频率太高了,导致很多时间都花费在了数据与GPU通讯上面,8分半跑了2000多个iteration(250循环/分钟)

于是减慢validation频率至每一千个循环验证一次

options = trainingOptions('sgdm',...
                          'ValidationData',valDigitData,...
                          'ValidationFrequency',1000,...
                          'Plots','training-progress',...
                          'ExecutionEnvironment','gpu');

《MATLAB R2017b: Deep Learning with CNN》 jianshunotMnist2.png

可以看到13分钟跑了22000多次循环(1700循环/分钟),可谓效率大大提升。

下面问题就来了,该用这个做什么呢。。。

    原文作者:MATLAB笔记
    原文地址: https://www.jianshu.com/p/b30743359798
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞