看PyTorch源代码的心路历程

1. 起因

曾经碰到过别人的模型prelu在内部的推理引擎算出的结果与其在原始框架PyTorch中不一致的情况,虽然理论上大家实现的都是一个算法,但是从参数上看,因为经过了模型转换,中间做了一些调整。为了确定究竟是初始参数传递就出了问题还是在后续传递过程中继续做了更改、亦或者是最终算法实现方面有着细微差别导致最终输出不同,就想着去看一看PyTorch一路下来是怎么做的。
但是代码跟着跟着就跟丢了,才会发现,PyTorch真的是一个很复杂的项目,但就像舌尖里面说的,环境越是恶劣,回报越是丰厚。为了以后再想跟踪的时候方便,因此决定以PReLU为例静态梳理一下PyTorch的代码结构。捣鼓的这些天,对如何构建一个带有C/C++代码的Python又有了新的了解,这也算是意外的收获吧。

2. 历程

首先,我们从PReLU的导入路径torch.nn.PReLU中知道,他应在径进torch\nn\之下,进入该路径虽然没看到,但是我们在该路径下的__init__.py中知道,其实它就在torch\nn\modules\activation.py中。类PReLU最终调用了从torch\nn\functional.py导入的prelu方法。顺腾摸瓜,找到prelu,它长下面这样:

def prelu(input, weight):
    # type: (Tensor, Tensor) -> Tensor
    if not torch.jit.is_scripting(): 
        if type(input) is not Tensor and has_torch_function((input,)):
            return handle_torch_function(prelu, (input,), input, weight)
    return torch.prelu(input, weight)

经过人脑对代码的一番执行你会发现,第一个if条件满足,而第二个if不满足。因此,最终想看算法,得去看torch.prelu()。好吧,接着干……

一番搜寻之后你会发现,Python代码中在torch这个包下面你是找不到prelu的定义的。但是绝望之际我们在torch包的__init__.py之中看到看下面几行代码:

# pytorch\torch\__init__.py

# 为了简洁,省去不必要代码,详细代码参见pytorch\torch\__init__.py
try:
    # _initExtension is chosen (arbitrarily) as a sentinel.
    from torch._C import _initExtension


__all__ += [name for name in dir(_C)
            if name[0] != '_' and
            not name.endswith('Base')]

if TYPE_CHECKING:
    # Some type signatures pulled in from _VariableFunctions here clash with
    # signatures already imported. For now these clashes are ignored; see
    # PR #43339 for details.
    from torch._C._VariableFunctions import *  # type: ignore

for name in dir(_C._VariableFunctions):
    if name.startswith('__'):
        continue
    globals()[name] = getattr(_C._VariableFunctions, name)
    __all__.append(name)

这是全村最后的希望了。我们知道__all__中的名字其实就是该模块有意暴露出去的API。
什么意思呢?也就是说虽然我们明文上已经看不到了prelu的定义,但是这几行代码表明有一大堆身份不明的API被暗搓搓的导入了,这其中就很有可能存在我们朝思暮想的prelu

那么我们怎么凭借这么一点微弱的线索确定我们的猜测到底对不对呢?这里我们就用到了Python的一个关键知识:C/C++扩展。(戳这里《使用C语言编写Python模块-引子》《Python调用C++之PYBIND11简介》了解更多)

我们知道Python C/C++扩展有着固定的格式,只要我们找到模块初始化入口,就能顺藤摸瓜找到该模块暴露的给Python解释器所有函数。Python 3中的初始化函数样子为PyInit_<module_name>,其中<module_name>就是模块的名字。例如在前面提到的from torch._C import *中,模块torch下面必要有一个名字为_C的子模块。因此它的初始化函数应该为PyInit__C,我们搜索该名字就能找到模块入口。当然另外还有一种方法,就是查看setup.py文件中关于扩展的描述信息:

