文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
本文主要介绍Keras的一些基本用法,主要涉及已有网络的fine tuning,以ResNet50为例。
- Demo
#!/usr/bin/env python
# _*_ coding: utf-8 _*_
from keras.models import Model
from keras.layers import Dense
from keras.applications.resnet50 import ResNet50
from keras.preprocessing.image import ImageDataGenerator
# 训练的batch_size
batch_size = 16
# 训练的epoch
epochs = 100
# 图像Generator,用来构建输入数据
train_datagen = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.2,
horizontal_flip=True)
# 从文件中读取数据,目录结构应为train下面是各个类别的子目录,每个子目录中为对应类别的图像
train_generator = train_datagen.flow_from_directory('./train', target_size = (224, 224), batch_size = batch_size)
# 训练图像的数量
image_numbers = train_generator.samples
# 输出类别信息
print train_generator.class_indices
# 生成测试数据
test_datagen = ImageDataGenerator()
validation_generator = test_datagen.flow_from_directory('./validation', target_size = (224, 224), batch_size = batch_size)
# 使用ResNet的结构,不包括最后一层,且加载ImageNet的预训练参数
base_model = ResNet50(weights = 'imagenet', include_top = False, pooling = 'avg')
# 构建网络的最后一层,3是自己的数据的类别
predictions = Dense(3, activation='softmax')(base_model.output)
# 定义整个模型
model = Model(inputs=base_model.input, outputs=predictions)
# 编译模型,loss为交叉熵损失
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
# 训练模型
model.fit_generator(train_generator,steps_per_epoch = image_numbers // batch_size, epochs = epochs, validation_data = validation_generator, validation_steps = batch_size)
# 保存训练得到的模型
model.save_weights('weights.h5')
- 部分结果
{'Type_3': 2, 'Type_2': 1, 'Type_1': 0}
Found 761 images belonging to 3 classes.
Epoch 1/40
1/16 [>.............................] - ETA: 119s - loss: 1.33922017-06-07 10:18:48.246289: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:247] PoolAllocator: After 2521 get requests, put_count=2161 evicted_count=1000 eviction_rate=0.462749 and unsatisfied allocation rate=0.579135
2017-06-07 10:18:48.246348: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:259] Raising pool_size_limit_ from 100 to 110
16/16 [==============================] - 120s - loss: 2.3753 - val_loss: 10.8293
Epoch 2/40
1/16 [>.............................] - ETA: 5s - loss: 1.00542017-06-07 10:20:40.464589: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:247] PoolAllocator: After 2270 get requests, put_count=2642 evicted_count=1000 eviction_rate=0.378501 and unsatisfied allocation rate=0.286784
2017-06-07 10:20:40.464643: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:259] Raising pool_size_limit_ from 256 to 281
16/16 [==============================] - 83s - loss: 1.7988 - val_loss: 11.5219
Epoch 3/40
16/16 [==============================] - 81s - loss: 1.6640 - val_loss: 11.0043
Epoch 4/40
3/16 [====>.........................] - ETA: 4s - loss: 1.87452017-06-07 10:23:26.725923: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:247] PoolAllocator: After 11057 get requests, put_count=11071 evicted_count=1000 eviction_rate=0.0903261 and unsatisfied allocation rate=0.0945103
2017-06-07 10:23:26.725986: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:259] Raising pool_size_limit_ from 655 to 720
16/16 [==============================] - 83s - loss: 1.7237 - val_loss: 11.7738
Epoch 5/40
16/16 [==============================] - 83s - loss: 1.6304 - val_loss: 10.6538
Epoch 6/40
16/16 [==============================] - 80s - loss: 1.2182 - val_loss: 4.5027
Epoch 7/40
16/16 [==============================] - 83s - loss: 1.3179 - val_loss: 11.5891
Epoch 8/40
16/16 [==============================] - 82s - loss: 1.1806 - val_loss: 10.5800
Epoch 9/40
16/16 [==============================] - 81s - loss: 1.1935 - val_loss: 11.1477
Epoch 10/40
16/16 [==============================] - 80s - loss: 1.1727 - val_loss: 7.0913
Epoch 11/40
16/16 [==============================] - 83s - loss: 1.2058 - val_loss: 6.4474
Epoch 12/40
16/16 [==============================] - 82s - loss: 1.2702 - val_loss: 7.7678
Epoch 13/40
16/16 [==============================] - 84s - loss: 1.2060 - val_loss: 7.9961
Epoch 14/40
16/16 [==============================] - 83s - loss: 1.0768 - val_loss: 11.2121
Epoch 15/40
16/16 [==============================] - 80s - loss: 1.1401 - val_loss: 13.2052
Epoch 16/40
16/16 [==============================] - 83s - loss: 1.1961 - val_loss: 13.0330