cuElim/include/multiplication.cuh
2024-09-05 23:46:07 +08:00

56 lines
1.9 KiB
Plaintext

#ifndef MULTIPLICATION_CUH
#define MULTIPLICATION_CUH
#include "matrix.cuh"
#include "gf28.cuh"
__global__ void gpu_addmul_kernel(base_t *a, size_t a_pitch, base_t *tb, size_t tb_pitch, base_t *c, size_t c_pitch, size_t tb_num, size_t width, size_t nrows)
{
size_t w = blockIdx.x * blockDim.x + threadIdx.x;
size_t r = blockIdx.y * blockDim.y + threadIdx.y;
if (w >= width || r >= nrows)
{
return;
}
base_t val = *at_pitch(a, a_pitch, r, 0);
base_t temp = base_zero;
for (size_t i = 0; i < tb_num; i++)
{
temp ^= *at_pitch(tb, tb_pitch, i * (1 << base_deg) + get8(val, i), w);
}
*at_pitch(c, c_pitch, r, w) ^= temp;
}
__host__ void GF28Matrix::gpu_addmul(const GF28Matrix &a, const GF28Matrix &b, const GF28 &gf)
{
assert(a.ncols == b.nrows && a.nrows == nrows && b.ncols == ncols);
cudaMemcpyToSymbol(d_mul_table, gf.mul_table, (1 << base_deg) * (1 << base_deg) * sizeof(gf28_t));
GF28Matrix tb(base_num * (1 << base_deg), b.ncols);
for (size_t w = 0; w < a.width; w++)
{
size_t tb_num = min(base_num, a.ncols - w * base_num);
dim3 block_tb(THREAD_X, THREAD_Y);
dim3 grid_tb((b.width - 1) / block_tb.x + 1, (tb.nrows - 1) / block_tb.y + 1);
gpu_mktb_kernel<<<grid_tb, block_tb>>>(tb.data, tb.pitch, b.at_base(w * base_num, 0), b.pitch, tb.width, tb_num * (1 << base_deg));
cudaDeviceSynchronize();
dim3 block(THREAD_X, THREAD_Y);
dim3 grid((b.width - 1) / block.x + 1, (nrows - 1) / block.y + 1);
gpu_addmul_kernel<<<grid, block>>>(a.at_base(0, w), a.pitch, tb.data, tb.pitch, data, pitch, tb_num, width, nrows);
cudaDeviceSynchronize();
}
}
__host__ GF28Matrix gpu_mul(const GF28Matrix &a, const GF28Matrix &b, const GF28 &gf)
{
assert(a.ncols == b.nrows);
GF28Matrix c(a.nrows, b.ncols);
c.gpu_addmul(a, b, gf);
return c;
}
#endif