Skip to content

Instantly share code, notes, and snippets.

@JellyWX
Last active February 7, 2019 14:27
Show Gist options
  • Save JellyWX/8d92011259ae304cec479372842d0547 to your computer and use it in GitHub Desktop.
Save JellyWX/8d92011259ae304cec479372842d0547 to your computer and use it in GitHub Desktop.
use std::ops::Mul;
use std::fmt::{Display, Formatter, Result as FmtResult};
#[derive(Debug)]
enum MulError {
BadDimensions,
}
struct Matrix {
contents: Vec<f64>,
rows: usize,
columns: usize,
}
trait DotProduct<T> {
fn dot_product(&self, vec: &[T]) -> Result<T, MulError>;
}
impl<T> DotProduct<T> for Vec<T>
where T: std::iter::Sum<<T as std::ops::Mul>::Output> + std::ops::Mul + Copy {
fn dot_product(&self, vec: &[T]) -> Result<T, MulError> {
if self.len() != vec.len() {
return Err(MulError::BadDimensions)
}
let out = self.iter().zip(vec).map(|(a, b)| *a * *b).sum();
Ok(out)
}
}
impl Matrix {
fn new(contents: Vec<f64>, row_size: usize) -> Matrix {
let rows = contents.len() / row_size;
return Matrix { contents: contents, rows: rows, columns: row_size };
}
fn from_scalar(scalar: f64, dimension: usize) -> Matrix {
let mut end = vec![];
for i in 0..dimension {
for j in 0..dimension {
if j != i {
end.push(0f64);
}
else {
end.push(scalar);
}
}
}
return Matrix { contents: end, rows: dimension, columns: dimension }
}
fn transpose(&self) -> Matrix {
let mut output = vec![];
let r = self.rows;
let c = self.columns;
for column in 0..c {
for row in 0..r {
output.push(self.contents[(row * c) + column])
}
}
return Matrix::new(output, self.rows);
}
}
impl Display for Matrix {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
let mut out = "___".repeat(self.columns);
for a in 0..self.rows {
out += "\n|";
for b in 0..self.columns {
out += &format!(" {} ", self.contents[(a * self.columns) + b]);
}
out += "\n|"
}
write!(f, "{}", out)
}
}
impl Mul for Matrix {
type Output = Result<Self, MulError>;
fn mul(self, rhs: Matrix) -> Result<Matrix, MulError> {
let m = rhs.transpose();
let mut out = vec![];
for i in 0..self.rows {
let v1 = &self.contents[(i * self.columns)..((i + 1) * self.columns)].to_vec();
for j in 0..m.rows {
let v2 = &m.contents[(j * m.columns)..((j + 1) * m.columns)];
out.push(v1.dot_product(&v2)?);
}
}
return Ok(Matrix::new(out, self.columns));
}
}
fn main() {
let mat = Matrix::new(vec![
1.0, 0.0, 0.0,
1.0, 0.0, 0.0,
1.0, 0.0, 0.0,
1.0, 0.0, 0.0], 3);
let mat2 = Matrix::from_scalar(5.5, 2);
let mat3 = Matrix::new(vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0
], 4);
println!("{}", mat);
println!("{}", mat3);
println!("{}", mat.transpose());
println!("{}", (mat2 * mat3).unwrap())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment