设为首页 - 加入收藏 ASP站长网(Aspzz.Cn)- 科技、建站、经验、云计算、5G、大数据,站长网!
热搜: 重新 试卷 文件
当前位置: 首页 > 运营中心 > 建站资源 > 优化 > 正文

使用深度学习检测疟疾(7)

发布时间:2019-05-24 20:19 所属栏目:21 来源:Dipanjan (dj) Sarkar
导读:我们将使用预训练的 VGG-19 深度训练模型(由剑桥大学的视觉几何组(VGG)开发)进行我们的实验。像 VGG-19 这样的预训练模型是在一个大的数据集(Imagenet)上使用了很多不同的图像分类训练的。因此,这个模型应该

我们将使用预训练的 VGG-19 深度训练模型(由剑桥大学的视觉几何组(VGG)开发)进行我们的实验。像 VGG-19 这样的预训练模型是在一个大的数据集(Imagenet)上使用了很多不同的图像分类训练的。因此,这个模型应该已经学习到了健壮的特征层级结构,相对于你的 CNN 模型学到的特征,是空间不变的、转动不变的、平移不变的。因此,这个模型,已经从百万幅图片中学习到了一个好的特征显示,对于像疟疾检测这样的计算机视觉问题,可以作为一个好的合适新图像的特征提取器。在我们的问题中发挥迁移学习的能力之前,让我们先讨论 VGG-19 模型。

理解 VGG-19 模型

VGG-19 模型是一个构建在 ImageNet 数据库之上的 19 层(卷积和全连接的)的深度学习网络,ImageNet 数据库为了图像识别和分类的目的而开发。该模型是由 Karen Simonyan 和 Andrew Zisserman 构建的,在他们的论文“大规模图像识别的非常深的卷积网络”中进行了描述。VGG-19 的架构模型是:

使用深度学习检测疟疾

VGG-19 模型架构

你可以看到我们总共有 16 个使用 3x3 卷积过滤器的卷积层,与最大的池化层来下采样,和由 4096 个单元组成的两个全连接的隐藏层,每个隐藏层之后跟随一个由 1000 个单元组成的致密层,每个单元代表 ImageNet 数据库中的一个分类。我们不需要最后三层,因为我们将使用我们自己的全连接致密层来预测疟疾。我们更关心前五个块,因此我们可以利用 VGG 模型作为一个有效的特征提取器。

我们将使用模型之一作为一个简单的特征提取器,通过冻结五个卷积块的方式来确保它们的位权在每个纪元后不会更新。对于最后一个模型,我们会对 VGG 模型进行微调,我们会解冻最后两个块(第 4 和第 5)因此当我们训练我们的模型时,它们的位权在每个时期(每批数据)被更新。

模型 2:预训练的模型作为一个特征提取器

为了构建这个模型,我们将利用 TensorFlow 载入 VGG-19 模型并冻结卷积块,因此我们能够将它们用作特征提取器。我们在末尾插入我们自己的致密层来执行分类任务。

  1. vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet',
  2. input_shape=INPUT_SHAPE)
  3. vgg.trainable = False
  4. # Freeze the layers
  5. for layer in vgg.layers:
  6. layer.trainable = False
  7. base_vgg = vgg
  8. base_out = base_vgg.output
  9. pool_out = tf.keras.layers.Flatten()(base_out)
  10. hidden1 = tf.keras.layers.Dense(512, activation='relu')(pool_out)
  11. drop1 = tf.keras.layers.Dropout(rate=0.3)(hidden1)
  12. hidden2 = tf.keras.layers.Dense(512, activation='relu')(drop1)
  13. drop2 = tf.keras.layers.Dropout(rate=0.3)(hidden2)
  14.  
  15. out = tf.keras.layers.Dense(1, activation='sigmoid')(drop2)
  16.  
  17. model = tf.keras.Model(inputs=base_vgg.input, outputs=out)
  18. model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=1e-4),
  19. loss='binary_crossentropy',
  20. metrics=['accuracy'])
  21. model.summary()
  22.  
  23.  
  24. # Output
  25. Model: "model_1"
  26. _________________________________________________________________
  27. Layer (type) Output Shape Param #
  28. =================================================================
  29. input_2 (InputLayer) [(None, 125, 125, 3)] 0
  30. _________________________________________________________________
  31. block1_conv1 (Conv2D) (None, 125, 125, 64) 1792
  32. _________________________________________________________________
  33. block1_conv2 (Conv2D) (None, 125, 125, 64) 36928
  34. _________________________________________________________________
  35. ...
  36. ...
  37. _________________________________________________________________
  38. block5_pool (MaxPooling2D) (None, 3, 3, 512) 0
  39. _________________________________________________________________
  40. flatten_1 (Flatten) (None, 4608) 0
  41. _________________________________________________________________
  42. dense_3 (Dense) (None, 512) 2359808
  43. _________________________________________________________________
  44. dropout_2 (Dropout) (None, 512) 0
  45. _________________________________________________________________
  46. dense_4 (Dense) (None, 512) 262656
  47. _________________________________________________________________
  48. dropout_3 (Dropout) (None, 512) 0
  49. _________________________________________________________________
  50. dense_5 (Dense) (None, 1) 513
  51. =================================================================
  52. Total params: 22,647,361
  53. Trainable params: 2,622,977
  54. Non-trainable params: 20,024,384
  55. _________________________________________________________________

(编辑:ASP站长网)

网友评论
推荐文章
    热点阅读