Keras的R语言接口

Getting Start

首先,从github上下载keras

devtools::install_github("rstudio/keras")

Keras R界面默认使用TensorFlow后端引擎。 要安装核心Keras库以及TensorFlow后端,请使用install_keras()函数:

library(keras)
install_keras()

这将提供Keras和TensorFlow的默认基于CPU的安装。 如果想要更加自定义的安装,可以看install_keras()的文档。

MNIST

我们可以通过一个简单的例子来学习Keras的基础知识:从MNIST数据集识别手写数字。 MNIST由28 x 28像这样的手写数字的灰度图像组成:

《Keras的R语言接口》 image.png

准备数据

数据在keras的这个包中,

library(keras)
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y

x数据是灰度值的三维数组(图像,宽度,高度)。 为了准备训练数据,我们通过将宽度和高度重新整形为一个维度(28×28图像平展成长度为784个向量)将三维数组转换为矩阵。 然后,我们将灰度值从范围在0到255之间的整数转换为介于0和1之间的浮点值:

# reshape
x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_test <- array_reshape(x_test, c(nrow(x_test), 784))
# rescale
x_train <- x_train / 255
x_test <- x_test / 255

请注意,我们使用array_reshape()函数而不是dim < – ()函数来重新整形数组。 这是为了使用行主语义(与R的默认列主语义相反)重新解释数据,这又与Keras调用的数值库解释数组维度的方式兼容。

y数据是一个整数向量,其取值范围为0到9.为了准备这些数据以进行训练,我们使用Keras to_categorical()函数将向量单向热编码为二进制类矩阵:

y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)

定义模型

Keras的核心数据结构是一种模型,一种组织图层的方法。 最简单的模型是Sequential模型,这是一个线性的层堆栈。

我们首先创建一个顺序模型,然后使用管道(%>%)运算符添加图层:

model <- keras_model_sequential() 
model %>% 
  layer_dense(units = 256, activation = 'relu', input_shape = c(784)) %>% 
  layer_dropout(rate = 0.4) %>% 
  layer_dense(units = 128, activation = 'relu') %>%
  layer_dropout(rate = 0.3) %>%
  layer_dense(units = 10, activation = 'softmax')

他对第一层的input_shape参数指定了输入数据的形状(代表灰度图像的长度为784的数字向量)。 最后一层使用softmax激活函数输出长度为10的数字向量(每个数字的概率)。

使用summary()函数打印模型的详细信息:

summary(model)
_________________________________________________________________________
Layer (type)                    Output Shape                  Param #    
=========================================================================
dense_1 (Dense)                 (None, 256)                   200960     
_________________________________________________________________________
dropout_1 (Dropout)             (None, 256)                   0          
_________________________________________________________________________
dense_2 (Dense)                 (None, 128)                   32896      
_________________________________________________________________________
dropout_2 (Dropout)             (None, 128)                   0          
_________________________________________________________________________
dense_3 (Dense)                 (None, 10)                    1290       
=========================================================================
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
_________________________________________________________________________

接下来,使用适当的损失函数,优化器和指标编译模型:

model %>% compile(
  loss = 'categorical_crossentropy',
  optimizer = optimizer_rmsprop(),
  metrics = c('accuracy')
)

训练评估模型

使用fit()函数使用128个图像的批次对30个时期的模型进行训练:

