admin管理员组

文章数量:1596408

官方文档链接:https://pytorch/docs/master/fx.html#

概述

FX是供开发人员用于转换nn.Module实例的工具包。FX由三个主要组件组成:符号追踪:symbolic tracer, 中间层表示:intermediate representation, Python代码生成:Python code generation。这些组件的运行演示:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪(symbolic tracer): 对Python代码进行"符号执行"。它以构造的值(也叫作:代理Proxies)为输入,贯穿运行所有代码。记录下对这些Proxie的操作。更多的符号追踪的信息可见 symbolic_trace() 和Tracer的相关文档。

**中间层表示(intermediate representation): ** 它里面保存了在符号追中中记录下的运算操作。它由表示函数输入、调用哪些对象(函数、方法或torch.nn.Module实例)和返回值的节点列表组成。关于IR的更多信息可以在Graph的文档中找到。IR是应用转换的格式。

**Python代码生成(Python code generation): ** Python代码生成使FX成为Python代码到Python代码(或模块到模块)转换工具包。对于每个 Graph IR,我们可以创建与图的语义匹配的有效Python代码。此功能包含在GraphModule中,GraphModule是一个torch.nn.Module实例,它包含一个图以及从该图生成的正向方法。

综合起来,这个组件的流水线(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python-to-Python 转换通道。 此外,这些组件可以单独使用。 例如,可以单独使用符号跟踪来捕获代码形式以用于分析(而不是转换)。 代码生成可用于以编程方式生成模型,例如从配置文件生成模型。 FX 有很多用途!

在示例库中有几个转换的样例。

API

symbolic_trace

torch.fx.symbolic_trace(root, concrete_args=None, enable_cpatching=False)

符号追踪的函数,以nn.Module或者函数实例为输入,然后将追踪过程中记录的操作记录下来构造一个GraphModule对象并返回。

concrete_args的作用是根据函数中的分支和参数进行定制化,无论是删除控制流还是数据结构。

例如:

def f(a, b):
    if b 

本文标签: 笔记PytorchFX