cuElim/include/gfp/gfp_mat.cuh

288 lines
8.7 KiB
Plaintext
Raw Normal View History

2024-09-14 15:57:00 +08:00
#ifndef GFP_MAT_CUH
#define GFP_MAT_CUH
#include "gfp_header.cuh"
#include <random>
#include <vector>
#include <algorithm>
2024-09-14 16:15:13 +08:00
namespace gfp
2024-09-14 15:57:00 +08:00
{
2024-09-19 15:59:53 +08:00
struct ElimResult
{
size_t rank;
vector<size_t> pivot;
vector<size_t> swap_row;
};
2024-09-14 16:15:13 +08:00
class MatGFP
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
public:
enum MatType
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
root,
window,
moved,
};
// 只能构造root矩阵
MatGFP(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
{
width = ncols;
rowstride = ((width - 1) / 8 + 1) * 8; // 以32字节8*32bit为单位对齐
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(gfp_t)));
CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(gfp_t)));
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
// 只能以gfp_t为单位建立window矩阵
MatGFP(const MatGFP &src, size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) : nrows(end_rj - begin_ri), ncols(end_wj - begin_wi), width(end_wj - begin_wi), rowstride(src.rowstride), type(window), data(src.at_base(begin_ri, begin_wi))
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
assert(begin_ri < end_rj && end_rj <= src.nrows && begin_wi < end_wj && end_wj <= src.width);
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
// 只能拷贝构造root矩阵
MatGFP(const MatGFP &m) : MatGFP(m.nrows, m.ncols)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(gfp_t), m.data, m.rowstride * sizeof(gfp_t), m.width * sizeof(gfp_t), nrows, cudaMemcpyDefault));
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
MatGFP(MatGFP &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
m.type = moved;
m.data = nullptr;
}
MatGFP &operator=(const MatGFP &m)
{
if (this == &m)
{
return *this;
}
assert(nrows == m.nrows && ncols == m.ncols);
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(gfp_t), m.data, m.rowstride * sizeof(gfp_t), m.width * sizeof(gfp_t), nrows, cudaMemcpyDefault));
return *this;
}
MatGFP &operator=(MatGFP &&m) noexcept
{
if (this == &m)
{
return *this;
}
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
nrows = m.nrows;
ncols = m.ncols;
width = m.width;
rowstride = m.rowstride;
type = m.type;
data = m.data;
m.type = moved;
m.data = nullptr;
return *this;
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
~MatGFP()
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
if (type == root)
2024-09-14 15:57:00 +08:00
{
2024-09-19 15:59:53 +08:00
// cout << nrows << " " << ncols << endl;
// cout << *data << endl;
2024-09-14 16:15:13 +08:00
CUDA_CHECK(cudaFree(data));
2024-09-14 15:57:00 +08:00
}
}
2024-09-14 16:15:13 +08:00
inline gfp_t *at_base(size_t r, size_t w) const
{
return data + r * rowstride + w;
}
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
void randomize(uint_fast32_t seed)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
assert(type == root);
static default_random_engine e(seed);
static uniform_int_distribution<gfp_t> d;
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = d(e) % gfprime;
}
}
2024-09-14 15:57:00 +08:00
}
2024-09-19 15:59:53 +08:00
void deepcopy(const MatGFP &src, size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj)
{
assert(end_rj - begin_ri <= nrows && end_wj - begin_wi <= width);
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(gfp_t), src.at_base(begin_ri, begin_wi), src.rowstride * sizeof(gfp_t), (end_wj - begin_wi) * sizeof(gfp_t), end_rj - begin_ri, cudaMemcpyDefault));
}
2024-09-14 16:15:13 +08:00
// 生成随机最简化行阶梯矩阵 前rank_col中选择nrows个主元列
void randomize(size_t rank_col, uint_fast32_t seed)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
assert(nrows <= rank_col && rank_col <= ncols);
randomize(seed);
vector<size_t> pivot(rank_col);
iota(pivot.begin(), pivot.end(), 0);
random_shuffle(pivot.begin(), pivot.end());
pivot.resize(nrows);
sort(pivot.begin(), pivot.end());
vector<bool> pivotmask(width, true);
for (size_t r = 0; r < nrows; r++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
pivotmask[pivot[r]] = false;
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
for (size_t r = 0; r < nrows; r++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t w = 0; w < pivot[r]; w++)
2024-09-14 15:57:00 +08:00
{
*at_base(r, w) = base_zero;
}
2024-09-14 16:15:13 +08:00
*at_base(r, pivot[r]) = base_one;
for (size_t w = pivot[r] + 1; w < rank_col; w++)
{
if (!pivotmask[w])
{
*at_base(r, w) = base_zero;
}
}
2024-09-14 15:57:00 +08:00
}
}
2024-09-14 16:15:13 +08:00
bool operator==(const MatGFP &m) const
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
if (nrows != m.nrows || ncols != m.ncols)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
return false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
if (*at_base(r, w) != *m.at_base(r, w))
{
return false;
}
2024-09-14 15:57:00 +08:00
}
}
2024-09-14 16:15:13 +08:00
return true;
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
bool operator==(const gfp_t base) const
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t r = 0; r < nrows; r++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t w = 0; w < width; w++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
if (*at_base(r, w) != base)
{
return false;
}
2024-09-14 15:57:00 +08:00
}
}
2024-09-14 16:15:13 +08:00
return true;
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
void operator+=(const MatGFP &m)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
assert(nrows == m.nrows && ncols == m.ncols);
for (size_t r = 0; r < nrows; r++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = (*at_base(r, w) + *m.at_base(r, w)) % gfprime;
}
2024-09-14 15:57:00 +08:00
}
}
2024-09-14 16:15:13 +08:00
MatGFP operator+(const MatGFP &m) const
{
MatGFP temp(*this);
temp += m;
return temp;
}
// a(m*k)*b(k,n) 其中k不超过65536
void cpu_addmul(const MatGFP &a, const MatGFP &b)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
assert(a.ncols == b.nrows && a.nrows == nrows && b.ncols == ncols);
assert(a.ncols <= 65536);
for (size_t r = 0; r < nrows; r++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t w = 0; w < width; w++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t i = 0; i < a.ncols; i++)
{
*at_base(r, w) += (*a.at_base(r, i) * *b.at_base(i, w)) % gfprime;
}
*at_base(r, w) %= gfprime;
2024-09-14 15:57:00 +08:00
}
}
}
2024-09-14 16:15:13 +08:00
void gpu_mul(const MatGFP &a, const MatGFP &b);
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
MatGFP operator*(const MatGFP &m) const
{
MatGFP temp(nrows, m.width);
temp.gpu_mul(*this, m);
return temp;
}
2024-09-14 15:57:00 +08:00
2024-09-19 15:59:53 +08:00
void cpu_swap_row(size_t r1, size_t r2)
{
if (r1 == r2)
{
return;
}
gfp_t *p1 = at_base(r1, 0);
gfp_t *p2 = at_base(r2, 0);
for (size_t i = 0; i < width; i++)
{
gfp_t temp = p1[i];
p1[i] = p2[i];
p2[i] = temp;
}
}
void cpu_mul_row(size_t r, gfp_t mul, size_t begin_w = 0, size_t end_w = 0)
{
for (size_t i = begin_w; i < (end_w ? end_w : width); i++)
{
*at_base(r, i) = (*at_base(r, i) * mul) % gfprime;
}
}
void cpu_addmul_row(size_t r1, size_t r2, gfp_t mul, size_t begin_w = 0, size_t end_w = 0)
{
for (size_t i = begin_w; i < (end_w ? end_w : width); i++)
{
*at_base(r1, i) = (*at_base(r1, i) + *at_base(r2, i) * mul) % gfprime;
}
}
ElimResult gpu_elim();
ElimResult cpu_elim();
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
friend ostream &operator<<(ostream &out, const MatGFP &m);
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
size_t nrows, ncols, width;
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
private:
MatGFP() : nrows(0), ncols(0), width(0), rowstride(0), type(moved), data(nullptr) {}
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
size_t rowstride;
MatType type;
gfp_t *data;
};
2024-09-14 15:57:00 +08:00
2024-09-14 16:15:13 +08:00
ostream &operator<<(ostream &out, const MatGFP &m)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t r = 0; r < m.nrows; r++)
2024-09-14 15:57:00 +08:00
{
2024-09-14 16:15:13 +08:00
for (size_t w = 0; w < m.width; w++)
{
printf("%05u ", *m.at_base(r, w));
}
printf("\n");
2024-09-14 15:57:00 +08:00
}
2024-09-14 16:15:13 +08:00
return out;
2024-09-14 15:57:00 +08:00
}
}
#endif