设为首页 - 加入收藏 ASP站长网(Aspzz.Cn)- 科技、建站、经验、云计算、5G、大数据,站长网!
热搜: 数据 手机 公司
当前位置: 首页 > 运营中心 > 建站资源 > 经验 > 正文

看懂这篇指南,包你掌握神经网络的“黑匣子”(4)

发布时间:2019-06-04 09:24 所属栏目:19 来源:读芯术
导读:现在可以导入图像并对其进行加工: fromkeras.preprocessing.imageimportload_img #loadanimagefromfile image=load_img('car.jpeg',target_size=(224,224)) plt.imshow(image) plt.title('ORIGINALIMAGE') 一共分

现在可以导入图像并对其进行加工:

  1. from keras.preprocessing.image import load_img 
  2. # load an image from file 
  3. image = load_img('car.jpeg', target_size=(224, 224)) 
  4. plt.imshow(image) 
  5. plt.title('ORIGINAL IMAGE') 

看懂这篇指南,包你掌握神经网络的“黑匣子”

一共分为三个步骤:

  • 对图像进行预处理
  • 计算不同遮挡部分的概率
  • 绘制热图
  1. from keras.preprocessing.image import img_to_array 
  2. from keras.applications.vgg16 import preprocess_input 
  3. # convert the image pixels to a numpy array 
  4. image = img_to_array(image) 
  5. # reshape data for the model 
  6. imageimage = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) 
  7. # prepare the image for the VGG model 
  8. image = preprocess_input(image) 
  9. # predict the probability across all output classes 
  10. yhat = model.predict(image) 
  11. temp = image[0] 
  12. print(temp.shape) 
  13. heatmap = np.zeros((224,224)) 
  14. correct_class = np.argmax(yhat) 
  15. for n,(x,y,image) in enumerate(iter_occlusion(temp,14)): 
  16.     heatmap[x:x+14,y:y+14] = model.predict(image.reshape((1, image.shape[0], image.shape[1], image.shape[2])))[0][correct_class] 
  17.     print(x,y,n,' - ',image.shape) 
  18. heatmapheatmap1 = heatmap/heatmap.max() 
  19. plt.imshow(heatmap) 

看懂这篇指南,包你掌握神经网络的“黑匣子”

是不是很有趣呢?接着将使用标准化的热图概率来创建一个遮挡部分并进行绘制:

  1. import skimage.io as io 
  2. #creating mask from the standardised heatmap probabilities 
  3. mask = heatmap1 < 0.85 
  4. maskmask1 = mask *256 
  5. maskmask = mask.astype(int) 
  6. io.imshow(mask,cmap='gray') 

看懂这篇指南,包你掌握神经网络的“黑匣子”

最后,通过使用下述程序,对输入图像进行遮挡:

  1. import cv2 
  2. #read the image 
  3. image = cv2.imread('car.jpeg') 
  4. image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) 
  5. #resize image to appropriate dimensions 
  6. image = cv2.resize(image,(224,224)) 
  7. maskmask = mask.astype('uint8') 
  8. #apply the mask to the image 
  9. final = cv2.bitwise_and(image,image,maskmask = mask) 
  10. final = cv2.cvtColor(final,cv2.COLOR_BGR2RGB) 
  11. #plot the final image 
  12. plt.imshow(final) 

看懂这篇指南,包你掌握神经网络的“黑匣子”

猜猜为什么只能看到某些部分?没错——只有那些对输出图片类型的概率有显著贡献的部分是可见的。简而言之,这就是遮挡图的全部含义。

特征图——将输入特征的贡献可视化

特征图是另一种基于梯度的可视化技术。这类图像在 Deep Inside Convolutional Networks:Visualising Image Classification Models and Saliency Maps.论文中有介绍。

特征图计算出每个像素对模型输出的影响,包括计算相对于输入图像每一像素而言输出的梯度。

(编辑:ASP站长网)

网友评论
推荐文章
    热点阅读