Last active
February 7, 2019 14:27
-
-
Save JellyWX/8d92011259ae304cec479372842d0547 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::ops::Mul; | |
use std::fmt::{Display, Formatter, Result as FmtResult}; | |
#[derive(Debug)] | |
enum MulError { | |
BadDimensions, | |
} | |
struct Matrix { | |
contents: Vec<f64>, | |
rows: usize, | |
columns: usize, | |
} | |
trait DotProduct<T> { | |
fn dot_product(&self, vec: &[T]) -> Result<T, MulError>; | |
} | |
impl<T> DotProduct<T> for Vec<T> | |
where T: std::iter::Sum<<T as std::ops::Mul>::Output> + std::ops::Mul + Copy { | |
fn dot_product(&self, vec: &[T]) -> Result<T, MulError> { | |
if self.len() != vec.len() { | |
return Err(MulError::BadDimensions) | |
} | |
let out = self.iter().zip(vec).map(|(a, b)| *a * *b).sum(); | |
Ok(out) | |
} | |
} | |
impl Matrix { | |
fn new(contents: Vec<f64>, row_size: usize) -> Matrix { | |
let rows = contents.len() / row_size; | |
return Matrix { contents: contents, rows: rows, columns: row_size }; | |
} | |
fn from_scalar(scalar: f64, dimension: usize) -> Matrix { | |
let mut end = vec![]; | |
for i in 0..dimension { | |
for j in 0..dimension { | |
if j != i { | |
end.push(0f64); | |
} | |
else { | |
end.push(scalar); | |
} | |
} | |
} | |
return Matrix { contents: end, rows: dimension, columns: dimension } | |
} | |
fn transpose(&self) -> Matrix { | |
let mut output = vec![]; | |
let r = self.rows; | |
let c = self.columns; | |
for column in 0..c { | |
for row in 0..r { | |
output.push(self.contents[(row * c) + column]) | |
} | |
} | |
return Matrix::new(output, self.rows); | |
} | |
} | |
impl Display for Matrix { | |
fn fmt(&self, f: &mut Formatter) -> FmtResult { | |
let mut out = "___".repeat(self.columns); | |
for a in 0..self.rows { | |
out += "\n|"; | |
for b in 0..self.columns { | |
out += &format!(" {} ", self.contents[(a * self.columns) + b]); | |
} | |
out += "\n|" | |
} | |
write!(f, "{}", out) | |
} | |
} | |
impl Mul for Matrix { | |
type Output = Result<Self, MulError>; | |
fn mul(self, rhs: Matrix) -> Result<Matrix, MulError> { | |
let m = rhs.transpose(); | |
let mut out = vec![]; | |
for i in 0..self.rows { | |
let v1 = &self.contents[(i * self.columns)..((i + 1) * self.columns)].to_vec(); | |
for j in 0..m.rows { | |
let v2 = &m.contents[(j * m.columns)..((j + 1) * m.columns)]; | |
out.push(v1.dot_product(&v2)?); | |
} | |
} | |
return Ok(Matrix::new(out, self.columns)); | |
} | |
} | |
fn main() { | |
let mat = Matrix::new(vec![ | |
1.0, 0.0, 0.0, | |
1.0, 0.0, 0.0, | |
1.0, 0.0, 0.0, | |
1.0, 0.0, 0.0], 3); | |
let mat2 = Matrix::from_scalar(5.5, 2); | |
let mat3 = Matrix::new(vec![ | |
1.0, 2.0, 3.0, 4.0, | |
5.0, 6.0, 7.0, 8.0 | |
], 4); | |
println!("{}", mat); | |
println!("{}", mat3); | |
println!("{}", mat.transpose()); | |
println!("{}", (mat2 * mat3).unwrap()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment