使用cuFFT实现大整数乘法

序言

在某些场合,我们可能需要使用远超内置整型范围的整数进行运算,比如公钥加密等。如果使用最原始的竖式计算,那么时间复杂度是 T(n2) ,其中n是相乘的两个整数的位数。使用Karatsuba算法优化,时间复杂度可以降至 T(nlog23)T(n1.585) 。而如果使用快速傅里叶变换(FFT),则可以优化到 T(nlogn)

原理

整数乘法与多项式

一个整数,我们可以把它表示成一个多项式。比如长度为n的十进制整数 a0a1a2a3...an1 写成多项式:

A(x)=a0xn1+a1xn2+...+an2x+an1

整数2345,写成多项式就是:


A(x)=2x3+3x2+4x+5

这样一来,整数乘法就变成了多项式的乘法。比如2345*123,就变成了:


×+=2x5+2x5+4x4+3x4+7x4+2x3+6x3+6x3+4x3+16x3+3x2+x2+9x2+8x2+5x222x2+4x+2x+12x+10x22x+531515

(这时会发现结果多项式中的某些系数大于10,只要做进位处理即可得到288435,就是2345*123的结果。)

很显然,这样直接计算的时间复杂度是 T(n2) (假设两个多项式次数均为n-1)。

为了简洁,还可以把多项式系数抽出来表示成向量的形式: (2,7,16,22,22,15) ,这是多项式的系数表示法

对于一个n-1次多项式A(x),它由n个系数 a0,a1,a2,...,an1 决定。如果对n个不同的x进行求值,就能得到n个不同的点: (x0,A(x0)),(x1,A(x1)),...,(xn1,A(xn1))
反过来,只要我们知道y=A(x)上n个不同的点,那么我们就可以通过对这n个点插值求得A(x)的n个系数。

n个各异点跟n个系数一样都可以确定唯一的一个n-1次多项式。这种用n个各异的点表示多项式的方法称为多项式的点值表示法

多项式乘法与卷积

把多项式乘法结果 C(x)=A(x)B(x) 的系数用公式描述出来:

ci=k=0iakbik

如果学习过信号处理,你可能会觉得这个公式很眼熟。没错,这就是卷积公式。多项式乘法实际上就是在做卷积。

比如 (2,3,4,5) (1,2,3) 做卷积:

2345
321 (21)
2345
321 (2,2×2+3×1)
2345
321 (2,7,2×3+3×2+4×1)
2345
321 (2,7,16,3×3+4×2+5×1)
2345
321 (2,7,16,22,4×3+5×2)
2345
321 (2,7,16,22,22,5×3)=(2,7,16,22,22,15)

卷积与快速傅里叶变换

卷积的计算方法有很多种,这里主要介绍使用快速傅里叶变换的方法。这个方法利用了以下结论:

时域上的卷积等价于频域上的点乘->系数表示下的多项式乘法等价于点值表示下的点乘

设n-1次多项式A(x)、B(x)它们各有n个系数,以及2n-2次多项式C(x)=A(x)*B(x),它有2n-1个系数。

对A(x)、B(x)求值,得到点值表示下的A(x)、B(x):

A(x):(x0,A(x0)),(x1,A(x1)),...(x2n2,A(x2n2))B(x):(x0,B(x0)),(x1,B(x1)),...(x2n2,B(x2n2))

需要注意,尽管A(x)和B(x)的系数只有n个,但是这里使用了2n-1个点来表示,有n-1个点是“多余”的。这是因为C(x)是2n-2次多项式,需要2n-1个点来表示。

因为 C(xk)=A(xk)B(xk)) ,那么可以计算得到C(x)的点值表示:

C(x):(x0,A(x0)B(x0)),(x1,A(x1)B(x1)),...(x2n2,A(x2n2)B(x2n2))

也就是说,在点值表示下A(x)*B(x)的计算就是点乘。2n-1个点的点乘的时间复杂度是 T(n) 。比系数表示下的多项式乘法 T(n2) 的复杂度低得多。

