设为首页 - 加入收藏 ASP站长网(Aspzz.Cn)- 科技、建站、经验、云计算、5G、大数据,站长网!
热搜: 重新 试卷 文件
当前位置: 首页 > 运营中心 > 建站资源 > 优化 > 正文

使用深度学习检测疟疾(10)

发布时间:2019-05-24 20:19 所属栏目:21 来源:Dipanjan (dj) Sarkar
导读:这看起来是我们的最好的模型。它给了我们近乎 96.5% 的验证精确率,基于训练精度,它看起来不像我们的第一个模型那样过拟合。这可以通过下列的学习曲线验证。 微调过的预训练 CNN 的学习曲线 让我们保存这个模型,

这看起来是我们的最好的模型。它给了我们近乎 96.5% 的验证精确率,基于训练精度,它看起来不像我们的第一个模型那样过拟合。这可以通过下列的学习曲线验证。

使用深度学习检测疟疾

微调过的预训练 CNN 的学习曲线

让我们保存这个模型,因此我们能够在测试集上使用。

  1. model.save('vgg_finetuned.h5')

这就完成了我们的模型训练阶段。现在我们准备好了在测试集上测试我们模型的性能。

深度学习模型性能评估

我们将通过在我们的测试集上做预测来评估我们在训练阶段构建的三个模型,因为仅仅验证是不够的!我们同样构建了一个检测工具模块叫做 model_evaluation_utils,我们可以使用相关分类指标用来评估使用我们深度学习模型的性能。第一步是扩展我们的数据集。

  1. test_imgs_scaled = test_data / 255.
  2. test_imgs_scaled.shape, test_labels.shape
  3.  
  4. # Output
  5. ((8268, 125, 125, 3), (8268,))

下一步包括载入我们保存的深度学习模型,在测试集上预测。

  1. # Load Saved Deep Learning Models
  2. basic_cnn = tf.keras.models.load_model('./basic_cnn.h5')
  3. vgg_frz = tf.keras.models.load_model('./vgg_frozen.h5')
  4. vgg_ft = tf.keras.models.load_model('./vgg_finetuned.h5')
  5.  
  6. # Make Predictions on Test Data
  7. basic_cnn_preds = basic_cnn.predict(test_imgs_scaled, batch_size=512)
  8. vgg_frz_preds = vgg_frz.predict(test_imgs_scaled, batch_size=512)
  9. vgg_ft_preds = vgg_ft.predict(test_imgs_scaled, batch_size=512)
  10.  
  11. basic_cnn_pred_labels = le.inverse_transform([1 if pred > 0.5 else 0
  12. for pred in basic_cnn_preds.ravel()])
  13. vgg_frz_pred_labels = le.inverse_transform([1 if pred > 0.5 else 0
  14. for pred in vgg_frz_preds.ravel()])
  15. vgg_ft_pred_labels = le.inverse_transform([1 if pred > 0.5 else 0
  16. for pred in vgg_ft_preds.ravel()])

下一步是应用我们的 model_evaluation_utils 模块根据相应分类指标来检查每个模块的性能。

  1. import model_evaluation_utils as meu
  2. import pandas as pd
  3.  
  4. basic_cnn_metrics = meu.get_metrics(true_labels=test_labels, predicted_labels=basic_cnn_pred_labels)
  5. vgg_frz_metrics = meu.get_metrics(true_labels=test_labels, predicted_labels=vgg_frz_pred_labels)
  6. vgg_ft_metrics = meu.get_metrics(true_labels=test_labels, predicted_labels=vgg_ft_pred_labels)
  7.  
  8. pd.DataFrame([basic_cnn_metrics, vgg_frz_metrics, vgg_ft_metrics],
  9. index=['Basic CNN', 'VGG-19 Frozen', 'VGG-19 Fine-tuned'])

使用深度学习检测疟疾

Model accuracy

看起来我们的第三个模型在我们的测试集上执行的最好,给出了一个模型精确性为 96% 的 F1 得分,这非常好,与我们之前提到的研究论文和文章中的更复杂的模型相当。

总结

(编辑:ASP站长网)

网友评论
推荐文章
    热点阅读