2024-09-12 18:53:59 +08:00
|
|
|
|
#ifndef GF256_MAT_CUH
|
|
|
|
|
#define GF256_MAT_CUH
|
2024-09-05 16:56:58 +08:00
|
|
|
|
|
2024-09-06 15:58:40 +08:00
|
|
|
|
#include "gf256_header.cuh"
|
2024-09-05 16:56:58 +08:00
|
|
|
|
|
2024-09-06 15:58:40 +08:00
|
|
|
|
#include <random>
|
2024-09-12 18:53:59 +08:00
|
|
|
|
#include <vector>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
struct ElimResult
|
|
|
|
|
{
|
|
|
|
|
size_t rank;
|
|
|
|
|
vector<size_t> pivot;
|
|
|
|
|
vector<size_t> swap_row;
|
|
|
|
|
};
|
2024-09-06 15:58:40 +08:00
|
|
|
|
|
|
|
|
|
class MatGF256
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
public:
|
|
|
|
|
enum MatType
|
|
|
|
|
{
|
|
|
|
|
root,
|
2024-09-12 18:53:59 +08:00
|
|
|
|
window,
|
|
|
|
|
moved,
|
2024-09-05 16:56:58 +08:00
|
|
|
|
};
|
|
|
|
|
// 只能构造root矩阵
|
2024-09-06 15:58:40 +08:00
|
|
|
|
MatGF256(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
2024-09-06 15:58:40 +08:00
|
|
|
|
width = (ncols - 1) / gf256_num + 1;
|
|
|
|
|
rowstride = ((width - 1) / 4 + 1) * 4; // 以32字节(4*64bit)为单位对齐
|
|
|
|
|
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(base_t)));
|
|
|
|
|
CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(base_t)));
|
2024-09-05 16:56:58 +08:00
|
|
|
|
}
|
2024-09-12 18:53:59 +08:00
|
|
|
|
// 只能以base_t为单位建立window矩阵
|
|
|
|
|
MatGF256(const MatGF256 &src, size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) : nrows(end_rj - begin_ri), ncols((end_wj == src.width ? src.ncols : end_wj * gf256_num) - begin_wi * gf256_num), width(end_wj - begin_wi), rowstride(src.rowstride), type(window), data(src.at_base(begin_ri, begin_wi))
|
|
|
|
|
{
|
|
|
|
|
assert(begin_ri < end_rj && end_rj <= src.nrows && begin_wi < end_wj && end_wj <= src.width);
|
|
|
|
|
}
|
2024-09-05 16:56:58 +08:00
|
|
|
|
// 只能拷贝构造root矩阵
|
2024-09-06 15:58:40 +08:00
|
|
|
|
MatGF256(const MatGF256 &m) : MatGF256(m.nrows, m.ncols)
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
2024-09-06 15:58:40 +08:00
|
|
|
|
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
|
2024-09-05 16:56:58 +08:00
|
|
|
|
}
|
2024-09-06 15:58:40 +08:00
|
|
|
|
MatGF256(MatGF256 &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data)
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
2024-09-12 18:53:59 +08:00
|
|
|
|
m.type = moved;
|
2024-09-05 16:56:58 +08:00
|
|
|
|
m.data = nullptr;
|
|
|
|
|
}
|
2024-09-06 15:58:40 +08:00
|
|
|
|
MatGF256 &operator=(const MatGF256 &m)
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
if (this == &m)
|
|
|
|
|
{
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
assert(nrows == m.nrows && ncols == m.ncols);
|
2024-09-06 15:58:40 +08:00
|
|
|
|
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
|
2024-09-05 16:56:58 +08:00
|
|
|
|
return *this;
|
|
|
|
|
}
|
2024-09-06 15:58:40 +08:00
|
|
|
|
MatGF256 &operator=(MatGF256 &&m) noexcept
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
if (this == &m)
|
|
|
|
|
{
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
if (type == root)
|
|
|
|
|
{
|
|
|
|
|
CUDA_CHECK(cudaFree(data));
|
|
|
|
|
}
|
|
|
|
|
nrows = m.nrows;
|
|
|
|
|
ncols = m.ncols;
|
|
|
|
|
width = m.width;
|
2024-09-06 15:58:40 +08:00
|
|
|
|
rowstride = m.rowstride;
|
2024-09-05 16:56:58 +08:00
|
|
|
|
type = m.type;
|
|
|
|
|
data = m.data;
|
2024-09-12 18:53:59 +08:00
|
|
|
|
m.type = moved;
|
2024-09-05 16:56:58 +08:00
|
|
|
|
m.data = nullptr;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-06 15:58:40 +08:00
|
|
|
|
~MatGF256()
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
if (type == root)
|
|
|
|
|
{
|
|
|
|
|
CUDA_CHECK(cudaFree(data));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline base_t *at_base(size_t r, size_t w) const
|
|
|
|
|
{
|
2024-09-06 15:58:40 +08:00
|
|
|
|
return data + r * rowstride + w;
|
2024-09-05 16:56:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
2024-09-12 18:53:59 +08:00
|
|
|
|
void randomize(uint_fast32_t seed)
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
assert(type == root);
|
|
|
|
|
static default_random_engine e(seed);
|
|
|
|
|
static uniform_int_distribution<base_t> d;
|
2024-09-06 15:58:40 +08:00
|
|
|
|
base_t lastmask = base_fullmask >> (width * base_len - ncols * gf256_len);
|
2024-09-05 16:56:58 +08:00
|
|
|
|
for (size_t r = 0; r < nrows; r++)
|
|
|
|
|
{
|
|
|
|
|
for (size_t w = 0; w < width; w++)
|
|
|
|
|
{
|
|
|
|
|
*at_base(r, w) = d(e);
|
|
|
|
|
}
|
|
|
|
|
*at_base(r, width - 1) &= lastmask;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-12 18:53:59 +08:00
|
|
|
|
// 生成随机最简化行阶梯矩阵 前rank_col中选择nrows个主元列
|
|
|
|
|
void randomize(size_t rank_col, uint_fast32_t seed)
|
|
|
|
|
{
|
|
|
|
|
assert(nrows <= rank_col && rank_col <= ncols);
|
|
|
|
|
randomize(seed);
|
|
|
|
|
vector<size_t> pivot(rank_col);
|
|
|
|
|
iota(pivot.begin(), pivot.end(), 0);
|
|
|
|
|
random_shuffle(pivot.begin(), pivot.end());
|
|
|
|
|
pivot.resize(nrows);
|
|
|
|
|
sort(pivot.begin(), pivot.end());
|
|
|
|
|
|
|
|
|
|
vector<base_t> pivotmask(width, base_fullmask);
|
|
|
|
|
for (size_t r = 0; r < nrows; r++)
|
|
|
|
|
{
|
|
|
|
|
del8(pivotmask[pivot[r] / gf256_num], pivot[r] % gf256_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t r = 0; r < nrows; r++)
|
|
|
|
|
{
|
|
|
|
|
for (size_t w = 0; w < pivot[r] / gf256_num; w++)
|
|
|
|
|
{
|
|
|
|
|
*at_base(r, w) = base_zero;
|
|
|
|
|
}
|
|
|
|
|
base_t *now = at_base(r, pivot[r] / gf256_num);
|
|
|
|
|
*now = concat8(base_zero, pivot[r] % gf256_num + 1, *now & pivotmask[pivot[r] / gf256_num]);
|
|
|
|
|
set8(*now, gf256_one, pivot[r] % gf256_num);
|
|
|
|
|
for (size_t w = pivot[r] / gf256_num + 1; w < rank_col / gf256_num + 1; w++)
|
|
|
|
|
{
|
|
|
|
|
*at_base(r, w) &= pivotmask[w];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-06 15:58:40 +08:00
|
|
|
|
bool operator==(const MatGF256 &m) const
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
if (nrows != m.nrows || ncols != m.ncols)
|
|
|
|
|
{
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (size_t r = 0; r < nrows; r++)
|
|
|
|
|
{
|
|
|
|
|
for (size_t w = 0; w < width; w++)
|
|
|
|
|
{
|
|
|
|
|
if (*at_base(r, w) != *m.at_base(r, w))
|
|
|
|
|
{
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator==(const base_t base) const
|
|
|
|
|
{
|
|
|
|
|
for (size_t r = 0; r < nrows; r++)
|
|
|
|
|
{
|
|
|
|
|
for (size_t w = 0; w < width; w++)
|
|
|
|
|
{
|
|
|
|
|
if (*at_base(r, w) != base)
|
|
|
|
|
{
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
2024-09-06 15:58:40 +08:00
|
|
|
|
void operator^=(const MatGF256 &m)
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
assert(nrows == m.nrows && ncols == m.ncols);
|
|
|
|
|
for (size_t r = 0; r < nrows; r++)
|
|
|
|
|
{
|
|
|
|
|
for (size_t w = 0; w < width; w++)
|
|
|
|
|
{
|
|
|
|
|
*at_base(r, w) ^= *m.at_base(r, w);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-09-06 15:58:40 +08:00
|
|
|
|
MatGF256 operator^(const MatGF256 &m) const
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
2024-09-06 15:58:40 +08:00
|
|
|
|
MatGF256 temp(*this);
|
2024-09-05 16:56:58 +08:00
|
|
|
|
temp ^= m;
|
|
|
|
|
return temp;
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-06 15:58:40 +08:00
|
|
|
|
void gpu_addmul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
|
|
|
|
|
friend MatGF256 gpu_mul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
|
2024-09-05 16:56:58 +08:00
|
|
|
|
|
2024-09-12 18:53:59 +08:00
|
|
|
|
// size_t cpu_elim_base(base_t *base_col, size_t st_r, size_t w, vector<size_t> &p_col, vector<size_t> &p_row, base_t step[gf256_num], const GF256 &gf);
|
|
|
|
|
void cpu_swap_row(size_t r1, size_t r2);
|
|
|
|
|
// void cpu_mul_row(size_t r, gf256_t val, const GF256 &gf);
|
|
|
|
|
ElimResult gpu_elim(const GF256 &gf);
|
|
|
|
|
|
|
|
|
|
friend ostream &operator<<(ostream &out, const MatGF256 &m);
|
|
|
|
|
|
|
|
|
|
size_t nrows, ncols, width;
|
2024-09-05 16:56:58 +08:00
|
|
|
|
|
|
|
|
|
private:
|
2024-09-12 18:53:59 +08:00
|
|
|
|
MatGF256() : nrows(0), ncols(0), width(0), rowstride(0), type(moved), data(nullptr) {}
|
|
|
|
|
|
|
|
|
|
size_t rowstride;
|
2024-09-05 16:56:58 +08:00
|
|
|
|
MatType type;
|
|
|
|
|
base_t *data;
|
|
|
|
|
};
|
|
|
|
|
|
2024-09-06 15:58:40 +08:00
|
|
|
|
ostream &operator<<(ostream &out, const MatGF256 &m)
|
2024-09-05 16:56:58 +08:00
|
|
|
|
{
|
|
|
|
|
for (size_t r = 0; r < m.nrows; r++)
|
|
|
|
|
{
|
|
|
|
|
for (size_t w = 0; w < m.width; w++)
|
|
|
|
|
{
|
|
|
|
|
printf("%016lX ", rev8(*m.at_base(r, w)));
|
|
|
|
|
}
|
|
|
|
|
printf("\n");
|
|
|
|
|
}
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|