import argparse
import asyncio
import gc
import re
import logging
import traceback
import os
from functools import partial
import json
import tempfile
import importlib.resources
import itertools
import websockets
import numpy as np
import torch
import torchaudio
from dotenv import load_dotenv
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
infer_batch_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
# --- Configuration ---
load_dotenv()
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_MODEL_PATH = os.path.join(BASE_DIR, "..", "checkpoints", "F5TTS_v1_Base", "model_1250000.safetensors")
DEFAULT_VOCODER_PATH = os.path.join(BASE_DIR, "..", "checkpoints", "vocos-mel-24khz")
MODEL_CKPT_PATH = os.environ.get("F5_TTS_MODEL_PATH", DEFAULT_MODEL_PATH)
VOCODER_PATH = os.environ.get("F5_TTS_VOCODER_PATH", DEFAULT_VOCODER_PATH)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Device Configuration ---
def get_device():
"""
Determines the appropriate device for inference based on availability and environment variables.
Priority: CUDA > MPS (Apple Silicon) > CPU.
"""
device_env = os.environ.get("DEVICE", "auto").lower()
logger.info(f"Device environment variable set to: '{device_env}'")
if device_env == "auto":
if torch.cuda.is_available():
logger.info("CUDA is available. Using 'cuda' for acceleration.")
return "cuda", torch.float16
elif torch.backends.mps.is_available():
logger.info("Apple MPS is available. Using 'mps' for acceleration on Apple Silicon.")
return "mps", torch.float16
else:
logger.warning("----------------------------------------------------------")
logger.warning("WARNING: No CUDA or MPS device found.")
logger.warning("Falling back to CPU. Performance will be much slower.")
logger.warning("----------------------------------------------------------")
return "cpu", torch.float32
elif device_env == "cuda":
if torch.cuda.is_available():
logger.info("Forcing CUDA device as per environment variable.")
return "cuda", torch.float16
else:
logger.error("FATAL: DEVICE is set to 'cuda' but no CUDA device is available.")
logger.error("Please check your NVIDIA drivers and CUDA installation.")
raise RuntimeError("CUDA not available, but was explicitly requested.")
elif device_env == "mps":
if torch.backends.mps.is_available():
logger.info("Forcing MPS device as per environment variable.")
return "mps", torch.float16
else:
logger.error("FATAL: DEVICE is set to 'mps' but no MPS device is available.")
logger.error("This is unexpected on an Apple Silicon Mac. Check your PyTorch installation.")
raise RuntimeError("MPS not available, but was explicitly requested.")
elif device_env == "cpu":
logger.info("Forcing CPU as per environment variable.")
return "cpu", torch.float32
else:
logger.error(f"Invalid DEVICE environment variable: '{device_env}'. Must be one of 'auto', 'cuda', 'mps', 'cpu'.")
raise ValueError(f"Invalid device specified: {device_env}")
DEVICE, DTYPE = get_device()
# Sentinel objects to signal different states in the audio queue.
JOB_END_MARKER = object()
SEGMENT_END_MARKER = object()
# --- Utility Functions ---
def split_text_into_sentences(text: str) -> list[str]:
"""
Splits a long text into sentences based on common punctuation,
ensuring that the punctuation is kept with its sentence.
"""
if not text:
return []
sentences = re.split(r'(?<=[,。?!;,?!;])', text)
return [s.strip() for s in sentences if s and s.strip()]
class PreloadedModels:
"""Loads and holds the core TTS models in memory for maximum performance."""
def __init__(self, model_name, ckpt_file, vocab_file, vocoder_path, device, dtype):
logger.info(f"Loading core TTS models into memory on {device} with {dtype}...")
self.device = device
self.dtype = dtype
model_cfg = OmegaConf.load(str(importlib.resources.files("f5_tts").joinpath(f"configs/{model_name}.yaml")))
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
self.model = load_model(
self.model_cls, self.model_arc, ckpt_path=ckpt_file,
mel_spec_type=self.mel_spec_type, vocab_file=vocab_file,
ode_method="euler", use_ema=True, device=self.device,
).to(self.device, dtype=self.dtype)
self.vocoder = load_vocoder(
vocoder_name=self.mel_spec_type, is_local=True,
local_path=vocoder_path, device=self.device
)
if self.device in ['cuda', 'mps']:
logger.info(f"Compiling model for PyTorch 2.x on '{self.device}' for further speedup...")
self.model = torch.compile(self.model)
logger.info("Core models loaded and compiled successfully.")
class TTSClientSession:
"""
Manages a single client's session using a high-performance, perpetual pipeline
with dedicated asyncio tasks for inference and network sending.
"""
def __init__(self, preloaded_models, ref_audio_path, ref_text, websocket):
self.models = preloaded_models
self.device = self.models.device
self.websocket = websocket
self.is_running = False
self.text_queue = asyncio.Queue()
self.audio_queue = asyncio.Queue()
self.inference_task = None
self.send_task = None
self.job_id_lock = asyncio.Lock()
self.job_id_counter = itertools.count()
self.current_job_id = -1
self.update_reference(ref_audio_path, ref_text)
def update_reference(self, ref_audio_path, ref_text):
"""Processes and sets the reference audio for this specific session."""
self.ref_audio_path, self.ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text)
self.audio, self.sr = torchaudio.load(self.ref_audio_path)
async def _warm_up(self):
"""Warms up the model for the new session asynchronously."""
logger.info("Warming up the model for the new session...")
gen_text = "Warm-up text."
loop = asyncio.get_running_loop()
@torch.inference_mode()
def run_warmup_inference():
"""The synchronous, blocking inference call."""
with torch.amp.autocast(device_type=self.device, dtype=self.models.dtype, enabled=(self.device in ['cuda', 'mps'])):
for _ in infer_batch_process(
(self.audio, self.sr), self.ref_text, [gen_text],
self.models.model, self.models.vocoder, progress=None,
device=self.device, streaming=True,
):
pass
await loop.run_in_executor(None, run_warmup_inference)
logger.info("Session warm-up completed.")
async def start(self):
"""Starts the perpetual inference and sending tasks."""
if self.is_running: return
logger.info("Starting session tasks...")
self.is_running = True
self.inference_task = asyncio.create_task(self._inference_worker())
self.send_task = asyncio.create_task(self._send_worker())
await self._increment_job_id() # Initialize first job ID
logger.info("Session tasks started.")
async def stop(self):
"""Stops the perpetual tasks and cleans up."""
if not self.is_running: return
logger.info("Stopping session tasks...")
self.is_running = False
# Cancel and gather tasks to ensure they are stopped
tasks = []
if self.inference_task:
self.inference_task.cancel()
tasks.append(self.inference_task)
if self.send_task:
self.send_task.cancel()
tasks.append(self.send_task)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("Session tasks stopped.")
async def add_text_to_queue(self, payload):
"""Adds a text payload from the client to the processing queue."""
if self.is_running:
await self.text_queue.put(payload)
async def _increment_job_id(self):
"""Atomically increments the job ID to signal a new job or an interruption."""
async with self.job_id_lock:
self.current_job_id = next(self.job_id_counter)
logger.info(f"Advanced to new job_id: {self.current_job_id}")
return self.current_job_id
async def prepare_for_new_job(self):
"""Starts a new job by incrementing the job ID."""
return await self._increment_job_id()
async def request_stop(self):
"""Stops the current job by incrementing the job ID, making all current work obsolete."""
job_id = await self._increment_job_id()
await self.audio_queue.put((job_id, JOB_END_MARKER))
return job_id
async def _inference_worker(self):
"""
The perpetual, GPU-bound worker. It processes text requests and puts
audio chunks tagged with a job_id into the audio_queue.
"""
loop = asyncio.get_running_loop()
while self.is_running:
try:
request = await self.text_queue.get()
if request is None: break
job_id = self.current_job_id
text_to_process = request.get("text")
if text_to_process == "__FLUSH_AUDIO__":
logger.info(f"Flush signal received for job_id: {job_id}. Queuing JOB END marker.")
await self.audio_queue.put((job_id, JOB_END_MARKER))
continue
if not text_to_process or not text_to_process.strip() or text_to_process.strip() in ",。?!,?!.\n":
logger.warning(f"Skipping inference for empty or invalid text: '{text_to_process}'")
continue
speed = request.get("speed", 1.0)
chunk_size = request.get("chunk_size", 2048)
logger.info(f"Starting new synthesis job for job_id {job_id}.")
sentences = await loop.run_in_executor(None, split_text_into_sentences, text_to_process)
if not sentences:
logger.warning(f"Text '{text_to_process}' resulted in no sentences to synthesize.")
continue
total_audio_chunks = 0
for i, sentence in enumerate(sentences):
logger.info(f"Synthesizing sentence {i+1}/{len(sentences)} for job_id {job_id}: '{sentence}'")
if job_id != self.current_job_id:
logger.warning(f"Job {job_id} was interrupted by new job {self.current_job_id}. Stopping generation.")
break
# Run the blocking inference in a separate thread
infer_func = partial(
infer_batch_process,
(self.audio, self.sr), self.ref_text, [sentence],
self.models.model, self.models.vocoder, progress=None,
device=self.device, streaming=True, chunk_size=chunk_size, speed=speed
)
# This is a generator, so we need to iterate over it in the thread
@torch.inference_mode()
def run_inference_generator():
with torch.amp.autocast(device_type=self.device, dtype=self.models.dtype, enabled=(self.device in ['cuda', 'mps'])):
return list(infer_func())
audio_stream = await loop.run_in_executor(None, run_inference_generator)
for audio_chunk, _ in audio_stream:
if job_id != self.current_job_id: break
if len(audio_chunk) > 0:
total_audio_chunks += 1
clipped_chunk = np.clip(audio_chunk, -1.0, 1.0)
await self.audio_queue.put((job_id, clipped_chunk))
if job_id != self.current_job_id: break
await self.audio_queue.put((job_id, SEGMENT_END_MARKER))
logger.info(f"Finished sentence {i+1}/{len(sentences)}. Segment marker queued.")
logger.info(f"Finished synthesis for job_id {job_id}. Total audio chunks sent: {total_audio_chunks}.")
except asyncio.CancelledError:
logger.info("Inference worker cancelled.")
break
except Exception as e:
logger.error(f"Inference worker error: {e}", exc_info=True)
self.is_running = False
break
logger.info("Inference worker finished.")
async def _send_worker(self):
"""
The perpetual, Network-bound worker. It sends audio chunks to the client,
strictly adhering to the job_id to prevent sending stale data.
"""
active_send_job_id = -1
while self.is_running:
try:
item = await self.audio_queue.get()
if item is None: break
job_id, chunk = item
if job_id < active_send_job_id:
logger.info(f"Discarding stale audio chunk from job {job_id} (current: {active_send_job_id}).")
continue
if job_id > active_send_job_id:
logger.info(f"Send worker switching from job {active_send_job_id} to new job {job_id}.")
active_send_job_id = job_id
if chunk is JOB_END_MARKER:
logger.info(f"Sending JOB END marker for job {job_id}.")
await self.websocket.send("<JOB_END>")
elif chunk is SEGMENT_END_MARKER:
logger.info(f"Sending SEGMENT END marker for job {job_id}.")
await self.websocket.send("<SEG_END>")
else:
await self.websocket.send(chunk.astype(np.float32).tobytes())
except asyncio.CancelledError:
logger.info("Send worker cancelled.")
break
except websockets.exceptions.ConnectionClosed:
logger.warning("Client disconnected during stream.")
self.is_running = False
break
except Exception as e:
logger.error(f"Send worker error: {e}", exc_info=True)
self.is_running = False
break
logger.info("Send worker finished.")
async def handle_client(websocket, preloaded_models):
ref_audio_file, session = None, None
addr = websocket.remote_address
try:
logger.info(f"Handling new client from {addr}")
# 1. Receive header
header_json = await websocket.recv()
header = json.loads(header_json)
# 2. Receive reference audio
ref_audio_bytes = await websocket.recv()
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as fp:
fp.write(ref_audio_bytes)
ref_audio_file = fp.name
session = TTSClientSession(preloaded_models, ref_audio_file, header['ref_text'], websocket)
await session._warm_up()
await session.start()
await websocket.send("READY")
logger.info(f"Session ready for {addr}. Now handling synthesis jobs.")
async for message in websocket:
payload = json.loads(message)
if payload.get("is_new_job"):
logger.info(f"Received new job signal from {addr}.")
if session: await session.prepare_for_new_job()
if payload.get("text") == "__STOP_GENERATION__":
logger.info(f"Received __STOP_GENERATION__ command from {addr}.")
if session: await session.request_stop()
elif "text" in payload:
logger.info(f"Received text from {addr}: '{payload.get('text', '')[:50]}...'")
if session: await session.add_text_to_queue(payload)
except websockets.exceptions.ConnectionClosedOK:
logger.info(f"Client {addr} disconnected gracefully.")
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"Connection with {addr} lost unexpectedly: {e}")
except Exception as e:
logger.error(f"Error handling client {addr}: {e}")
traceback.print_exc()
finally:
if session:
await session.stop()
if ref_audio_file and os.path.exists(ref_audio_file): os.remove(ref_audio_file)
logger.info(f"Connection with {addr} closed and resources cleaned up.")
async def start_server(host, port, preloaded_models):
# Use functools.partial to create a handler with the preloaded_models argument pre-filled.
# The websockets library will call this handler with a single 'websocket' argument.
# `partial` will then call our handle_client with the `websocket` connection and the pre-filled `preloaded_models` keyword argument.
handler = partial(handle_client, preloaded_models=preloaded_models)
server = await websockets.serve(handler, host, port, max_size=5*1024*1024, ping_interval=20, ping_timeout=60)
logger.info(f"WebSocket server started on ws://{host}:{port}")
await server.wait_closed()
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7300)
args = parser.parse_args()
try:
models = PreloadedModels(
model_name="F5TTS_v1_Base",
ckpt_file=MODEL_CKPT_PATH,
vocab_file="",
vocoder_path=VOCODER_PATH,
device=DEVICE,
dtype=DTYPE,
)
await start_server(args.host, args.port, models)
except KeyboardInterrupt:
logger.info("Server shutting down.")
gc.collect()
if __name__ == "__main__":
asyncio.run(main())
评论