SpeechLM2#

Note

The SpeechLM2 collection is still in active development and the code is likely to keep changing.

SpeechLM2 refers to a collection that augments pre-trained Large Language Models (LLMs) with speech understanding and generation capabilities.

This collection is designed to be compact, efficient, and to support easy swapping of different LLMs backed by HuggingFace AutoModel. It has a first-class support for using dynamic batch sizes via Lhotse and various model parallelism techniques (e.g., FSDP2, Tensor Parallel, Sequence Parallel) via PyTorch DTensor API.

We currently support three main model types: * SALM (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities. * DuplexS2SModel - a full-duplex speech-to-speech model with an ASR encoder, directly predicting discrete audio codes. * DuplexS2SSpeechDecoderModel - a variant of DuplexS2SModel with a separate transformer decoder for speech generation.

Using Pretrained Models#

After installing NeMo, you can load and use a pretrained speechlm2 model as follows:

import nemo.collections.speechlm2 as slm

# Load a pretrained SALM model
model = slm.models.SALM.from_pretrained("model_name_or_path")

# Set model to evaluation mode
model = model.eval()

Inference with Pretrained Models#

SALM#

You can run inference using the loaded pretrained SALM model:

import torch
import torchaudio
import nemo.collections.speechlm2 as slm

model = slm.models.SALM.from_pretrained("path/to/pretrained_checkpoint").eval()

# Load audio file
audio_path = "path/to/audio.wav"
audio_signal, sample_rate = torchaudio.load(audio_path)

# Resample if needed
if sample_rate != 16000:  # Most models expect 16kHz audio
    audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000)
    sample_rate = 16000

# Prepare audio for model
audio_signal = audio_signal.to(model.device)
audio_len = torch.tensor([audio_signal.shape[1]], device=model.device)

# Create a prompt for SALM model inference
# The audio_locator_tag is a special token that will be replaced with audio embeddings
prompt = [{"role": "user", "content": f"{model.audio_locator_tag}"}]

# Generate response
with torch.no_grad():
    output = model.generate(
        prompts=[prompt],
        audios=audio_signal,
        audio_lens=audio_len,
        generation_config=None  # You can customize generation parameters here
    )

# Process the output tokens
response = model.tokenizer.ids_to_text(output[0])
print(f"Model response: {response}")

DuplexS2SModel#

You can run inference using the loaded pretrained DuplexS2SModel:

import torch
import torchaudio
import nemo.collections.speechlm2 as slm

model = slm.models.DuplexS2SModel.from_pretrained("path/to/pretrained_checkpoint").eval()

# Load audio file
audio_path = "path/to/audio.wav"
audio_signal, sample_rate = torchaudio.load(audio_path)

# Resample if needed
if sample_rate != 16000:  # Most models expect 16kHz audio
    audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000)
    sample_rate = 16000

# Prepare audio for model
audio_signal = audio_signal.to(model.device)
audio_len = torch.tensor([audio_signal.shape[1]], device=model.device)

# Run offline inference
results = model.offline_inference(
    input_signal=audio_signal,
    input_signal_lens=audio_len
)

# Decode text and audio tokens
transcription = results["text"][0]
audio = results["audio"][0]

Training a Model#

This example demonstrates how to train a SALM model. The remaining models can be trained in a similar manner.

from omegaconf import OmegaConf
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import ModelParallelStrategy

import nemo.collections.speechlm2 as slm
from nemo.collections.speechlm2.data import SALMDataset, DataModule
from nemo.utils.exp_manager import exp_manager

# Load configuration
config_path = "path/to/config.yaml"  # E.g., from examples/speechlm2/conf/salm.yaml
cfg = OmegaConf.load(config_path)

# Initialize PyTorch Lightning trainer
trainer = Trainer(
    max_steps=100000,
    accelerator="gpu",
    devices=1,
    precision="bf16-true",
    strategy=ModelParallelStrategy(data_parallel_size=2, tensor_parallel_size=1),
    limit_train_batches=1000,
    val_check_interval=1000,
    use_distributed_sampler=False,
    logger=False,
    enable_checkpointing=False,
)

# Set up experiment manager for logging
exp_manager(trainer, cfg.get("exp_manager", None))

# Initialize model with configuration
model = slm.models.SALM(OmegaConf.to_container(cfg.model, resolve=True))

# Create dataset and datamodule
dataset = SALMDataset(tokenizer=model.tokenizer)
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset)

# Train the model
trainer.fit(model, datamodule)

Example Using Command-Line Training Script#

Alternatively, you can train a model using the provided training scripts in the examples directory:

# Train a SALM model
python examples/speechlm2/salm_train.py \
  --config-path=examples/speechlm2/conf \
  --config-name=salm

# For inference/evaluation
python examples/speechlm2/salm_eval.py \
  pretrained_name=/path/to/checkpoint \
  inputs=/path/to/test_manifest \
  batch_size=64 \
  max_new_tokens=128 \
  output_manifest=generations.jsonl

For more detailed information on training at scale, model parallelism, and SLURM-based training, see training and scaling.

Collection Structure#

The speechlm2 collection is organized into the following key components:

  • Models: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, and SALM

  • Modules: Contains audio perception and speech generation modules

  • Data: Includes dataset classes and data loading utilities

SpeechLM2 Documentation#

For more information, see additional sections in the SpeechLM2 docs: