Skip to content

Instantly share code, notes, and snippets.

@istupakov
Last active March 8, 2025 13:43
Show Gist options
  • Save istupakov/717e55862f8629f5561ed7137c0dbbfc to your computer and use it in GitHub Desktop.
Save istupakov/717e55862f8629f5561ed7137c0dbbfc to your computer and use it in GitHub Desktop.
GigaAM CTC v2 export to onnx (with metadata for sherpa_onnx) and inference examples
import onnx
import gigaam
from gigaam.onnx_utils import VOCAB
onnx_dir = "gigaam-onnx"
model_type = "v2_ctc"
model = gigaam.load_model(
model_type,
fp16_encoder=False, # only fp32 tensors
use_flash=False, # disable flash attention
)
model.to_onnx(dir_path=onnx_dir)
# The next part creates tokens.txt and adds metadata needed only for sherpa_onnx.
with open(f"{onnx_dir}/tokens.txt", "w") as f:
f.writelines(f"{token} {i}\n" for i, token in enumerate(VOCAB + ["<blk>"]))
filename = f"{onnx_dir}/{model_type}.onnx"
model = onnx.load(filename)
meta_data = {
"model_type": "EncDecCTCModel",
"vocab_size": len(VOCAB) + 1,
"normalize_type": "",
"subsampling_factor": 4,
"is_giga_am": 1,
"model_name": f"GigaAM {model_type}",
"model_author": "GigaChat Team",
"model_license": "MIT License",
"language": "Russian",
"url": "https://github.com/salute-developers/GigaAM",
}
for key, value in meta_data.items():
meta = model.metadata_props.add(key=key, value=str(value))
onnx.save(model, filename)
from onnxruntime.quantization import QuantType, quantize_dynamic
onnx_dir = "gigaam-onnx"
model_type = "v2_ctc"
filename = f"{onnx_dir}/{model_type}.onnx"
filename_int8 = f"{onnx_dir}/{model_type}.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)
import numpy as np
import onnxruntime as rt
import soundfile as sf
gigaam_features = rt.InferenceSession("gigaam-onnx/features.onnx")
gigaam_encoder = rt.InferenceSession("gigaam-onnx/v2_ctc.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
def transcribe(waveforms):
(features,) = gigaam_features.run(["features"], {"waveforms": waveforms})
(log_probs,) = gigaam_encoder.run(["log_probs"], {"features": features, "feature_lengths": [features.shape[-1]] * features.shape[0]})
tokens = [" "] + [chr(ord("а") + i) for i in range(32)]
blank_token_idx = 33
results = []
for indices in np.argmax(log_probs, axis=-1):
indices = indices[np.diff(indices).nonzero()]
indices = indices[indices != blank_token_idx]
results.append("".join([tokens[i] for i in indices]))
return results
waveforms, sample_rate = sf.read("sample.wav", always_2d=True, dtype=np.float32)
assert sample_rate == 16000
print(transcribe(waveforms.T))
import sherpa_onnx
import soundfile as sf
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc("gigaam-onnx/v2_ctc.onnx", "gigaam-onnx/tokens.txt")
waveform, sample_rate = sf.read("sample.wav")
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, waveform)
recognizer.decode_stream(stream)
print(stream.result.text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment