session://14:29:33
~/ / posts / 2017-01-pytorch-release.md

PyTorch 发布:动态图把“写模型”变回了“写 Python”

2017-01-18· 1 min read · [产品实践]
// TL;DR
  • PyTorch 用 define-by-run 动态图:代码跑到哪,图建到哪,print 和断点直接能用。
  • API 贴着 NumPy 设计,学习成本极低。
  • 框架之争进入双雄时代:研究圈倒向 PyTorch 的速度肉眼可见。

它解决了什么痛

用过 TensorFlow 的人都懂那种拧巴:先用 Python“描述”一张静态图,再 session.run 把数据喂进去——中间出了错,报错信息指向的是图编译器的内部,跟你写的代码隔着一层毛玻璃。调试基本靠猜。

PyTorch 的答案是 define-by-run:没有“先建图再执行”两个阶段,前向传播就是普通 Python 代码,跑到哪图建到哪。想看中间变量?print 就行。想调试?pdb 直接下断点。控制流就用 Python 的 if 和 for,RNN 处理变长序列再也不用跟 tf.while_loop 搏斗。

# PyTorch 的前向就是普通 Python
def forward(self, x):
    h = torch.relu(self.fc1(x))
    if h.norm() > 10:      # 动态控制流,就这么写
        h = h / h.norm()
    return self.fc2(h)

我的判断

性能上它未必赢——静态图理论上有更多优化空间。但研究的瓶颈从来不是跑得慢,是改得慢。一个让试错周期缩短一半的工具,在研究圈的传播速度会是病毒级的。我赌一年内顶会论文的开源代码会大面积换成 PyTorch。生产部署 TensorFlow、研究实验 PyTorch 的双轨格局,可能要持续很久。

开源项目深度学习工程实践
cat newsletter.txt

每周一封,<5 分钟读完

把这一周我读过、想过、动手做过的东西,压缩成一封信。订阅者目前 5210+ 人,0 干扰。