diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt new file mode 100644 index 00000000000..950686dd00a --- /dev/null +++ b/examples/models/parakeet/CMakeLists.txt @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(parakeet_runner) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +if(CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") + set(CMAKE_TOOLCHAIN_IOS ON) +else() + set(CMAKE_TOOLCHAIN_IOS OFF) +endif() + +# Let files say "include " +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# Need this for gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# Find executorch libraries +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# Common ops for all builds +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# CPU-only builds need quantized and custom ops +if(NOT EXECUTORCH_BUILD_CUDA AND MSVC) + list(APPEND link_libraries quantized_ops_lib custom_ops) + executorch_target_link_options_shared_lib(quantized_ops_lib) + executorch_target_link_options_shared_lib(custom_ops) +endif() + +# XNNPACK +if(TARGET xnnpack_backend) + set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) + if(TARGET kleidiai) + list(APPEND xnnpack_backend_libs kleidiai) + endif() + list(APPEND link_libraries ${xnnpack_backend_libs}) + executorch_target_link_options_shared_lib(xnnpack_backend) +endif() + +# Needed for cpuinfo where it uses android specific log lib +if(ANDROID) + list(APPEND link_libraries log) +endif() + +# Add the required ExecuTorch extensions +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor + tokenizers::tokenizers +) + +# Link CUDA backend +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + endif() +endif() + +if(EXECUTORCH_BUILD_METAL) + list(APPEND link_libraries metal_backend) + executorch_target_link_options_shared_lib(metal_backend) +endif() + +add_executable(parakeet_runner main.cpp) +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(parakeet_runner) + if(NOT APPLE AND NOT MSVC) + target_link_options(parakeet_runner PRIVATE "LINKER:-s") + endif() +endif() + +target_include_directories( + parakeet_runner PUBLIC ${_common_include_directories} +) +target_link_libraries(parakeet_runner PUBLIC ${link_libraries}) +target_compile_options(parakeet_runner PUBLIC ${_common_compile_options}) + +# On Windows, copy required DLLs to the executable directory +if(MSVC AND EXECUTORCH_BUILD_CUDA) + add_custom_command( + TARGET parakeet_runner + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ + $ + COMMENT "Copying aoti_cuda_shims.dll to parakeet_runner directory" + ) +endif() diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json new file mode 100644 index 00000000000..ea93d257ba7 --- /dev/null +++ b/examples/models/parakeet/CMakePresets.json @@ -0,0 +1,110 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "parakeet-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/parakeet", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "parakeet-cpu", + "displayName": "Parakeet runner (CPU)", + "inherits": ["parakeet-base"] + }, + { + "name": "parakeet-cuda", + "displayName": "Parakeet runner (CUDA)", + "inherits": ["parakeet-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, + { + "name": "parakeet-metal", + "displayName": "Parakeet runner (Metal)", + "inherits": ["parakeet-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + } + ], + "buildPresets": [ + { + "name": "parakeet-cpu", + "displayName": "Build Parakeet runner (CPU)", + "configurePreset": "parakeet-cpu", + "targets": ["parakeet_runner"] + }, + { + "name": "parakeet-cuda", + "displayName": "Build Parakeet runner (CUDA)", + "configurePreset": "parakeet-cuda", + "targets": ["parakeet_runner"] + }, + { + "name": "parakeet-metal", + "displayName": "Build Parakeet runner (Metal)", + "configurePreset": "parakeet-metal", + "targets": ["parakeet_runner"] + } + ], + "workflowPresets": [ + { + "name": "parakeet-cpu", + "displayName": "Configure and build Parakeet runner (CPU)", + "steps": [ + { + "type": "configure", + "name": "parakeet-cpu" + }, + { + "type": "build", + "name": "parakeet-cpu" + } + ] + }, + { + "name": "parakeet-cuda", + "displayName": "Configure and build Parakeet runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "parakeet-cuda" + }, + { + "type": "build", + "name": "parakeet-cuda" + } + ] + }, + { + "name": "parakeet-metal", + "displayName": "Configure and build Parakeet runner (Metal)", + "steps": [ + { + "type": "configure", + "name": "parakeet-metal" + }, + { + "type": "build", + "name": "parakeet-metal" + } + ] + } + ] +} diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md new file mode 100644 index 00000000000..045f22571fd --- /dev/null +++ b/examples/models/parakeet/README.md @@ -0,0 +1,73 @@ +# Parakeet TDT Export for ExecuTorch + +Export [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) speech recognition model to ExecuTorch. + +## Installation + +```bash +pip install nemo_toolkit[asr] torchaudio +``` + +## Export + +Export the model: +```bash +python export_parakeet_tdt.py +``` + +Test transcription on an audio file and compare eager vs lowered results: +```bash +python export_parakeet_tdt.py --audio /path/to/audio.wav +``` + +### Export Arguments + +| Argument | Description | +|----------|-------------| +| `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | +| `--audio` | Path to audio file for transcription test | + +**Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting. + +## C++ Runner + +### Building + +First, build ExecuTorch with the LLM preset from the executorch root directory: + +```bash +cmake --workflow --preset llm-release +``` + +Then build the parakeet runner: + +```bash +cd examples/models/parakeet +cmake --workflow --preset parakeet-cpu +``` + +Available presets: +- `parakeet-cpu` - CPU-only build +- `parakeet-cuda` - CUDA acceleration (Linux/Windows) +- `parakeet-metal` - Metal acceleration (macOS) + +### Running + +From the executorch root directory: + +```bash +./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path examples/models/parakeet/parakeet_tdt_exports/parakeet_tdt.pte \ + --audio_path /path/to/audio.wav \ + --tokenizer_path examples/models/parakeet/tokenizer.model +``` + +### Runner Arguments + +| Argument | Description | +|----------|-------------| +| `--model_path` | Path to Parakeet model (.pte) | +| `--audio_path` | Path to input audio file (.wav) | +| `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | +| `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py new file mode 100644 index 00000000000..509da67051c --- /dev/null +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -0,0 +1,467 @@ +"""Export nvidia/parakeet-tdt-0.6b-v3 components to ExecuTorch.""" + +import os + +import torch +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass +from torch.export import Dim, export + + +def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: + """Load audio file and resample to target sample rate.""" + try: + import torchaudio + + waveform, sr = torchaudio.load(audio_path) + except Exception: + from scipy.io import wavfile + + sr, data = wavfile.read(audio_path) + if data.dtype == "int16": + data = data.astype("float32") / 32768.0 + elif data.dtype == "int32": + data = data.astype("float32") / 2147483648.0 + waveform = torch.from_numpy(data).unsqueeze(0) + + if waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + if sr != sample_rate: + try: + import torchaudio + + resampler = torchaudio.transforms.Resample(sr, sample_rate) + waveform = resampler(waveform) + except ImportError: + from scipy import signal + + num_samples = int(len(waveform[0]) * sample_rate / sr) + resampled = signal.resample(waveform[0].numpy(), num_samples) + waveform = torch.from_numpy(resampled).unsqueeze(0).float() + + return waveform + + +def greedy_decode_eager( + encoder_output: torch.Tensor, encoder_len: torch.Tensor, model +) -> list[int]: + hypotheses = model.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoder_output, + encoded_lengths=encoder_len, + return_hypotheses=True, + ) + return hypotheses[0].y_sequence + + +class DecoderPredict(torch.nn.Module): + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + self.pred_hidden = decoder.pred_hidden + self.pred_rnn_layers = getattr(decoder, "pred_rnn_layers", 2) + + def forward( + self, token: torch.Tensor, h: torch.Tensor, c: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + g, new_state = self.decoder.predict(y=token, state=[h, c], add_sos=False) + return g, new_state[0], new_state[1] + + +def greedy_decode_executorch( + encoder_output: torch.Tensor, + encoder_len: int, + program, + blank_id: int, + vocab_size: int, + num_rnn_layers: int = 2, + pred_hidden: int = 640, + max_symbols_per_step: int = 10, + durations: list[int] | None = None, +) -> list[int]: + if durations is None: + durations = [0, 1, 2, 3, 4] + + hypothesis = [] + num_token_classes = vocab_size + 1 + + encoder_output = encoder_output.transpose(1, 2) + + proj_enc_method = program.load_method("joint_project_encoder") + f_proj = proj_enc_method.execute([encoder_output.contiguous()])[0] + + decoder_predict_method = program.load_method("decoder_predict") + proj_dec_method = program.load_method("joint_project_decoder") + joint_method = program.load_method("joint") + + h = torch.zeros(num_rnn_layers, 1, pred_hidden) + c = torch.zeros(num_rnn_layers, 1, pred_hidden) + + sos_g = torch.zeros(1, 1, pred_hidden) + g_proj = proj_dec_method.execute([sos_g])[0] + + t = 0 + symbols_on_frame = 0 + + while t < encoder_len: + f_t = f_proj[:, t : t + 1, :].contiguous() + + joint_out = joint_method.execute([f_t, g_proj]) + + full_logits = joint_out[0].squeeze() + token_logits = full_logits[:num_token_classes] + duration_logits = full_logits[num_token_classes:] + + k = token_logits.argmax().item() + dur_idx = duration_logits.argmax().item() + dur = durations[dur_idx] + + if k == blank_id: + t += max(dur, 1) + symbols_on_frame = 0 + else: + hypothesis.append(k) + + token = torch.tensor([[k]], dtype=torch.long) + result = decoder_predict_method.execute([token, h, c]) + g = result[0] + h = result[1] + c = result[2] + + g_proj = proj_dec_method.execute([g])[0] + t += dur + + if dur == 0: + symbols_on_frame += 1 + if symbols_on_frame >= max_symbols_per_step: + t += 1 + symbols_on_frame = 0 + else: + symbols_on_frame = 0 + + return hypothesis + + +def transcribe_executorch(audio_path: str, model, et_buffer) -> str: + from executorch.runtime import Runtime + + runtime = Runtime.get() + program = runtime.load_program(et_buffer) + + # Get sample rate from model + sample_rate = model.preprocessor._cfg.sample_rate + + with torch.no_grad(): + audio = load_audio(audio_path, sample_rate=sample_rate) + preprocessor_method = program.load_method("preprocessor") + audio_1d = audio.squeeze(0) + audio_len = torch.tensor([audio_1d.shape[0]], dtype=torch.int64) + proc_result = preprocessor_method.execute([audio_1d, audio_len]) + mel = proc_result[0] + mel_len = proc_result[1].item() + + encoder_method = program.load_method("encoder") + mel_len_tensor = torch.tensor([mel_len], dtype=torch.int64) + enc_result = encoder_method.execute([mel, mel_len_tensor]) + encoded = enc_result[0] + encoded_len = enc_result[1].item() + + vocab_size = model.tokenizer.vocab_size + tokens = greedy_decode_executorch( + encoded, + encoded_len, + program, + blank_id=vocab_size, + vocab_size=vocab_size, + num_rnn_layers=model.decoder.pred_rnn_layers, + pred_hidden=model.decoder.pred_hidden, + ) + + return model.tokenizer.ids_to_text(tokens) + + +def transcribe_eager(audio_path: str, model) -> str: + with torch.no_grad(): + audio = load_audio(audio_path) + mel, mel_len = model.preprocessor( + input_signal=audio, length=torch.tensor([audio.shape[1]]) + ) + encoded, encoded_len = model.encoder(audio_signal=mel, length=mel_len) + tokens = greedy_decode_eager(encoded, encoded_len, model) + return model.tokenizer.ids_to_text(tokens) + + +def load_model(): + import nemo.collections.asr as nemo_asr + + model = nemo_asr.models.ASRModel.from_pretrained( + "nvidia/parakeet-tdt-0.6b-v3", map_location="cpu" + ) + model.eval() + model.cpu() + return model + + +class JointAfterProjection(torch.nn.Module): + def __init__(self, joint): + super().__init__() + self.joint = joint + + def forward(self, f, g): + return self.joint.joint_after_projection(f, g) + + +class JointProjectEncoder(torch.nn.Module): + def __init__(self, joint): + super().__init__() + self.joint = joint + + def forward(self, f): + return self.joint.project_encoder(f) + + +class JointProjectDecoder(torch.nn.Module): + def __init__(self, joint): + super().__init__() + self.joint = joint + + def forward(self, g): + return self.joint.project_prednet(g) + + +class PreprocessorWrapper(torch.nn.Module): + def __init__(self, preprocessor): + super().__init__() + self.preprocessor = preprocessor + + def forward( + self, audio: torch.Tensor, length: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + audio_signal = audio.unsqueeze(0) + mel, mel_len = self.preprocessor(input_signal=audio_signal, length=length) + return mel, mel_len + + +def export_all(model): + programs = {} + + preprocessor_wrapper = PreprocessorWrapper(model.preprocessor) + preprocessor_wrapper.eval() + sample_audio = torch.randn(16000 * 10) + sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64) + # The preprocessor definition changes if cuda is available (likely due to making it cuda graphable). + # Unfortunately that new definition is not supported by export, so we need to stop that from happening. + old_cuda_is_available = torch.cuda.is_available + torch.cuda.is_available = lambda: False + programs["preprocessor"] = export( + preprocessor_wrapper, + (sample_audio, sample_length), + dynamic_shapes={ + "audio": {0: Dim("audio_len", min=1600, max=16000 * 600)}, + "length": {}, + }, + strict=False, + ) + torch.cuda.is_available = old_cuda_is_available + + feat_in = getattr(model.encoder, "_feat_in", 128) + audio_signal = torch.randn(1, feat_in, 100) + length = torch.tensor([100], dtype=torch.int64) + programs["encoder"] = export( + model.encoder, + (), + kwargs={"audio_signal": audio_signal, "length": length}, + dynamic_shapes={"audio_signal": {2: Dim.AUTO}, "length": {}}, + strict=False, + ) + + decoder_predict = DecoderPredict(model.decoder) + decoder_predict.eval() + token = torch.tensor([[0]], dtype=torch.long) + num_layers = model.decoder.pred_rnn_layers + pred_hidden = model.decoder.pred_hidden + h = torch.zeros(num_layers, 1, pred_hidden) + c = torch.zeros(num_layers, 1, pred_hidden) + programs["decoder_predict"] = export( + decoder_predict, + (token, h, c), + dynamic_shapes={"token": {}, "h": {}, "c": {}}, + strict=False, + ) + + joint_hidden = model.joint.joint_hidden + + f_proj = torch.randn(1, 1, joint_hidden) + g_proj = torch.randn(1, 1, joint_hidden) + programs["joint"] = export( + JointAfterProjection(model.joint), + (f_proj, g_proj), + dynamic_shapes={"f": {}, "g": {}}, + strict=False, + ) + + enc_output_dim = getattr(model.encoder, "_feat_out", 1024) + + programs["joint_project_encoder"] = export( + JointProjectEncoder(model.joint), + (torch.randn(1, 25, enc_output_dim),), + dynamic_shapes={"f": {1: Dim("enc_time", min=1, max=60000)}}, + strict=False, + ) + + programs["joint_project_decoder"] = export( + JointProjectDecoder(model.joint), + (torch.randn(1, 1, pred_hidden),), + dynamic_shapes={"g": {}}, + strict=False, + ) + + sample_rate = model.preprocessor._cfg.sample_rate + metadata = { + "num_rnn_layers": num_layers, + "pred_hidden": pred_hidden, + "joint_hidden": joint_hidden, + "vocab_size": model.tokenizer.vocab_size, + "blank_id": model.tokenizer.vocab_size, + "sample_rate": sample_rate, + } + + return programs, metadata + + +def lower_to_executorch(programs, metadata=None, backend="portable"): + partitioner = {} + + if backend == "xnnpack": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) + + print("\nLowering to ExecuTorch with XNNPACK...") + for key in programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + partitioner[key] = [XnnpackPartitioner()] + + elif backend in ("cuda", "cuda-windows"): + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir.backend.compile_spec_schema import CompileSpec + from torch._inductor.decomposition import conv1d_to_conv2d + + print( + f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}..." + ) + + for key, ep in programs.items(): + if key != "preprocessor": + programs[key] = ep.run_decompositions( + {torch.ops.aten.conv1d.default: conv1d_to_conv2d} + ) + + for key in programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + if backend == "cuda-windows": + compile_specs.append( + CompileSpec("platform", "windows".encode("utf-8")) + ) + partitioner[key] = [CudaPartitioner(compile_specs)] + + else: + print("\nLowering to ExecuTorch...") + partitioner = [] + + constant_methods = {} + if metadata: + for key, value in metadata.items(): + constant_methods[key] = value + + et_prog = to_edge_transform_and_lower( + programs, + partitioner=partitioner, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=constant_methods if constant_methods else None, + ) + return et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + +def main(): + import argparse + import sys + + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", default="./parakeet_tdt_exports") + parser.add_argument( + "--audio", type=str, help="Path to audio file for transcription test" + ) + parser.add_argument( + "--backend", + type=str, + default="portable", + choices=["portable", "xnnpack", "cuda", "cuda-windows"], + help="Backend for acceleration (default: portable)", + ) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + print("Loading model...") + model = load_model() + + print("\nExporting components...") + programs, metadata = export_all(model) + + et = lower_to_executorch(programs, metadata=metadata, backend=args.backend) + + pte_path = os.path.join(args.output_dir, "parakeet_tdt.pte") + print(f"\nSaving ExecuTorch program to: {pte_path}") + with open(pte_path, "wb") as f: + et.write_to_file(f) + print(f"Saved {os.path.getsize(pte_path) / (1024 * 1024):.1f} MB") + + # Save .ptd data files (e.g., CUDA delegate data) + if et._tensor_data: + print(f"\nSaving {len(et._tensor_data)} data file(s)...") + et.write_tensor_data_to_file(args.output_dir) + + if args.audio: + print("\n" + "=" * 60) + print("Testing transcription...") + print("=" * 60) + + print("\n[Eager PyTorch]") + eager_text = transcribe_eager(args.audio, model) + print(f" Result: {eager_text}") + + print("\n[ExecuTorch Runtime]") + et_text = transcribe_executorch(args.audio, model, et.buffer) + print(f" Result: {et_text}") + + if eager_text == et_text: + print("\n✓ Transcriptions match!") + else: + print("\n✗ Transcriptions differ!") + print(f" Eager: {eager_text}") + print(f" ET: {et_text}") + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp new file mode 100644 index 00000000000..173cf722d77 --- /dev/null +++ b/examples/models/parakeet/main.cpp @@ -0,0 +1,417 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); +DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); +DEFINE_string( + tokenizer_path, + "tokenizer.json", + "Path to SentencePiece tokenizer model file."); +DEFINE_string( + data_path, + "", + "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +namespace { + +// TDT duration values +const std::vector DURATIONS = {0, 1, 2, 3, 4}; + +std::vector greedy_decode_executorch( + Module& model, + const ::executorch::aten::Tensor& encoder_output, + int64_t encoder_len, + int64_t blank_id, + int64_t vocab_size, + int64_t num_rnn_layers = 2, + int64_t pred_hidden = 640, + int64_t max_symbols_per_step = 10) { + std::vector hypothesis; + int64_t num_token_classes = vocab_size + 1; + + // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] + auto enc_sizes = encoder_output.sizes(); + int64_t batch = enc_sizes[0]; + int64_t enc_dim = enc_sizes[1]; + int64_t time_steps = enc_sizes[2]; + + // Create transposed tensor + std::vector transposed_data(batch * time_steps * enc_dim); + const float* src = encoder_output.const_data_ptr(); + for (int64_t t = 0; t < time_steps; t++) { + for (int64_t d = 0; d < enc_dim; d++) { + transposed_data[t * enc_dim + d] = src[d * time_steps + t]; + } + } + + auto transposed_tensor = from_blob( + transposed_data.data(), + {static_cast<::executorch::aten::SizesType>(batch), + static_cast<::executorch::aten::SizesType>(time_steps), + static_cast<::executorch::aten::SizesType>(enc_dim)}, + ::executorch::aten::ScalarType::Float); + + // Project encoder output + auto proj_enc_result = model.execute( + "joint_project_encoder", + std::vector<::executorch::runtime::EValue>{transposed_tensor}); + if (!proj_enc_result.ok()) { + ET_LOG(Error, "joint_project_encoder failed"); + return hypothesis; + } + auto f_proj = proj_enc_result.get()[0].toTensor(); + + // Initialize LSTM state + std::vector h_data(num_rnn_layers * 1 * pred_hidden, 0.0f); + std::vector c_data(num_rnn_layers * 1 * pred_hidden, 0.0f); + + auto h = from_blob( + h_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + auto c = from_blob( + c_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + + // Initialize decoder state with zeros + std::vector sos_g_data(1 * 1 * pred_hidden, 0.0f); + auto sos_g = from_blob( + sos_g_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + + auto g_proj_result = model.execute( + "joint_project_decoder", + std::vector<::executorch::runtime::EValue>{sos_g}); + if (!g_proj_result.ok()) { + ET_LOG(Error, "joint_project_decoder failed"); + return hypothesis; + } + auto g_proj_tensor = g_proj_result.get()[0].toTensor(); + + // Copy g_proj data for reuse + std::vector g_proj_data( + g_proj_tensor.const_data_ptr(), + g_proj_tensor.const_data_ptr() + g_proj_tensor.numel()); + + int64_t t = 0; + int64_t symbols_on_frame = 0; + + // Scan over encoder output + while (t < encoder_len) { + // Get encoder frame at time t: f_proj[:, t:t+1, :] + const float* f_proj_data = f_proj.const_data_ptr(); + int64_t proj_dim = f_proj.sizes()[2]; + + std::vector f_t_data(1 * 1 * proj_dim); + for (int64_t d = 0; d < proj_dim; d++) { + f_t_data[d] = f_proj_data[t * proj_dim + d]; + } + auto f_t = from_blob( + f_t_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, + ::executorch::aten::ScalarType::Float); + + auto g_proj = from_blob( + g_proj_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, + ::executorch::aten::ScalarType::Float); + + // Joint network + auto joint_result = model.execute( + "joint", std::vector<::executorch::runtime::EValue>{f_t, g_proj}); + if (!joint_result.ok()) { + ET_LOG(Error, "joint failed at t=%lld", static_cast(t)); + return hypothesis; + } + auto full_logits = joint_result.get()[0].toTensor(); + + // Split logits into token and duration + const float* logits_data = full_logits.const_data_ptr(); + + // Find argmax for token logits + int64_t k = 0; + float max_token_logit = logits_data[0]; + for (int64_t i = 1; i < num_token_classes; i++) { + if (logits_data[i] > max_token_logit) { + max_token_logit = logits_data[i]; + k = i; + } + } + + // Find argmax for duration logits + int64_t dur_idx = 0; + float max_dur_logit = logits_data[num_token_classes]; + for (size_t i = 1; i < DURATIONS.size(); i++) { + if (logits_data[num_token_classes + i] > max_dur_logit) { + max_dur_logit = logits_data[num_token_classes + i]; + dur_idx = i; + } + } + int64_t dur = DURATIONS[dur_idx]; + + if (k == blank_id) { + t += std::max(dur, (int64_t)1); + symbols_on_frame = 0; + } else { + hypothesis.push_back(k); + + // Update decoder state + std::vector token_data = {k}; + auto token = from_blob( + token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); + + auto decoder_result = model.execute( + "decoder_predict", + std::vector<::executorch::runtime::EValue>{token, h, c}); + if (!decoder_result.ok()) { + ET_LOG(Error, "decoder_predict failed"); + return hypothesis; + } + auto& outputs = decoder_result.get(); + auto g = outputs[0].toTensor(); + auto new_h = outputs[1].toTensor(); + auto new_c = outputs[2].toTensor(); + + // Update h and c + std::memcpy( + h_data.data(), + new_h.const_data_ptr(), + h_data.size() * sizeof(float)); + std::memcpy( + c_data.data(), + new_c.const_data_ptr(), + c_data.size() * sizeof(float)); + + // Project decoder output + auto proj_dec_result = model.execute( + "joint_project_decoder", + std::vector<::executorch::runtime::EValue>{g}); + if (!proj_dec_result.ok()) { + ET_LOG(Error, "joint_project_decoder failed"); + return hypothesis; + } + auto new_g_proj = proj_dec_result.get()[0].toTensor(); + std::memcpy( + g_proj_data.data(), + new_g_proj.const_data_ptr(), + g_proj_data.size() * sizeof(float)); + + t += dur; + + if (dur == 0) { + symbols_on_frame++; + if (symbols_on_frame >= max_symbols_per_step) { + t++; + symbols_on_frame = 0; + } + } else { + symbols_on_frame = 0; + } + } + } + + return hypothesis; +} + +std::string tokens_to_text( + const std::vector& tokens, + tokenizers::Tokenizer* tokenizer) { + // Decode tokens to text one by one + std::string result; + uint64_t prev_token = 0; + for (size_t i = 0; i < tokens.size(); i++) { + uint64_t token = static_cast(tokens[i]); + auto decode_result = tokenizer->decode(prev_token, token); + if (decode_result.ok()) { + result += decode_result.get(); + } + prev_token = token; + } + + return result; +} + +} // namespace + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_audio_path.empty()) { + ET_LOG(Error, "audio_path flag must be provided."); + return 1; + } + + // Load model (which includes the bundled preprocessor) + ET_LOG(Info, "Loading model from: %s", FLAGS_model_path.c_str()); + std::unique_ptr model; + if (!FLAGS_data_path.empty()) { + ET_LOG(Info, "Loading data from: %s", FLAGS_data_path.c_str()); + model = std::make_unique( + FLAGS_model_path, FLAGS_data_path, Module::LoadMode::Mmap); + } else { + model = std::make_unique(FLAGS_model_path, Module::LoadMode::Mmap); + } + auto model_load_error = model->load(); + if (model_load_error != Error::Ok) { + ET_LOG(Error, "Failed to load model."); + return 1; + } + + // Load audio + ET_LOG(Info, "Loading audio from: %s", FLAGS_audio_path.c_str()); + std::vector audio_data = + ::executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path); + ET_LOG(Info, "Loaded %zu audio samples", audio_data.size()); + + auto audio_tensor = from_blob( + audio_data.data(), + {static_cast<::executorch::aten::SizesType>(audio_data.size())}, + ::executorch::aten::ScalarType::Float); + std::vector audio_len_data = { + static_cast(audio_data.size())}; + auto audio_len_tensor = from_blob( + audio_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); + + ET_LOG(Info, "Running preprocessor..."); + auto proc_result = model->execute( + "preprocessor", + std::vector<::executorch::runtime::EValue>{ + audio_tensor, audio_len_tensor}); + if (!proc_result.ok()) { + ET_LOG(Error, "Preprocessor forward failed."); + return 1; + } + auto& proc_outputs = proc_result.get(); + auto mel = proc_outputs[0].toTensor(); + auto mel_len_tensor_out = proc_outputs[1].toTensor(); + int64_t mel_len_value = mel_len_tensor_out.const_data_ptr()[0]; + + // Create mel_len tensor for encoder + std::vector mel_len_data = {mel_len_value}; + auto mel_len = + from_blob(mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); + + ET_LOG( + Info, + "Mel spectrogram shape: [%ld, %ld, %ld], mel_len: %lld", + static_cast(mel.sizes()[0]), + static_cast(mel.sizes()[1]), + static_cast(mel.sizes()[2]), + static_cast(mel_len_value)); + + // Run encoder + ET_LOG(Info, "Running encoder..."); + auto enc_result = model->execute( + "encoder", std::vector<::executorch::runtime::EValue>{mel, mel_len}); + if (!enc_result.ok()) { + ET_LOG(Error, "Encoder forward failed."); + return 1; + } + auto& enc_outputs = enc_result.get(); + auto encoded = enc_outputs[0].toTensor(); + int64_t encoded_len = enc_outputs[1].toTensor().const_data_ptr()[0]; + + ET_LOG( + Info, + "Encoder output shape: [%ld, %ld, %ld], len=%ld", + static_cast(encoded.sizes()[0]), + static_cast(encoded.sizes()[1]), + static_cast(encoded.sizes()[2]), + static_cast(encoded_len)); + + // Query model metadata from constant_methods + std::vector<::executorch::runtime::EValue> empty_inputs; + auto num_rnn_layers_result = model->execute("num_rnn_layers", empty_inputs); + auto pred_hidden_result = model->execute("pred_hidden", empty_inputs); + auto vocab_size_result = model->execute("vocab_size", empty_inputs); + auto blank_id_result = model->execute("blank_id", empty_inputs); + auto sample_rate_result = model->execute("sample_rate", empty_inputs); + + if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || + !vocab_size_result.ok() || !blank_id_result.ok() || + !sample_rate_result.ok()) { + ET_LOG( + Error, + "Failed to query model metadata. Make sure the model was exported with constant_methods."); + return 1; + } + + int64_t vocab_size = vocab_size_result.get()[0].toInt(); + int64_t blank_id = blank_id_result.get()[0].toInt(); + int64_t num_rnn_layers = num_rnn_layers_result.get()[0].toInt(); + int64_t pred_hidden = pred_hidden_result.get()[0].toInt(); + int64_t sample_rate = sample_rate_result.get()[0].toInt(); + + ET_LOG( + Info, + "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld", + static_cast(vocab_size), + static_cast(blank_id), + static_cast(num_rnn_layers), + static_cast(pred_hidden), + static_cast(sample_rate)); + + ET_LOG(Info, "Running TDT greedy decode..."); + auto tokens = greedy_decode_executorch( + *model, + encoded, + encoded_len, + blank_id, + vocab_size, + num_rnn_layers, + pred_hidden); + + ET_LOG(Info, "Decoded %zu tokens", tokens.size()); + + // Load tokenizer + ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); + auto tokenizer = + ::executorch::extension::llm::load_tokenizer(FLAGS_tokenizer_path); + if (!tokenizer || !tokenizer->is_loaded()) { + ET_LOG( + Error, + "Failed to load tokenizer from: %s", + FLAGS_tokenizer_path.c_str()); + return 1; + } + + // Convert tokens to text + std::string text = tokens_to_text(tokens, tokenizer.get()); + std::cout << "Transcription tokens: " << text << std::endl; + + ET_LOG(Info, "Done!"); + return 0; +}