设为首页 - 加入收藏 ASP站长网(Aspzz.Cn)- 科技、建站、经验、云计算、5G、大数据,站长网!
热搜: 手机 数据 公司
当前位置: 首页 > 运营中心 > 建站资源 > 经验 > 正文

PyTorch最佳实践,怎样才能写出一手风格优美的代码(2)

发布时间:2019-05-06 17:40 所属栏目:19 来源:机器之心编译
导读:请注意以下几点: 我们复用了简单的循环构建模块(如卷积块 ConvBlocks),它们由相同的循环模式(卷积、激活函数、归一化)组成,并装入独立的 nn.Module 中。 我们构建了一个所需要层的列表,并最终使用「nn.Sequenti

请注意以下几点:

  • 我们复用了简单的循环构建模块(如卷积块 ConvBlocks),它们由相同的循环模式(卷积、激活函数、归一化)组成,并装入独立的 nn.Module 中。
  • 我们构建了一个所需要层的列表,并最终使用「nn.Sequential()」将所有层级组合到了一个模型中。我们在 list 对象前使用「*」操作来展开它。
  • 在前向传导过程中,我们直接使用输入数据运行模型。

2. PyTorch 环境下的简单残差网络

  1. class ResnetBlock(nn.Module): 
  2.     def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 
  3.         super(ResnetBlock, self).__init__() 
  4.         selfself.conv_block = self.build_conv_block(...) 
  5.  
  6.     def build_conv_block(self, ...): 
  7.         conv_block = [] 
  8.  
  9.         conv_block += [nn.Conv2d(...), 
  10.                        norm_layer(...), 
  11.                        nn.ReLU()] 
  12.         if use_dropout: 
  13.             conv_block += [nn.Dropout(...)] 
  14.  
  15.         conv_block += [nn.Conv2d(...), 
  16.                        norm_layer(...)] 
  17.  
  18.         return nn.Sequential(*conv_block) 
  19.  
  20.     def forward(self, x): 
  21.         out = x + self.conv_block(x) 
  22.         return ou 

在这里,ResNet 模块的跳跃连接直接在前向传导过程中实现了,PyTorch 允许在前向传导过程中进行动态操作。

3. PyTorch 环境下的带多个输出的网络

对于有多个输出的网络(例如使用一个预训练好的 VGG 网络构建感知损失),我们使用以下模式:

  1. class Vgg19(torch.nn.Module): 
  2.   def __init__(self, requires_grad=False): 
  3.     super(Vgg19, self).__init__() 
  4.     vgg_pretrained_features = models.vgg19(pretrained=True).features 
  5.     self.slice1 = torch.nn.Sequential() 
  6.     self.slice2 = torch.nn.Sequential() 
  7.     self.slice3 = torch.nn.Sequential() 
  8.  
  9.     for x in range(7): 
  10.         self.slice1.add_module(str(x), vgg_pretrained_features[x]) 
  11.     for x in range(7, 21): 
  12.         self.slice2.add_module(str(x), vgg_pretrained_features[x]) 
  13.     for x in range(21, 30): 
  14.         self.slice3.add_module(str(x), vgg_pretrained_features[x]) 
  15.     if not requires_grad: 
  16.         for param in self.parameters(): 
  17.             param.requires_grad = False 
  18.  
  19.   def forward(self, x): 
  20.     h_relu1 = self.slice1(x) 
  21.     h_relu2 = self.slice2(h_relu1)         
  22.     h_relu3 = self.slice3(h_relu2)         
  23.     out = [h_relu1, h_relu2, h_relu3] 
  24.     return out 

(编辑:ASP站长网)

网友评论
推荐文章
    热点阅读