cuElim/include/matrix.cuh
2024-09-05 23:46:07 +08:00

209 lines
5.5 KiB
Plaintext
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef MATRIX_CUH
#define MATRIX_CUH
#include "header.cuh"
#include "gf28.cuh"
class GF28Matrix
{
public:
enum MatType
{
root,
view
};
// 只能构造root矩阵
GF28Matrix(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
{
width = (ncols - 1) / base_num + 1;
pitch = ((width - 1) / 4 + 1) * 4; // 以32字节4*64bit为单位对齐
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * pitch * sizeof(base_t)));
CUDA_CHECK(cudaMemset(data, 0, nrows * pitch * sizeof(base_t)));
}
// 只能拷贝构造root矩阵
GF28Matrix(const GF28Matrix &m) : GF28Matrix(m.nrows, m.ncols)
{
CUDA_CHECK(cudaMemcpy2D(data, pitch * sizeof(base_t), m.data, m.pitch * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
}
GF28Matrix(GF28Matrix &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), pitch(m.pitch), type(m.type), data(m.data)
{
m.nrows = 0;
m.ncols = 0;
m.width = 0;
m.pitch = 0;
m.type = view;
m.data = nullptr;
}
GF28Matrix &operator=(const GF28Matrix &m)
{
if (this == &m)
{
return *this;
}
assert(nrows == m.nrows && ncols == m.ncols);
CUDA_CHECK(cudaMemcpy2D(data, pitch * sizeof(base_t), m.data, m.pitch * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
return *this;
}
GF28Matrix &operator=(GF28Matrix &&m) noexcept
{
if (this == &m)
{
return *this;
}
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
nrows = m.nrows;
ncols = m.ncols;
width = m.width;
pitch = m.pitch;
type = m.type;
data = m.data;
m.nrows = 0;
m.ncols = 0;
m.width = 0;
m.pitch = 0;
m.type = view;
m.data = nullptr;
return *this;
}
~GF28Matrix()
{
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
}
inline base_t *at_base(size_t r, size_t w) const
{
return data + r * pitch + w;
}
// 只能以base_t为单位进行操作
GF28Matrix createView(size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) const
{
assert(begin_ri < end_rj && end_rj <= nrows && begin_wi < end_wj && end_wj <= width);
GF28Matrix view;
view.nrows = end_rj - begin_ri;
view.ncols = (end_wj == width ? ncols : end_wj * base_num) - begin_wi * base_num;
view.width = end_wj - begin_wi;
view.pitch = pitch;
view.data = at_base(begin_ri, begin_wi);
return view;
}
void randomize(base_t seed)
{
assert(type == root);
static default_random_engine e(seed);
static uniform_int_distribution<base_t> d;
base_t lastmask = base_fullmask >> (width * base_len - ncols * base_deg);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = d(e);
}
*at_base(r, width - 1) &= lastmask;
}
}
// void write(string path) const
// {
// assert(type == root);
// ofstream out(path, ios::binary);
// out.write((char *)data, nrows * pitch * sizeof(base_t));
// out.close();
// }
// void read(string path)
// {
// assert(type == root);
// ifstream in(path, ios::binary);
// in.read((char *)data, nrows * pitch * sizeof(base_t));
// in.close();
// }
bool operator==(const GF28Matrix &m) const
{
if (nrows != m.nrows || ncols != m.ncols)
{
return false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != *m.at_base(r, w))
{
return false;
}
}
}
return true;
}
bool operator==(const base_t base) const
{
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != base)
{
return false;
}
}
}
return true;
}
void operator^=(const GF28Matrix &m)
{
assert(nrows == m.nrows && ncols == m.ncols);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) ^= *m.at_base(r, w);
}
}
}
GF28Matrix operator^(const GF28Matrix &m) const
{
GF28Matrix temp(*this);
temp ^= m;
return temp;
}
friend ostream &operator<<(ostream &out, const GF28Matrix &m);
void gpu_addmul(const GF28Matrix &a, const GF28Matrix &b, const GF28 &gf);
friend GF28Matrix gpu_mul(const GF28Matrix &a, const GF28Matrix &b, const GF28 &gf);
// size_t nrows, ncols;
// size_t width, pitch;
private:
GF28Matrix() : nrows(0), ncols(0), width(0), pitch(0), type(view), data(nullptr) {}
size_t nrows, ncols;
size_t width, pitch;
MatType type;
base_t *data;
};
ostream &operator<<(ostream &out, const GF28Matrix &m)
{
for (size_t r = 0; r < m.nrows; r++)
{
for (size_t w = 0; w < m.width; w++)
{
printf("%016lX ", rev8(*m.at_base(r, w)));
}
printf("\n");
}
return out;
}
#endif