// pytorch\setup.py
main_sources = ["torch/csrc/stub.c"]
C = Extension("torch._C",
                  libraries=main_libraries,
                  sources=main_sources,
                  language='c',
                  extra_compile_args=main_compile_args + extra_compile_args,
                  include_dirs=[],
                  library_dirs=library_dirs,
                  extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
    extensions.append(C)

不管是通过搜索还是查看setup.py,我们最终都成功定位到了位于pytorch\torch\csrc\stub.c下的模块初始化函数PyInit__C(void),并进一步跟踪其调用的函数initModule(),便可以知道具体都暴露了哪些API给Python解释器。

// pytorch\torch\csrc\stub.c
PyMODINIT_FUNC PyInit__C(void)
{
  return initModule();
}


// pytorch\torch\csrc\Module.cpp
initModule()

进入initModule()寻找一番,你会发现,模块_C中依然没有prelu的Python接口。怎么办?莫慌,通过前面对torch.__init__.py的分析,我们知道我们还有希望——_C模块下的子模块_VariableFunctions,这真的是最后的希望了!没了别的路可以走了,只能是硬着头皮找。经过一番惊天地泣鬼神、艰苦卓绝的寻找,我们在initModule()的调用链initModule()->THPVariable_initModule(module)->torch::autograd::initTorchFunctions(module)中发现了_VariableFunctions的踪影。Aha,simple!

void initTorchFunctions(PyObject* module) {
  if (PyType_Ready(&THPVariableFunctions) < 0) {
    throw python_error();
  }
  Py_INCREF(&THPVariableFunctions);

  // Steals
  Py_INCREF(&THPVariableFunctions);
  if (PyModule_AddObject(module, "_VariableFunctionsClass", reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) {
    throw python_error();
  }
  // PyType_GenericNew returns a new reference
  THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
  // PyModule_AddObject steals a reference
  if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
    throw python_error();
  }
}

但是!!别高兴太早!查看模块_VariableFunctions中暴露的接口你会发现,根本就没有我们想要的!如下面的代码所示:

static PyMethodDef torch_functions[] = {
  {"arange", castPyCFunctionWithKeywords(THPVariable_arange),
    METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor),
    METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"dsmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
  {"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"hsmm", castPyCFunctionWithKeywords(THPVariable_hspmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  ${py_method_defs}
  {NULL}
};

上面的代码中我们找不到prelu的任何身影。会不会prelu可以绕开C/C++扩展的方式直接被Python使用呢?所以不会出现在这里?答案是不会,自古华山一条路,程序是不会跟你讲潜规则的。那么既然最终代码已经跟丢了,作者一定是使用了黑魔法,作为麻瓜的我无计可施,本文也该结束了……

等等,上面的C代码中好像混入了奇怪的东西——${py_method_defs}。这种语法好像C/C++语法里面是没有的,反而是Shell这类脚本里面才会有,难道是新特性?费劲查找了一圈,并没有发现C/C++中有这种语法,既然不是正经语法,那么混入C/C++中肯定会导致编译失败,但是它确实就在那里。那么真相只有一个:它就是个占位符,后面肯定会有真正的代码替换它!

接下来怎么办?搜索!使用py_method_defs作为关键字全局搜索,最终我们会发现,确实是有一个Python脚本对这个占位符进行了替换,而替换的结果就是我们一直寻找的prelu终于出现在了模块_VariableFunctions之中。好,破案了。

但是就像警察破案,即便有单个证据,也要找到其他证据形成完整证据链才能使得证据具有说服力。虽然我们通过搜索得知了prelu会出现在模块_VariableFunctions中,但是它究竟怎么来的目前还是很模糊:占位符在什么时候被谁调用的脚本进行了替换?

实际上,这一切都是有迹可循的。踪迹依旧在setup.py中。进入setup.py的主函数,在调用setup函数之前会看到一个名为build_deps()的函数调用,此函数最终会调用指定平台的CMake去按照根目录下CMakeLists.txt中的脚本进行构建。根目录下的CMakeLists.txt最终又会调用到caffe2目录下的CMakeLists.txt(add_subdirectory(caffe2)),而caffe2/CMakeLists.txt中就会调用到进行代码生成的Python脚本,如下所示:

代码生成脚本起调过程示意图

// pytorch\caffe2\CMakeLists.txt
  add_custom_command( OUTPUT
    ${TORCH_GENERATED_CODE}
    COMMAND
    "${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
      --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
      --native-functions-path "aten/src/ATen/native/native_functions.yaml"
      --nn-path "aten/src"
      $<$<BOOL:${INTERN_DISABLE_AUTOGRAD}>:--disable-autograd>
      $<$<BOOL:${SELECTED_OP_LIST}>:--selected-op-list-path="${SELECTED_OP_LIST}">
      --force_schema_registration

进行代码生成的主要流程如下面代码块所示,其大概流程是main()先解析传递给脚本的参数,之后将参数传递给generate_code()。结合caffe2/CMakeLists.txt中脚本调用时传递的参数可知,generate_code()中的是三个gen_*()函数都得到了调用,而在gen_autograd_python()会调用到一个名为create_python_bindings()的函数,这个函数就是真正执行代码生成的地方。

代码生成器调用流程示意图
// tools/setup_helpers/generate_code.py
def generate_code(ninja_global=None,
                  declarations_path=None,
                  nn_path=None,
                  native_functions_path=None,
                  install_dir=None,
                  subset=None,
                  disable_autograd=False,
                  force_schema_registration=False,
                  operator_selector=None):

    if subset == "pybindings" or not subset:
        gen_autograd_python(
            declarations_path or DECLARATIONS_PATH,
            native_functions_path or NATIVE_FUNCTIONS_PATH,
            autograd_gen_dir,
            autograd_dir)

    if operator_selector is None:
        operator_selector = SelectiveBuilder.get_nop_selector()

    if subset == "libtorch" or not subset:

        gen_autograd(
            declarations_path or DECLARATIONS_PATH,
            native_functions_path or NATIVE_FUNCTIONS_PATH,
            autograd_gen_dir,
            autograd_dir,
            disable_autograd=disable_autograd,
            operator_selector=operator_selector,
        )

    if subset == "python" or not subset:
        gen_annotated(
            native_functions_path or NATIVE_FUNCTIONS_PATH,
            python_install_dir,
            autograd_dir)

def main():
    parser = argparse.ArgumentParser(description='Autogenerate code')
    parser.add_argument('--declarations-path')
    parser.add_argument('--native-functions-path')
    parser.add_argument('--nn-path')
    parser.add_argument('--ninja-global')
    parser.add_argument('--install_dir')
    parser.add_argument(
        '--subset',
        help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
    )
    parser.add_argument(
        '--disable-autograd',
        default=False,
        action='store_true',
        help='It can skip generating autograd related code when the flag is set',
    )
    parser.add_argument(
        '--selected-op-list-path',
        help='Path to the YAML file that contains the list of operators to include for custom build.',
    )
    parser.add_argument(
        '--operators_yaml_path',
        help='Path to the model YAML file that contains the list of operators to include for custom build.',
    )
    parser.add_argument(
        '--force_schema_registration',
        action='store_true',
        help='force it to generate schema-only registrations for ops that are not'
        'listed on --selected-op-list'
    )
    options = parser.parse_args()

    generate_code(
        options.ninja_global,
        options.declarations_path,
        options.nn_path,
        options.native_functions_path,
        options.install_dir,
        options.subset,
        options.disable_autograd,
        options.force_schema_registration,
        # options.selected_op_list
        operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path),
    )

if __name__ == "__main__":
    main()
// pytorch\tools\autograd\gen_autograd.py
def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
    from .load_derivatives import load_derivatives
    differentiability_infos = load_derivatives(
        os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)

    template_path = os.path.join(autograd_dir, 'templates')

    # Generate Functions.h/cpp
    from .gen_autograd_functions import gen_autograd_functions_python
    gen_autograd_functions_python(
        out, differentiability_infos, template_path)

    # Generate Python bindings
    from . import gen_python_functions
    deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
    gen_python_functions.gen(
        out, native_functions_path, deprecated_path, template_path)
// pytorch\tools\autograd\gen_python_functions.py
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                            Main Function
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)

    methods = load_signatures(native_yaml_path, deprecated_yaml_path, method=True)
    create_python_bindings(
        fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)

    functions = load_signatures(native_yaml_path, deprecated_yaml_path, method=False)
    create_python_bindings(
        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)

def create_python_bindings(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    module: Optional[str],
    filename: str,
    *,
    method: bool,
) -> None:
    """Generates Python bindings to ATen functions"""
    py_methods: List[str] = []
    py_method_defs: List[str] = []
    py_forwards: List[str] = []

    grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
    for pair in pairs:
        if pred(pair.function):
            grouped[pair.function.func.name.name].append(pair)

    for name in sorted(grouped.keys(), key=lambda x: str(x)):
        overloads = grouped[name]
        py_methods.append(method_impl(name, module, overloads, method=method))
        py_method_defs.append(method_def(name, module, overloads, method=method))
        py_forwards.extend(forward_decls(name, overloads, method=method))

    fm.write_with_template(filename, filename, lambda: {
        'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
        'py_forwards': py_forwards,
        'py_methods': py_methods,
        'py_method_defs': py_method_defs,
    })

最终通过查看native_functions.yaml的内容以及深入跟踪加载native_functions.yaml的代码发现,native_functions.yaml中的prelu最终会被写到以python_torch_functions.cpp为模板的文件中,也就是调用

    create_python_bindings(
        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)

的时候被生成。整个生成的过程其实是很繁琐的,一层层跟踪后可以发现,最终生成的代码可以实现将一个名为at::<func_name>的函数暴露给Python。例如我们的prelu,暴露给Python的API最终会调用一个名为at::prelu()的函数来做真正的计算。那么这个at::<func_name>(例如at::prelu())的定义又在哪里呢?

还是一样,故技重施!仍然使用Python脚本根据native_functions.yaml文件中的内容去以pytorch\aten\src\ATen\templates目录下的各种模板去生成对应的实际C++源文件。最终结果是得到at::<func_name>,在这个函数中,它调用了Dispatcher这个类寻找到目标函数的句柄。通常情况下能够使用的函数句柄都通过一个叫Library的类来管理。Python脚本以RegisterSchema.cpp为模板,生成了注册这些目标函数的注册代码,并通过一个名为TORCH_LIBRARY的宏调用Library类来注册管理。

#define TORCH_LIBRARY(ns, m) \
  static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
  static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
    torch::Library::DEF, \
    &TORCH_LIBRARY_init_ ## ns, \
    #ns, c10::nullopt, __FILE__, __LINE__ \
  ); \
  void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
class TorchLibraryInit final {
private:
  using InitFn = void(Library&);
  Library lib_;
public:
  TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
    : lib_(kind, ns, k, file, line) {
    fn(lib_);
  }
};
PyTorch组成示意图

3. 总结

PyTorch虽然在使用上是非常的Pythonic,但实际上Python只不过是为了方便使用裹在C++代码上的一层糖衣。用起来虽然好用,但是看起来实在是非常费劲,特别是如果静态的梳理代码,很多用于连接Python C/C++接口与实际逻辑代码之间的C++代码都是通过Python脚本生成的。至此,整个大的线索已经摸清了,剩下的就是去查看具体细节的实现。

说实话,人脑执行Python代码之后再去理解C++代码实在是费劲,也费头发。因此我决定的让电脑去生成C++代码再接着看更具体的细节,比如究竟每一个算子是怎么注册到Library之中的。

4. Bonus

我真心怀疑我们生活在一个虚拟机里,为什么呢?因为到处可见运用于计算机里面的空间和时间局部性原理的实例。就在我写完这个博客的时候,意外的发现了一篇PyTorch工程师讲解PyTorch内部原理的博文,这对后续读代码应该会有很大帮助。等不及就戳它吧 http://blog.ezyang.com/2019/05/pytorch-internals/

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,324评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,303评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,192评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,555评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,569评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,566评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,927评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,583评论 0 257
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,827评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,590评论 2 320
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,669评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,365评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,941评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,928评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,159评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,880评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,399评论 2 342