cuElim/include/gf2/gf2_mat.cuh

249 lines
7.4 KiB
Plaintext
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef GF2_MAT_CUH
#define GF2_MAT_CUH
#include "gf2_header.cuh"
#include <random>
#include <vector>
#include <bitset>
// #include <algorithm>
namespace gf2
{
struct ElimResult
{
size_t rank;
vector<size_t> pivot;
vector<size_t> swap_row;
};
class MatGF2
{
public:
enum MatType
{
root,
window,
moved,
};
// 只能构造root矩阵
MatGF2(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
{
width = (ncols - 1) / gf2_num + 1;
rowstride = ((width - 1) / 4 + 1) * 4; // 以32字节4*64bit为单位对齐
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(base_t)));
CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(base_t)));
}
// 只能以base_t为单位建立window矩阵
MatGF2(const MatGF2 &src, size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) : nrows(end_rj - begin_ri), ncols((end_wj == src.width ? src.ncols : end_wj * gf2_num) - begin_wi * gf2_num), width(end_wj - begin_wi), rowstride(src.rowstride), type(window), data(src.at_base(begin_ri, begin_wi))
{
assert(begin_ri < end_rj && end_rj <= src.nrows && begin_wi < end_wj && end_wj <= src.width);
}
// 只能拷贝构造root矩阵
MatGF2(const MatGF2 &m) : MatGF2(m.nrows, m.ncols)
{
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
}
MatGF2(MatGF2 &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data)
{
m.type = moved;
m.data = nullptr;
}
MatGF2 &operator=(const MatGF2 &m)
{
if (this == &m)
{
return *this;
}
assert(nrows == m.nrows && ncols == m.ncols);
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault));
return *this;
}
MatGF2 &operator=(MatGF2 &&m) noexcept
{
if (this == &m)
{
return *this;
}
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
nrows = m.nrows;
ncols = m.ncols;
width = m.width;
rowstride = m.rowstride;
type = m.type;
data = m.data;
m.type = moved;
m.data = nullptr;
return *this;
}
~MatGF2()
{
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
}
inline base_t *at_base(size_t r, size_t w) const
{
return data + r * rowstride + w;
}
void randomize(uint_fast32_t seed)
{
assert(type == root);
static default_random_engine e(seed);
static uniform_int_distribution<base_t> d;
base_t lastmask = base_fullmask >> (width * base_len - ncols * gf2_len);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = d(e);
}
*at_base(r, width - 1) &= lastmask;
}
}
// 生成随机最简化行阶梯矩阵 前rank_col中选择nrows个主元列
void randomize(size_t rank_col, uint_fast32_t seed)
{
assert(nrows <= rank_col && rank_col <= ncols);
randomize(seed);
vector<size_t> pivot(rank_col);
iota(pivot.begin(), pivot.end(), 0);
random_shuffle(pivot.begin(), pivot.end());
pivot.resize(nrows);
sort(pivot.begin(), pivot.end());
vector<base_t> pivotmask(width, base_fullmask);
for (size_t r = 0; r < nrows; r++)
{
del(pivotmask[pivot[r] / gf2_num], pivot[r] % gf2_num);
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < pivot[r] / gf2_num; w++)
{
*at_base(r, w) = base_zero;
}
base_t *now = at_base(r, pivot[r] / gf2_num);
*now = concat(base_zero, pivot[r] % gf2_num + 1, *now & pivotmask[pivot[r] / gf2_num]);
set(*now, pivot[r] % gf2_num);
for (size_t w = pivot[r] / gf2_num + 1; w < rank_col / gf2_num + 1; w++)
{
*at_base(r, w) &= pivotmask[w];
}
}
}
bool operator==(const MatGF2 &m) const
{
if (nrows != m.nrows || ncols != m.ncols)
{
return false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != *m.at_base(r, w))
{
return false;
}
}
}
return true;
}
bool operator==(const base_t base) const
{
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != base)
{
return false;
}
}
}
return true;
}
void operator^=(const MatGF2 &m)
{
assert(nrows == m.nrows && ncols == m.ncols);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) ^= *m.at_base(r, w);
}
}
}
MatGF2 operator^(const MatGF2 &m) const
{
MatGF2 temp(*this);
temp ^= m;
return temp;
}
void gpu_addmul(const MatGF2 &a, const MatGF2 &b);
friend MatGF2 gpu_mul(const MatGF2 &a, const MatGF2 &b);
MatGF2 operator*(const MatGF2 &m) const
{
return gpu_mul(*this, m);
}
ElimResult gpu_elim();
friend ostream &operator<<(ostream &out, const MatGF2 &m);
size_t nrows, ncols, width;
private:
MatGF2() : nrows(0), ncols(0), width(0), rowstride(0), type(moved), data(nullptr) {}
void cpu_swap_row(size_t r1, size_t r2)
{
if (r1 == r2)
{
return;
}
base_t *p1 = at_base(r1, 0);
base_t *p2 = at_base(r2, 0);
for (size_t i = 0; i < width; i++)
{
base_t temp = p1[i];
p1[i] = p2[i];
p2[i] = temp;
}
}
size_t rowstride;
MatType type;
base_t *data;
};
ostream &operator<<(ostream &out, const MatGF2 &m)
{
for (size_t r = 0; r < m.nrows; r++)
{
for (size_t w = 0; w < m.width; w++)
{
bitset<gf2_num> temp(rev(*m.at_base(r, w)));
out << temp << " ";
}
out << endl;
}
return out;
}
}
#endif