Created
September 29, 2018 03:31
-
-
Save tkokof/ebaa026e126faa7e4ecc69293b5096f3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// desc simple implementation of sparse matrix | |
// maintainer hugoyu | |
#ifndef __sparse_matrix_h__ | |
#define __sparse_matrix_h__ | |
#include <cassert> | |
#include <string> | |
#include "common.h" | |
template<typename T, typename Container> | |
class sparse_matrix | |
{ | |
public: | |
constexpr sparse_matrix(uint32 row, uint32 col) | |
{ | |
assert(is_valid_size(row, col)); | |
m_row = row; | |
m_col = col; | |
} | |
constexpr sparse_matrix(const sparse_matrix& other) | |
{ | |
m_row = other.m_row; | |
m_col = other.m_col; | |
m_element_buffer = other.m_element_buffer; | |
} | |
constexpr sparse_matrix(sparse_matrix&& other) | |
{ | |
m_row = other.m_row; | |
m_col = other.m_col; | |
m_element_buffer.swap(other.m_element_buffer); | |
} | |
constexpr sparse_matrix& operator =(const sparse_matrix& other) | |
{ | |
if (this != &other) | |
{ | |
assert(is_valid_size(row, col)); | |
m_row = row; | |
m_col = col; | |
m_element_buffer = other.m_element_buffer; | |
} | |
} | |
constexpr sparse_matrix& operator =(sparse_matrix&& other) | |
{ | |
if (this != &other) | |
{ | |
m_row = other.m_row; | |
m_col = other.m_col; | |
m_element_buffer.swap(other.m_element_buffer); | |
} | |
} | |
constexpr uint32 row() const | |
{ | |
return m_row; | |
} | |
constexpr uint32 col() const | |
{ | |
return m_col; | |
} | |
constexpr const T& operator ()(uint32 row, uint32 col) const | |
{ | |
assert(is_valid_index(row, col)); | |
auto iter = m_element_buffer.find(gen_element_key(row, col)); | |
if (iter != m_element_buffer.end()) | |
{ | |
return iter->second; | |
} | |
return T(); | |
} | |
constexpr T& operator ()(uint32 row, uint32 col) | |
{ | |
assert(is_valid_index(row, col)); | |
return m_element_buffer[gen_element_key(row, col)]; | |
} | |
sparse_matrix<T, Container> operator *(const T& right) const | |
{ | |
sparse_matrix<T> temp(m_row, m_col); | |
for (auto& element : m_element_buffer) | |
{ | |
/* | |
uint32 row = 0; | |
uint32 col = 0; | |
extract_element_key(iter.first, row, col); | |
temp(row, col) = iter.second * right; | |
*/ | |
temp.m_element_buffer[element.first] = element.second * right; | |
} | |
return temp; | |
} | |
sparse_matrix<T, Container>& operator *=(const T& right) | |
{ | |
for (auto& element : m_element_buffer) | |
{ | |
element.second *= right; | |
} | |
return *this; | |
} | |
sparse_matrix<T, Container> operator +(const sparse_matrix<T, Container>& right) const | |
{ | |
assert(row() == right.row() && col() == right.col()); | |
sparse_matrix<T> temp(m_row, m_col); | |
/* | |
for (auto& element : m_element_buffer) | |
{ | |
//uint32 row = 0; | |
//uint32 col = 0; | |
//extract_element_key(element.first, row, col); | |
//temp(row, col) = element.second; | |
temp.m_element_buffer[element.first] = element.second; | |
} | |
*/ | |
temp.m_element_buffer = m_element_buffer; | |
for (auto& element : right.m_element_buffer) | |
{ | |
/* | |
uint32 row = 0; | |
uint32 col = 0; | |
extract_element_key(iter.first, row, col); | |
temp(row, col) = (*this)(row, col) + iter.second; | |
*/ | |
/* | |
auto val = T(); | |
auto left_iter = m_element_buffer.find(element.first); | |
if (left_iter != m_element_buffer.end()) | |
{ | |
val = left_iter->second; | |
} | |
temp.m_element_buffer[element.first] = val + element.second; | |
*/ | |
temp.m_element_buffer[element.first] += element.second; | |
} | |
return temp; | |
} | |
sparse_matrix<T, Container>& operator +=(const sparse_matrix<T, Container>& right) | |
{ | |
assert(row() == right.row() && col() == right.col()); | |
for (auto& element : right.m_element_buffer) | |
{ | |
/* | |
uint32 row = 0; | |
uint32 col = 0; | |
extract_element_key(iter.first, row, col); | |
(*this)(row, col) += iter.second; | |
*/ | |
/* | |
auto val = T(); | |
auto left_iter = m_element_buffer.find(element.first); | |
if (left_iter != m_element_buffer.end()) | |
{ | |
val = left_iter->second; | |
} | |
m_element_buffer[element.first] = val + element.second; | |
*/ | |
m_element_buffer[element.first] += element.second; | |
} | |
return *this; | |
} | |
sparse_matrix<T, Container> operator -(const sparse_matrix<T, Container>& right) const | |
{ | |
assert(row() == right.row() && col() == right.col()); | |
sparse_matrix<T> temp(m_row, m_col); | |
/* | |
for (auto& element : m_element_buffer) | |
{ | |
uint32 row = 0; | |
uint32 col = 0; | |
extract_element_key(element.first, row, col); | |
temp(row, col) = element.second; | |
} | |
*/ | |
temp.m_element_buffer = m_element_buffer; | |
for (auto& element : right.m_element_buffer) | |
{ | |
/* | |
uint32 row = 0; | |
uint32 col = 0; | |
extract_element_key(iter.first, row, col); | |
temp(row, col) = (*this)(row, col) - iter.second; | |
*/ | |
/* | |
auto val = T(); | |
auto left_iter = m_element_buffer.find(element.first); | |
if (left_iter != m_element_buffer.end()) | |
{ | |
val = left_iter->second; | |
} | |
temp.m_element_buffer[element.first] = val - element.second; | |
*/ | |
temp.m_element_buffer[element.first] -= element.second; | |
} | |
return temp; | |
} | |
sparse_matrix<T, Container> operator -=(const sparse_matrix<T, Container>& right) | |
{ | |
assert(row() == right.m_row && col() == right.m_col); | |
for (auto& element : right.m_element_buffer) | |
{ | |
/* | |
uint32 row = 0; | |
uint32 col = 0; | |
extract_element_key(iter.first, row, col); | |
(*this)(row, col) -= iter.second; | |
*/ | |
/* | |
auto val = T(); | |
auto left_iter = m_element_buffer.find(element.first); | |
if (left_iter != m_element_buffer.end()) | |
{ | |
val = left_iter->second; | |
} | |
m_element_buffer[element.first] = val - element.second; | |
*/ | |
m_element_buffer[element.first] -= element.second; | |
} | |
return *this; | |
} | |
sparse_matrix<T, Container> operator *(const sparse_matrix<T, Container>& right) const | |
{ | |
assert(col() == right.m_row); | |
sparse_matrix<T> temp(m_row, m_col); | |
auto row = m_row; | |
auto col = right.m_col; | |
auto inn = m_col; | |
for (uint32 i = 0; i < row; ++i) | |
{ | |
for (uint32 j = 0; j < col; ++j) | |
{ | |
auto val = T(); | |
for (uint32 k = 0; k < inn; ++k) | |
{ | |
val += (*this)(i, k) * right(k, j); | |
} | |
temp(i, j) = val; | |
} | |
} | |
return temp; | |
} | |
constexpr sparse_matrix<T, Container>& operator *=(const sparse_matrix<T, Container>& right) | |
{ | |
assert(col() == right.row()); | |
// NOTE if right is *this, we can optimize space ? | |
auto temp = (*this) * right; | |
m_row = temp.m_row; | |
m_col = temp.m_col; | |
// simple swap, or we can just use move constructor | |
m_element_buffer.swap(temp.m_element_buffer); | |
return *this; | |
} | |
std::string to_string() | |
{ | |
std::string string_buffer; | |
string_buffer.reserve(m_row * m_col * 8); | |
for (uint32 i = 0; i < m_row; ++i) | |
{ | |
for (uint32 j = 0; j < m_col; ++j) | |
{ | |
//string_buffer.append(this->operator()(i, j)); | |
string_buffer.append(std::to_string((*this)(i, j))); | |
string_buffer.append(", "); | |
} | |
string_buffer.append("\n"); | |
} | |
return string_buffer; | |
} | |
private: | |
constexpr bool is_valid_size(uint32 row, uint32 col) const noexcept | |
{ | |
return row > 0 && col > 0; | |
} | |
constexpr bool is_valid_index(uint32 row, uint32 col) const noexcept | |
{ | |
return row < m_row && col < m_col; | |
} | |
constexpr static uint64 gen_element_key(uint32 row, uint32 col) noexcept | |
{ | |
return ((uint64)row << 32) | ((uint64)col); | |
} | |
constexpr static void extract_element_key(uint64 key, uint32& row, uint32& col) noexcept | |
{ | |
row = (uint32)((key >> 32) & 0xFFFFFFFF); | |
col = (uint32)(key & 0xFFFFFFFF); | |
} | |
private: | |
uint32 m_row{ 0 }; | |
uint32 m_col{ 0 }; | |
Container m_element_buffer; | |
}; | |
template<typename T, typename Container> | |
constexpr sparse_matrix<T, Container> operator *(const T& left, const sparse_matrix<T, Container>& right) | |
{ | |
return right * left; | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment