PyTorch源码浅析(五)

我前几天在物理所为了讲Julia的多重派发写了一个C的实现,其实在PyTorch里也是利用这个机制来实现运行时(runtime)根据底层类型调用不同的SIMD加速的。关于中科院物理所的数值俱乐部:http://num.v2nobel.com/talks/9/

但是即便你注册了也不一定能够参加,因为我们暂时没有很大的场地,此外物理所是有门禁的你也不一定进的来。我为物理相关方向的人写了一个Julia Tutorial,本章的C语言的demo源码可以在里面找到:Roger-luo/TutorialForPhysicists.jl

多重派发

首先我们需要了解什么是多重派发,详细的内容可以参考这个wiki:Multiple dispatch,具体来说就是根据一些条件(函数签名)在运行时分发相关的函数。PyTorch中使用了这个方法,就使得CPU上可以根据环境变量,硬件指令实现等条件选择SIMD加速的backend。首先我们用一个C的demo来展示如何在C语言中模拟多重派发。

目标

我们这个demo的目标是根据C语言的内置类型对不同函数进行分发,具体来说就是现在有(以下是伪代码)

func_1(int32, int64)
func_2(int32, float32)

...

他们都是根据不同类型实现的类似功能的函数,我们希望在调用一个入口函数func的时候自动根据输入变量的类型进行方法的派发。

首先我们需要标记一下各个类型

typedef enum {
    tagInt32 = 0,
    tagInt64,
    tagFloat32,
    tagFloat64,
    ntypes,
} Tag;

然后定义一个结构体作为类型

typedef struct
{
    void *data;
    size_t size;
    Tag tag;
} Type;

之后我们需要对不同的类型定义一些工厂函数来产生相对应的实例,所以这里定义一个宏来批量产生函数定义

#define MAKE_TYPE(NAME, CTYPE) \     Type *make_##NAME(CTYPE data) \     {                                               \
        Type *type;                                 \
        type = (Type *)malloc(sizeof(Type));        \
        type->size = sizeof(CTYPE);                 \
        type->data = (void *)malloc(sizeof(CTYPE)); \
        type->tag = tag##NAME; \         CTYPE *temp_ptr = (CTYPE *)type->data;      \
        *temp_ptr = data;                           \
        return type;                                \
    }

MAKE_TYPE(Int32, int32_t)
MAKE_TYPE(Int64, int64_t)
MAKE_TYPE(Float32, float)
MAKE_TYPE(Float64, double)

最后提供用来释放(析构)的函数

void type_free(Type *type)
{
    free(type->data);
    free(type);
}

然后定义函数类型,我们这里的函数都是输入变量为2的(可变参数的话,看情况换成指针的指针之类的方案)

typedef int (*FuncType)(Type *a, Type *b);

然后我们有一些函数在不同类型上的实现

// int32 float32
int Func_0x001(Type *a, Type *b)
{
    int32_t *a_data = a->data;
    float *b_data = b->data;

    printf("input: int32 %d, float32 %f", *a_data, *b_data);
    return 0;
}

// int32 int64
int Func_0x102(Type *a, Type *b)
{
    int32_t *a_data = a->data;
    int64_t *b_data = b->data;

    printf("input: int32 %d, int64 %ld", *a_data, *b_data);
    return 0;
}

// int64 int64
int Func_0x013(Type *a, Type *b)
{
    int64_t *a_data = a->data;
    int64_t *b_data = b->data;

    printf("input: int64 %ld, int64 %ld", *a_data, *b_data);
    return 0;
}

// int32 float64
int Func_0x010(Type *a, Type *b)
{
    int32_t *a_data = a->data;
    double *b_data = b->data;

    printf("input: int32 %d, float64 %lf", *a_data, *b_data);

    return 0;
}

最后别忘了让没有相关实现的类型fall back到错误处理或者默认方法上去(PyTorch中是回退到一般的没有SIMD指令的实现上去)

int fallback()
{
    printf("Error: MethodError: No method match input type\n");
    exit(-1);
    return 0;
}

接下来定义一个函数调用表,表的各个维度就是表示各个类型

FuncType FUNC_CALL_LIST[ntypes][ntypes] = {
    {
        NULL,        // int32 int32
        &Func_0x102, // int32 int64
        &Func_0x001, // int32 float32
        &Func_0x010, // int32 float64
    },

    {
        NULL,        // int64 int32
        &Func_0x013, // int64 int64
        NULL,        // int64 float32
        NULL,        // int64 float64
    },

    {
        NULL, // float32 int32
        NULL, // float32 int64
        NULL, // float64 int32
        NULL, // float64 int64
    },
};

然后定义函数入口

int FuncEntry(Type *a, Type *b)
{
    FuncType func_ptr = FUNC_CALL_LIST[a->tag][b->tag];

    if (func_ptr != NULL)
        (*func_ptr)(a, b);
    else
        return fallback();
}

最后编译一下,就可以看到效果,程序会在运行时分发对应的方法,然后输出结果,如果没有对应的方法就回退到默认实现上去。

int main()
{
    Type *a = make_Int32(2);
    Type *b = make_Float32(1.0);
    FuncEntry(a, b);
    type_free(a);
    type_free(b);

    printf("\n");

    a = make_Int64(2);
    b = make_Int64(3);
    FuncEntry(a, b);
    type_free(a); type_free(b);

    printf("\n");
    a = make_Float32(2.0);
    b = make_Float64(2.0);
    FuncEntry(a, b);
    type_free(a); type_free(b);
    return 0;
}

编译指令只要是就行

gcc main.c -o main

然后需要头文件

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>

PyTorch中的TH库的实现机制是类似的,但是没有类型判断而是根据支持的SIMD指令和环境变量来分发方法。

类似我在demo里使用的,每个SIMD实现都是

THVector_(name_EXT)

的形式,其中name是函数名,EXT是对应的SIMD类型,有:DEFAULT,AVX,AVX2,SSE,等

相关的信息会记录在 FunctionDescription这个结构体中,然后最后根据对应的条件进行派发。如果有新的实现只需要插入到generic/THVectorDispatch.cpp中即可,不需要管其它部分,当硬件和环境变量的相关条件满足的时候会自动分配过去,这是运行时分配的,因为是直接访问地址,所以复杂度也是O(1),不需要进行重新编译。

    原文作者:罗秀哲
    原文地址: https://zhuanlan.zhihu.com/p/35703294
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