cuElim/include/gfp/gfp_mul.cuh

86 lines
3.3 KiB
Plaintext
Raw Normal View History

2024-09-14 15:57:00 +08:00
#ifndef GFP_MUL_CUH
#define GFP_MUL_CUH
#include "gfp_mat.cuh"
2024-09-14 16:15:13 +08:00
namespace gfp
{
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
static const int BlockRow = 128, BlockCol = 128; // 每个block处理c矩阵的一个子块
static const int StepSize = 8; // block中一个循环处理的A矩阵的列数B矩阵的行数
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
static_assert(BlockCol % THREAD_X == 0 && BlockRow % THREAD_Y == 0);
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
__global__ void gpu_mul_kernel(gfp_t *__restrict__ a, const size_t a_rs, gfp_t *__restrict__ b, const size_t b_rs, gfp_t *__restrict__ c, const size_t c_rs, const size_t nrows, const size_t ncols, const size_t nsteps)
{
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
const unsigned int bx = blockIdx.x;
const unsigned int by = blockIdx.y;
const unsigned int tx = threadIdx.x;
const unsigned int ty = threadIdx.y;
const unsigned int tid = ty * blockDim.x + tx;
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
__shared__ gfp_t s_a[StepSize][BlockRow];
__shared__ gfp_t s_b[StepSize][BlockCol];
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
gfp_t tmp_c[BlockRow / THREAD_Y][BlockCol / THREAD_X] = {0};
for (int s = 0; s < (nsteps - 1) / StepSize + 1; s++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (int k = tid; k < StepSize * BlockRow; k += blockDim.x * blockDim.y)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
const int a_r = k / StepSize;
const int a_c = k % StepSize;
s_a[a_c][a_r] = by * BlockRow + a_r < nrows && s * StepSize + a_c < nsteps ? *at_base(a, a_rs, by * BlockRow + a_r, s * StepSize + a_c) : 0;
const int b_r = k / BlockCol;
const int b_c = k % BlockCol;
s_b[b_r][b_c] = s * StepSize + b_r < nsteps && bx * BlockCol + b_c < ncols ? *at_base(b, b_rs, s * StepSize + b_r, bx * BlockCol + b_c) : 0;
}
__syncthreads();
for (int k = 0; k < StepSize; k++)
{
for (int j = 0; j < BlockRow / THREAD_Y; j++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (int i = 0; i < BlockCol / THREAD_X; i++)
{
tmp_c[j][i] += (s_a[k][j * THREAD_Y + ty] * s_b[k][i * THREAD_X + tx]) % gfprime;
}
2024-09-14 15:57:00 +08:00
}
}
2024-09-14 16:15:13 +08:00
__syncthreads();
if (s & gfp_fullmask == gfp_fullmask)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (int j = 0; j < BlockRow / THREAD_Y; j++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (int i = 0; i < BlockCol / THREAD_X; i++)
{
tmp_c[j][i] %= gfprime;
}
2024-09-14 15:57:00 +08:00
}
}
}
2024-09-14 16:15:13 +08:00
for (int j = 0; j < BlockRow / THREAD_Y; j++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (int i = 0; i < BlockCol / THREAD_X; i++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
if (by * BlockRow + j * THREAD_Y + ty < nrows && bx * BlockCol + i * THREAD_X + tx < ncols)
{
*at_base(c, c_rs, by * BlockRow + j * THREAD_Y + ty, bx * BlockCol + i * THREAD_X + tx) = tmp_c[j][i] % gfprime;
}
2024-09-14 15:57:00 +08:00
}
}
}
2024-09-14 16:15:13 +08:00
__host__ void MatGFP::gpu_mul(const MatGFP &a, const MatGFP &b)
{
assert(a.ncols == b.nrows && a.nrows == nrows && b.ncols == ncols);
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
dim3 block(THREAD_X, THREAD_Y);
dim3 grid((width - 1) / block.x + 1, (nrows - 1) / block.y + 1);
gpu_mul_kernel<<<grid, block>>>(a.data, a.rowstride, b.data, b.rowstride, data, rowstride, nrows, width, a.width);
cudaDeviceSynchronize();
}
2024-09-14 15:57:00 +08:00
}
#endif