Created
September 30, 2017 02:02
-
-
Save khvorov/b48bd69b2f92e3802790c971c5b0470e to your computer and use it in GitHub Desktop.
Matrix addition & multiplication at compile time using constexpr
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <array> | |
#include <type_traits> | |
// TODO: using C++17 (-std=c++1z) as std::array functions are not constexpr in C++14 | |
// g++ -Wall -pedantic -Wno-missing-braces -std=c++1z -O0 | |
template <typename T, std::size_t R, std::size_t C> | |
struct matrix | |
{ | |
std::array<std::array<T, C>, R> data_ {}; | |
constexpr T & operator()(const std::size_t i, const std::size_t j) | |
{ | |
return data_.at(i).at(j); | |
} | |
constexpr const T & operator()(const std::size_t i, const std::size_t j) const | |
{ | |
return data_.at(i).at(j); | |
} | |
template <typename U> | |
constexpr bool operator==(const matrix<U, R, C> & rhs) const | |
{ | |
for (std::size_t i = 0; i < R; ++i) | |
for (std::size_t j = 0; j < C; ++j) | |
if ((*this)(i, j) != rhs(i, j)) | |
return false; | |
return true; | |
} | |
template <typename U> | |
constexpr bool operator!=(const matrix<U, R, C> & rhs) const | |
{ | |
return !(*this == rhs); | |
} | |
// matrix addition | |
constexpr auto operator+(const matrix<T, R, C> & b) const | |
{ | |
matrix<T, R, C> result {}; | |
for (std::size_t i = 0; i < R; ++i) | |
for (std::size_t j = 0; j < C; ++j) | |
result(i, j) = (*this)(i, j) + b(i, j); | |
return result; | |
} | |
// matrix multiplication | |
template <typename U, std::size_t R2, std::size_t C2> | |
constexpr auto operator*(const matrix<U, R2, C2> & b) const -> std::enable_if_t<C == R2, matrix<T, R, C2>> | |
{ | |
matrix<T, R, C2> result {}; | |
for (std::size_t i = 0; i < R; ++i) | |
{ | |
for (std::size_t j = 0; j < C2; ++j) | |
{ | |
T v { 0 }; | |
for (std::size_t k = 0; k < R2; ++k) | |
v += (*this)(i, k) * b(k, j); | |
result(i, j) = v; | |
} | |
} | |
return result; | |
} | |
}; | |
template <std::size_t R, std::size_t C> | |
using imatrix = matrix<int, R, C>; | |
constexpr void test_multiply() | |
{ | |
constexpr imatrix<2, 2> a {0, -1, 1, 0}; | |
constexpr imatrix<2, 2> b {1, 2, 3, 4}; | |
// constexpr imatrix<3, 3> e; | |
// static_assert(a == e, "equality test"); // --> size mismatch, will not compile | |
// test multiplication | |
constexpr auto ab = a * b; | |
constexpr imatrix<2, 2> aXb {-3, -4, 1, 2}; | |
static_assert(ab == aXb, "multiply failed"); | |
// test addition | |
constexpr imatrix<2, 2> c {2, -1, 1, 2}; | |
constexpr auto abc = ab + c; | |
constexpr imatrix<2, 2> aXbPc {-1, -5, 2, 4}; | |
static_assert(abc == aXbPc, "addition failed"); | |
} | |
int main() | |
{ | |
test_multiply(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment