Created
December 5, 2024 07:24
-
-
Save Rassibassi/ebc9fe70a983e0bf6d0171dc68d7c3f6 to your computer and use it in GitHub Desktop.
adventofcode 2024 day 4 - in fourier domain
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "from numpy.fft import fft2, fftshift\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "\n", | |
| "with open(\"input.txt\", \"r\") as f:\n", | |
| " text = f.read()\n", | |
| "lines = text.split(\"\\n\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[['X', 'X', 'M', 'M', 'A', 'M', 'X', 'X', 'X', 'M', 'A', 'S', 'M', 'X', 'M', 'A', 'S', 'M', 'S', 'M', 'S', 'M', 'M', 'A', 'M', 'M', 'M', 'X', 'X', 'X', 'M', 'M', 'A', 'M', 'X', 'X', 'M', 'A', 'S', 'A', 'M', 'X', 'S', 'M', 'A', 'M', 'S', 'A', 'M', 'X', 'X', 'X', 'X', 'M', 'M', 'S', 'X', 'S', 'X', 'M', 'X', 'M', 'M', 'X', 'M', 'X', 'M', 'X', 'A', 'M', 'S', 'S', 'S', 'S', 'X', 'M', 'A', 'S', 'X', 'S', 'X', 'M', 'M', 'M', 'S', 'A', 'A', 'M', 'X', 'M', 'M', 'M', 'X', 'A', 'M', 'M', 'M', 'M', 'S', 'X', 'M', 'A', 'X', 'M', 'S', 'M', 'S', 'S', 'M', 'S', 'M', 'X', 'M', 'A', 'S', 'X', 'S', 'S', 'M', 'X', 'S', 'M', 'A', 'M', 'M', 'S', 'M', 'S', 'X', 'S', 'X', 'M', 'A', 'M', 'S', 'A', 'M', 'X', 'S', 'M']]\n", | |
| "140\n", | |
| "140\n", | |
| "{'M', 'X', 'A', 'S'}\n", | |
| "n_chars=4, n_channels=100\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "data_char = []\n", | |
| "data_char_set = set()\n", | |
| "for line in lines:\n", | |
| " if line.strip():\n", | |
| " chars = list(line)\n", | |
| " data_char.append(chars)\n", | |
| " data_char_set.update(chars)\n", | |
| "\n", | |
| "print(data_char[-1:])\n", | |
| "print(len(data_char[0]))\n", | |
| "print(len(data_char))\n", | |
| "print(data_char_set)\n", | |
| "\n", | |
| "n_chars = len(data_char_set)\n", | |
| "n_channels = 100\n", | |
| "signal_std = 1.0\n", | |
| "noise_std = 0.1\n", | |
| "\n", | |
| "print(f\"{n_chars=}, {n_channels=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "patterns_data = np.random.normal(scale=signal_std, size=(n_chars, n_channels))\n", | |
| "patterns = np.vsplit(patterns_data, indices_or_sections=4)\n", | |
| "patterns_map = {char: patterns[ii][0] for ii, char in enumerate(data_char_set)}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "M\n", | |
| "[-1.00865786 1.59814271 0.68408232 -2.09696917 0.66296818] ...\n", | |
| "X\n", | |
| "[-0.11301961 1.19589783 -0.52924807 -1.34699305 -0.66106872] ...\n", | |
| "A\n", | |
| "[ 0.60423744 0.27361665 -0.31382957 -0.07260951 -0.59464571] ...\n", | |
| "S\n", | |
| "[ 0.1098687 -0.74272931 0.56876845 -0.40129504 -1.2811461 ] ...\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for k,v in patterns_map.items():\n", | |
| " print(k)\n", | |
| " print(v[:5], \"...\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(140, 140, 100)\n", | |
| "(100, 140, 140)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "data_signal = []\n", | |
| "for line in data_char:\n", | |
| " signal_line = []\n", | |
| " for char in line:\n", | |
| " pattern = patterns_map[char]\n", | |
| " signal_line.append(pattern)\n", | |
| " signal_line_arr = np.stack(signal_line, axis=0)\n", | |
| " data_signal.append(signal_line_arr)\n", | |
| "signal_arr = np.stack(data_signal, axis=0)\n", | |
| "print(signal_arr.shape)\n", | |
| "signal_arr = np.transpose(signal_arr, axes=[2, 0, 1])\n", | |
| "print(signal_arr.shape)\n", | |
| "\n", | |
| "# padding\n", | |
| "padding_noise_arr = np.random.normal(scale=noise_std, size=(n_channels, 160, 160))\n", | |
| "\n", | |
| "padding_noise_arr[:, 10:10+signal_arr.shape[1], 10:10+signal_arr.shape[2]] = signal_arr\n", | |
| "\n", | |
| "signal_arr = padding_noise_arr" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "n_rows = n_chars\n", | |
| "n_cols = n_chars\n", | |
| "\n", | |
| "template = np.random.normal(scale=noise_std, size=(n_channels, n_chars, n_chars))\n", | |
| "\n", | |
| "all_patterns_2d = []\n", | |
| "\n", | |
| "# left-right\n", | |
| "pattern_2d = np.copy(template)\n", | |
| "pattern_2d[:, 0, 0] = patterns_map[\"X\"]\n", | |
| "pattern_2d[:, 0, 1] = patterns_map[\"M\"]\n", | |
| "pattern_2d[:, 0, 2] = patterns_map[\"A\"]\n", | |
| "pattern_2d[:, 0, 3] = patterns_map[\"S\"]\n", | |
| "\n", | |
| "all_patterns_2d.append(np.copy(pattern_2d))\n", | |
| "all_patterns_2d.append(np.copy(np.flip(pattern_2d, axis=-1)))\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "# up-down\n", | |
| "pattern_2d = np.copy(template)\n", | |
| "pattern_2d[:, 0, 0] = patterns_map[\"X\"]\n", | |
| "pattern_2d[:, 1, 0] = patterns_map[\"M\"]\n", | |
| "pattern_2d[:, 2, 0] = patterns_map[\"A\"]\n", | |
| "pattern_2d[:, 3, 0] = patterns_map[\"S\"]\n", | |
| "\n", | |
| "all_patterns_2d.append(np.copy(pattern_2d))\n", | |
| "all_patterns_2d.append(np.copy(np.flip(pattern_2d, axis=-2)))\n", | |
| "\n", | |
| "# diag\n", | |
| "pattern_2d = np.copy(template)\n", | |
| "pattern_2d[:, 0, 0] = patterns_map[\"X\"]\n", | |
| "pattern_2d[:, 1, 1] = patterns_map[\"M\"]\n", | |
| "pattern_2d[:, 2, 2] = patterns_map[\"A\"]\n", | |
| "pattern_2d[:, 3, 3] = patterns_map[\"S\"]\n", | |
| "\n", | |
| "all_patterns_2d.append(np.copy(pattern_2d))\n", | |
| "all_patterns_2d.append(np.copy(np.flip(pattern_2d, axis=-1)))\n", | |
| "all_patterns_2d.append(np.copy(np.flip(pattern_2d, axis=-2)))\n", | |
| "all_patterns_2d.append(np.copy(np.flip(pattern_2d, axis=(-1, -2))))\n", | |
| "\n", | |
| "all_patterns_2d = np.stack(all_patterns_2d, axis=0)\n", | |
| "\n", | |
| "def pad_pattern_2d(pattern_2d):\n", | |
| " padded_pattern_2d = np.zeros_like(signal_arr)\n", | |
| " padded_pattern_2d = np.expand_dims(padded_pattern_2d, axis=0)\n", | |
| " # repeat pattern_2d along batch axis\n", | |
| " padded_pattern_2d = np.repeat(padded_pattern_2d, pattern_2d.shape[0], axis=0)\n", | |
| "\n", | |
| " # centralize pattern_2d into padded_pattern_2d\n", | |
| " s_axisM2 = slice((padded_pattern_2d.shape[-2] - pattern_2d.shape[-2]) // 2, (padded_pattern_2d.shape[-2] + pattern_2d.shape[-2]) // 2)\n", | |
| " s_axisM1 = slice((padded_pattern_2d.shape[-1] - pattern_2d.shape[-1]) // 2, (padded_pattern_2d.shape[-1] + pattern_2d.shape[-1]) // 2)\n", | |
| " padded_pattern_2d[..., s_axisM2, s_axisM1] = pattern_2d\n", | |
| "\n", | |
| " return padded_pattern_2d\n", | |
| "\n", | |
| "all_padded_pattern_2d = pad_pattern_2d(all_patterns_2d)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "signal_arr.shape=(1, 100, 160, 160)\n", | |
| "all_padded_pattern_2d.shape=(8, 100, 160, 160)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "signal_arr = np.expand_dims(signal_arr, axis=0)\n", | |
| "print(f\"{signal_arr.shape=}\")\n", | |
| "print(f\"{all_padded_pattern_2d.shape=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(8, 100, 160, 160)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "signal_arr_fft = fft2(signal_arr)\n", | |
| "padded_pattern_2d_fft = fft2(all_padded_pattern_2d)\n", | |
| "\n", | |
| "correlation_temp = np.fft.ifft2(signal_arr_fft * np.conj(padded_pattern_2d_fft))\n", | |
| "correlation = fftshift(np.real(correlation_temp))\n", | |
| "\n", | |
| "print(correlation.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(8, 160, 160)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "correlation = np.mean(correlation, axis=1)\n", | |
| "print(correlation.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "per_pattern.shape=(8, 25600)\n", | |
| "per_pattern_matches=array([360, 361, 416, 433, 213, 194, 206, 195])\n", | |
| "all matches: 2378\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "per_pattern = np.reshape(correlation, (correlation.shape[0], -1))\n", | |
| "print(f\"{per_pattern.shape=}\")\n", | |
| "per_pattern_matches = np.sum((np.abs(per_pattern - 4.0)) < 0.50, axis=-1)\n", | |
| "print(f\"{per_pattern_matches=}\")\n", | |
| "# all matches\n", | |
| "print(f\"all matches: {np.sum(per_pattern_matches)}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment