Skip to content

Instantly share code, notes, and snippets.

@rnett
Last active December 2, 2021 20:05
Show Gist options
  • Save rnett/fb8b2646ae3dcd5e37933262e3c0c813 to your computer and use it in GitHub Desktop.
Save rnett/fb8b2646ae3dcd5e37933262e3c0c813 to your computer and use it in GitHub Desktop.
Audio Denoiser using Lip Reading.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Audio Denoiser using Lip Reading.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/rnett/fb8b2646ae3dcd5e37933262e3c0c813/audio-denoiser-using-lip-reading.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "4_jxmQj2MIOS",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"# Downloading and preprocessing data#\n"
]
},
{
"metadata": {
"pycharm": {
"metadata": false
},
"id": "pDq23_xo7CUU",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Files##"
]
},
{
"metadata": {
"id": "r83v0t7msyEp",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"Raw video files get stores in `files`.\n",
"\n",
"Video is then loaded, turned into frames, then the numpy arrays are saved in `data`.\n"
]
},
{
"metadata": {
"id": "MUVnBSuVY0D-",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"from typing import List, Dict, Iterable\n",
"import numpy as np\n",
"import matplotlib as mpl\n",
"mpl.rc('image', cmap='gray')\n",
"from matplotlib import pyplot as plt\n",
"\n",
"#!pip install git+https://github.com/avivga/face-detection.git\n",
"# !pip3 install git+https://github.com/avivga/mediaio.git\n",
"# !pip3 install imageio\n",
"# !pip3 install imageio-ffmpeg\n",
" \n",
"# import sys\n",
" \n",
"# print(\"Version:\", sys.version)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "yFbgnqff_Jgj",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"File constants"
]
},
{
"metadata": {
"pycharm": {
"metadata": false,
"name": "#%%\n"
},
"id": "Ym3XZ_39Wq67",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"\n",
"raw_dir = 'files'\n",
"video_dir = 'videos'\n",
"audio_dir = 'audio'\n",
"test_dir = 'test'\n",
"train_dir = 'train'"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "h9iwenmr_ML7",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"Download videos"
]
},
{
"metadata": {
"id": "ZR7wsmkMyDlA",
"colab_type": "code",
"outputId": "7f47c1d0-e755-427c-d2db-db9af9a81753",
"pycharm": {},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"cell_type": "code",
"source": [
"import requests, zipfile, io\n",
"\n",
"import shutil\n",
"import os\n",
"\n",
"\n",
"#@title Force file refresh\n",
"force = False #@param {type:\"boolean\"}\n",
"\n",
"if not os.path.isdir(raw_dir) or force: \n",
"\n",
" try:\n",
" shutil.rmtree(raw_dir)\n",
" except FileNotFoundError:\n",
" pass \n",
"\n",
" os.mkdir(raw_dir)\n",
"\n",
" for i in range(3):\n",
" url = \"http://spandh.dcs.shef.ac.uk/gridcorpus/s{}/video/s{}.mpg_vcd.zip\".format(i+1, i+1)\n",
" r = requests.get(url)\n",
" z = zipfile.ZipFile(io.BytesIO(r.content))\n",
" z.extractall(\"tmp\")\n",
"\n",
" files = os.listdir(\"tmp\")\n",
"\n",
" for f in files:\n",
" files1 = os.listdir(\"tmp/\"+f)\n",
" for f1 in files1:\n",
" if f1 != 'Thumbs.db':\n",
" shutil.move(\"tmp/\"+f + \"/\" + f1, raw_dir + '/' + f1.replace('.mpg', f\"_s{i+1}.mpg\"))\n",
"\n",
" shutil.rmtree('tmp')\n",
" \n",
"print(\"Have {0} data files\".format(len(os.listdir(raw_dir))))\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Have 3000 data files\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "nhx0aGVd-TR6",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"test_s3_limit = 25\n",
"test_others_limit = 25\n",
"train_limit = 220\n",
"\n",
"test_s3_count = 0\n",
"test_other_count = 0\n",
"train1_count = 0\n",
"train2_count = 0\n",
"\n",
"if not (os.path.isdir(test_dir) and os.path.isdir(train_dir)) or force: \n",
"\n",
" try:\n",
" shutil.rmtree(test_dir)\n",
" except FileNotFoundError:\n",
" pass \n",
" try:\n",
" shutil.rmtree(train_dir)\n",
" except FileNotFoundError:\n",
" pass \n",
" \n",
" os.mkdir(test_dir)\n",
" os.mkdir(train_dir)\n",
"\n",
" for file in os.listdir(raw_dir):\n",
" if '_s3' in file:\n",
" if test_s3_count < test_s3_limit:\n",
" shutil.copy(raw_dir + '/' + file, test_dir)\n",
" test_s3_count += 1\n",
" elif '_s1' in file:\n",
" if train1_count < train_limit / 2:\n",
" shutil.copy(raw_dir + '/' + file, train_dir)\n",
" train1_count += 1\n",
" if test_other_count < test_others_limit:\n",
" shutil.copy(raw_dir + '/' + file, test_dir)\n",
" test_other_count += 1\n",
" elif '_s2' in file:\n",
" if train2_count < train_limit / 2:\n",
" shutil.copy(raw_dir + '/' + file, train_dir)\n",
" train2_count += 1\n",
" if test_other_count < test_others_limit:\n",
" shutil.copy(raw_dir + '/' + file, test_dir)\n",
" test_other_count += 1\n",
" "
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "uIiTuv51nwBX",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"import cv2, os, gc\n",
"import h5py, imageio\n",
"from PIL import Image\n",
"from mediaio.video_io import VideoFileReader\n",
"from imutils import face_utils\n",
"import numpy as np\n",
"import argparse\n",
"import imutils\n",
"import dlib\n",
"import cv2\n",
"\n",
"def rgb2gray(rgb):\n",
" return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])\n",
"\n",
"lip_size = (30, 75)\n",
" \n",
"class VideoLoader:\n",
" def __init__(self, file):\n",
" self.file = file\n",
" self.training = train_dir in file\n",
" self.name = self.file.split(\"/\",1)[1].replace(\".mpg\", \"\")\n",
" self.loaded = False\n",
" \n",
" def load_and_save(self):\n",
" \n",
" try:\n",
" with imageio.get_reader(self.file) as reader:\n",
"\n",
" size = reader.get_meta_data()[\"size\"]\n",
" video_shape = (75, size[1], size[0])\n",
" gray_frames = np.ndarray(shape=video_shape, dtype=np.uint8)\n",
"\n",
" data = np.zeros(shape=(len(gray_frames),lip_size[0],lip_size[1]), dtype=np.float32)\n",
"\n",
" # initialize dlib's face detector (HOG-based) and then create\n",
" # the facial landmark predictor\n",
" detector = dlib.get_frontal_face_detector()\n",
" predictor = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat')\n",
"\n",
" for i in range(75):\n",
" gray = cv2.cvtColor(reader.get_next_data(), cv2.COLOR_BGR2GRAY)\n",
" gray_frames[i, ] = gray\n",
"\n",
"\n",
" # detect faces in the grayscale image\n",
" rects = detector(gray, 1)\n",
"\n",
" isset = False\n",
"\n",
" # loop over the face detections\n",
" for (k, rect) in enumerate(rects):\n",
" # determine the facial landmarks for the face region, then\n",
" # convert the landmark (x, y)-coordinates to a NumPy array\n",
" shape = predictor(gray_frames[i, ], rect)\n",
" shape = face_utils.shape_to_np(shape)\n",
"\n",
" # loop over the face parts individually\n",
" for (name, (l, m)) in face_utils.FACIAL_LANDMARKS_IDXS.items():\n",
" # clone the original image so we can draw on it, then\n",
" # display the name of the face part on the image\n",
" if name == 'mouth':\n",
" # clone = gray_frames[i, ].copy()\n",
" # cv2.putText(clone, name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,\n",
" # 0.7, (0, 0, 255), 2)\n",
"\n",
" # # loop over the subset of facial landmarks, drawing the\n",
" # # specific face part\n",
" # for (x, y) in shape[l:m]:\n",
" # cv2.circle(clone, (x, y), 1, (0, 0, 255), -1)\n",
"\n",
" # extract the ROI of the face region as a separate image\n",
"\n",
" (x, y, w, h) = cv2.boundingRect(np.array([shape[l:m]]))\n",
" roi = gray_frames[i, ][y:y + h, x:x + w]\n",
" roi = imutils.resize(roi, width=250, inter=cv2.INTER_CUBIC)\n",
" #roi = np.resize(roi,(100,250))\n",
"\n",
" roi = np.array(Image.fromarray(roi).resize((lip_size[1], lip_size[0]), Image.ANTIALIAS))\n",
" isset = True\n",
"\n",
" if not isset:\n",
" print(\"\\nCould not find mouth for video\", self.file)\n",
"\n",
" del data\n",
" del gray_frames\n",
" gc.collect()\n",
" return False\n",
"\n",
" data[i] = roi\n",
"\n",
" if self.training:\n",
" h5f = h5py.File(train_dir + '/' + video_dir + '/' + self.name + '.hdf5', 'w')\n",
" else:\n",
" h5f = h5py.File(test_dir + '/' + video_dir + '/' + self.name + '.hdf5', 'w')\n",
" h5f.create_dataset('video', data=data, compression=\"gzip\")\n",
" h5f.close()\n",
"\n",
" del data\n",
" del gray_frames\n",
" gc.collect()\n",
" return True\n",
" except:\n",
" return False"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "9vCf-apD_VW4",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"## Video##\n",
"\n",
"Here we extract video frames from the file, and save them."
]
},
{
"metadata": {
"id": "zy1RgbHyRGPT",
"colab_type": "code",
"pycharm": {},
"outputId": "f6cc5582-c9b8-4653-d68d-b35207154838",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"cell_type": "code",
"source": [
"import sys\n",
"\n",
"#@title Force video refresh\n",
"video_force = False #@param {type:\"boolean\"}\n",
"limit = 200 #@param {type:\"slider\", min:10, max:1000, step:10}\n",
"\n",
"print(\"Training videos\")\n",
"\n",
"videos = (VideoLoader(train_dir + '/' + f) for f in os.listdir(train_dir))\n",
"\n",
"if not os.path.isdir(train_dir + '/' + video_dir) or video_force: \n",
"\n",
" try:\n",
" shutil.rmtree(train_dir + '/' + video_dir)\n",
" except FileNotFoundError:\n",
" pass \n",
"\n",
" os.mkdir(train_dir + '/' + video_dir)\n",
" \n",
" done = 0\n",
" \n",
" while done <= limit:\n",
" try:\n",
" video = next(videos)\n",
" except StopIteration:\n",
" print(f\"\\nFinished with {done} training videos\")\n",
" break\n",
" #print(\"Video:\", video.file)\n",
" \n",
" if video.load_and_save():\n",
" done += 1\n",
" \n",
" sys.stdout.write('\\r{}/{} ({} %)'.format(done, limit, int(100 * done / limit)))\n",
" sys.stdout.flush()\n",
" \n",
" \n",
"print(\"\\nTest videos:\")\n",
" \n",
"test_limit = 40 #@param {type:\"slider\", min:10, max:1000, step:10}\n",
"\n",
"videos = (VideoLoader(test_dir + '/' + f) for f in os.listdir(test_dir))\n",
"\n",
"if not os.path.isdir(test_dir + '/' + video_dir) or video_force: \n",
"\n",
" try:\n",
" shutil.rmtree(test_dir + '/' + video_dir)\n",
" except FileNotFoundError:\n",
" pass \n",
"\n",
" os.mkdir(test_dir + '/' + video_dir)\n",
" \n",
" done = 0\n",
" \n",
" while done <= test_limit:\n",
" try:\n",
" video = next(videos)\n",
" except StopIteration:\n",
" print(f\"\\nFinished with {done} test videos\")\n",
" break\n",
" #print(\"Video:\", video.file)\n",
" \n",
"# if os.isfile(video.file):\n",
"# done += 1\n",
"# continue\n",
" \n",
" if video.load_and_save():\n",
" done += 1\n",
" \n",
" sys.stdout.write('\\r{}/{} ({} %)'.format(done, test_limit, int(100 * done / test_limit)))\n",
" sys.stdout.flush()\n",
" "
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Training videos\n",
"\n",
"Test videos:\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "MVlBYV8L_dys",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"This provides methods for loading videos."
]
},
{
"metadata": {
"id": "Cy6-_EHrpJG4",
"colab_type": "code",
"pycharm": {},
"outputId": "0d054b54-0184-406e-8baa-b73207ad3d78",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"cell_type": "code",
"source": [
"import os\n",
"\n",
"class Video:\n",
" def __init__(self, name, training):\n",
" self.name = name.replace('.hdf5', '')\n",
" \n",
" if training:\n",
" self.file = train_dir + '/' + video_dir + '/' + name\n",
" else:\n",
" self.file = test_dir + '/' + video_dir + '/' + name\n",
" \n",
" #h5f = h5py.File(self.file,'r')\n",
" #self.data = h5f['video'][:]\n",
" #h5f.close()\n",
" \n",
" #self.data = np.load(self.file)\n",
" \n",
" def data(self):\n",
" h5f = h5py.File(self.file,'r')\n",
" data = h5f['video'][:]\n",
" h5f.close()\n",
" return data\n",
"\n",
"def get_videos(limit=10, training=True):\n",
" if training:\n",
" files = [f for f in os.listdir(train_dir + '/' + video_dir)][:limit]\n",
" else:\n",
" files = [f for f in os.listdir(test_dir + '/' + video_dir)][:limit]\n",
" \n",
" return [Video(f, training) for f in files]\n",
"\n",
"print(len(os.listdir(train_dir + '/' + video_dir)), \" Training Videos\")\n",
"print(len(os.listdir(test_dir + '/' + video_dir)), \"Test Videos\")\n",
"\n",
"videos = get_videos(5)\n",
"\n",
"video_shape = np.shape(videos[0].data())\n",
"\n",
"print(\"Shape: \", video_shape)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"201 Training Videos\n",
"41 Test Videos\n",
"Shape: (75, 30, 75)\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "bIyKCgla_hZ3",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"An example video.\n",
"\n",
"I'm not sure whats up with the colors, but it shouldn't matter."
]
},
{
"metadata": {
"id": "gxVNq9DjulY3",
"colab_type": "code",
"pycharm": {},
"outputId": "4ea02bff-909e-409e-b529-0f8aad7dcad6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 874
}
},
"cell_type": "code",
"source": [
"\n",
"video_shape = np.shape(videos[0].data())\n",
"print('Video shape:', video_shape, videos[0].data().dtype)\n",
"for i in range(0, 75, 15):\n",
" plt.imshow(videos[0].data()[i])\n",
" plt.show()\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Video shape: (75, 30, 75) float32\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
}
]
},
{
"metadata": {
"id": "AADkvcGUZQoi",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"Sucesfully loads visual data, now for audio"
]
},
{
"metadata": {
"id": "-fEODXR936sk",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"## Audio##\n",
"\n",
"\n",
"Here we extract audio from the video files and apply **mfcc** to it.\n",
"\n",
"We then save it for later use."
]
},
{
"metadata": {
"id": "9K6M6YZu3OmX",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"video_frame_rate = 25\n",
"framerate = 22050\n",
"n_fft = int(float(framerate) / video_frame_rate)\n",
"frame_step = int(n_fft / 4)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "5e57VlcA3c6t",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"def reconstruct_audio(magnitude, phase):\n",
" \n",
" magnitude = librosa.db_to_amplitude(magnitude)\n",
"\n",
" mel_filterbank = librosa.filters.mel(\n",
" sr=framerate,\n",
" n_fft=n_fft,\n",
" n_mels=80,\n",
" fmin=0,\n",
" fmax=8000\n",
" )\n",
" \n",
" magnitude = np.dot(np.linalg.pinv(mel_filterbank), magnitude)\n",
" \n",
" mag_phase = magnitude * phase\n",
" \n",
" used_mp2 = magnitude * phase\n",
" wave = librosa.istft(magnitude * phase, hop_length=frame_step)\n",
" \n",
" pad = 65664 - len(wave)\n",
" \n",
" return np.pad(wave, (0, pad), 'constant')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "IMRtSoUuduPY",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"def signal_to_spectrogram(signal):\n",
" D = librosa.core.stft(signal.astype(np.float64), n_fft=n_fft, hop_length=frame_step)\n",
" magnitude, phase = librosa.core.magphase(D)\n",
"\n",
" mel_filterbank = librosa.filters.mel(\n",
" sr=framerate,\n",
" n_fft=n_fft,\n",
" n_mels=80,\n",
" fmin=0,\n",
" fmax=8000\n",
" )\n",
"\n",
" magnitude = np.dot(mel_filterbank, magnitude)\n",
"\n",
" magnitude = librosa.amplitude_to_db(magnitude)\n",
"\n",
" return magnitude, phase"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "7rq-5o2UtpJH",
"colab_type": "code",
"pycharm": {},
"outputId": "3354fc5f-1f3f-4861-e160-8d2e8082ec3f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"cell_type": "code",
"source": [
"audio_force = False #@param {type:\"boolean\"}\n",
"#limit = 100 #@param {type:\"slider\", min:10, max:1000, step:10}\n",
"\n",
"import librosa, scipy\n",
"import shutil\n",
"import tempfile\n",
"import urllib.request\n",
"import cv2, os, gc\n",
"import h5py\n",
"import sys\n",
"\n",
"print(\"Training audio\")\n",
"\n",
"#only load files we have video for\n",
"files = [train_dir + '/' + '/' + f.replace('.hdf5', '.mpg') \n",
" for f in os.listdir(train_dir + '/' + video_dir)]\n",
"\n",
"if not os.path.isdir(train_dir + '/' + audio_dir) or audio_force: \n",
"\n",
" try:\n",
" shutil.rmtree(train_dir + '/' + audio_dir)\n",
" except FileNotFoundError:\n",
" pass \n",
"\n",
" os.mkdir(train_dir + '/' + audio_dir)\n",
" \n",
" total = len(files)\n",
" done = 0\n",
"\n",
" for f in files:\n",
" #audio = mp.VideoFileClip(f).audio\n",
"\n",
" #arr = audio.to_soundarray()\n",
"\n",
" wave, _ = librosa.load(f, mono=True, sr=framerate)\n",
" \n",
" noise = np.random.normal(0,0.05,len(wave))\n",
" \n",
" mel_spectrogram, phase = signal_to_spectrogram(wave)\n",
" \n",
" noisy = wave + noise\n",
" \n",
" noisy_spectrogram, _ = signal_to_spectrogram(noisy)\n",
" \n",
" #mag_phase = get_mag_phase(mel_spectrogram, phase)\n",
" \n",
" # this gets something, not entirely sure what\n",
" \n",
" # add any preprocessing here!\n",
"\n",
" name = f.split(\"/\", 1)[1].replace(\".mpg\", \"\")\n",
"\n",
" h5f = h5py.File(train_dir + '/' + audio_dir + '/' + name + '.hdf5', 'w')\n",
" h5f.create_dataset('spectrogram', data=mel_spectrogram.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('audio', data=wave.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('noisy_spectrogram', data=noisy_spectrogram.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('noisy_audio', data=noisy.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('phase', data=phase, compression=\"gzip\")\n",
" h5f.close()\n",
" \n",
" done += 1\n",
" \n",
" sys.stdout.write('\\r{}/{} ({} %)'.format(done, total, int(100 * done / total)))\n",
" sys.stdout.flush()\n",
"\n",
" \n",
"print(\"\\nTesting audio\")\n",
" \n",
" \n",
"#only load files we have video for\n",
"files = [test_dir + '/' + '/' + f.replace('.hdf5', '.mpg') \n",
" for f in os.listdir(test_dir + '/' + video_dir)]\n",
"\n",
"if not os.path.isdir(test_dir + '/' + audio_dir) or audio_force: \n",
"\n",
" try:\n",
" shutil.rmtree(test_dir + '/' + audio_dir)\n",
" except FileNotFoundError:\n",
" pass \n",
"\n",
" os.mkdir(test_dir + '/' + audio_dir)\n",
" \n",
" total = len(files)\n",
" done = 0\n",
"\n",
" for f in files:\n",
" #audio = mp.VideoFileClip(f).audio\n",
"\n",
" #arr = audio.to_soundarray()\n",
"\n",
" wave, _ = librosa.load(f, mono=True, sr=framerate)\n",
" \n",
" noise = np.random.normal(0,0.05,len(wave))\n",
" \n",
" mel_spectrogram, phase = signal_to_spectrogram(wave)\n",
" \n",
" noisy = wave + noise\n",
" \n",
" noisy_spectrogram, _ = signal_to_spectrogram(noisy)\n",
" \n",
" #mag_phase = get_mag_phase(mel_spectrogram, phase)\n",
" \n",
" # this gets something, not entirely sure what\n",
" \n",
" # add any preprocessing here!\n",
"\n",
" name = f.split(\"/\", 1)[1].replace(\".mpg\", \"\")\n",
"\n",
" h5f = h5py.File(test_dir + '/' + audio_dir + '/' + name + '.hdf5', 'w')\n",
" h5f.create_dataset('spectrogram', data=mel_spectrogram.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('audio', data=wave.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('noisy_spectrogram', data=noisy_spectrogram.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('noisy_audio', data=noisy.astype('float32'), compression=\"gzip\")\n",
" h5f.create_dataset('phase', data=phase, compression=\"gzip\")\n",
" h5f.close()\n",
" \n",
" done += 1\n",
" \n",
" sys.stdout.write('\\r{}/{} ({} %)'.format(done, total, int(100 * done / total)))\n",
" sys.stdout.flush()\n",
" \n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Training audio\n",
"\n",
"Testing audio\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "AEIxgWb2ByMC",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"This is code for loading audio files.\n",
"\n",
"We show a sample audio file, that has been transformed."
]
},
{
"metadata": {
"id": "ltj7OJqGpJRC",
"colab_type": "code",
"pycharm": {},
"outputId": "bcf71d90-85d0-451d-b948-7fbcad72a876",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 492
}
},
"cell_type": "code",
"source": [
"#@title Default title text { run: \"auto\" }\n",
"noise_level = 0.5 #@param {type:\"slider\", min:0, max:1, step:0.05}\n",
"import librosa.display\n",
"np.random.seed(0)\n",
"\n",
"n_fft = int(float(framerate) / video_frame_rate)\n",
"hop_length = int(n_fft / 4)\n",
"\n",
"class Audio:\n",
" def __init__(self, name, training):\n",
" self.name = name.replace('.hdf5', '')\n",
" \n",
" if training:\n",
" self.file = train_dir + '/' + audio_dir + '/' + name\n",
" else:\n",
" self.file = test_dir + '/' + audio_dir + '/' + name\n",
" \n",
"# h5f = h5py.File(self.file,'r')\n",
"# self.data = h5f['mfcc'][:]\n",
" \n",
"# self.audio = h5f['audio'][:]\n",
" \n",
"# self.mfcc_formatted = self.data.reshape(self.data.shape[0], self.data.shape[1])\n",
" \n",
"# h5f.close()\n",
" \n",
" def get_data(self, keys):\n",
" \n",
" if isinstance(keys, str):\n",
" keys = [keys]\n",
" \n",
" h5f = h5py.File(self.file,'r')\n",
" data = {}\n",
" \n",
" for k in keys:\n",
" if k in h5f:\n",
" if k =='noisy_audio':\n",
" data_audio = h5f[k][:]\n",
" audio_shape= np.shape(data_audio)\n",
" data_noise = np.random.normal(0,noise_level,audio_shape)\n",
" data[k] = data_noise + data_audio\n",
" else:\n",
" data[k] = h5f[k][:]\n",
" \n",
" \n",
" if len(data) == 1:\n",
" return data[next(iter(data))]\n",
" \n",
" return data\n",
" \n",
" def __getitem__(self, item):\n",
" return self.get_data(item)\n",
" \n",
" def audio(self):\n",
" return self['audio']\n",
" \n",
" def spectrogram(self):\n",
" return self['spectrogram']\n",
" \n",
" def phase(self):\n",
" return self['phase']\n",
" \n",
" def noisy_audio(self):\n",
" return self['noisy_audio']\n",
" \n",
" def noisy_spectrogram(self):\n",
" return self['noisy_spectrogram']\n",
" \n",
" def noise_spectrogram(self):\n",
" return self['noise_spectrogram']\n",
" \n",
" def reconstruct_audio(self, spectrogram):\n",
" return reconstruct_audio(spectrogram, self['phase'])\n",
" \n",
" \n",
"def get_audios(limit=10, training=True) -> List[Audio]:\n",
" \n",
" if training:\n",
" files = [f for f in os.listdir(train_dir + '/' + audio_dir)][:limit]\n",
" else:\n",
" files = [f for f in os.listdir(test_dir + '/' + audio_dir)][:limit]\n",
" \n",
" return [Audio(f, training) for f in files]\n",
"\n",
"audios = get_audios(5)\n",
"\n",
"spectrogram_shape = np.shape(audios[0].spectrogram())\n",
"\n",
"phase_shape = audios[0].phase().shape\n",
"\n",
"audio_shape = np.shape(audios[0].audio())\n",
"\n",
"print(\"Sound Shape: \", audio_shape, audios[0].audio().dtype)\n",
"print(\"Spectrogram Shape: \", spectrogram_shape, audios[0].spectrogram().dtype)\n",
"print(\"Phase Shape: \", audios[0].phase().shape, audios[0].phase().dtype)\n",
"\n",
"print(\"Waveform\")\n",
"\n",
"librosa.display.waveplot(audios[0].audio())\n",
"plt.show()\n",
"\n",
"print(\"Spectrogram\")\n",
"plt.imshow(audios[0].spectrogram())\n",
"plt.show()\n",
"\n",
"# plt.imshow(audios[0].mfcc_formatted, aspect='auto')\n",
"# plt.show()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Sound Shape: (65664,) float32\n",
"Spectrogram Shape: (80, 299) float32\n",
"Phase Shape: (442, 299) complex64\n",
"Waveform\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Spectrogram\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
}
]
},
{
"metadata": {
"id": "KlLsduShHZo6",
"colab_type": "code",
"pycharm": {},
"outputId": "f72ded93-228f-40bd-8569-4344f3b7ca25",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 392
}
},
"cell_type": "code",
"source": [
"import time\n",
"\n",
"print(\"Reconstructed Waveform\")\n",
"\n",
"start = time.time()\n",
"recon = audios[0].reconstruct_audio(audios[0].spectrogram())\n",
"end = time.time()\n",
"\n",
"print(\"Time: \", end - start)\n",
"\n",
"print(\"Recon Shape:\", recon.shape)\n",
"\n",
"librosa.display.waveplot(recon)\n",
"plt.show()\n",
"\n",
"import IPython.display as ipy_display\n",
"\n",
"ipy_display.Audio(recon, rate=framerate)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Reconstructed Waveform\n",
"Time: 0.05281829833984375\n",
"Recon Shape: (65664,)\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<IPython.lib.display.Audio object>"
],
"text/html": [
"\n",
" <audio controls=\"controls\" >\n",
" <source src=\"data:audio/wav;base64,\" type=\"audio/wav\" />\n",
" Your browser does not support the audio element.\n",
" </audio>\n",
" "
]
},
"metadata": {
"tags": []
},
"execution_count": 96
}
]
},
{
"metadata": {
"id": "gSPUXfj6XQ3k",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"Remove old files"
]
},
{
"metadata": {
"id": "dTnSG4lOXXHN",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"#shutil.rmtree('files')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "-KLR9J3phXNL",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Combined"
]
},
{
"metadata": {
"id": "_p2x5TXTMCmE",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"class AudioVideo:\n",
" def __init__(self, name, training=True):\n",
" self.video = Video(name, training)\n",
" self.audio = Audio(name, training)\n",
" \n",
" def get_data(self, keys):\n",
" \n",
" if isinstance(keys, str):\n",
" keys = [keys]\n",
" \n",
" if 'video' in keys:\n",
" data = {'video': self.video.data()}\n",
" keys = {k for k in keys if k != 'video'}\n",
" else:\n",
" data = {}\n",
" \n",
" d = self.audio[keys]\n",
" \n",
" if isinstance(d, dict):\n",
" for k, v in d.items():\n",
" data[k] = v\n",
" else:\n",
" data[keys[0]] = d\n",
" \n",
" \n",
" if len(data) == 1:\n",
" return data[next(iter(data))]\n",
" \n",
" return data\n",
" \n",
" def __getitem__(self, item: Iterable[str]):\n",
" return self.get_data(item)\n",
" \n",
" def audio_data(self):\n",
" return self['audio']\n",
" \n",
" def video_data(self):\n",
" return self['video']\n",
" \n",
" def spectrogram(self):\n",
" return self['spectrogram']\n",
" \n",
" def phase(self):\n",
" return self['phase']\n",
" \n",
" def video_data(self):\n",
" return self['video']\n",
" \n",
" def noisy_audio(self):\n",
" return self['noisy_audio']\n",
" \n",
" def noisy_spectrogram(self):\n",
" return self['noisy_spectrogram']\n",
" \n",
" \n",
" \n",
"def get_audio_and_video(limit=100, training=True):\n",
" \n",
" if training:\n",
" files = [f for f in os.listdir(train_dir + '/' + video_dir)][:limit]\n",
" else:\n",
" files = [f for f in os.listdir(test_dir + '/' + video_dir)][:limit]\n",
" \n",
" while True:\n",
" for f in files:\n",
" yield AudioVideo(f, training)\n",
"\n",
" \n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "DkFTUmqMhCoR",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Training Utilities"
]
},
{
"metadata": {
"id": "vldFdtwVhPAb",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Data generators\n",
"\n",
"Our data is quite large, and Keras supprots using generators for training and prediction, so we take advantage of this."
]
},
{
"metadata": {
"id": "_sG4xKle4ZSr",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"def get_training_data(limit=100, batch=10, use_video=False):\n",
" i = 0\n",
" \n",
" av_getter = get_audio_and_video(limit)\n",
" \n",
" while True:\n",
" \n",
" if use_video:\n",
" video_input = np.empty((batch,) + video_shape, dtype='uint8')\n",
" \n",
" spectrogram_input = np.empty((batch,) + spectrogram_shape, dtype='float32')\n",
" \n",
" spectrogram_output = np.empty((batch,) + spectrogram_shape, dtype='float32')\n",
" \n",
" for i in range(batch):\n",
" av = next(av_getter)\n",
" \n",
" keys = ['spectrogram', 'noisy_spectrogram']\n",
" \n",
" if use_video:\n",
" keys.append('video')\n",
" \n",
" data = av[keys]\n",
" \n",
" spectrogram_input[i] = data['noisy_spectrogram']\n",
" spectrogram_output[i] = data['spectrogram']\n",
" \n",
" if use_video:\n",
" video_input[i] = data['video']\n",
" \n",
" data = {}\n",
" \n",
" if use_video:\n",
" data['video'] = video_input\n",
" \n",
" data['noisy_spectrogram'] = spectrogram_input\n",
" \n",
" data['spectrogram'] = spectrogram_output\n",
" \n",
" yield data\n",
" \n",
"def get_testing_data(limit=100, use_video=False):\n",
" av_getter = get_audio_and_video(limit, training=False)\n",
" \n",
" noisy_spectrogram = np.empty((limit,) + spectrogram_shape, dtype='float32')\n",
" audio = np.empty((limit,) + audio_shape, dtype='float32')\n",
" noisy_audio = np.empty((limit,) + audio_shape, dtype='float32')\n",
" \n",
" if use_video:\n",
" video = np.empty((limit,) + video_shape, dtype='uint8')\n",
" \n",
" phase = np.empty((limit,) + phase_shape, dtype='complex64')\n",
" \n",
" for i in range(limit):\n",
" \n",
" av = next(av_getter)\n",
"\n",
" keys = ['noisy_spectrogram', 'audio', 'noisy_audio', 'video', 'phase']\n",
"\n",
" if use_video:\n",
" keys.append('video')\n",
" \n",
" data = av[keys]\n",
" noisy_spectrogram[i] = data['noisy_spectrogram']\n",
" audio[i] = data['audio']\n",
" noisy_audio[i] = data['noisy_audio']\n",
" \n",
" if use_video:\n",
" video[i] = data['video']\n",
" \n",
" phase[i] = data['phase']\n",
" \n",
" result = {\n",
" 'spectrogram_input': noisy_spectrogram,\n",
" 'audio': audio,\n",
" 'noisy_audio': noisy_audio,\n",
" 'phase_input': phase\n",
" }\n",
" \n",
" if use_video:\n",
" result['video_input'] = video\n",
" \n",
" return result"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "0bDj-2J57Q9j",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"## Tensorflow Inverse STFT\n",
"\n",
"Tensorflow's istft doesn't play nice with librosa's stft, so we had to implement our own. We use this for turning the spectrogram back into audio in the model, which isn't really nessecary now, but we implemented this so we could try a residual network (it didn't work)."
]
},
{
"metadata": {
"id": "RAIBdh6pH9mP",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"import tensorflow.signal\n",
"\n",
"from tensorflow import convert_to_tensor, map_fn\n",
"import tensorflow.keras.backend as K\n",
"from tensorflow import float32 as tf_float32\n",
"import tensorflow_probability as tfp\n",
"\n",
"def log10(x):\n",
" num = tf.log(x)\n",
" den = tf.log(tf.constant(10, dtype=num.dtype))\n",
" return num / den\n",
"\n",
"def tf_db_to_amp(signal):\n",
" return 10.0**(0.05 * signal)\n",
"\n",
"def tf_amp_to_db(signal):\n",
" amin = 1e-10\n",
" topdb = 80.0\n",
" signal = signal ** 2\n",
" \n",
" log_spec = 10.0 * log10(tf.math.maximum(amin, magnitude))\n",
" log_spec -= 10.0 * log10(tf.math.maximum(amin, 1.0))\n",
" log_spec = tf.math.maximum(log_spec, log_spec.max() - top_db)\n",
" \n",
" return log_spec\n",
"\n",
"tf_mel = tf.constant(\n",
" librosa.filters.mel(\n",
" sr=framerate,\n",
" n_fft=n_fft,\n",
" n_mels=80,\n",
" fmin=0,\n",
" fmax=8000\n",
" )\n",
")\n",
" \n",
"ifft_window = librosa.filters.get_window('hann', n_fft, fftbins=True)\n",
"ifft_window = librosa.util.pad_center(ifft_window, n_fft)\n",
"ifft_window = tf.constant(ifft_window, dtype='float32')\n",
"\n",
"ifft_window_sum = tf.constant(librosa.filters.window_sumsquare('hann',\n",
" 299,\n",
" win_length=n_fft,\n",
" n_fft=n_fft,\n",
" hop_length=frame_step))\n",
" \n",
"approx_nonzero_indices = ifft_window_sum > librosa.util.tiny(ifft_window_sum)\n",
" \n",
"divisor = tf.where(approx_nonzero_indices, ifft_window_sum, tf.ones(ifft_window_sum.shape[0]))\n",
" \n",
"def tf_istft(stft_matrix, hop_length, window='hann', center=True):\n",
" win_length = n_fft\n",
" \n",
" n_frames = int(stft_matrix.shape[1])\n",
" \n",
" expected_signal_len = n_fft + hop_length * (n_frames - 1)\n",
" \n",
" y = tf.zeros((expected_signal_len,), 'float32')\n",
" \n",
" i = tf.constant(0, dtype='int32')\n",
" \n",
" _, y = tf.while_loop(\n",
" lambda i, y: i < tf.constant(n_frames), \n",
" lambda i, y: tf_loop_body(y, i, hop_length, stft_matrix, ifft_window, expected_signal_len, n_fft), \n",
" loop_vars = [i,y])\n",
" \n",
" #y = tf.math.add_n(y)\n",
" \n",
" y /= divisor\n",
" \n",
" y = y[int(n_fft // 2):-int(n_fft // 2)]\n",
" \n",
" return y\n",
"\n",
"def tf_loop_body(y, i, hop_length, stft_matrix, ifft_window, expected_signal_len, n_fft):\n",
" \n",
" sample = i * hop_length\n",
" spec = tf.squeeze(stft_matrix[:, i])\n",
" spec = tf.concat((spec, tf.math.conj(spec[-2:0:-1])), 0)\n",
" ytmp = ifft_window * tf.math.real(tf.signal.ifft(spec))\n",
"\n",
" #tf.assign_add(y[sample:(sample + n_fft)], ytmp)\n",
"\n",
" ytmp = tf.pad(ytmp, [[sample, expected_signal_len - (sample + n_fft)]], mode='CONSTANT')\n",
" \n",
" return [tf.add(i, 1), tf.add(y, ytmp)]\n",
"\n",
"import functools\n",
"\n",
"def tf_reconstruct_audio(mp):\n",
" \n",
" magnitude = mp[0]\n",
" phase = mp[1]\n",
" \n",
" print(phase.dtype)\n",
" \n",
" magnitude = tf_db_to_amp(magnitude)\n",
"\n",
" mel_filterbank = tf_mel\n",
"\n",
" magnitude = tfp.math.pinv(mel_filterbank) @ tf.cast(magnitude, 'float64')\n",
" \n",
" \n",
" mag_phase = tf.cast(magnitude, 'complex64') * phase\n",
" \n",
" wave = tf_istft(mag_phase, hop_length=frame_step)\n",
" \n",
" pad = 65664 - int(wave.shape[0])\n",
" \n",
" if pad > 0:\n",
" wave = tf.pad(wave, [[0, pad]], 'constant')\n",
" \n",
" return wave\n",
"\n",
"def tensor_reconstruct_audio(ip):\n",
" tensor = tf.map_fn(\n",
" tf_reconstruct_audio, \n",
" ip, \n",
" dtype=tf_float32, infer_shape=False)\n",
" tensor.set_shape((None, 65664,))\n",
" return tensor\n",
" \n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "F5vpLYMCctGg",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"# with tf.Session() as sess:\n",
"# m = tf.constant(audios[0].spectrogram())\n",
"# p = tf.constant(audios[0].phase())\n",
"# recon = tf_reconstruct_audio([m, p])\n",
"# start = time.time()\n",
"# recon = sess.run(recon)\n",
"# end = time.time()\n",
"\n",
"# print(\"Time: \", end - start)\n",
"\n",
"# print(\"Recon Shape:\", recon.shape)\n",
"\n",
"# librosa.display.waveplot(recon)\n",
"# plt.show()\n",
"\n",
"# import IPython.display as ipy_display\n",
"\n",
"# ipy_display.Audio(recon, rate=framerate)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "uJJURBxOq94A",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"# Denoising Autoencoder#\n",
"\n",
"Audio encoder and decoder from [here](https://github.com/avivga/audio-visual-speech-enhancement).\n",
"\n",
"Video encoder from [here](https://github.com/rizkiarm/LipNet)."
]
},
{
"metadata": {
"id": "7JD6KOuZuXC-",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"from tensorflow.keras.layers import Input, Flatten, Dense, Reshape, Concatenate, Dropout\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.optimizers import SGD, Adadelta, Adam\n",
"from tensorflow.keras.layers import Conv3D, ZeroPadding3D, UpSampling1D, ZeroPadding1D\n",
"from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Cropping2D, Cropping1D\n",
"from tensorflow.keras.layers import MaxPooling3D, MaxPooling2D, UpSampling2D\n",
"from tensorflow.keras.layers import Dense, Activation, SpatialDropout3D, Flatten\n",
"from tensorflow.keras.layers import Bidirectional, TimeDistributed, Subtract\n",
"from tensorflow.keras.layers import GRU, LSTM, Lambda, LeakyReLU, ZeroPadding2D\n",
"from tensorflow.keras.layers import BatchNormalization, MaxPooling1D, Conv1D\n",
"from tensorflow.keras.layers import SpatialDropout2D, Conv1D, Permute\n",
"\n",
"import tensorflow.keras.backend as K\n",
"\n",
"from progressbar import progressbar as tqdm\n",
"\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' \n",
"tf.logging.set_verbosity(tf.logging.ERROR)\n",
"\n",
"class Autoencoder:\n",
" def __init__(self, use_video=False):\n",
" \n",
" self.use_video = use_video\n",
" \n",
" self.reconstruct_layer = Lambda(tensor_reconstruct_audio, name='reconstruct_audio', output_shape=(65664,))\n",
" \n",
" self.video = [\n",
" ZeroPadding3D(padding=(1, 3, 3), name='zero1'),\n",
" Conv3D(32, (3, 5, 5), strides=(1, 2, 2), kernel_initializer='he_normal', name='conv1'),\n",
" BatchNormalization(name='batc1'),\n",
" Activation('relu', name='actv1'),\n",
" SpatialDropout3D(0.3),\n",
" MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max1'),\n",
" \n",
" ZeroPadding3D(padding=(1, 2, 2), name='zero2'),\n",
" Conv3D(64, (3, 5, 5), strides=(1, 1, 1), kernel_initializer='he_normal', name='conv2'),\n",
" BatchNormalization(name='batc2'),\n",
" Activation('relu', name='actv2'),\n",
" SpatialDropout3D(0.3),\n",
" MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max2'),\n",
" \n",
" ZeroPadding3D(padding=(1, 1, 1), name='zero3'),\n",
" Conv3D(92, (3, 3, 3), strides=(1, 1, 1), kernel_initializer='he_normal', name='conv3'),\n",
" BatchNormalization(name='batc3'),\n",
" Activation('relu', name='actv3'),\n",
" SpatialDropout3D(0.3),\n",
" MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max3'),\n",
" \n",
" TimeDistributed(Flatten()),\n",
" \n",
" Bidirectional(GRU(256, return_sequences=True, kernel_initializer='Orthogonal', name='gru1'), merge_mode='concat'),\n",
" \n",
" Dropout(0.3)\n",
" ]\n",
" \n",
" self.spectrogram_encoder = [\n",
" Reshape((spectrogram_shape) + (1,)),\n",
" \n",
" Conv2D(64, kernel_size=(5, 5), strides=(2, 2), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Conv2D(64, kernel_size=(4, 4), strides=(1, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Conv2D(128, kernel_size=(4, 4), strides=(2, 2), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Conv2D(128, kernel_size=(2, 2), strides=(2, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Conv2D(128, kernel_size=(2, 2), strides=(2, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Conv2D(128, kernel_size=(2, 2), strides=(2, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Permute((2, 1, 3)),\n",
" TimeDistributed(Flatten())\n",
" ]\n",
" \n",
" self.spectrogram_latent = [\n",
" \n",
"# GRU(384, return_sequences=True, kernel_initializer='Orthogonal', name='latent_rnn1'),\n",
"# TimeDistributed(BatchNormalization()),\n",
"# TimeDistributed(LeakyReLU()),\n",
" \n",
"# LSTM(384, return_sequences=True, kernel_initializer='Orthogonal', name='latent_rnn2'),\n",
" \n",
" TimeDistributed(Dense(384)),\n",
" TimeDistributed(BatchNormalization()),\n",
" TimeDistributed(LeakyReLU()),\n",
" \n",
" \n",
" TimeDistributed(Reshape((3, 128))),\n",
" Permute((2, 1, 3)),\n",
" ]\n",
" \n",
" self.spectrogram_decoder = [\n",
" \n",
" Conv2DTranspose(128, kernel_size=(2, 2), strides=(2, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Conv2DTranspose(128, kernel_size=(2, 2), strides=(2, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
"\n",
" Conv2DTranspose(128, kernel_size=(2, 2), strides=(2, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
"\n",
" Conv2DTranspose(128, kernel_size=(4, 4), strides=(2, 2), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
"\n",
" Conv2DTranspose(64, kernel_size=(4, 4), strides=(1, 1), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
"\n",
" Conv2DTranspose(64, kernel_size=(5, 5), strides=(2, 2), padding='same'),\n",
" BatchNormalization(),\n",
" LeakyReLU(),\n",
" \n",
" Conv2DTranspose(1, kernel_size=(1, 1), strides=(1, 1), padding='same'),\n",
" \n",
" Cropping2D(((8, 8), (0, 1))),\n",
" \n",
" Reshape((80, 299), name='spectrogram_output'),\n",
" ]\n",
" \n",
" self.discriminator = [\n",
" Reshape((80, 299, 1)),\n",
" \n",
" Conv2D(32,4,strides=2,activation=None,padding='same'),\n",
" LeakyReLU(alpha=0.1),\n",
"\n",
" Conv2D(64,4,strides=2,activation=None,padding='same'),\n",
" LeakyReLU(alpha=0.1),\n",
"\n",
" Flatten(),\n",
" \n",
" Dense(1,activation='sigmoid')\n",
" ]\n",
" \n",
" def get_autoencoder(self):\n",
" \"\"\" Builds the full autoencoder model with encoder and decoder. \"\"\"\n",
" \n",
" \n",
" # Spectrogram Encoder\n",
" \n",
" \n",
" spectrogram_input = Input(shape=spectrogram_shape,name='spectrogram_input')\n",
"\n",
" s = spectrogram_input\n",
" \n",
" for l in self.spectrogram_encoder:\n",
" s = l(s)\n",
" encoding_layer = l\n",
"\n",
"\n",
" spectrogram_encoder_model = Model(inputs=spectrogram_input, outputs=s)\n",
"\n",
" print(\"Spectrogram encoder output:\", spectrogram_encoder_model.output_shape[1:])\n",
"\n",
" print(\"Spectrogram encoder:\")\n",
" print(spectrogram_encoder_model.summary())\n",
"\n",
" print(\"\\n\\n\\n\")\n",
" \n",
" inputs = [spectrogram_input]\n",
" \n",
" # Video Encoder\n",
" \n",
" if self.use_video:\n",
" \n",
" video_input = Input(shape=video_shape,name='video_input')\n",
"\n",
" v = video_input\n",
"\n",
" v = Reshape(video_shape + (1,))(v)\n",
"\n",
" for l in self.video:\n",
" v = l(v)\n",
"\n",
"\n",
" self.video_encoder_model = Model(inputs=video_input, outputs=v)\n",
"\n",
" print(\"Video encoder output:\", self.video_encoder_model.output_shape[1:])\n",
"\n",
" print(\"Video encoder:\")\n",
" print(self.video_encoder_model.summary())\n",
"\n",
" print(\"\\n\\n\\n\")\n",
" \n",
" inputs.append(video_input)\n",
" \n",
" encoding_layer = Concatenate(axis=2)\n",
" encoding = encoding_layer([s, v])\n",
" else:\n",
" encoding = s\n",
" \n",
" embedding_shape = encoding_layer.output_shape\n",
" \n",
" # Spectrogram Latent\n",
" \n",
" print(\"Embedding Shape:\", embedding_shape)\n",
" \n",
" se_in = Input(embedding_shape[1:])\n",
" \n",
" se = se_in\n",
" \n",
" s = encoding\n",
" \n",
" for l in self.spectrogram_latent:\n",
" s = l(s)\n",
" se = l(se)\n",
" \n",
" print(\"\\n\\n\\n\")\n",
"\n",
" print(\"Spectrogram Latent:\")\n",
" print(Model(inputs=se_in, outputs=se).summary())\n",
" print(\"\\n\\n\\n\")\n",
" \n",
" # Spectrogram Decoder\n",
" \n",
" sd_in = Input((3, 75, 128))\n",
" \n",
" sd = sd_in\n",
" \n",
" for l in self.spectrogram_decoder:\n",
" sd = l(sd)\n",
" s = l(s)\n",
" \n",
" spectrogram_decoder_model = Model(inputs=sd_in, outputs=sd)\n",
" \n",
" print(\"Spectrogram decoder output:\", spectrogram_decoder_model.output_shape[1:])\n",
" \n",
" print(\"Spectrogram decoder:\")\n",
" print(spectrogram_decoder_model.summary())\n",
" \n",
" print(\"\\n\\n\\n\")\n",
" \n",
" \n",
" self.spectrogram_model = Model(inputs=inputs, outputs=s)\n",
" \n",
" \n",
" discrim_in = Input((80, 299), name='discriminator_input')\n",
" \n",
" discrim = discrim_in\n",
" \n",
" discrim_on_decoded = s\n",
" \n",
" for l in self.discriminator:\n",
" discrim = l(discrim)\n",
" discrim_on_decoded = l(discrim_on_decoded)\n",
" \n",
" self.discriminator_model = Model(inputs=discrim_in, outputs=discrim)\n",
" self.discriminator_on_model = Model(inputs=inputs, outputs=discrim_on_decoded)\n",
" \n",
" print(\"Discriminator:\")\n",
" print(self.discriminator_model.summary())\n",
" print(\"\\n\\n\\n\")\n",
" \n",
" print(\"Discriminator on model:\")\n",
" print(self.discriminator_on_model.summary())\n",
" print(\"\\n\\n\\n\")\n",
" \n",
" phase_input = Input((442, 299), name='phase_input', dtype='complex64')\n",
" \n",
" reconstructed = self.reconstruct_layer([s, phase_input])\n",
" \n",
" reconstructed_model = Model(inputs=inputs + [phase_input,], outputs=reconstructed)\n",
" \n",
" #print(\"Reconstructed decoder output:\", reconstructed_model.output_shape[1:])\n",
" \n",
" self.model = Model(inputs=inputs + [phase_input,], outputs=[reconstructed])\n",
" \n",
" \n",
" print(\"Full Model:\")\n",
" print(self.model.summary())\n",
" \n",
" print(\"\\n\\n\\n\")\n",
" \n",
" \n",
" \n",
" return self.model, self.spectrogram_model, self.discriminator_model, self.discriminator_on_model\n",
" \n",
" def compile(self, ops, losses):\n",
" return self.spectrogram_model.compile(ops[0],loss=losses[0]), \\\n",
" self.discriminator_model.compile(ops[1], loss=losses[1]), \\\n",
" self.discriminator_on_model.compile(ops[2], loss=losses[2])\n",
" \n",
" def train(self, data_limit=100, data_batch=10, epochs=10, steps_per_epoch=10, autoencoder_reps=3, discriminator_reps=2, d_on_ae_reps=1):\n",
" data_gen = self.get_data(data_limit, data_batch)\n",
" \n",
" self.discriminator_model.trainable = False\n",
" \n",
" ae_loss_history = []\n",
" discrim_loss_history = []\n",
" d_on_ae_loss_history = []\n",
" \n",
" for i in tqdm(range(epochs)):\n",
" for j in tqdm(range(steps_per_epoch)):\n",
" data = next(data_gen)\n",
" \n",
" x_real = data['spectrogram']\n",
" y_real = [1]*data_batch\n",
" \n",
" if self.use_video:\n",
" x_gen = [data['noisy_spectrogram'], data['video']]\n",
" else:\n",
" x_gen = [data['noisy_spectrogram']]\n",
" \n",
" y_gen = [1]*data_batch\n",
" \n",
" # train autoencoder\n",
" \n",
" ae_loss = 0\n",
" for k in range(autoencoder_reps):\n",
" ae_loss += self.spectrogram_model.train_on_batch(x_gen, x_real)\n",
" \n",
" if autoencoder_reps > 0:\n",
" ae_loss /= autoencoder_reps\n",
" \n",
" \n",
" x_fake = self.spectrogram_model.predict(x_gen)\n",
" y_fake = [0]*data_batch\n",
" \n",
" # train discriminator\n",
" \n",
" self.discriminator_model.trainable = True\n",
" \n",
" real_loss = 0\n",
" fake_loss = 0\n",
" for k in range(discriminator_reps):\n",
" \n",
" real_loss += self.discriminator_model.train_on_batch(x_real, y_real)\n",
" fake_loss += self.discriminator_model.train_on_batch(x_fake, y_fake)\n",
" \n",
" self.discriminator_model.trainable = False\n",
" \n",
" discrim_loss = 0.5*(real_loss + fake_loss)\n",
" \n",
" if discriminator_reps > 0:\n",
" discrim_loss /= discriminator_reps\n",
" \n",
" # train discriminator on autoencoder\n",
" \n",
" d_on_ae_loss = 0\n",
" for k in range(d_on_ae_reps):\n",
" d_on_ae_loss += self.discriminator_on_model.train_on_batch(x_gen, y_gen)\n",
" \n",
" if d_on_ae_reps > 0:\n",
" d_on_ae_loss /= d_on_ae_reps\n",
" \n",
" ae_loss_history.append(ae_loss)\n",
" discrim_loss_history.append(discrim_loss)\n",
" d_on_ae_loss_history.append(d_on_ae_loss)\n",
" \n",
" return ae_loss_history, discrim_loss_history, d_on_ae_loss_history\n",
" \n",
" def get_data(self, limit=100, batch=10):\n",
" return get_training_data(limit, batch, self.use_video)\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ZL2XP9L9h8zN",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Compile the model.\n",
"\n",
"Using SGD with MAE for the audoencoder, and Adam with Binary Crossentropy for the discriminator and combined."
]
},
{
"metadata": {
"id": "wcO7ZQdFCjUB",
"colab_type": "code",
"pycharm": {},
"outputId": "61845869-f919-4a32-9d00-fb392a33f060",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 10557
}
},
"cell_type": "code",
"source": [
"from tensorflow.keras.optimizers import Adam\n",
"\n",
"#@title Autoencoder Type\n",
"use_video = True #@param {type:\"boolean\"}\n",
"#use_audio = True #@param {type:\"boolean\"}\n",
"#use_spectrogram = False #@param {type:\"boolean\"}\n",
"\n",
"autoencoder = Autoencoder(\n",
" use_video=use_video)\n",
"\n",
"model, spec, discrim, d_on_g = autoencoder.get_autoencoder()\n",
"\n",
"#TODO try other optimizers (Adam, Adagrad)\n",
"#SGD(0.02,momentum=0.9)\n",
"autoencoder.compile(ops = [\n",
" Adam(lr=5e-3),\n",
" Adam(lr=0.0002,beta_1=0.5,beta_2=0.999,epsilon=1e-3),\n",
" Adam(lr=0.0002,beta_1=0.5,beta_2=0.999,epsilon=1e-3)\n",
"], losses = [\n",
" 'mean_squared_error', #mean_squared_error\n",
" 'binary_crossentropy', \n",
" 'binary_crossentropy'\n",
"])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Spectrogram encoder output: (75, 384)\n",
"Spectrogram encoder:\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"spectrogram_input (InputLaye (None, 80, 299) 0 \n",
"_________________________________________________________________\n",
"reshape_8 (Reshape) (None, 80, 299, 1) 0 \n",
"_________________________________________________________________\n",
"conv2d_16 (Conv2D) (None, 40, 150, 64) 1664 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_26 (B (None, 40, 150, 64) 256 \n",
"_________________________________________________________________\n",
"leaky_re_lu_30 (LeakyReLU) (None, 40, 150, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_17 (Conv2D) (None, 40, 150, 64) 65600 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_27 (B (None, 40, 150, 64) 256 \n",
"_________________________________________________________________\n",
"leaky_re_lu_31 (LeakyReLU) (None, 40, 150, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_18 (Conv2D) (None, 20, 75, 128) 131200 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_28 (B (None, 20, 75, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_32 (LeakyReLU) (None, 20, 75, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_19 (Conv2D) (None, 10, 75, 128) 65664 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_29 (B (None, 10, 75, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_33 (LeakyReLU) (None, 10, 75, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_20 (Conv2D) (None, 5, 75, 128) 65664 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_30 (B (None, 5, 75, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_34 (LeakyReLU) (None, 5, 75, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_21 (Conv2D) (None, 3, 75, 128) 65664 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_31 (B (None, 3, 75, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_35 (LeakyReLU) (None, 3, 75, 128) 0 \n",
"_________________________________________________________________\n",
"permute_4 (Permute) (None, 75, 3, 128) 0 \n",
"_________________________________________________________________\n",
"time_distributed_13 (TimeDis (None, 75, 384) 0 \n",
"=================================================================\n",
"Total params: 398,016\n",
"Trainable params: 396,736\n",
"Non-trainable params: 1,280\n",
"_________________________________________________________________\n",
"None\n",
"\n",
"\n",
"\n",
"\n",
"Video encoder output: (75, 512)\n",
"Video encoder:\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"video_input (InputLayer) (None, 75, 30, 75) 0 \n",
"_________________________________________________________________\n",
"reshape_11 (Reshape) (None, 75, 30, 75, 1) 0 \n",
"_________________________________________________________________\n",
"zero1 (ZeroPadding3D) (None, 77, 36, 81, 1) 0 \n",
"_________________________________________________________________\n",
"conv1 (Conv3D) (None, 75, 16, 39, 32) 2432 \n",
"_________________________________________________________________\n",
"batc1 (BatchNormalizationV1) (None, 75, 16, 39, 32) 128 \n",
"_________________________________________________________________\n",
"actv1 (Activation) (None, 75, 16, 39, 32) 0 \n",
"_________________________________________________________________\n",
"spatial_dropout3d_6 (Spatial (None, 75, 16, 39, 32) 0 \n",
"_________________________________________________________________\n",
"max1 (MaxPooling3D) (None, 75, 8, 19, 32) 0 \n",
"_________________________________________________________________\n",
"zero2 (ZeroPadding3D) (None, 77, 12, 23, 32) 0 \n",
"_________________________________________________________________\n",
"conv2 (Conv3D) (None, 75, 8, 19, 64) 153664 \n",
"_________________________________________________________________\n",
"batc2 (BatchNormalizationV1) (None, 75, 8, 19, 64) 256 \n",
"_________________________________________________________________\n",
"actv2 (Activation) (None, 75, 8, 19, 64) 0 \n",
"_________________________________________________________________\n",
"spatial_dropout3d_7 (Spatial (None, 75, 8, 19, 64) 0 \n",
"_________________________________________________________________\n",
"max2 (MaxPooling3D) (None, 75, 4, 9, 64) 0 \n",
"_________________________________________________________________\n",
"zero3 (ZeroPadding3D) (None, 77, 6, 11, 64) 0 \n",
"_________________________________________________________________\n",
"conv3 (Conv3D) (None, 75, 4, 9, 92) 159068 \n",
"_________________________________________________________________\n",
"batc3 (BatchNormalizationV1) (None, 75, 4, 9, 92) 368 \n",
"_________________________________________________________________\n",
"actv3 (Activation) (None, 75, 4, 9, 92) 0 \n",
"_________________________________________________________________\n",
"spatial_dropout3d_8 (Spatial (None, 75, 4, 9, 92) 0 \n",
"_________________________________________________________________\n",
"max3 (MaxPooling3D) (None, 75, 2, 4, 92) 0 \n",
"_________________________________________________________________\n",
"time_distributed_12 (TimeDis (None, 75, 736) 0 \n",
"_________________________________________________________________\n",
"bidirectional_2 (Bidirection (None, 75, 512) 1525248 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 75, 512) 0 \n",
"=================================================================\n",
"Total params: 1,841,164\n",
"Trainable params: 1,840,788\n",
"Non-trainable params: 376\n",
"_________________________________________________________________\n",
"None\n",
"\n",
"\n",
"\n",
"\n",
"Embedding Shape: (None, 75, 896)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"Spectrogram Latent:\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_5 (InputLayer) (None, 75, 896) 0 \n",
"_________________________________________________________________\n",
"time_distributed_14 (TimeDis (None, 75, 384) 344448 \n",
"_________________________________________________________________\n",
"time_distributed_15 (TimeDis (None, 75, 384) 1536 \n",
"_________________________________________________________________\n",
"time_distributed_16 (TimeDis (None, 75, 384) 0 \n",
"_________________________________________________________________\n",
"time_distributed_17 (TimeDis (None, 75, 3, 128) 0 \n",
"_________________________________________________________________\n",
"permute_5 (Permute) (None, 3, 75, 128) 0 \n",
"=================================================================\n",
"Total params: 345,984\n",
"Trainable params: 345,216\n",
"Non-trainable params: 768\n",
"_________________________________________________________________\n",
"None\n",
"\n",
"\n",
"\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Spectrogram decoder output: (80, 299)\n",
"Spectrogram decoder:\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_6 (InputLayer) (None, 3, 75, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_transpose_14 (Conv2DT (None, 6, 75, 128) 65664 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_33 (B (None, 6, 75, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_37 (LeakyReLU) (None, 6, 75, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_transpose_15 (Conv2DT (None, 12, 75, 128) 65664 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_34 (B (None, 12, 75, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_38 (LeakyReLU) (None, 12, 75, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_transpose_16 (Conv2DT (None, 24, 75, 128) 65664 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_35 (B (None, 24, 75, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_39 (LeakyReLU) (None, 24, 75, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_transpose_17 (Conv2DT (None, 48, 150, 128) 262272 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_36 (B (None, 48, 150, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_40 (LeakyReLU) (None, 48, 150, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_transpose_18 (Conv2DT (None, 48, 150, 64) 131136 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_37 (B (None, 48, 150, 64) 256 \n",
"_________________________________________________________________\n",
"leaky_re_lu_41 (LeakyReLU) (None, 48, 150, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_transpose_19 (Conv2DT (None, 96, 300, 64) 102464 \n",
"_________________________________________________________________\n",
"batch_normalization_v1_38 (B (None, 96, 300, 64) 256 \n",
"_________________________________________________________________\n",
"leaky_re_lu_42 (LeakyReLU) (None, 96, 300, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_transpose_20 (Conv2DT (None, 96, 300, 1) 65 \n",
"_________________________________________________________________\n",
"cropping2d_2 (Cropping2D) (None, 80, 299, 1) 0 \n",
"_________________________________________________________________\n",
"spectrogram_output (Reshape) (None, 80, 299) 0 \n",
"=================================================================\n",
"Total params: 695,489\n",
"Trainable params: 694,209\n",
"Non-trainable params: 1,280\n",
"_________________________________________________________________\n",
"None\n",
"\n",
"\n",
"\n",
"\n",
"Discriminator:\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"discriminator_input (InputLa (None, 80, 299) 0 \n",
"_________________________________________________________________\n",
"reshape_10 (Reshape) (None, 80, 299, 1) 0 \n",
"_________________________________________________________________\n",
"conv2d_22 (Conv2D) (None, 40, 150, 32) 544 \n",
"_________________________________________________________________\n",
"leaky_re_lu_43 (LeakyReLU) (None, 40, 150, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_23 (Conv2D) (None, 20, 75, 64) 32832 \n",
"_________________________________________________________________\n",
"leaky_re_lu_44 (LeakyReLU) (None, 20, 75, 64) 0 \n",
"_________________________________________________________________\n",
"flatten_8 (Flatten) (None, 96000) 0 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 1) 96001 \n",
"=================================================================\n",
"Total params: 129,377\n",
"Trainable params: 129,377\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n",
"\n",
"\n",
"\n",
"\n",
"Discriminator on model:\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"video_input (InputLayer) (None, 75, 30, 75) 0 \n",
"__________________________________________________________________________________________________\n",
"spectrogram_input (InputLayer) (None, 80, 299) 0 \n",
"__________________________________________________________________________________________________\n",
"reshape_11 (Reshape) (None, 75, 30, 75, 1 0 video_input[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape_8 (Reshape) (None, 80, 299, 1) 0 spectrogram_input[0][0] \n",
"__________________________________________________________________________________________________\n",
"zero1 (ZeroPadding3D) (None, 77, 36, 81, 1 0 reshape_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_16 (Conv2D) (None, 40, 150, 64) 1664 reshape_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv1 (Conv3D) (None, 75, 16, 39, 3 2432 zero1[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_26 (Batc (None, 40, 150, 64) 256 conv2d_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"batc1 (BatchNormalizationV1) (None, 75, 16, 39, 3 128 conv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_30 (LeakyReLU) (None, 40, 150, 64) 0 batch_normalization_v1_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"actv1 (Activation) (None, 75, 16, 39, 3 0 batc1[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_17 (Conv2D) (None, 40, 150, 64) 65600 leaky_re_lu_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"spatial_dropout3d_6 (SpatialDro (None, 75, 16, 39, 3 0 actv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_27 (Batc (None, 40, 150, 64) 256 conv2d_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"max1 (MaxPooling3D) (None, 75, 8, 19, 32 0 spatial_dropout3d_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_31 (LeakyReLU) (None, 40, 150, 64) 0 batch_normalization_v1_27[0][0] \n",
"__________________________________________________________________________________________________\n",
"zero2 (ZeroPadding3D) (None, 77, 12, 23, 3 0 max1[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_18 (Conv2D) (None, 20, 75, 128) 131200 leaky_re_lu_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2 (Conv3D) (None, 75, 8, 19, 64 153664 zero2[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_28 (Batc (None, 20, 75, 128) 512 conv2d_18[0][0] \n",
"__________________________________________________________________________________________________\n",
"batc2 (BatchNormalizationV1) (None, 75, 8, 19, 64 256 conv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_32 (LeakyReLU) (None, 20, 75, 128) 0 batch_normalization_v1_28[0][0] \n",
"__________________________________________________________________________________________________\n",
"actv2 (Activation) (None, 75, 8, 19, 64 0 batc2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_19 (Conv2D) (None, 10, 75, 128) 65664 leaky_re_lu_32[0][0] \n",
"__________________________________________________________________________________________________\n",
"spatial_dropout3d_7 (SpatialDro (None, 75, 8, 19, 64 0 actv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_29 (Batc (None, 10, 75, 128) 512 conv2d_19[0][0] \n",
"__________________________________________________________________________________________________\n",
"max2 (MaxPooling3D) (None, 75, 4, 9, 64) 0 spatial_dropout3d_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_33 (LeakyReLU) (None, 10, 75, 128) 0 batch_normalization_v1_29[0][0] \n",
"__________________________________________________________________________________________________\n",
"zero3 (ZeroPadding3D) (None, 77, 6, 11, 64 0 max2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_20 (Conv2D) (None, 5, 75, 128) 65664 leaky_re_lu_33[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv3 (Conv3D) (None, 75, 4, 9, 92) 159068 zero3[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_30 (Batc (None, 5, 75, 128) 512 conv2d_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"batc3 (BatchNormalizationV1) (None, 75, 4, 9, 92) 368 conv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_34 (LeakyReLU) (None, 5, 75, 128) 0 batch_normalization_v1_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"actv3 (Activation) (None, 75, 4, 9, 92) 0 batc3[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_21 (Conv2D) (None, 3, 75, 128) 65664 leaky_re_lu_34[0][0] \n",
"__________________________________________________________________________________________________\n",
"spatial_dropout3d_8 (SpatialDro (None, 75, 4, 9, 92) 0 actv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_31 (Batc (None, 3, 75, 128) 512 conv2d_21[0][0] \n",
"__________________________________________________________________________________________________\n",
"max3 (MaxPooling3D) (None, 75, 2, 4, 92) 0 spatial_dropout3d_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_35 (LeakyReLU) (None, 3, 75, 128) 0 batch_normalization_v1_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_12 (TimeDistri (None, 75, 736) 0 max3[0][0] \n",
"__________________________________________________________________________________________________\n",
"permute_4 (Permute) (None, 75, 3, 128) 0 leaky_re_lu_35[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_2 (Bidirectional) (None, 75, 512) 1525248 time_distributed_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_13 (TimeDistri (None, 75, 384) 0 permute_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_2 (Dropout) (None, 75, 512) 0 bidirectional_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_2 (Concatenate) (None, 75, 896) 0 time_distributed_13[0][0] \n",
" dropout_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_14 (TimeDistri (None, 75, 384) 344448 concatenate_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_15 (TimeDistri (None, 75, 384) 1536 time_distributed_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_16 (TimeDistri (None, 75, 384) 0 time_distributed_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_17 (TimeDistri (None, 75, 3, 128) 0 time_distributed_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"permute_5 (Permute) (None, 3, 75, 128) 0 time_distributed_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_14 (Conv2DTran (None, 6, 75, 128) 65664 permute_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_33 (Batc (None, 6, 75, 128) 512 conv2d_transpose_14[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_37 (LeakyReLU) (None, 6, 75, 128) 0 batch_normalization_v1_33[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_15 (Conv2DTran (None, 12, 75, 128) 65664 leaky_re_lu_37[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_34 (Batc (None, 12, 75, 128) 512 conv2d_transpose_15[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_38 (LeakyReLU) (None, 12, 75, 128) 0 batch_normalization_v1_34[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_16 (Conv2DTran (None, 24, 75, 128) 65664 leaky_re_lu_38[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_35 (Batc (None, 24, 75, 128) 512 conv2d_transpose_16[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_39 (LeakyReLU) (None, 24, 75, 128) 0 batch_normalization_v1_35[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_17 (Conv2DTran (None, 48, 150, 128) 262272 leaky_re_lu_39[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_36 (Batc (None, 48, 150, 128) 512 conv2d_transpose_17[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_40 (LeakyReLU) (None, 48, 150, 128) 0 batch_normalization_v1_36[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_18 (Conv2DTran (None, 48, 150, 64) 131136 leaky_re_lu_40[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_37 (Batc (None, 48, 150, 64) 256 conv2d_transpose_18[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_41 (LeakyReLU) (None, 48, 150, 64) 0 batch_normalization_v1_37[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_19 (Conv2DTran (None, 96, 300, 64) 102464 leaky_re_lu_41[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_38 (Batc (None, 96, 300, 64) 256 conv2d_transpose_19[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_42 (LeakyReLU) (None, 96, 300, 64) 0 batch_normalization_v1_38[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_20 (Conv2DTran (None, 96, 300, 1) 65 leaky_re_lu_42[1][0] \n",
"__________________________________________________________________________________________________\n",
"cropping2d_2 (Cropping2D) (None, 80, 299, 1) 0 conv2d_transpose_20[1][0] \n",
"__________________________________________________________________________________________________\n",
"spectrogram_output (Reshape) (None, 80, 299) 0 cropping2d_2[1][0] \n",
"__________________________________________________________________________________________________\n",
"reshape_10 (Reshape) (None, 80, 299, 1) 0 spectrogram_output[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_22 (Conv2D) (None, 40, 150, 32) 544 reshape_10[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_43 (LeakyReLU) (None, 40, 150, 32) 0 conv2d_22[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_23 (Conv2D) (None, 20, 75, 64) 32832 leaky_re_lu_43[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_44 (LeakyReLU) (None, 20, 75, 64) 0 conv2d_23[1][0] \n",
"__________________________________________________________________________________________________\n",
"flatten_8 (Flatten) (None, 96000) 0 leaky_re_lu_44[1][0] \n",
"__________________________________________________________________________________________________\n",
"dense_5 (Dense) (None, 1) 96001 flatten_8[1][0] \n",
"==================================================================================================\n",
"Total params: 3,410,030\n",
"Trainable params: 3,406,326\n",
"Non-trainable params: 3,704\n",
"__________________________________________________________________________________________________\n",
"None\n",
"\n",
"\n",
"\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n",
"/usr/local/lib/python3.6/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"<dtype: 'complex64'>\n",
"Full Model:\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"video_input (InputLayer) (None, 75, 30, 75) 0 \n",
"__________________________________________________________________________________________________\n",
"spectrogram_input (InputLayer) (None, 80, 299) 0 \n",
"__________________________________________________________________________________________________\n",
"reshape_11 (Reshape) (None, 75, 30, 75, 1 0 video_input[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape_8 (Reshape) (None, 80, 299, 1) 0 spectrogram_input[0][0] \n",
"__________________________________________________________________________________________________\n",
"zero1 (ZeroPadding3D) (None, 77, 36, 81, 1 0 reshape_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_16 (Conv2D) (None, 40, 150, 64) 1664 reshape_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv1 (Conv3D) (None, 75, 16, 39, 3 2432 zero1[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_26 (Batc (None, 40, 150, 64) 256 conv2d_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"batc1 (BatchNormalizationV1) (None, 75, 16, 39, 3 128 conv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_30 (LeakyReLU) (None, 40, 150, 64) 0 batch_normalization_v1_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"actv1 (Activation) (None, 75, 16, 39, 3 0 batc1[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_17 (Conv2D) (None, 40, 150, 64) 65600 leaky_re_lu_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"spatial_dropout3d_6 (SpatialDro (None, 75, 16, 39, 3 0 actv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_27 (Batc (None, 40, 150, 64) 256 conv2d_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"max1 (MaxPooling3D) (None, 75, 8, 19, 32 0 spatial_dropout3d_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_31 (LeakyReLU) (None, 40, 150, 64) 0 batch_normalization_v1_27[0][0] \n",
"__________________________________________________________________________________________________\n",
"zero2 (ZeroPadding3D) (None, 77, 12, 23, 3 0 max1[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_18 (Conv2D) (None, 20, 75, 128) 131200 leaky_re_lu_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2 (Conv3D) (None, 75, 8, 19, 64 153664 zero2[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_28 (Batc (None, 20, 75, 128) 512 conv2d_18[0][0] \n",
"__________________________________________________________________________________________________\n",
"batc2 (BatchNormalizationV1) (None, 75, 8, 19, 64 256 conv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_32 (LeakyReLU) (None, 20, 75, 128) 0 batch_normalization_v1_28[0][0] \n",
"__________________________________________________________________________________________________\n",
"actv2 (Activation) (None, 75, 8, 19, 64 0 batc2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_19 (Conv2D) (None, 10, 75, 128) 65664 leaky_re_lu_32[0][0] \n",
"__________________________________________________________________________________________________\n",
"spatial_dropout3d_7 (SpatialDro (None, 75, 8, 19, 64 0 actv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_29 (Batc (None, 10, 75, 128) 512 conv2d_19[0][0] \n",
"__________________________________________________________________________________________________\n",
"max2 (MaxPooling3D) (None, 75, 4, 9, 64) 0 spatial_dropout3d_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_33 (LeakyReLU) (None, 10, 75, 128) 0 batch_normalization_v1_29[0][0] \n",
"__________________________________________________________________________________________________\n",
"zero3 (ZeroPadding3D) (None, 77, 6, 11, 64 0 max2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_20 (Conv2D) (None, 5, 75, 128) 65664 leaky_re_lu_33[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv3 (Conv3D) (None, 75, 4, 9, 92) 159068 zero3[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_30 (Batc (None, 5, 75, 128) 512 conv2d_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"batc3 (BatchNormalizationV1) (None, 75, 4, 9, 92) 368 conv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_34 (LeakyReLU) (None, 5, 75, 128) 0 batch_normalization_v1_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"actv3 (Activation) (None, 75, 4, 9, 92) 0 batc3[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_21 (Conv2D) (None, 3, 75, 128) 65664 leaky_re_lu_34[0][0] \n",
"__________________________________________________________________________________________________\n",
"spatial_dropout3d_8 (SpatialDro (None, 75, 4, 9, 92) 0 actv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_31 (Batc (None, 3, 75, 128) 512 conv2d_21[0][0] \n",
"__________________________________________________________________________________________________\n",
"max3 (MaxPooling3D) (None, 75, 2, 4, 92) 0 spatial_dropout3d_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_35 (LeakyReLU) (None, 3, 75, 128) 0 batch_normalization_v1_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_12 (TimeDistri (None, 75, 736) 0 max3[0][0] \n",
"__________________________________________________________________________________________________\n",
"permute_4 (Permute) (None, 75, 3, 128) 0 leaky_re_lu_35[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_2 (Bidirectional) (None, 75, 512) 1525248 time_distributed_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_13 (TimeDistri (None, 75, 384) 0 permute_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_2 (Dropout) (None, 75, 512) 0 bidirectional_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_2 (Concatenate) (None, 75, 896) 0 time_distributed_13[0][0] \n",
" dropout_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_14 (TimeDistri (None, 75, 384) 344448 concatenate_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_15 (TimeDistri (None, 75, 384) 1536 time_distributed_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_16 (TimeDistri (None, 75, 384) 0 time_distributed_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_17 (TimeDistri (None, 75, 3, 128) 0 time_distributed_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"permute_5 (Permute) (None, 3, 75, 128) 0 time_distributed_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_14 (Conv2DTran (None, 6, 75, 128) 65664 permute_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_33 (Batc (None, 6, 75, 128) 512 conv2d_transpose_14[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_37 (LeakyReLU) (None, 6, 75, 128) 0 batch_normalization_v1_33[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_15 (Conv2DTran (None, 12, 75, 128) 65664 leaky_re_lu_37[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_34 (Batc (None, 12, 75, 128) 512 conv2d_transpose_15[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_38 (LeakyReLU) (None, 12, 75, 128) 0 batch_normalization_v1_34[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_16 (Conv2DTran (None, 24, 75, 128) 65664 leaky_re_lu_38[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_35 (Batc (None, 24, 75, 128) 512 conv2d_transpose_16[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_39 (LeakyReLU) (None, 24, 75, 128) 0 batch_normalization_v1_35[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_17 (Conv2DTran (None, 48, 150, 128) 262272 leaky_re_lu_39[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_36 (Batc (None, 48, 150, 128) 512 conv2d_transpose_17[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_40 (LeakyReLU) (None, 48, 150, 128) 0 batch_normalization_v1_36[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_18 (Conv2DTran (None, 48, 150, 64) 131136 leaky_re_lu_40[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_37 (Batc (None, 48, 150, 64) 256 conv2d_transpose_18[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_41 (LeakyReLU) (None, 48, 150, 64) 0 batch_normalization_v1_37[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_19 (Conv2DTran (None, 96, 300, 64) 102464 leaky_re_lu_41[1][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_v1_38 (Batc (None, 96, 300, 64) 256 conv2d_transpose_19[1][0] \n",
"__________________________________________________________________________________________________\n",
"leaky_re_lu_42 (LeakyReLU) (None, 96, 300, 64) 0 batch_normalization_v1_38[1][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_transpose_20 (Conv2DTran (None, 96, 300, 1) 65 leaky_re_lu_42[1][0] \n",
"__________________________________________________________________________________________________\n",
"cropping2d_2 (Cropping2D) (None, 80, 299, 1) 0 conv2d_transpose_20[1][0] \n",
"__________________________________________________________________________________________________\n",
"spectrogram_output (Reshape) (None, 80, 299) 0 cropping2d_2[1][0] \n",
"__________________________________________________________________________________________________\n",
"phase_input (InputLayer) (None, 442, 299) 0 \n",
"__________________________________________________________________________________________________\n",
"reconstruct_audio (Lambda) (None, 65664) 0 spectrogram_output[1][0] \n",
" phase_input[0][0] \n",
"==================================================================================================\n",
"Total params: 3,280,653\n",
"Trainable params: 3,276,949\n",
"Non-trainable params: 3,704\n",
"__________________________________________________________________________________________________\n",
"None\n",
"\n",
"\n",
"\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(None, None, None)"
]
},
"metadata": {
"tags": []
},
"execution_count": 103
}
]
},
{
"metadata": {
"id": "hLL_ydwVICNN",
"colab_type": "code",
"pycharm": {},
"outputId": "73185117-05e2-4346-bb3a-634b8fc41fff",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"cell_type": "code",
"source": [
"data = next(autoencoder.get_data())\n",
"\n",
"size = 0\n",
"\n",
"for v in data.values():\n",
" size += v.nbytes\n",
"\n",
"print(\"Size: \", size / 10**6, \"MB\")"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Size: 3.6011 MB\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "3cg5_fYA3B5Y",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"Fit the model using a generator."
]
},
{
"metadata": {
"id": "cHSZQRC8QEup",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Should train for testing with `autoencoder.train(data_limit=200, data_batch=20, steps_per_epoch=10, epochs=10)`."
]
},
{
"metadata": {
"id": "rKwZkytZrlpK",
"colab_type": "code",
"pycharm": {},
"outputId": "2836df38-ff59-4404-e444-556958abb17b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 884
}
},
"cell_type": "code",
"source": [
"#history = ae_model.fit_generator(get_data(1000, 10), steps_per_epoch = 5, epochs=20, verbose=2)\n",
"\n",
"ae_hist, discrim_hist, d_on_ae_hist = autoencoder.train(data_limit=200, data_batch=20, steps_per_epoch=10, epochs=50, autoencoder_reps=3, discriminator_reps=5, d_on_ae_reps=1)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"100% (10 of 10) |########################| Elapsed Time: 0:01:31 Time: 0:01:31\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (10 of 10) |########################| Elapsed Time: 0:00:47 Time: 0:00:47\n",
"100% (50 of 50) |########################| Elapsed Time: 0:40:19 Time: 0:40:19\n"
],
"name": "stderr"
}
]
},
{
"metadata": {
"id": "bXwWO-eReua-",
"colab_type": "code",
"pycharm": {},
"outputId": "7c40c8d3-4c6e-41e0-86f4-2bd4b6724f64",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 283
}
},
"cell_type": "code",
"source": [
"plt.plot(ae_hist,label='Autoencoder')\n",
"plt.plot(discrim_hist, label='Discriminator')\n",
"plt.plot(d_on_ae_hist, label='Together')\n",
"#plt.plot(history.history['val_loss'], label='val_loss')\n",
"plt.legend()\n",
"plt.xlabel('Step')\n",
"plt.ylabel('Loss')\n",
"plt.show()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
}
]
},
{
"metadata": {
"id": "YC4yK7FK34qg",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Losses after the 100th step"
]
},
{
"metadata": {
"id": "Uiokrr8J36ks",
"colab_type": "code",
"outputId": "3b02e3eb-f52e-4e4d-f930-6eaa199bf586",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 283
}
},
"cell_type": "code",
"source": [
"mean = np.mean(ae_hist) + np.mean(discrim_hist) + np.mean(d_on_ae_hist)\n",
"mean /= 3\n",
"\n",
"limit = mean * 2\n",
"\n",
"start = 50\n",
"for i in range(len(ae_hist)):\n",
" if ae_hist[i] <= limit and discrim_hist[i] <= limit and d_on_ae_hist[i] <= limit:\n",
" start = i\n",
"\n",
"start = 40\n",
" \n",
"plt.plot(ae_hist[start:],label='Autoencoder')\n",
"plt.plot(discrim_hist[start:], label='Discriminator')\n",
"plt.plot(d_on_ae_hist[start:], label='Together')\n",
"#plt.plot(history.history['val_loss'], label='val_loss')\n",
"plt.legend()\n",
"plt.xlabel('Step')\n",
"plt.ylabel('Loss')\n",
"plt.show()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
}
]
},
{
"metadata": {
"id": "fWLje2N33GX7",
"colab_type": "text",
"pycharm": {}
},
"cell_type": "markdown",
"source": [
"Here is a sample output:"
]
},
{
"metadata": {
"id": "OkMYsWtpmrPl",
"colab_type": "code",
"pycharm": {},
"colab": {}
},
"cell_type": "code",
"source": [
"def clean(av, use_video):\n",
" if use_video:\n",
" return reconstruct_audio(spec.predict([[av.noisy_spectrogram()], [av.video_data()]])[0], av.phase())\n",
" else:\n",
" return model.predict([[av.noisy_spectrogram()], [av.phase()]])[0]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ytszGuuDezYw",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Testing"
]
},
{
"metadata": {
"id": "Ak2KW608fCBM",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Test on seen data"
]
},
{
"metadata": {
"id": "zNUkuYGKfu5q",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### Graphs"
]
},
{
"metadata": {
"id": "JvwlI25B7AxD",
"colab_type": "code",
"pycharm": {},
"cellView": "form",
"outputId": "87b41ebc-6dd2-4d4f-e0d6-464812a32128",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1007
}
},
"cell_type": "code",
"source": [
"#@title\n",
"import librosa.display, librosa.output\n",
"\n",
"\n",
"\n",
"gen = get_audio_and_video()\n",
"#next(gen)\n",
"test = next(gen)\n",
"\n",
"print(\"Known\")\n",
"audio = test.audio_data()\n",
"librosa.display.waveplot(audio)\n",
"plt.show()\n",
"\n",
"print(\"Known Spectrogram\")\n",
"plt.imshow(test.spectrogram())\n",
"plt.show()\n",
"\n",
"print(\"Noisy\")\n",
"noisy = test.noisy_audio()\n",
"\n",
"librosa.display.waveplot(noisy)\n",
"plt.show()\n",
"\n",
"# if autoencoder.use_video:\n",
"# p = spec.predict([[test.noisy_spectrogram()], [test.video()]])[0]\n",
"# else:\n",
"# p = spec.predict([[test.noisy_spectrogram()]])[0]\n",
"\n",
"p = clean(test, autoencoder.use_video)\n",
"\n",
"# print(\"Spectrogram\")\n",
"# plt.imshow(p)\n",
"# plt.show()\n",
"\n",
"# p = reconstruct_audio(p, test.phase())\n",
"\n",
"print(\"Predicted\")\n",
"librosa.display.waveplot(p)\n",
"plt.show()\n",
"\n",
"\n",
"# Use facial point recognition instead of video\n",
"# Jane Zang doing similar research"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Known\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Known Spectrogram\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Noisy\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Predicted\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
}
]
},
{
"metadata": {
"id": "SO9tgJ9wfq6p",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### Audio"
]
},
{
"metadata": {
"id": "VfWjzdU4vD_A",
"colab_type": "code",
"pycharm": {},
"cellView": "form",
"outputId": "8b348aeb-c4f1-43ca-d613-27bcd124a8b2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
}
},
"cell_type": "code",
"source": [
"#@title\n",
"\n",
"all_validation_audio = np.concatenate((test.audio_data(), test.noisy_audio(), p))\n",
"ipy_display.display(ipy_display.Audio(all_validation_audio, rate=framerate))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.lib.display.Audio object>"
],
"text/html": [
"\n",
" <audio controls=\"controls\" >\n",
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment