cuElim/include/gf256/gf256_mat.cuh

194 lines
5.1 KiB
Plaintext
Raw Normal View History

2024-09-06 15:58:40 +08:00
#ifndef MATGF256_CUH
#define MATGF256_CUH
2024-09-05 16:56:58 +08:00
2024-09-06 15:58:40 +08:00
#include "gf256_header.cuh"
2024-09-05 16:56:58 +08:00
2024-09-06 15:58:40 +08:00
#include <random>
class MatGF256
2024-09-05 16:56:58 +08:00
{
public:
enum MatType
{
root,
view
};
// 只能构造root矩阵
2024-09-06 15:58:40 +08:00
MatGF256(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
2024-09-05 16:56:58 +08:00
{
2024-09-06 15:58:40 +08:00
width = (ncols - 1) / gf256_num + 1;
rowstride = ((width - 1) / 4 + 1) * 4; // 以32字节4*64bit为单位对齐
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(base_t)));
CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(base_t)));
2024-09-05 16:56:58 +08:00
}
// 只能拷贝构造root矩阵
2024-09-06 15:58:40 +08:00
MatGF256(const MatGF256 &m) : MatGF256(m.nrows, m.ncols)
2024-09-05 16:56:58 +08:00
{
2024-09-06 15:58:40 +08:00
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
2024-09-05 16:56:58 +08:00
}
2024-09-06 15:58:40 +08:00
MatGF256(MatGF256 &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data)
2024-09-05 16:56:58 +08:00
{
m.nrows = 0;
m.ncols = 0;
m.width = 0;
2024-09-06 15:58:40 +08:00
m.rowstride = 0;
2024-09-05 16:56:58 +08:00
m.type = view;
m.data = nullptr;
}
2024-09-06 15:58:40 +08:00
MatGF256 &operator=(const MatGF256 &m)
2024-09-05 16:56:58 +08:00
{
if (this == &m)
{
return *this;
}
assert(nrows == m.nrows && ncols == m.ncols);
2024-09-06 15:58:40 +08:00
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
2024-09-05 16:56:58 +08:00
return *this;
}
2024-09-06 15:58:40 +08:00
MatGF256 &operator=(MatGF256 &&m) noexcept
2024-09-05 16:56:58 +08:00
{
if (this == &m)
{
return *this;
}
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
nrows = m.nrows;
ncols = m.ncols;
width = m.width;
2024-09-06 15:58:40 +08:00
rowstride = m.rowstride;
2024-09-05 16:56:58 +08:00
type = m.type;
data = m.data;
m.nrows = 0;
m.ncols = 0;
m.width = 0;
2024-09-06 15:58:40 +08:00
m.rowstride = 0;
2024-09-05 16:56:58 +08:00
m.type = view;
m.data = nullptr;
return *this;
}
2024-09-06 15:58:40 +08:00
~MatGF256()
2024-09-05 16:56:58 +08:00
{
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
}
inline base_t *at_base(size_t r, size_t w) const
{
2024-09-06 15:58:40 +08:00
return data + r * rowstride + w;
2024-09-05 16:56:58 +08:00
}
// 只能以base_t为单位进行操作
2024-09-06 15:58:40 +08:00
MatGF256 createView(size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) const
2024-09-05 16:56:58 +08:00
{
assert(begin_ri < end_rj && end_rj <= nrows && begin_wi < end_wj && end_wj <= width);
2024-09-06 15:58:40 +08:00
MatGF256 view;
2024-09-05 16:56:58 +08:00
view.nrows = end_rj - begin_ri;
2024-09-06 15:58:40 +08:00
view.ncols = (end_wj == width ? ncols : end_wj * gf256_num) - begin_wi * gf256_num;
2024-09-05 16:56:58 +08:00
view.width = end_wj - begin_wi;
2024-09-06 15:58:40 +08:00
view.rowstride = rowstride;
2024-09-05 16:56:58 +08:00
view.data = at_base(begin_ri, begin_wi);
return view;
}
void randomize(base_t seed)
{
assert(type == root);
static default_random_engine e(seed);
static uniform_int_distribution<base_t> d;
2024-09-06 15:58:40 +08:00
base_t lastmask = base_fullmask >> (width * base_len - ncols * gf256_len);
2024-09-05 16:56:58 +08:00
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = d(e);
}
*at_base(r, width - 1) &= lastmask;
}
}
2024-09-06 15:58:40 +08:00
bool operator==(const MatGF256 &m) const
2024-09-05 16:56:58 +08:00
{
if (nrows != m.nrows || ncols != m.ncols)
{
return false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != *m.at_base(r, w))
{
return false;
}
}
}
return true;
}
bool operator==(const base_t base) const
{
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != base)
{
return false;
}
}
}
return true;
}
2024-09-06 15:58:40 +08:00
void operator^=(const MatGF256 &m)
2024-09-05 16:56:58 +08:00
{
assert(nrows == m.nrows && ncols == m.ncols);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) ^= *m.at_base(r, w);
}
}
}
2024-09-06 15:58:40 +08:00
MatGF256 operator^(const MatGF256 &m) const
2024-09-05 16:56:58 +08:00
{
2024-09-06 15:58:40 +08:00
MatGF256 temp(*this);
2024-09-05 16:56:58 +08:00
temp ^= m;
return temp;
}
2024-09-06 15:58:40 +08:00
friend ostream &operator<<(ostream &out, const MatGF256 &m);
void gpu_addmul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
friend MatGF256 gpu_mul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
2024-09-05 16:56:58 +08:00
// size_t nrows, ncols;
2024-09-06 15:58:40 +08:00
// size_t width, rowstride;
2024-09-05 16:56:58 +08:00
private:
2024-09-06 15:58:40 +08:00
MatGF256() : nrows(0), ncols(0), width(0), rowstride(0), type(view), data(nullptr) {}
2024-09-05 16:56:58 +08:00
size_t nrows, ncols;
2024-09-06 15:58:40 +08:00
size_t width, rowstride;
2024-09-05 16:56:58 +08:00
MatType type;
base_t *data;
};
2024-09-06 15:58:40 +08:00
ostream &operator<<(ostream &out, const MatGF256 &m)
2024-09-05 16:56:58 +08:00
{
for (size_t r = 0; r < m.nrows; r++)
{
for (size_t w = 0; w < m.width; w++)
{
printf("%016lX ", rev8(*m.at_base(r, w)));
}
printf("\n");
}
return out;
}
#endif