cuElim/include/gf256/gf256_mat.cuh
2024-09-19 15:59:53 +08:00

242 lines
7.4 KiB
Plaintext
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef GF256_MAT_CUH
#define GF256_MAT_CUH
#include "gf256_header.cuh"
#include <random>
#include <vector>
#include <algorithm>
namespace gf256
{
struct ElimResult
{
size_t rank;
vector<size_t> pivot;
vector<size_t> swap_row;
};
class MatGF256
{
public:
enum MatType
{
root,
window,
moved,
};
// 只能构造root矩阵
MatGF256(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
{
width = (ncols - 1) / gf256_num + 1;
rowstride = ((width - 1) / 4 + 1) * 4; // 以32字节4*64bit为单位对齐
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(base_t)));
CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(base_t)));
}
// 只能以base_t为单位建立window矩阵
MatGF256(const MatGF256 &src, size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) : nrows(end_rj - begin_ri), ncols((end_wj == src.width ? src.ncols : end_wj * gf256_num) - begin_wi * gf256_num), width(end_wj - begin_wi), rowstride(src.rowstride), type(window), data(src.at_base(begin_ri, begin_wi))
{
assert(begin_ri < end_rj && end_rj <= src.nrows && begin_wi < end_wj && end_wj <= src.width);
}
// 只能拷贝构造root矩阵
MatGF256(const MatGF256 &m) : MatGF256(m.nrows, m.ncols)
{
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
}
MatGF256(MatGF256 &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data)
{
m.type = moved;
m.data = nullptr;
}
MatGF256 &operator=(const MatGF256 &m)
{
if (this == &m)
{
return *this;
}
assert(nrows == m.nrows && ncols == m.ncols);
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
return *this;
}
MatGF256 &operator=(MatGF256 &&m) noexcept
{
if (this == &m)
{
return *this;
}
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
nrows = m.nrows;
ncols = m.ncols;
width = m.width;
rowstride = m.rowstride;
type = m.type;
data = m.data;
m.type = moved;
m.data = nullptr;
return *this;
}
~MatGF256()
{
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
}
inline base_t *at_base(size_t r, size_t w) const
{
return data + r * rowstride + w;
}
void randomize(uint_fast32_t seed)
{
assert(type == root);
static default_random_engine e(seed);
static uniform_int_distribution<base_t> d;
base_t lastmask = base_fullmask >> (width * base_len - ncols * gf256_len);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = d(e);
}
*at_base(r, width - 1) &= lastmask;
}
}
// 生成随机最简化行阶梯矩阵 前rank_col中选择nrows个主元列
void randomize(size_t rank_col, uint_fast32_t seed)
{
assert(nrows <= rank_col && rank_col <= ncols);
randomize(seed);
vector<size_t> pivot(rank_col);
iota(pivot.begin(), pivot.end(), 0);
random_shuffle(pivot.begin(), pivot.end());
pivot.resize(nrows);
sort(pivot.begin(), pivot.end());
vector<base_t> pivotmask(width, base_fullmask);
for (size_t r = 0; r < nrows; r++)
{
del8(pivotmask[pivot[r] / gf256_num], pivot[r] % gf256_num);
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < pivot[r] / gf256_num; w++)
{
*at_base(r, w) = base_zero;
}
base_t *now = at_base(r, pivot[r] / gf256_num);
*now = concat8(base_zero, pivot[r] % gf256_num + 1, *now & pivotmask[pivot[r] / gf256_num]);
set8(*now, gf256_one, pivot[r] % gf256_num);
for (size_t w = pivot[r] / gf256_num + 1; w < rank_col / gf256_num + 1; w++)
{
*at_base(r, w) &= pivotmask[w];
}
}
}
bool operator==(const MatGF256 &m) const
{
if (nrows != m.nrows || ncols != m.ncols)
{
return false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != *m.at_base(r, w))
{
return false;
}
}
}
return true;
}
bool operator==(const base_t base) const
{
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != base)
{
return false;
}
}
}
return true;
}
void operator^=(const MatGF256 &m)
{
assert(nrows == m.nrows && ncols == m.ncols);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) ^= *m.at_base(r, w);
}
}
}
MatGF256 operator^(const MatGF256 &m) const
{
MatGF256 temp(*this);
temp ^= m;
return temp;
}
void gpu_addmul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
friend MatGF256 gpu_mul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
ElimResult gpu_elim(const GF256 &gf);
friend ostream &operator<<(ostream &out, const MatGF256 &m);
size_t nrows, ncols, width;
private:
MatGF256() : nrows(0), ncols(0), width(0), rowstride(0), type(moved), data(nullptr) {}
void cpu_swap_row(size_t r1, size_t r2)
{
if (r1 == r2)
{
return;
}
base_t *p1 = at_base(r1, 0);
base_t *p2 = at_base(r2, 0);
for (size_t i = 0; i < width; i++)
{
base_t temp = p1[i];
p1[i] = p2[i];
p2[i] = temp;
}
}
size_t rowstride;
MatType type;
base_t *data;
};
ostream &operator<<(ostream &out, const MatGF256 &m)
{
for (size_t r = 0; r < m.nrows; r++)
{
for (size_t w = 0; w < m.width; w++)
{
printf("%016lX ", rev8(*m.at_base(r, w)));
}
printf("\n");
}
return out;
}
}
#endif