我们将使用预训练的 VGG-19 深度训练模型(由剑桥大学的视觉几何组(VGG)开发)进行我们的实验。像 VGG-19 这样的预训练模型是在一个大的数据集(Imagenet)上使用了很多不同的图像分类训练的。因此,这个模型应该已经学习到了健壮的特征层级结构,相对于你的 CNN 模型学到的特征,是空间不变的、转动不变的、平移不变的。因此,这个模型,已经从百万幅图片中学习到了一个好的特征显示,对于像疟疾检测这样的计算机视觉问题,可以作为一个好的合适新图像的特征提取器。在我们的问题中发挥迁移学习的能力之前,让我们先讨论 VGG-19 模型。
理解 VGG-19 模型
VGG-19 模型是一个构建在 ImageNet 数据库之上的 19 层(卷积和全连接的)的深度学习网络,ImageNet 数据库为了图像识别和分类的目的而开发。该模型是由 Karen Simonyan 和 Andrew Zisserman 构建的,在他们的论文“大规模图像识别的非常深的卷积网络”中进行了描述。VGG-19 的架构模型是:
VGG-19 模型架构
你可以看到我们总共有 16 个使用 3x3 卷积过滤器的卷积层,与最大的池化层来下采样,和由 4096 个单元组成的两个全连接的隐藏层,每个隐藏层之后跟随一个由 1000 个单元组成的致密层,每个单元代表 ImageNet 数据库中的一个分类。我们不需要最后三层,因为我们将使用我们自己的全连接致密层来预测疟疾。我们更关心前五个块,因此我们可以利用 VGG 模型作为一个有效的特征提取器。
我们将使用模型之一作为一个简单的特征提取器,通过冻结五个卷积块的方式来确保它们的位权在每个纪元后不会更新。对于最后一个模型,我们会对 VGG 模型进行微调,我们会解冻最后两个块(第 4 和第 5)因此当我们训练我们的模型时,它们的位权在每个时期(每批数据)被更新。
模型 2:预训练的模型作为一个特征提取器
为了构建这个模型,我们将利用 TensorFlow 载入 VGG-19 模型并冻结卷积块,因此我们能够将它们用作特征提取器。我们在末尾插入我们自己的致密层来执行分类任务。
vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet', input_shape=INPUT_SHAPE) vgg.trainable = False # Freeze the layers for layer in vgg.layers: layer.trainable = False base_vgg = vgg base_out = base_vgg.output pool_out = tf.keras.layers.Flatten()(base_out) hidden1 = tf.keras.layers.Dense(512, activation='relu')(pool_out) drop1 = tf.keras.layers.Dropout(rate=0.3)(hidden1) hidden2 = tf.keras.layers.Dense(512, activation='relu')(drop1) drop2 = tf.keras.layers.Dropout(rate=0.3)(hidden2) -
out = tf.keras.layers.Dense(1, activation='sigmoid')(drop2) -
model = tf.keras.Model(inputs=base_vgg.input, outputs=out) model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy']) model.summary() -
-
# Output Model: "model_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 125, 125, 3)] 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 125, 125, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 125, 125, 64) 36928 _________________________________________________________________ ... ... _________________________________________________________________ block5_pool (MaxPooling2D) (None, 3, 3, 512) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 4608) 0 _________________________________________________________________ dense_3 (Dense) (None, 512) 2359808 _________________________________________________________________ dropout_2 (Dropout) (None, 512) 0 _________________________________________________________________ dense_4 (Dense) (None, 512) 262656 _________________________________________________________________ dropout_3 (Dropout) (None, 512) 0 _________________________________________________________________ dense_5 (Dense) (None, 1) 513 ================================================================= Total params: 22,647,361 Trainable params: 2,622,977 Non-trainable params: 20,024,384 _________________________________________________________________
(编辑:ASP站长网)
|