cuElim/include/gfp/gfp_mul.cuh

93 lines
3.3 KiB
Plaintext
Raw Normal View History

2024-09-14 15:57:00 +08:00
#ifndef GFP_MUL_CUH
#define GFP_MUL_CUH
#include "gfp_mat.cuh"
static const int BlockRow = 128, BlockCol = 128; // 每个block处理c矩阵的一个子块
static const int StepSize = 8; // block中一个循环处理的A矩阵的列数B矩阵的行数
static_assert(BlockCol % THREAD_X == 0 && BlockRow % THREAD_Y == 0);
__global__ void gfp_gpu_mul_kernel(gfp_t *__restrict__ a, const size_t a_rs, gfp_t *__restrict__ b, const size_t b_rs, gfp_t *__restrict__ c, const size_t c_rs, const size_t nrows, const size_t ncols, const size_t nsteps)
{
const unsigned int bx = blockIdx.x;
const unsigned int by = blockIdx.y;
const unsigned int tx = threadIdx.x;
const unsigned int ty = threadIdx.y;
const unsigned int tid = ty * blockDim.x + tx;
#if gfp_bits == 64
__shared__ alignas(8) gfp_t s_a[StepSize][BlockRow];
__shared__ alignas(8) gfp_t s_b[StepSize][BlockCol];
#else
__shared__ gfp_t s_a[StepSize][BlockRow];
__shared__ gfp_t s_b[StepSize][BlockCol];
#endif
gfp_t tmp_c[BlockRow / THREAD_Y][BlockCol / THREAD_X] = {0};
for (int s = 0; s < (nsteps - 1) / StepSize + 1; s++)
{
for (int k = tid; k < StepSize * BlockRow; k += blockDim.x * blockDim.y)
{
const int a_r = k / StepSize;
const int a_c = k % StepSize;
s_a[a_c][a_r] = by * BlockRow + a_r < nrows && s * StepSize + a_c < nsteps ? *at_base(a, a_rs, by * BlockRow + a_r, s * StepSize + a_c) : 0;
const int b_r = k / BlockCol;
const int b_c = k % BlockCol;
s_b[b_r][b_c] = s * StepSize + b_r < nsteps && bx * BlockCol + b_c < ncols ? *at_base(b, b_rs, s * StepSize + b_r, bx * BlockCol + b_c) : 0;
}
__syncthreads();
for (int k = 0; k < StepSize; k++)
{
for (int j = 0; j < BlockRow / THREAD_Y; j++)
{
for (int i = 0; i < BlockCol / THREAD_X; i++)
{
#if gfp_bits == 64
tmp_c[j][i] += (s_a[k][j * THREAD_Y + ty] * s_b[k][i * THREAD_X + tx]);
#else
tmp_c[j][i] += (s_a[k][j * THREAD_Y + ty] * s_b[k][i * THREAD_X + tx]) % gfp;
#endif
}
}
}
__syncthreads();
#if gfp_bits != 64
if (s & gfp_fullmask == gfp_fullmask)
{
for (int j = 0; j < BlockRow / THREAD_Y; j++)
{
for (int i = 0; i < BlockCol / THREAD_X; i++)
{
tmp_c[j][i] %= gfp;
}
}
}
#endif
}
for (int j = 0; j < BlockRow / THREAD_Y; j++)
{
for (int i = 0; i < BlockCol / THREAD_X; i++)
{
if (by * BlockRow + j * THREAD_Y + ty < nrows && bx * BlockCol + i * THREAD_X + tx < ncols)
{
*at_base(c, c_rs, by * BlockRow + j * THREAD_Y + ty, bx * BlockCol + i * THREAD_X + tx) = tmp_c[j][i] % gfp;
}
}
}
}
__host__ void MatGFP::gpu_mul(const MatGFP &a, const MatGFP &b)
{
assert(a.ncols == b.nrows && a.nrows == nrows && b.ncols == ncols);
dim3 block(THREAD_X, THREAD_Y);
dim3 grid((width - 1) / block.x + 1, (nrows - 1) / block.y + 1);
gfp_gpu_mul_kernel<<<grid, block>>>(a.data, a.rowstride, b.data, b.rowstride, data, rowstride, nrows, width, a.width);
cudaDeviceSynchronize();
}
#endif