Real-Time Audio Streaming Advanced
Getting audio from a user's mouth to your AI and back in under 500ms requires careful engineering at every layer. This lesson covers WebSocket audio streaming, buffering strategies that prevent glitches without adding latency, end-to-end optimization techniques, scaling to thousands of concurrent calls, and edge deployment patterns that shave off critical milliseconds.
WebSocket Audio Streaming Server
import asyncio
import json
import struct
import time
from typing import Dict, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dataclasses import dataclass, field
from collections import deque
app = FastAPI()
@dataclass
class AudioSession:
"""Represents one active audio streaming session."""
session_id: str
sample_rate: int = 16000
channels: int = 1
encoding: str = "pcm_s16le"
created_at: float = field(default_factory=time.time)
# Buffers
input_buffer: deque = field(default_factory=lambda: deque(maxlen=100))
output_buffer: deque = field(default_factory=lambda: deque(maxlen=100))
# Metrics
total_input_bytes: int = 0
total_output_bytes: int = 0
latency_samples: list = field(default_factory=list)
def record_latency(self, latency_ms: float):
self.latency_samples.append(latency_ms)
# Keep last 100 samples
if len(self.latency_samples) > 100:
self.latency_samples = self.latency_samples[-100:]
@property
def avg_latency_ms(self) -> float:
if not self.latency_samples:
return 0.0
return sum(self.latency_samples) / len(self.latency_samples)
@property
def p99_latency_ms(self) -> float:
if not self.latency_samples:
return 0.0
sorted_samples = sorted(self.latency_samples)
idx = int(len(sorted_samples) * 0.99)
return sorted_samples[min(idx, len(sorted_samples) - 1)]
class RealtimeAudioServer:
"""Production WebSocket server for real-time audio streaming.
Handles bidirectional audio: receives user audio, sends back TTS audio.
Optimized for minimal latency with jitter buffering.
"""
def __init__(self, voice_pipeline):
self.pipeline = voice_pipeline
self.sessions: Dict[str, AudioSession] = {}
self.max_concurrent = 1000 # Max concurrent sessions
async def handle_session(self, ws: WebSocket):
"""Handle one WebSocket audio session."""
await ws.accept()
session = None
try:
# Wait for initialization message
init_msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
session = AudioSession(
session_id=init_msg["session_id"],
sample_rate=init_msg.get("sample_rate", 16000),
encoding=init_msg.get("encoding", "pcm_s16le")
)
self.sessions[session.session_id] = session
if len(self.sessions) > self.max_concurrent:
await ws.send_json({"error": "server_at_capacity"})
return
await ws.send_json({"status": "connected", "session_id": session.session_id})
# Start bidirectional streaming
receive_task = asyncio.create_task(
self._receive_audio(ws, session)
)
send_task = asyncio.create_task(
self._send_audio(ws, session)
)
process_task = asyncio.create_task(
self._process_pipeline(session)
)
# Wait for any task to complete (usually receive on disconnect)
done, pending = await asyncio.wait(
[receive_task, send_task, process_task],
return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
except WebSocketDisconnect:
pass
except Exception as e:
print(f"Session error: {e}")
finally:
if session:
del self.sessions[session.session_id]
async def _receive_audio(self, ws: WebSocket, session: AudioSession):
"""Receive audio chunks from client."""
while True:
data = await ws.receive_bytes()
session.input_buffer.append({
"audio": data,
"received_at": time.monotonic()
})
session.total_input_bytes += len(data)
async def _send_audio(self, ws: WebSocket, session: AudioSession):
"""Send processed audio chunks back to client."""
while True:
if session.output_buffer:
chunk = session.output_buffer.popleft()
await ws.send_bytes(chunk["audio"])
# Record round-trip latency
if "request_time" in chunk:
latency = (time.monotonic() - chunk["request_time"]) * 1000
session.record_latency(latency)
else:
await asyncio.sleep(0.01) # 10ms poll interval
async def _process_pipeline(self, session: AudioSession):
"""Process audio through the voice AI pipeline."""
while True:
if session.input_buffer:
chunk = session.input_buffer.popleft()
request_time = chunk["received_at"]
# Process through ASR -> NLU -> Dialog -> TTS
response_audio = await self.pipeline.process_audio_chunk(
chunk["audio"],
session_id=session.session_id
)
if response_audio:
session.output_buffer.append({
"audio": response_audio,
"request_time": request_time
})
session.total_output_bytes += len(response_audio)
else:
await asyncio.sleep(0.005) # 5ms poll interval
# FastAPI WebSocket endpoint
server = RealtimeAudioServer(voice_pipeline=None) # inject pipeline
@app.websocket("/ws/audio")
async def audio_ws(ws: WebSocket):
await server.handle_session(ws)
Audio Buffering Strategies
import numpy as np
from collections import deque
from dataclasses import dataclass
import time
@dataclass
class BufferConfig:
"""Configuration for audio jitter buffer."""
min_buffer_ms: int = 20 # Minimum buffer before playback
max_buffer_ms: int = 200 # Maximum buffer size
target_buffer_ms: int = 60 # Target buffer level
sample_rate: int = 16000
channels: int = 1
class AdaptiveJitterBuffer:
"""Adaptive jitter buffer for smooth audio playback.
Problem: Audio chunks arrive at irregular intervals over the network.
Without buffering: choppy audio with gaps and clicks.
With too much buffering: noticeable delay.
Solution: Adaptive buffer that adjusts based on network conditions.
"""
def __init__(self, config: BufferConfig = None):
self.config = config or BufferConfig()
self.buffer = deque()
self.total_buffered_ms = 0.0
self.underrun_count = 0
self.overrun_count = 0
# Jitter tracking
self.arrival_times = deque(maxlen=50)
self.jitter_ms = 0.0
# Adaptive target
self.adaptive_target_ms = self.config.target_buffer_ms
def write(self, audio_chunk: bytes, timestamp: float = None):
"""Add audio chunk to buffer."""
now = time.monotonic()
chunk_ms = self._bytes_to_ms(len(audio_chunk))
# Track jitter
if timestamp:
self.arrival_times.append(now)
if len(self.arrival_times) >= 2:
intervals = [
self.arrival_times[i] - self.arrival_times[i-1]
for i in range(1, len(self.arrival_times))
]
avg_interval = sum(intervals) / len(intervals)
jitter_values = [abs(i - avg_interval) for i in intervals]
self.jitter_ms = (sum(jitter_values) / len(jitter_values)) * 1000
# Adapt buffer target based on jitter
self.adaptive_target_ms = max(
self.config.min_buffer_ms,
min(self.config.max_buffer_ms, self.jitter_ms * 3)
)
# Add to buffer
self.buffer.append(audio_chunk)
self.total_buffered_ms += chunk_ms
# Prevent buffer overflow
while self.total_buffered_ms > self.config.max_buffer_ms:
dropped = self.buffer.popleft()
self.total_buffered_ms -= self._bytes_to_ms(len(dropped))
self.overrun_count += 1
def read(self, duration_ms: int) -> bytes:
"""Read audio from buffer for playback.
Returns silence if buffer is empty (underrun).
"""
needed_bytes = self._ms_to_bytes(duration_ms)
if not self.buffer:
self.underrun_count += 1
# Return silence to prevent audio glitch
return b'\x00' * needed_bytes
# Wait for minimum buffer before starting playback
if self.total_buffered_ms < self.config.min_buffer_ms:
return b'\x00' * needed_bytes
# Read from buffer
result = b''
while len(result) < needed_bytes and self.buffer:
chunk = self.buffer.popleft()
result += chunk
self.total_buffered_ms -= self._bytes_to_ms(len(chunk))
# Pad with silence if not enough data
if len(result) < needed_bytes:
result += b'\x00' * (needed_bytes - len(result))
return result[:needed_bytes]
def get_stats(self) -> dict:
return {
"buffered_ms": self.total_buffered_ms,
"target_ms": self.adaptive_target_ms,
"jitter_ms": self.jitter_ms,
"underruns": self.underrun_count,
"overruns": self.overrun_count,
"chunks_queued": len(self.buffer)
}
def _bytes_to_ms(self, num_bytes: int) -> float:
samples = num_bytes / 2 # 16-bit audio = 2 bytes per sample
return (samples / self.config.sample_rate) * 1000
def _ms_to_bytes(self, ms: int) -> int:
samples = int(self.config.sample_rate * ms / 1000)
return samples * 2 # 16-bit = 2 bytes per sample
End-to-End Latency Optimization
# Latency optimization techniques for <500ms round-trip
class LatencyOptimizer:
"""Collection of techniques for minimizing voice AI latency."""
@staticmethod
def technique_1_streaming_pipeline():
"""STREAMING EVERYTHING: Don't wait for complete results at any stage.
Instead of: Audio -> [wait for full transcript] -> [wait for full response] -> [wait for full audio]
Do: Audio chunks -> streaming ASR -> incremental NLU -> streaming LLM -> streaming TTS
Savings: 2-5 seconds on a typical turn.
"""
pass
@staticmethod
def technique_2_speculative_processing():
"""Start TTS on partial LLM output.
When the LLM streams "Your account balance is $1,234.56. Is there..."
Start TTS on "Your account balance is" immediately, while LLM
continues generating the rest.
Implementation:
- Buffer LLM tokens until you have a complete sentence/clause
- Send each sentence to TTS as soon as it's complete
- Queue TTS audio chunks for seamless playback
"""
pass
@staticmethod
def technique_3_connection_pooling():
"""Keep persistent connections to all services.
Cold connection costs:
- WebSocket to ASR: 100-300ms
- HTTP to LLM API: 50-200ms
- WebSocket to TTS: 100-200ms
With connection pooling: 0ms (already connected)
Implementation: Pre-establish connections at server startup,
implement health checks, auto-reconnect on failure.
"""
pass
@staticmethod
def technique_4_audio_codec_selection():
"""Choose the right codec for your use case.
For telephony (8kHz):
- Use mulaw/alaw natively (don't convert)
- Both ASR and TTS support telephony codecs directly
For WebRTC (browser):
- Use Opus codec (best quality-to-bandwidth ratio)
- 16kHz sample rate is ideal for voice
For internal processing:
- Use raw PCM (no encode/decode overhead)
- 16kHz, 16-bit, mono
"""
pass
@staticmethod
def technique_5_prefetch_and_cache():
"""Pre-generate audio for predictable responses.
Cache TTS for:
- Greetings and goodbyes
- Error messages
- Menu prompts
- Common responses (top 20 by frequency)
At 16kHz PCM, 1 second of audio = 32KB
100 cached phrases = ~3MB in Redis
Hit rate: typically 25-40% of all utterances
"""
pass
# Latency budget breakdown for <500ms target
OPTIMIZED_BUDGET = {
"network_in": 20, # ms - user audio to server
"asr_streaming": 150, # ms - streaming ASR (first final result)
"nlu_classification": 30, # ms - fast intent classifier
"dialog_response": 100, # ms - cached/fast response generation
"tts_first_byte": 150, # ms - streaming TTS first audio chunk
"network_out": 20, # ms - server audio to user
"buffer_jitter": 30, # ms - jitter buffer
"total": 500 # ms - total target
}
# How each optimization technique helps:
# Technique 1 (Streaming): Saves 2000-5000ms
# Technique 2 (Speculative): Saves 500-1500ms
# Technique 3 (Pool): Saves 200-500ms on first call
# Technique 4 (Codec): Saves 10-50ms per conversion avoided
# Technique 5 (Cache): Saves 150-500ms when cache hits
Concurrent Call Handling
import asyncio
from dataclasses import dataclass
from typing import Dict
import time
@dataclass
class ServerCapacity:
"""Server capacity planning for concurrent voice sessions."""
cpu_cores: int
ram_gb: int
gpu_count: int = 0
gpu_vram_gb: int = 0
@property
def max_concurrent_calls(self) -> dict:
"""Estimate max concurrent calls by architecture."""
return {
"cloud_asr_cloud_tts": {
# All AI on external APIs - server just routes audio
"calls": self.cpu_cores * 50, # ~50 calls per core
"bottleneck": "WebSocket connections and bandwidth",
"ram_per_call_mb": 5,
"max_by_ram": (self.ram_gb * 1024) // 5,
},
"self_hosted_whisper_cloud_tts": {
# ASR on GPU, TTS on cloud
"calls": self.gpu_count * 20, # ~20 concurrent with large-v3
"bottleneck": "GPU VRAM for Whisper",
"gpu_vram_per_call_mb": self.gpu_vram_gb * 1024 // 20,
},
"fully_self_hosted": {
# Everything on your hardware
"calls": min(self.gpu_count * 10, self.cpu_cores * 5),
"bottleneck": "GPU shared between ASR and TTS",
"note": "Need separate GPUs for ASR and TTS at scale"
}
}
class ConnectionManager:
"""Manage concurrent WebSocket connections with backpressure."""
def __init__(self, max_connections: int = 1000,
max_per_ip: int = 10):
self.max_connections = max_connections
self.max_per_ip = max_per_ip
self.active: Dict[str, dict] = {}
self.ip_counts: Dict[str, int] = {}
self._semaphore = asyncio.Semaphore(max_connections)
async def acquire(self, session_id: str, client_ip: str) -> bool:
"""Try to acquire a connection slot."""
# Check per-IP limit
ip_count = self.ip_counts.get(client_ip, 0)
if ip_count >= self.max_per_ip:
return False
# Try to acquire global semaphore
acquired = self._semaphore._value > 0
if not acquired:
return False
await self._semaphore.acquire()
self.active[session_id] = {
"ip": client_ip,
"connected_at": time.time()
}
self.ip_counts[client_ip] = ip_count + 1
return True
def release(self, session_id: str):
"""Release a connection slot."""
if session_id in self.active:
ip = self.active[session_id]["ip"]
del self.active[session_id]
self.ip_counts[ip] = max(0, self.ip_counts.get(ip, 1) - 1)
self._semaphore.release()
@property
def stats(self) -> dict:
return {
"active_connections": len(self.active),
"available_slots": self._semaphore._value,
"unique_ips": len(self.ip_counts),
"utilization_pct": len(self.active) / self.max_connections * 100
}
Edge Deployment for Latency
# Edge deployment architecture for voice AI
#
# Why Edge?
# - Network latency to a central server can be 50-200ms
# - Edge reduces this to 5-20ms
# - For voice AI, this saves 100-400ms round-trip
#
# What to deploy at the edge:
# - WebSocket termination
# - Audio preprocessing (noise filtering, VAD)
# - ASR (if using self-hosted Whisper)
# - TTS cache (pre-generated common responses)
# - Jitter buffer
#
# What stays in the cloud:
# - LLM inference (too expensive for edge GPUs)
# - Dialog state database
# - Analytics and logging
# - Model training
EDGE_DEPLOYMENT_CONFIG = {
"provider": "cloudflare_workers", # or "aws_lambda@edge", "fly.io"
"edge_services": {
"websocket_proxy": {
"purpose": "Terminate WebSocket close to user",
"latency_savings": "50-150ms",
"resources": "Minimal CPU, no GPU"
},
"audio_preprocessor": {
"purpose": "Noise filtering, volume normalization, VAD",
"latency_savings": "10-30ms (avoids sending silence to ASR)",
"resources": "Light CPU"
},
"tts_cache": {
"purpose": "Serve pre-generated audio for common responses",
"latency_savings": "150-500ms (skip TTS entirely for cached)",
"resources": "~50MB storage for 500 cached phrases"
},
"asr_edge": {
"purpose": "Run Whisper at the edge for privacy + speed",
"latency_savings": "100-300ms",
"resources": "GPU edge node (expensive, only if needed)",
"alternative": "Stream audio to cloud ASR, accept latency"
}
},
"cloud_services": {
"llm_inference": "GPT-4o / Claude - stays in cloud",
"dialog_state": "Redis cluster in central region",
"analytics": "Async, non-blocking, eventual consistency OK"
},
"regions": [
{"name": "us-east", "pop": "Ashburn, VA", "coverage": "East Coast US"},
{"name": "us-west", "pop": "San Jose, CA", "coverage": "West Coast US"},
{"name": "eu-west", "pop": "London, UK", "coverage": "Western Europe"},
{"name": "ap-south", "pop": "Mumbai, IN", "coverage": "South Asia"},
]
}
The 500ms Rule: If your end-to-end latency exceeds 500ms, users perceive the system as "slow" and start speaking over it. If it exceeds 1500ms, they think it's broken. If it exceeds 3 seconds on a phone call, they hang up. Every millisecond you save directly improves user satisfaction and call completion rates. Measure P99 latency, not averages — your worst-case latency is what users remember.
Lilly Tech Systems