history <- model %>% fit(
  x_train, y_train, 
  epochs = 30, batch_size = 128, 
  validation_split = 0.2
)
Train on 48000 samples, validate on 12000 samples
Epoch 1/30
48000/48000 [==============================] - 4s 91us/step - loss: 0.4249 - acc: 0.8717 - val_loss: 0.1666 - val_acc: 0.9490
Epoch 2/30
48000/48000 [==============================] - 4s 81us/step - loss: 0.2023 - acc: 0.9399 - val_loss: 0.1278 - val_acc: 0.9634
Epoch 3/30
48000/48000 [==============================] - 4s 79us/step - loss: 0.1552 - acc: 0.9534 - val_loss: 0.1148 - val_acc: 0.9681
Epoch 4/30
48000/48000 [==============================] - 4s 81us/step - loss: 0.1320 - acc: 0.9609 - val_loss: 0.1008 - val_acc: 0.9716
Epoch 5/30
48000/48000 [==============================] - 4s 76us/step - loss: 0.1148 - acc: 0.9658 - val_loss: 0.0933 - val_acc: 0.9738
Epoch 6/30
48000/48000 [==============================] - 4s 77us/step - loss: 0.1048 - acc: 0.9684 - val_loss: 0.0914 - val_acc: 0.9752
Epoch 7/30
48000/48000 [==============================] - 4s 78us/step - loss: 0.0979 - acc: 0.9715 - val_loss: 0.0901 - val_acc: 0.9752
Epoch 8/30
48000/48000 [==============================] - 4s 77us/step - loss: 0.0887 - acc: 0.9745 - val_loss: 0.0919 - val_acc: 0.9758
Epoch 9/30
48000/48000 [==============================] - 4s 77us/step - loss: 0.0858 - acc: 0.9748 - val_loss: 0.0904 - val_acc: 0.9779
Epoch 10/30
48000/48000 [==============================] - 4s 78us/step - loss: 0.0807 - acc: 0.9769 - val_loss: 0.0903 - val_acc: 0.9783
Epoch 11/30
48000/48000 [==============================] - 4s 77us/step - loss: 0.0781 - acc: 0.9781 - val_loss: 0.0956 - val_acc: 0.9771
Epoch 12/30
48000/48000 [==============================] - 4s 78us/step - loss: 0.0768 - acc: 0.9788 - val_loss: 0.0917 - val_acc: 0.9787
Epoch 13/30
48000/48000 [==============================] - 4s 77us/step - loss: 0.0706 - acc: 0.9794 - val_loss: 0.0909 - val_acc: 0.9784
Epoch 14/30
48000/48000 [==============================] - 4s 85us/step - loss: 0.0684 - acc: 0.9804 - val_loss: 0.0933 - val_acc: 0.9787
Epoch 15/30
48000/48000 [==============================] - 4s 84us/step - loss: 0.0682 - acc: 0.9810 - val_loss: 0.1013 - val_acc: 0.9785
Epoch 16/30
48000/48000 [==============================] - 4s 82us/step - loss: 0.0647 - acc: 0.9812 - val_loss: 0.0951 - val_acc: 0.9795
Epoch 17/30
48000/48000 [==============================] - 4s 78us/step - loss: 0.0627 - acc: 0.9829 - val_loss: 0.1004 - val_acc: 0.9792
Epoch 18/30
48000/48000 [==============================] - 4s 79us/step - loss: 0.0671 - acc: 0.9823 - val_loss: 0.0959 - val_acc: 0.9803
Epoch 19/30
48000/48000 [==============================] - 4s 77us/step - loss: 0.0602 - acc: 0.9831 - val_loss: 0.0976 - val_acc: 0.9797
Epoch 20/30
48000/48000 [==============================] - 4s 76us/step - loss: 0.0593 - acc: 0.9835 - val_loss: 0.1051 - val_acc: 0.9786
Epoch 21/30
48000/48000 [==============================] - 4s 78us/step - loss: 0.0592 - acc: 0.9840 - val_loss: 0.1008 - val_acc: 0.9799
Epoch 22/30
48000/48000 [==============================] - 4s 76us/step - loss: 0.0561 - acc: 0.9846 - val_loss: 0.1023 - val_acc: 0.9800
Epoch 23/30
48000/48000 [==============================] - 4s 78us/step - loss: 0.0592 - acc: 0.9844 - val_loss: 0.1100 - val_acc: 0.9787
Epoch 24/30
48000/48000 [==============================] - 4s 83us/step - loss: 0.0566 - acc: 0.9848 - val_loss: 0.1048 - val_acc: 0.9790
Epoch 25/30
48000/48000 [==============================] - 4s 79us/step - loss: 0.0531 - acc: 0.9852 - val_loss: 0.1091 - val_acc: 0.9802
Epoch 26/30
48000/48000 [==============================] - 4s 79us/step - loss: 0.0570 - acc: 0.9850 - val_loss: 0.1055 - val_acc: 0.9803
Epoch 27/30
48000/48000 [==============================] - 4s 84us/step - loss: 0.0515 - acc: 0.9868 - val_loss: 0.1114 - val_acc: 0.9798
Epoch 28/30
48000/48000 [==============================] - 4s 78us/step - loss: 0.0532 - acc: 0.9861 - val_loss: 0.1148 - val_acc: 0.9799
Epoch 29/30
48000/48000 [==============================] - 4s 76us/step - loss: 0.0532 - acc: 0.9860 - val_loss: 0.1105 - val_acc: 0.9796
Epoch 30/30
48000/48000 [==============================] - 4s 77us/step - loss: 0.0519 - acc: 0.9869 - val_loss: 0.1179 - val_acc: 0.9796

《Keras的R语言接口》 image.png

测试集合上评估模型的性能

model %>% evaluate(x_test, y_test)
10000/10000 [==============================] - 1s 55us/step
$loss
[1] 0.1040304

$acc
[1] 0.9815

进行预测

model %>% predict_classes(x_test)
[1] 7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3
  [34] 4 7 2 7 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4
  [67] 6 4 3 0 7 0 2 9 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6
 [100] 9 6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 4 4 4 9 2 5 4 7 6 7 9 0 5 8 5 6 6
 [133] 5 7 8 1 0 1 6 4 6 7 3 1 7 1 8 2 0 2 9 9 5 5 1 5 6 0 3 4 4 6 5 4 6
 [166] 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 9 2 5 0 1 1 1 0 9 0 3 1 6
 [199] 4 2 3 6 1 1 1 3 9 5 2 9 4 5 9 3 9 0 3 6 5 5 7 2 2 7 1 2 8 4 1 7 3
 [232] 3 8 8 7 9 2 2 4 1 5 9 8 7 2 3 0 2 4 2 4 1 9 5 7 7 2 8 2 0 8 5 7 7
 [265] 9 1 8 1 8 0 3 0 1 9 9 4 1 8 2 1 2 9 7 5 9 2 6 4 1 5 8 2 9 2 0 4 0
 [298] 0 2 8 4 7 1 2 4 0 2 7 4 3 3 0 0 3 1 9 6 5 2 5 9 7 9 3 0 4 2 0 7 1
 [331] 1 2 1 5 3 3 9 7 8 6 5 6 1 3 8 1 0 5 1 3 1 5 5 6 1 8 5 1 7 9 4 6 2
 [364] 2 5 0 6 5 6 3 7 2 0 8 8 5 4 1 1 4 0 7 3 7 6 1 6 2 1 9 2 8 6 1 9 5
 [397] 2 5 4 4 2 8 3 8 2 4 5 0 3 1 7 7 5 7 9 7 1 9 2 1 4 2 9 2 0 4 9 1 4
 [430] 8 1 8 4 5 9 8 8 3 7 6 0 0 3 0 2 0 6 9 9 3 3 3 2 3 9 1 2 6 8 0 5 6
 [463] 6 6 3 8 8 2 7 5 8 9 6 1 8 4 1 2 5 9 1 9 7 5 4 0 8 9 9 1 0 5 2 3 7
 [496] 0 9 4 0 6 3 9 5 2 1 3 1 3 6 5 7 4 2 2 6 3 2 6 5 4 8 9 7 1 3 0 3 8
 [529] 3 1 9 3 4 4 6 4 2 1 8 2 5 4 8 8 4 0 0 2 3 2 7 7 0 8 7 4 4 7 9 6 9
 [562] 0 9 8 0 4 6 0 6 3 5 4 8 3 3 9 3 3 3 7 8 0 2 2 1 7 0 6 5 4 3 8 0 9
 [595] 6 3 8 0 9 9 6 8 6 8 5 7 8 6 0 2 4 0 2 2 3 1 9 7 5 8 0 8 4 6 2 6 7
 [628] 9 3 2 9 8 2 2 9 2 7 3 5 9 1 8 0 2 0 5 2 1 3 7 6 7 1 2 5 8 0 3 7 2
 [661] 4 0 9 1 8 6 7 7 4 3 4 9 1 9 5 1 7 3 9 7 6 9 1 3 3 8 3 3 6 7 2 4 5
 [694] 8 5 1 1 4 4 3 1 0 7 7 0 7 9 4 4 8 5 5 4 0 8 2 1 0 8 4 5 0 4 0 6 1
 [727] 9 3 2 6 7 2 6 9 3 1 4 6 2 5 9 2 0 6 2 1 7 3 4 1 0 5 4 3 1 1 7 4 9
 [760] 9 4 8 4 0 2 4 5 1 1 6 4 7 1 9 4 2 4 1 5 5 3 8 3 1 4 5 6 8 9 4 1 5
 [793] 3 8 0 3 2 5 1 2 8 3 4 4 0 8 8 3 3 1 7 3 5 9 6 3 2 6 1 3 6 0 7 2 1
 [826] 7 1 4 2 4 2 1 7 9 6 1 1 2 4 8 1 7 7 4 8 0 7 3 1 3 1 0 7 7 0 3 5 5
 [859] 2 7 6 6 9 2 8 3 5 2 2 5 6 0 8 2 9 2 8 8 8 8 7 4 9 3 0 6 6 3 2 1 3
 [892] 2 2 9 3 0 0 5 7 8 3 4 4 6 0 2 9 1 4 7 4 7 3 9 8 8 4 7 1 2 1 2 2 3
 [925] 2 3 2 3 9 1 7 4 0 3 5 5 8 6 3 2 6 7 6 6 3 2 7 9 1 1 7 5 6 4 9 5 1
 [958] 3 3 4 7 8 9 1 1 0 9 1 4 4 5 4 0 6 2 2 3 1 5 1 2 0 3 8 1 2 6 7 1 6
 [991] 2 3 9 0 1 2 2 0 8 9
 [ reached getOption("max.print") -- omitted 9000 entries ]

很酷吧

    原文作者:Liam_ml
    原文地址: https://www.jianshu.com/p/5e5c140a563d
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