cuElim/include/gfp/gfp_mat.cuh

241 lines
7.2 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef GFP_MAT_CUH
#define GFP_MAT_CUH
#include "gfp_header.cuh"
#include <random>
#include <vector>
#include <algorithm>
namespace gfp
{
class MatGFP
{
public:
enum MatType
{
root,
window,
moved,
};
// 只能构造root矩阵
MatGFP(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root)
{
width = ncols;
rowstride = ((width - 1) / 8 + 1) * 8; // 以32字节8*32bit为单位对齐
CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(gfp_t)));
CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(gfp_t)));
}
// 只能以gfp_t为单位建立window矩阵
MatGFP(const MatGFP &src, size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) : nrows(end_rj - begin_ri), ncols(end_wj - begin_wi), width(end_wj - begin_wi), rowstride(src.rowstride), type(window), data(src.at_base(begin_ri, begin_wi))
{
assert(begin_ri < end_rj && end_rj <= src.nrows && begin_wi < end_wj && end_wj <= src.width);
}
// 只能拷贝构造root矩阵
MatGFP(const MatGFP &m) : MatGFP(m.nrows, m.ncols)
{
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(gfp_t), m.data, m.rowstride * sizeof(gfp_t), m.width * sizeof(gfp_t), nrows, cudaMemcpyDefault));
}
MatGFP(MatGFP &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data)
{
m.type = moved;
m.data = nullptr;
}
MatGFP &operator=(const MatGFP &m)
{
if (this == &m)
{
return *this;
}
assert(nrows == m.nrows && ncols == m.ncols);
CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(gfp_t), m.data, m.rowstride * sizeof(gfp_t), m.width * sizeof(gfp_t), nrows, cudaMemcpyDefault));
return *this;
}
MatGFP &operator=(MatGFP &&m) noexcept
{
if (this == &m)
{
return *this;
}
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
nrows = m.nrows;
ncols = m.ncols;
width = m.width;
rowstride = m.rowstride;
type = m.type;
data = m.data;
m.type = moved;
m.data = nullptr;
return *this;
}
~MatGFP()
{
if (type == root)
{
CUDA_CHECK(cudaFree(data));
}
}
inline gfp_t *at_base(size_t r, size_t w) const
{
return data + r * rowstride + w;
}
void randomize(uint_fast32_t seed)
{
assert(type == root);
static default_random_engine e(seed);
static uniform_int_distribution<gfp_t> d;
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = d(e) % gfprime;
}
}
}
// 生成随机最简化行阶梯矩阵 前rank_col中选择nrows个主元列
void randomize(size_t rank_col, uint_fast32_t seed)
{
assert(nrows <= rank_col && rank_col <= ncols);
randomize(seed);
vector<size_t> pivot(rank_col);
iota(pivot.begin(), pivot.end(), 0);
random_shuffle(pivot.begin(), pivot.end());
pivot.resize(nrows);
sort(pivot.begin(), pivot.end());
vector<bool> pivotmask(width, true);
for (size_t r = 0; r < nrows; r++)
{
pivotmask[pivot[r]] = false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < pivot[r]; w++)
{
*at_base(r, w) = base_zero;
}
*at_base(r, pivot[r]) = base_one;
for (size_t w = pivot[r] + 1; w < rank_col; w++)
{
if (!pivotmask[w])
{
*at_base(r, w) = base_zero;
}
}
}
}
bool operator==(const MatGFP &m) const
{
if (nrows != m.nrows || ncols != m.ncols)
{
return false;
}
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != *m.at_base(r, w))
{
return false;
}
}
}
return true;
}
bool operator==(const gfp_t base) const
{
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
if (*at_base(r, w) != base)
{
return false;
}
}
}
return true;
}
void operator+=(const MatGFP &m)
{
assert(nrows == m.nrows && ncols == m.ncols);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
*at_base(r, w) = (*at_base(r, w) + *m.at_base(r, w)) % gfprime;
}
}
}
MatGFP operator+(const MatGFP &m) const
{
MatGFP temp(*this);
temp += m;
return temp;
}
// a(m*k)*b(k,n) 其中k不超过65536
void cpu_addmul(const MatGFP &a, const MatGFP &b)
{
assert(a.ncols == b.nrows && a.nrows == nrows && b.ncols == ncols);
assert(a.ncols <= 65536);
for (size_t r = 0; r < nrows; r++)
{
for (size_t w = 0; w < width; w++)
{
for (size_t i = 0; i < a.ncols; i++)
{
*at_base(r, w) += (*a.at_base(r, i) * *b.at_base(i, w)) % gfprime;
}
*at_base(r, w) %= gfprime;
}
}
}
void gpu_mul(const MatGFP &a, const MatGFP &b);
MatGFP operator*(const MatGFP &m) const
{
MatGFP temp(nrows, m.width);
temp.gpu_mul(*this, m);
return temp;
}
// void cpu_swap_row(size_t r1, size_t r2);
// ElimResult gpu_elim();
friend ostream &operator<<(ostream &out, const MatGFP &m);
size_t nrows, ncols, width;
private:
MatGFP() : nrows(0), ncols(0), width(0), rowstride(0), type(moved), data(nullptr) {}
size_t rowstride;
MatType type;
gfp_t *data;
};
ostream &operator<<(ostream &out, const MatGFP &m)
{
for (size_t r = 0; r < m.nrows; r++)
{
for (size_t w = 0; w < m.width; w++)
{
printf("%05u ", *m.at_base(r, w));
}
printf("\n");
}
return out;
}
}
#endif