#ifndef MULTIPLICATION_CUH #define MULTIPLICATION_CUH #include "gf256_mat.cuh" __global__ void gpu_mktb_kernel(base_t *r_tb, size_t tb_rowstride, base_t *src, size_t s_rowstride, size_t width, size_t nrows) { size_t w = blockIdx.x * blockDim.x + threadIdx.x; size_t r = blockIdx.y * blockDim.y + threadIdx.y; if (w >= width || r >= nrows) { return; } gf256_t val = get8(r, 0); base_t s = *at_base(src, s_rowstride, get8(r, 1), w); base_t d = mul_base(val, s); *at_base(r_tb, tb_rowstride, r, w) = d; } __global__ void gpu_addmul_kernel(base_t *a, size_t a_rowstride, base_t *tb, size_t tb_rowstride, base_t *c, size_t c_rowstride, size_t tb_num, size_t width, size_t nrows) { size_t w = blockIdx.x * blockDim.x + threadIdx.x; size_t r = blockIdx.y * blockDim.y + threadIdx.y; if (w >= width || r >= nrows) { return; } base_t val = *at_base(a, a_rowstride, r, 0); base_t temp = base_zero; for (size_t i = 0; i < tb_num; i++) { temp ^= *at_base(tb, tb_rowstride, i * (1 << gf256_len) + get8(val, i), w); } *at_base(c, c_rowstride, r, w) ^= temp; } __host__ void MatGF256::gpu_addmul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf) { assert(a.ncols == b.nrows && a.nrows == nrows && b.ncols == ncols); gf.cpy_to_constant(); MatGF256 tb(gf256_num * (1 << gf256_len), b.ncols); for (size_t w = 0; w < a.width; w++) { size_t tb_num = min(gf256_num, a.ncols - w * gf256_num); dim3 block_tb(THREAD_X, THREAD_Y); dim3 grid_tb((b.width - 1) / block_tb.x + 1, (tb.nrows - 1) / block_tb.y + 1); gpu_mktb_kernel<<>>(tb.data, tb.rowstride, b.at_base(w * gf256_num, 0), b.rowstride, tb.width, tb_num * (1 << gf256_len)); cudaDeviceSynchronize(); dim3 block(THREAD_X, THREAD_Y); dim3 grid((b.width - 1) / block.x + 1, (nrows - 1) / block.y + 1); gpu_addmul_kernel<<>>(a.at_base(0, w), a.rowstride, tb.data, tb.rowstride, data, rowstride, tb_num, width, nrows); cudaDeviceSynchronize(); } } __host__ MatGF256 gpu_mul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf) { assert(a.ncols == b.nrows); MatGF256 c(a.nrows, b.ncols); c.gpu_addmul(a, b, gf); return c; } #endif