设为首页 - 加入收藏 ASP站长网(Aspzz.Cn)- 科技、建站、经验、云计算、5G、大数据,站长网!
热搜: 重新 试卷 文件
当前位置: 首页 > 运营中心 > 建站资源 > 优化 > 正文

100行Python代码,轻松搞定神经网络(4)

发布时间:2019-05-05 23:07 所属栏目:21 来源:大数据文摘
导读:现在可以用一些数据测试下我们的代码了。 X=np.random.randn(100,10) w=np.random.randn(10,1) b=np.random.randn(1) Y=X@W+B model=Linear(10,1) learner=Learner(model,mse_loss,SGDOptimizer(lr=0.05)) learner.

现在可以用一些数据测试下我们的代码了。

  1. X = np.random.randn(100, 10) 
  2. w = np.random.randn(10, 1) 
  3. b = np.random.randn(1) 
  4. Y = X @ W + B 
  5.  
  6. model = Linear(10, 1) 
  7. learner = Learner(model, mse_loss, SGDOptimizer(lr=0.05)) 
  8. learner.fit(X, Y, epochs=10, bs=10) 

我一共训练了10轮。

我们还能检查学到的权重和真实的权重是否一致。

  1. print(np.linalg.norm(m.weights.tensor - W), (m.bias.tensor - B)[0]) 
  2. > 1.848553648022619e-05 5.69305886743976e-06 

好了,就这么简单。让我们再试试非线性数据集,例如y=x1x2,并且再加上一个Sigmoid非线性层和另一个线性层让我们的模型更复杂些。像下面这样:

  1. X = np.random.randn(1000, 2) 
  2. Y = X[:, 0] * X[:, 1] 
  3.  
  4. losses1 = Learner( 
  5.     Sequential(Linear(2, 1)), 
  6.     mse_loss, 
  7.     SGDOptimizer(lr=0.01) 
  8. ).fit(X, Y, epochs=50, bs=50) 
  9.  
  10. losses2 = Learner( 
  11.     Sequential( 
  12.         Linear(2, 10), 
  13.         Sigmoid(), 
  14.         Linear(10, 1) 
  15.     ), 
  16.     mse_loss, 
  17.     SGDOptimizer(lr=0.3) 
  18. ).fit(X, Y, epochs=50, bs=50) 
  19.  
  20. plt.plot(losses1) 
  21. plt.plot(losses2) 
  22. plt.legend(['1 Layer', '2 Layers']) 
  23. plt.show() 

比较单一层vs两层模型在使用sigmoid激活函数的情况下的训练损失。

最后

希望通过搭建这个简单的神经网络,你已经掌握了用python和numpy实现神经网络的基本思路。

在这篇文章中,我们只定义了三种类型的层和一个损失函数, 所以还有很多事情可做,但基本原理都相似。感兴趣的同学可以试着实现更复杂的神经网络哦!

References:

  • Thinc Deep Learning Library:https://github.com/explosion/thinc
  • PyTorch Tutorial:https://pytorch.org/tutorials/beginner/nn_tutorial.html
  • Calculus on Computational Graphs:http://colah.github.io/posts/2015-08-Backprop/
  • HIPS Autograd:https://github.com/HIPS/autograd

相关报道:https://eisenjulian.github.io/deep-learning-in-100-lines/

【本文是51CTO专栏机构大数据文摘的原创文章,微信公众号“大数据文摘( id: BigDataDigest)”】

     大数据文摘二维码

戳这里,看该作者更多好文

【编辑推荐】

  1. 一行代码引发恐惧,深思提高线上代码质量的方法
  2. DeBug Python代码全靠print函数?换用这个一天2K+Star的工具吧
  3. 这里有8个流行的Python可视化工具包,你喜欢哪个?
  4. 出神入化:特斯拉AI主管、李飞飞高徒Karpathy的33个神经网络「炼丹」技巧
  5. 14个Q&A,讲述python与数据科学的“暧昧情事”
【责任编辑:赵宁宁 TEL:(010)68476606】
点赞 0

(编辑:ASP站长网)

网友评论
推荐文章
    热点阅读