所以,通过点值表示法计算多项式乘法的步骤如下:

  1. 求值。对系数表示的多项式 A(x),B(x) 进行求值,得到点值表示的 A(x),B(x)
  2. 点乘。在点值表示下,点乘 A(x),B(x) ,得到 C(x) 的点值表示。
  3. 插值。对表示 C(x) 的点进行插值,得到系数表示的 C(x)

我们知道,第二步点乘时间复杂度是 T(n) ,关键在于第一步和第三步的算法。

如果我们直接对n-1次的多项式求值n次得到n个点,每个点求值时间复杂度是 T(n) ,那么总的时间复杂度就是 T(n2) ,这跟直接相乘的复杂度是一样的,没有改进。

关键在于,通用的求值算法是要对多项式在所有点上都适用的,也就是对于所有的x,都能求出多项式A(x)的值。但是在多项式用于整数乘法时,我们根本不关心x是多少,只关心多项式的系数。x不管是整数、小数乃至虚数都对多项式系数没影响。

于是我们可以放松对求值的要求,只在某些特定的x上进行求值。

快速傅里叶变换(FFT)就是这样一种方法,对于一个n-1次多项式,FFT取的x的值是主n次单位复根。然后利用n次单位复根的特性,通过分而治之的方式将单个点求值的时间复杂度从 T(n) 降至 T(logn)

第一步如果利用FFT变换来求值,那么时间复杂度就是 T(nlogn)
同理,第三步可以使用逆FFT来插值,时间复杂度也是 T(nlogn)

这里就不介绍快速傅里叶变换的具体原理了,因为它实在是太重要了,以致于你可以在几乎每一本讲数字信号处理的书中找到它的详细介绍。

cuFFT实现

cuFFT

cuFFT是nVidia推出的基于CUDA的FFT运算库。使用cuFFT可以利用显卡对FFT运算进行加速。

改进

cuFFT是基于浮点运算的,从浮点数转换为整数,误差难以避免,比如把下面的代码里的进制改成万进制,得到的结果后几位往往是错的,这就是误差引起的。

更好的解决方法是使用基于模n加法群的快速傅里叶变换(此时称为快速数论变换NTT),完全不涉及浮点数,不存在误差问题。

完整代码

nvcc编译需要加上参数-lcufft。使用Visual Studio时,在项目的额外依赖项里加上cufft.lib。

#include <device_launch_parameters.h>
#include <cuda_runtime.h>
#include <cufft.h>
#include <vector>
#include <cmath>

const auto BATCH = 1;

__global__ void ComplexPointwiseMulAndScale(cufftComplex *a, cufftComplex *b, int size)
{
    const int numThreads = blockDim.x * gridDim.x;
    const int threadID = blockIdx.x * blockDim.x + threadIdx.x;
    float scale = 1.0f / (float)size;
    cufftComplex c;
    for (int i = threadID; i < size; i += numThreads)
    {
        c = cuCmulf(a[i], b[i]);
        b[i] = make_cuFloatComplex(scale*cuCrealf(c), scale*cuCimagf(c));
    }
}

__global__ void ConvertToInt(cufftReal *a, int size)
{
    const int numThreads = blockDim.x * gridDim.x;
    const int threadID = blockIdx.x * blockDim.x + threadIdx.x;
    auto b = (int*)a;
    for (int i = threadID; i < size; i += numThreads)
        b[i] = static_cast<int>(round(a[i]));
}

std::vector<int> multiply(const std::vector<float> &a, const std::vector<float> &b)
{
    const auto NX = a.size();
    cufftHandle plan_a, plan_b, plan_c;
    cufftComplex *data_a, *data_b;
    std::vector<int> c(a.size() + 1);
    c[0] = 0;

    //分配显卡内存并初始化,这里假设sizeof(int)==sizeof(float), sizeof(cufftComplex)==2*sizeof(float)
    cudaMalloc((void**)&data_a, sizeof(cufftComplex) * (NX / 2 + 1) * BATCH);
    cudaMalloc((void**)&data_b, sizeof(cufftComplex) * (NX / 2 + 1) * BATCH);
    cudaMemcpy(data_a, a.data(), sizeof(float) * a.size(), cudaMemcpyHostToDevice);
    cudaMemcpy(data_b, b.data(), sizeof(float) * b.size(), cudaMemcpyHostToDevice);
    if (cudaGetLastError() != cudaSuccess) { fprintf(stderr, "Cuda error: Failed to allocate\n"); return c; }

    if (cufftPlan1d(&plan_a, NX, CUFFT_R2C, BATCH) != CUFFT_SUCCESS) { fprintf(stderr, "CUFFT error: Plan creation failed"); return c; }
    if (cufftPlan1d(&plan_b, NX, CUFFT_R2C, BATCH) != CUFFT_SUCCESS) { fprintf(stderr, "CUFFT error: Plan creation failed"); return c; }
    if (cufftPlan1d(&plan_c, NX, CUFFT_C2R, BATCH) != CUFFT_SUCCESS) { fprintf(stderr, "CUFFT error: Plan creation failed"); return c; }

    //把A(x)转换到频域
    if (cufftExecR2C(plan_a, (cufftReal*)data_a, data_a) != CUFFT_SUCCESS)
    {
        fprintf(stderr, "CUFFT error: ExecR2C Forward failed");
        return c;
    }

    //把B(x)转换到频域
    if (cufftExecR2C(plan_b, (cufftReal*)data_b, data_b) != CUFFT_SUCCESS)
    {
        fprintf(stderr, "CUFFT error: ExecR2C Forward failed");
        return c;
    }

    //点乘
    ComplexPointwiseMulAndScale<<<NX / 256 + 1, 256>>>(data_a, data_b, NX);

    //把C(x)转换回时域
    if (cufftExecC2R(plan_c, data_b, (cufftReal*)data_b) != CUFFT_SUCCESS)
    {
        fprintf(stderr, "CUFFT error: ExecC2R Forward failed");
        return c;
    }

    //将浮点数的结果转换为整数
    ConvertToInt<<<NX / 256 + 1, 256>>>((cufftReal*)data_b, NX);

    if (cudaDeviceSynchronize() != cudaSuccess) 
    {
        fprintf(stderr, "Cuda error: Failed to synchronize\n");
        return c;
    }

    cudaMemcpy(&c[1], data_b, sizeof(float) * b.size(), cudaMemcpyDeviceToHost);

    cufftDestroy(plan_a);
    cufftDestroy(plan_b);
    cufftDestroy(plan_c);
    cudaFree(data_a);
    cudaFree(data_b);
    return c;
}


int main(int argc, char **argv) 
{
    //设置进制
    const auto base = 10;

    //999 * 9
    std::vector<float> a{ 0, 9, 9, 9 }; 
    std::vector<float> b{ 0, 0, 0, 9 };

    auto c = multiply(a, b);

    for (auto i : c)
        printf("%d ", i);
    printf("\n");

    //处理进位
    for (int i = c.size() - 1; i > 0; i--)
    {
        if (c[i] >= base)
        {
            c[i - 1] += c[i] / base;
            c[i] %= base;
        }
    }

    //去掉多余的零
    c.pop_back();
    auto i = 0;
    if (c[0] == 0)
        i++;

    //输出最终结果,改了进制需要改这里的输出方式,比如百进制是"%02d",千进制是"%03d"
    for (; i < c.size(); i++)
        printf("%d", c[i]);
    printf("\n");

    return 0;
}

参考文献

《算法导论》
cuFFT documentation

    原文作者:大整数乘法问题
    原文地址: https://blog.csdn.net/qmickecs/article/details/77504297
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