cuElim/include/gf256/gf256_mat.cuh
2024-09-06 15:58:40 +08:00

194 lines
5.1 KiB
Plaintext
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef MATGF256_CUH
#define MATGF256_CUH
#include "gf256_header.cuh"
#include <random>
class MatGF256
{
public:
enum MatType
{
root,
view
};
// 只能构造root矩阵
MatGF256(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
{
width = (ncols - 1) / gf256_num + 1;
rowstride = ((width - 1) / 4 + 1) * 4; // 以32字节4*64bit为单位对齐
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(base_t)));
CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(base_t)));
}
// 只能拷贝构造root矩阵
MatGF256(const MatGF256 &m) : MatGF256(m.nrows, m.ncols)
{
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
}
MatGF256(MatGF256 &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data)
{
m.nrows = 0;
m.ncols = 0;
m.width = 0;
m.rowstride = 0;
m.type = view;
m.data = nullptr;
}
MatGF256 &operator=(const MatGF256 &m)
{
if (this == &m)
{
return *this;
}
assert(nrows == m.nrows && ncols == m.ncols);
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
return *this;
}
MatGF256 &operator=(MatGF256 &&m) noexcept
{
if (this == &m)
{
return *this;
}
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
nrows = m.nrows;
ncols = m.ncols;
width = m.width;
rowstride = m.rowstride;
type = m.type;
data = m.data;
m.nrows = 0;
m.ncols = 0;
m.width = 0;
m.rowstride = 0;
m.type = view;
m.data = nullptr;
return *this;
}
~MatGF256()
{
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
}
inline base_t *at_base(size_t r, size_t w) const
{
return data + r * rowstride + w;
}
// 只能以base_t为单位进行操作
MatGF256 createView(size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) const
{
assert(begin_ri < end_rj && end_rj <= nrows && begin_wi < end_wj && end_wj <= width);
MatGF256 view;
view.nrows = end_rj - begin_ri;
view.ncols = (end_wj == width ? ncols : end_wj * gf256_num) - begin_wi * gf256_num;
view.width = end_wj - begin_wi;
view.rowstride = rowstride;
view.data = at_base(begin_ri, begin_wi);
return view;
}
void randomize(base_t seed)
{
assert(type == root);
static default_random_engine e(seed);
static uniform_int_distribution<base_t> d;
base_t lastmask = base_fullmask >> (width * base_len - ncols * gf256_len);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = d(e);
}
*at_base(r, width - 1) &= lastmask;
}
}
bool operator==(const MatGF256 &m) const
{
if (nrows != m.nrows || ncols != m.ncols)
{
return false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != *m.at_base(r, w))
{
return false;
}
}
}
return true;
}
bool operator==(const base_t base) const
{
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != base)
{
return false;
}
}
}
return true;
}
void operator^=(const MatGF256 &m)
{
assert(nrows == m.nrows && ncols == m.ncols);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) ^= *m.at_base(r, w);
}
}
}
MatGF256 operator^(const MatGF256 &m) const
{
MatGF256 temp(*this);
temp ^= m;
return temp;
}
friend ostream &operator<<(ostream &out, const MatGF256 &m);
void gpu_addmul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
friend MatGF256 gpu_mul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf);
// size_t nrows, ncols;
// size_t width, rowstride;
private:
MatGF256() : nrows(0), ncols(0), width(0), rowstride(0), type(view), data(nullptr) {}
size_t nrows, ncols;
size_t width, rowstride;
MatType type;
base_t *data;
};
ostream &operator<<(ostream &out, const MatGF256 &m)
{
for (size_t r = 0; r < m.nrows; r++)
{
for (size_t w = 0; w < m.width; w++)
{
printf("%016lX ", rev8(*m.at_base(r, w)));
}
printf("\n");
}
return out;
}
#endif