Rajaraman, et al. 的论文在一个数据集上利用六个预训练模型在检测疟疾对比无感染样本获取到令人吃惊的 95.9% 的准确率。我们的重点是从头开始尝试一些简单的 CNN 模型和用一个预训练的训练模型使用迁移学习来查看我们能够从相同的数据集中得到什么。我们将使用开源工具和框架,包括 Python 和 TensorFlow,来构建我们的模型。
数据集
我们分析的数据来自 Lister Hill 国家生物医学交流中心(LHNCBC)的研究人员,该中心是国家医学图书馆(NLM)的一部分,他们细心收集和标记了公开可用的健康和受感染的血涂片图像的数据集。这些研究者已经开发了一个运行在 Android 智能手机的疟疾检测手机应用,连接到一个传统的光学显微镜。它们使用吉姆萨染液将 150 个受恶性疟原虫感染的和 50 个健康病人的薄血涂片染色,这些薄血涂片是在孟加拉的吉大港医学院附属医院收集和照相的。使用智能手机的内置相机获取每个显微镜视窗内的图像。这些图片由在泰国曼谷的马希多-牛津热带医学研究所的一个专家使用幻灯片阅读器标记的。
让我们简要地查看一下数据集的结构。首先,我将安装一些基础的依赖(基于使用的操作系统)。
Installing dependencies
我使用的是云上的带有一个 GPU 的基于 Debian 的操作系统,这样我能更快的运行我的模型。为了查看目录结构,我们必须使用 sudo apt install tree 安装 tree 及其依赖(如果我们没有安装的话)。
Installing the tree dependency
我们有两个文件夹包含血细胞的图像,包括受感染的和健康的。我们通过输入可以获取关于图像总数更多的细节:
import os import glob -
base_dir = os.path.join('./cell_images') infected_dir = os.path.join(base_dir,'Parasitized') healthy_dir = os.path.join(base_dir,'Uninfected') -
infected_files = glob.glob(infected_dir+'/*.png') healthy_files = glob.glob(healthy_dir+'/*.png') len(infected_files), len(healthy_files) -
# Output (13779, 13779)
看起来我们有一个平衡的数据集,包含 13,779 张疟疾的和 13,779 张非疟疾的(健康的)血细胞图像。让我们根据这些构建数据帧,我们将用这些数据帧来构建我们的数据集。
import numpy as np import pandas as pd -
np.random.seed(42) -
files_df = pd.DataFrame({ 'filename': infected_files + healthy_files, 'label': ['malaria'] * len(infected_files) + ['healthy'] * len(healthy_files) }).sample(frac=1, random_state=42).reset_index(drop=True) -
files_df.head()
Datasets
构建和了解图像数据集
为了构建深度学习模型,我们需要训练数据,但是我们还需要使用不可见的数据测试模型的性能。相应的,我们将使用 60:10:30 的比例来划分用于训练、验证和测试的数据集。我们将在训练期间应用训练和验证数据集,并用测试数据集来检查模型的性能。
from sklearn.model_selection import train_test_split from collections import Counter -
train_files, test_files, train_labels, test_labels = train_test_split(files_df['filename'].values, files_df['label'].values, test_size=0.3, random_state=42) train_files, val_files, train_labels, val_labels = train_test_split(train_files, train_labels, test_size=0.1, random_state=42) -
print(train_files.shape, val_files.shape, test_files.shape) print('Train:', Counter(train_labels), '\nVal:', Counter(val_labels), '\nTest:', Counter(test_labels)) -
# Output (17361,) (1929,) (8268,) Train: Counter({'healthy': 8734, 'malaria': 8627}) Val: Counter({'healthy': 970, 'malaria': 959}) Test: Counter({'malaria': 4193, 'healthy': 4075})
(编辑:ASP站长网)
|