#ifndef GF256_MAT_CUH #define GF256_MAT_CUH #include "gf256_header.cuh" #include #include #include namespace gf256 { struct ElimResult { size_t rank; vector pivot; vector swap_row; }; class MatGF256 { public: enum MatType { root, window, moved, }; // 只能构造root矩阵 MatGF256(size_t nrows, size_t ncols) : nrows(nrows), ncols(ncols), type(root) { width = (ncols - 1) / gf256_num + 1; rowstride = ((width - 1) / 4 + 1) * 4; // 以32字节(4*64bit)为单位对齐 CUDA_CHECK(cudaMallocManaged((void **)&data, nrows * rowstride * sizeof(base_t))); CUDA_CHECK(cudaMemset(data, 0, nrows * rowstride * sizeof(base_t))); } // 只能以base_t为单位建立window矩阵 MatGF256(const MatGF256 &src, size_t begin_ri, size_t begin_wi, size_t end_rj, size_t end_wj) : nrows(end_rj - begin_ri), ncols((end_wj == src.width ? src.ncols : end_wj * gf256_num) - begin_wi * gf256_num), width(end_wj - begin_wi), rowstride(src.rowstride), type(window), data(src.at_base(begin_ri, begin_wi)) { assert(begin_ri < end_rj && end_rj <= src.nrows && begin_wi < end_wj && end_wj <= src.width); } // 只能拷贝构造root矩阵 MatGF256(const MatGF256 &m) : MatGF256(m.nrows, m.ncols) { CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault)); } MatGF256(MatGF256 &&m) noexcept : nrows(m.nrows), ncols(m.ncols), width(m.width), rowstride(m.rowstride), type(m.type), data(m.data) { m.type = moved; m.data = nullptr; } MatGF256 &operator=(const MatGF256 &m) { if (this == &m) { return *this; } assert(nrows == m.nrows && ncols == m.ncols); CUDA_CHECK(cudaMemcpy2D(data, rowstride * sizeof(base_t), m.data, m.rowstride * sizeof(base_t), m.width * sizeof(base_t), nrows, cudaMemcpyDefault)); return *this; } MatGF256 &operator=(MatGF256 &&m) noexcept { if (this == &m) { return *this; } if (type == root) { CUDA_CHECK(cudaFree(data)); } nrows = m.nrows; ncols = m.ncols; width = m.width; rowstride = m.rowstride; type = m.type; data = m.data; m.type = moved; m.data = nullptr; return *this; } ~MatGF256() { if (type == root) { CUDA_CHECK(cudaFree(data)); } } inline base_t *at_base(size_t r, size_t w) const { return data + r * rowstride + w; } void randomize(uint_fast32_t seed) { assert(type == root); static default_random_engine e(seed); static uniform_int_distribution d; base_t lastmask = base_fullmask >> (width * base_len - ncols * gf256_len); for (size_t r = 0; r < nrows; r++) { for (size_t w = 0; w < width; w++) { *at_base(r, w) = d(e); } *at_base(r, width - 1) &= lastmask; } } // 生成随机最简化行阶梯矩阵 前rank_col中选择nrows个主元列 void randomize(size_t rank_col, uint_fast32_t seed) { assert(nrows <= rank_col && rank_col <= ncols); randomize(seed); vector pivot(rank_col); iota(pivot.begin(), pivot.end(), 0); random_shuffle(pivot.begin(), pivot.end()); pivot.resize(nrows); sort(pivot.begin(), pivot.end()); vector pivotmask(width, base_fullmask); for (size_t r = 0; r < nrows; r++) { del8(pivotmask[pivot[r] / gf256_num], pivot[r] % gf256_num); } for (size_t r = 0; r < nrows; r++) { for (size_t w = 0; w < pivot[r] / gf256_num; w++) { *at_base(r, w) = base_zero; } base_t *now = at_base(r, pivot[r] / gf256_num); *now = concat8(base_zero, pivot[r] % gf256_num + 1, *now & pivotmask[pivot[r] / gf256_num]); set8(*now, gf256_one, pivot[r] % gf256_num); for (size_t w = pivot[r] / gf256_num + 1; w < rank_col / gf256_num + 1; w++) { *at_base(r, w) &= pivotmask[w]; } } } bool operator==(const MatGF256 &m) const { if (nrows != m.nrows || ncols != m.ncols) { return false; } for (size_t r = 0; r < nrows; r++) { for (size_t w = 0; w < width; w++) { if (*at_base(r, w) != *m.at_base(r, w)) { return false; } } } return true; } bool operator==(const base_t base) const { for (size_t r = 0; r < nrows; r++) { for (size_t w = 0; w < width; w++) { if (*at_base(r, w) != base) { return false; } } } return true; } void operator^=(const MatGF256 &m) { assert(nrows == m.nrows && ncols == m.ncols); for (size_t r = 0; r < nrows; r++) { for (size_t w = 0; w < width; w++) { *at_base(r, w) ^= *m.at_base(r, w); } } } MatGF256 operator^(const MatGF256 &m) const { MatGF256 temp(*this); temp ^= m; return temp; } void gpu_addmul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf); friend MatGF256 gpu_mul(const MatGF256 &a, const MatGF256 &b, const GF256 &gf); ElimResult gpu_elim(const GF256 &gf); friend ostream &operator<<(ostream &out, const MatGF256 &m); size_t nrows, ncols, width; private: MatGF256() : nrows(0), ncols(0), width(0), rowstride(0), type(moved), data(nullptr) {} void cpu_swap_row(size_t r1, size_t r2) { if (r1 == r2) { return; } base_t *p1 = at_base(r1, 0); base_t *p2 = at_base(r2, 0); for (size_t i = 0; i < width; i++) { base_t temp = p1[i]; p1[i] = p2[i]; p2[i] = temp; } } size_t rowstride; MatType type; base_t *data; }; ostream &operator<<(ostream &out, const MatGF256 &m) { for (size_t r = 0; r < m.nrows; r++) { for (size_t w = 0; w < m.width; w++) { printf("%016lX ", rev8(*m.at_base(r, w))); } printf("\n"); } return out; } } #endif