Last active
November 2, 2024 03:30
-
-
Save KatsuhiroMorishita/f97e129ac87b8518346daebf822ad1d5 to your computer and use it in GitHub Desktop.
Fixes for mp3 support for BireNET as of September 2, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from itertools import count, islice | |
from pathlib import Path | |
from typing import Any, Generator, Iterable, Optional, Tuple, Union | |
import numpy as np | |
import numpy.typing as npt | |
import requests | |
import soundfile as sf | |
from ordered_set import OrderedSet | |
from scipy.signal import butter, lfilter, resample | |
from tqdm import tqdm | |
from birdnet.types import Species | |
def get_species_from_file(species_file: Path, /, *, encoding: str = "utf8") -> OrderedSet[Species]: | |
species = OrderedSet(species_file.read_text(encoding).splitlines()) | |
return species | |
def bandpass_signal(audio_signal: npt.NDArray[np.float32], rate: int, fmin: int, fmax: int, new_fmin: int, new_fmax: int) -> npt.NDArray[np.float32]: | |
assert rate > 0 | |
assert fmin >= 0 | |
assert fmin < fmax | |
assert new_fmin >= 0 | |
assert new_fmin < new_fmax | |
nth_order = 5 | |
nyquist = 0.5 * rate | |
# Highpass | |
if fmin > new_fmin and fmax == new_fmax: | |
low = fmin / nyquist | |
b, a = butter(nth_order, low, btype="high") | |
audio_signal = lfilter(b, a, audio_signal) | |
# Lowpass | |
elif fmin == new_fmin and fmax < new_fmax: | |
high = fmax / nyquist | |
b, a = butter(nth_order, high, btype="low") | |
audio_signal = lfilter(b, a, audio_signal) | |
# Bandpass | |
elif fmin > new_fmin and fmax < new_fmax: | |
low = fmin / nyquist | |
high = fmax / nyquist | |
b, a = butter(nth_order, [low, high], btype="band") | |
audio_signal = lfilter(b, a, audio_signal) | |
sig_f32 = audio_signal.astype(np.float32) | |
return sig_f32 | |
def chunk_signal(audio_signal: npt.NDArray[np.float32], rate: int, chunk_size: float, chunk_overlap: float, min_chunk_size: float) -> Generator[Tuple[float, float, npt.NDArray[np.float32]], None, None]: | |
"""Split signal with overlap. | |
Args: | |
sig: The original signal to be split. | |
rate: The sampling rate. | |
seconds: The duration of a segment. | |
overlap: The overlapping seconds of segments. | |
minlen: Minimum length of a split. | |
Returns: | |
A list of splits. | |
""" | |
assert rate > 0 | |
assert min_chunk_size > 0 | |
assert chunk_overlap >= 0 | |
assert chunk_overlap < chunk_size | |
# Number of frames per chunk, per step and per minimum signal | |
chunk_frame_count = round(rate * chunk_size) | |
chunk_step_frame_count = round(rate * (chunk_size - chunk_overlap)) | |
min_chunk_frame_count = round(rate * min_chunk_size) | |
# Start of last chunk | |
last_chunk_position = round((audio_signal.size - chunk_frame_count + | |
chunk_step_frame_count - 1) / chunk_step_frame_count) * chunk_step_frame_count | |
# Make sure at least one chunk is returned | |
if last_chunk_position < 0: | |
last_chunk_position = 0 | |
# Omit last chunk if minimum signal duration is underrun | |
elif audio_signal.size - last_chunk_position < min_chunk_frame_count: | |
last_chunk_position = last_chunk_position - chunk_step_frame_count | |
# Append empty signal of chunk duration, so the last split has the desired length in any case | |
# TODO maybe add noise instead of empty signal | |
noise = np.zeros(shape=chunk_frame_count, dtype=audio_signal.dtype) | |
data = np.concatenate((audio_signal, noise)) | |
start: float = 0.0 | |
end: float = chunk_size | |
# Split signal with overlap | |
for i in range(0, 1 + last_chunk_position, chunk_step_frame_count): | |
chunk = data[i:i + chunk_frame_count] | |
yield start, end, chunk | |
# Advance start and end | |
start += chunk_size - chunk_overlap | |
end = start + chunk_size | |
def fillup_with_silence(audio_chunk: npt.NDArray[np.float32], target_length: int) -> npt.NDArray[np.float32]: | |
current_length = len(audio_chunk) | |
assert current_length <= target_length | |
if current_length == target_length: | |
return audio_chunk | |
silence_length = target_length - current_length | |
silence = np.zeros(silence_length, dtype=audio_chunk.dtype) | |
filled_chunk = np.concatenate((audio_chunk, silence)) | |
return filled_chunk | |
def flat_sigmoid(x: npt.NDArray[np.float32], sensitivity: float) -> npt.NDArray[np.float32]: | |
result: npt.NDArray[np.float32] = 1.0 / (1.0 + np.exp(sensitivity * np.clip(x, -15, 15))) | |
return result | |
def sigmoid_inverse(x: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: | |
return np.log(x / (1 - x)) | |
def get_app_data_path() -> Path: | |
"""Returns the appropriate application data path based on the operating system.""" | |
if os.name == 'nt': # Windows | |
app_data_path = os.getenv('APPDATA') | |
assert app_data_path is not None | |
elif os.name == 'posix': | |
if os.uname().sysname == 'Darwin': # Mac OS X | |
app_data_path = os.path.expanduser('~/Library/Application Support') | |
else: # Linux | |
app_data_path = os.path.expanduser('~/.local/share') | |
else: | |
raise OSError('Unsupported operating system') | |
result = Path(app_data_path) | |
return result | |
def get_birdnet_app_data_folder() -> Path: | |
app_data = get_app_data_path() | |
result = app_data / "birdnet" | |
return result | |
def download_file(url: str, file_path: Path) -> None: | |
assert file_path.parent.is_dir() | |
response = requests.get(url, timeout=30) | |
if response.status_code == 200: | |
with open(file_path, 'wb') as file: | |
file.write(response.content) | |
else: | |
raise ValueError(f"Failed to download the file. Status code: {response.status_code}") | |
def download_file_tqdm(url: str, file_path: Path, *, download_size: Optional[int] = None, description: Optional[str] = None) -> None: | |
assert file_path.parent.is_dir() | |
response = requests.get(url, stream=True, timeout=30) | |
total_size = int(response.headers.get('content-length', 0)) | |
if download_size is not None: | |
total_size = download_size | |
block_size = 1024 | |
with tqdm(total=total_size, unit='iB', unit_scale=True, desc=description) as tqdm_bar: | |
with open(file_path, 'wb') as file: | |
for data in response.iter_content(block_size): | |
tqdm_bar.update(len(data)) | |
file.write(data) | |
if response.status_code != 200 or (total_size not in (0, tqdm_bar.n)): | |
raise ValueError(f"Failed to download the file. Status code: {response.status_code}") | |
def itertools_batched(iterable: Iterable, n: int) -> Generator[Any, None, None]: | |
# https://docs.python.org/3.12/library/itertools.html#itertools.batched | |
# batched('ABCDEFG', 3) → ABC DEF G | |
if n < 1: | |
raise ValueError('n must be at least one') | |
iterator = iter(iterable) | |
while batch := tuple(islice(iterator, n)): | |
yield batch | |
def get_chunks_with_overlap(total_duration_s: Union[int, float], chunk_duration_s: Union[int, float], overlap_duration_s: Union[int, float]) -> Generator[Tuple[float, float], None, None]: | |
assert total_duration_s > 0 | |
assert chunk_duration_s > 0 | |
assert 0 <= overlap_duration_s < chunk_duration_s | |
if not isinstance(overlap_duration_s, float): | |
overlap_duration_s = float(overlap_duration_s) | |
if not isinstance(chunk_duration_s, float): | |
chunk_duration_s = float(chunk_duration_s) | |
if not isinstance(total_duration_s, float): | |
total_duration_s = float(total_duration_s) | |
step_duration = chunk_duration_s - overlap_duration_s | |
for start in count(0.0, step_duration): | |
assert start < total_duration_s | |
if (end := start + chunk_duration_s) < total_duration_s: | |
yield start, end | |
else: | |
yield start, total_duration_s | |
break | |
def resample_array(x: npt.NDArray, sample_rate: int, target_sample_rate: int) -> npt.NDArray: | |
assert len(x.shape) == 1 | |
assert 0 < sample_rate | |
assert 0 < target_sample_rate | |
if sample_rate == target_sample_rate: | |
return x | |
target_sample_count = round(len(x) / sample_rate * target_sample_rate) | |
x_resampled: npt.NDArray = resample(x, target_sample_count) | |
return x_resampled | |
def load_audio_in_chunks_with_overlap(audio_path: Path, /, *, chunk_duration_s: float = 3, overlap_duration_s: float = 0, target_sample_rate: int = 48000) -> Generator[Tuple[float, float, npt.NDArray[np.float32]], None, None]: | |
assert audio_path.is_file() | |
sf_info = sf.info(audio_path) | |
sample_rate = sf_info.samplerate | |
#print(sf_info) # デバッグ用 | |
#print("-- ", float(sf_info.duration)) # デバッグ用 | |
#print(sample_rate) # デバッグ用 | |
timestamps = get_chunks_with_overlap( | |
float(sf_info.duration), | |
float(chunk_duration_s), | |
float(overlap_duration_s), | |
) | |
# add for mp3 | |
_, ext = os.path.splitext(audio_path) # 拡張子を取得 | |
y = [] | |
if ext.lower() == ".mp3": # mp3だったらlibrosaを使って読み込む | |
import librosa | |
y, _ = librosa.load(audio_path, sr=sample_rate) # srを指定しないと22050 Hzで勝手にリサンプリングされる | |
#print("** ", len(y) / sample_rate) # 音源長を計算して表示 | |
for start, end in timestamps: | |
start_samples = round(start * sample_rate) | |
end_samples = round(end * sample_rate) | |
if ext.lower() == ".mp3": | |
audio = y[start_samples:end_samples] | |
else: | |
audio, _ = sf.read(audio_path, start=start_samples, stop=end_samples) | |
if len(audio) == 0: # 処理対象のデータがなくなった場合 | |
return | |
audio = resample_array(audio, sample_rate, target_sample_rate) | |
audio = audio.astype(np.float32) | |
#print(start, end, len(audio), audio[:5]) # デバッグ用 | |
yield start, end, audio |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment