Skip to content

Instantly share code, notes, and snippets.

@erikaderstedt
Created December 3, 2020 05:32

Revisions

  1. erikaderstedt created this gist Dec 3, 2020.
    71 changes: 71 additions & 0 deletions grid.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,71 @@
    use std::marker::Sized;
    use std::ops::{Index,IndexMut};
    use std::cmp::{PartialEq,Eq};

    pub trait GridElement: Sized + PartialEq + Eq + Clone {
    fn from_char(c: &char) -> Option<Self>;
    }

    pub struct Grid<T: GridElement> {
    pub rows: usize,
    pub cols: usize,
    locations: Vec<T>,
    }

    type Row = usize;
    type Column = usize;

    #[derive(Clone,Debug)]
    pub struct Position { pub row: Row, pub column: Column }

    impl Position {
    pub fn above(&self) -> Position { Position { row: self.row - 1, column: self.column }}
    pub fn below(&self) -> Position { Position { row: self.row + 1, column: self.column }}
    pub fn left(&self) -> Position { Position { row: self.row, column: self.column - 1 }}
    pub fn right(&self) -> Position { Position { row: self.row, column: self.column + 1 }}

    pub fn origin() -> Position { Position { row: 0usize, column: 0 } }
    }

    impl<T: GridElement> Grid<T> {

    pub fn load(lines: &[String]) -> Grid<T> {
    let rows = lines.len();
    let locations: Vec<T> = lines
    .iter()
    .filter(|line| line.chars().count() > 2)
    .flat_map(|line| line.chars().filter_map(|c| T::from_char(&c)))
    .collect();
    let cols = locations.len() / rows;
    assert!(rows * cols == locations.len(), "Grid is not rectangular, perhaps some items won't parse");
    Grid { rows, cols, locations }
    }

    // Iterate over grid elements of a certain type.
    // Iterate over all grid points together with position

    //

    pub fn neighbors(&self, position: &Position) -> Vec<&T> {
    let mut n = Vec::new();
    if position.row > 0 { n.push(&self[&position.above()]) }
    if position.row < self.rows - 1 { n.push(&self[&position.below()]) }
    if position.column > 0 { n.push(&self[&position.left()]) }
    if position.column < self.cols - 1 { n.push(&self[&position.right()]) }
    n
    }
    }

    impl<T: GridElement> Index<&Position> for Grid<T> {
    type Output = T;

    fn index(&self, index: &Position) -> &Self::Output {
    &self.locations[index.row * self.cols + index.column]
    }
    }

    impl<T: GridElement> IndexMut<&Position> for Grid<T> {
    fn index_mut(&mut self, index: &Position) -> &mut Self::Output {
    &mut self.locations[index.row * self.cols + index.column]
    }
    }