86 lines
3.3 KiB
Plaintext
86 lines
3.3 KiB
Plaintext
#ifndef GFP_MUL_CUH
|
||
#define GFP_MUL_CUH
|
||
|
||
#include "gfp_mat.cuh"
|
||
|
||
namespace gfp
|
||
{
|
||
|
||
static const int BlockRow = 128, BlockCol = 128; // 每个block处理c矩阵的一个子块
|
||
static const int StepSize = 8; // block中一个循环处理的A矩阵的列数(B矩阵的行数)
|
||
|
||
static_assert(BlockCol % THREAD_X == 0 && BlockRow % THREAD_Y == 0);
|
||
|
||
__global__ void gpu_mul_kernel(gfp_t *__restrict__ a, const size_t a_rs, gfp_t *__restrict__ b, const size_t b_rs, gfp_t *__restrict__ c, const size_t c_rs, const size_t nrows, const size_t ncols, const size_t nsteps)
|
||
{
|
||
|
||
const unsigned int bx = blockIdx.x;
|
||
const unsigned int by = blockIdx.y;
|
||
const unsigned int tx = threadIdx.x;
|
||
const unsigned int ty = threadIdx.y;
|
||
const unsigned int tid = ty * blockDim.x + tx;
|
||
|
||
__shared__ gfp_t s_a[StepSize][BlockRow];
|
||
__shared__ gfp_t s_b[StepSize][BlockCol];
|
||
|
||
gfp_t tmp_c[BlockRow / THREAD_Y][BlockCol / THREAD_X] = {0};
|
||
|
||
for (int s = 0; s < (nsteps - 1) / StepSize + 1; s++)
|
||
{
|
||
for (int k = tid; k < StepSize * BlockRow; k += blockDim.x * blockDim.y)
|
||
{
|
||
const int a_r = k / StepSize;
|
||
const int a_c = k % StepSize;
|
||
s_a[a_c][a_r] = by * BlockRow + a_r < nrows && s * StepSize + a_c < nsteps ? *at_base(a, a_rs, by * BlockRow + a_r, s * StepSize + a_c) : 0;
|
||
const int b_r = k / BlockCol;
|
||
const int b_c = k % BlockCol;
|
||
s_b[b_r][b_c] = s * StepSize + b_r < nsteps && bx * BlockCol + b_c < ncols ? *at_base(b, b_rs, s * StepSize + b_r, bx * BlockCol + b_c) : 0;
|
||
}
|
||
__syncthreads();
|
||
for (int k = 0; k < StepSize; k++)
|
||
{
|
||
for (int j = 0; j < BlockRow / THREAD_Y; j++)
|
||
{
|
||
for (int i = 0; i < BlockCol / THREAD_X; i++)
|
||
{
|
||
tmp_c[j][i] += (s_a[k][j * THREAD_Y + ty] * s_b[k][i * THREAD_X + tx]) % gfprime;
|
||
}
|
||
}
|
||
}
|
||
__syncthreads();
|
||
if (s & gfp_fullmask == gfp_fullmask)
|
||
{
|
||
for (int j = 0; j < BlockRow / THREAD_Y; j++)
|
||
{
|
||
for (int i = 0; i < BlockCol / THREAD_X; i++)
|
||
{
|
||
tmp_c[j][i] %= gfprime;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
for (int j = 0; j < BlockRow / THREAD_Y; j++)
|
||
{
|
||
for (int i = 0; i < BlockCol / THREAD_X; i++)
|
||
{
|
||
if (by * BlockRow + j * THREAD_Y + ty < nrows && bx * BlockCol + i * THREAD_X + tx < ncols)
|
||
{
|
||
*at_base(c, c_rs, by * BlockRow + j * THREAD_Y + ty, bx * BlockCol + i * THREAD_X + tx) = tmp_c[j][i] % gfprime;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
__host__ void MatGFP::gpu_mul(const MatGFP &a, const MatGFP &b)
|
||
{
|
||
assert(a.ncols == b.nrows && a.nrows == nrows && b.ncols == ncols);
|
||
|
||
dim3 block(THREAD_X, THREAD_Y);
|
||
dim3 grid((width - 1) / block.x + 1, (nrows - 1) / block.y + 1);
|
||
gpu_mul_kernel<<<grid, block>>>(a.data, a.rowstride, b.data, b.rowstride, data, rowstride, nrows, width, a.width);
|
||
cudaDeviceSynchronize();
|
||
}
|
||
}
|
||
|
||
#endif
|