diff options
Diffstat (limited to 'makima/makima-vllm/server.py')
| -rw-r--r-- | makima/makima-vllm/server.py | 390 |
1 files changed, 0 insertions, 390 deletions
diff --git a/makima/makima-vllm/server.py b/makima/makima-vllm/server.py deleted file mode 100644 index 2d9ea40..0000000 --- a/makima/makima-vllm/server.py +++ /dev/null @@ -1,390 +0,0 @@ -#!/usr/bin/env python3 -""" -Qwen3-TTS FastAPI Server - -Simple HTTP wrapper around Qwen3-TTS for use by makima. -Supports streaming audio output for real-time playback. -""" - -import io -import os -import base64 -import time -import asyncio -from typing import Optional, AsyncGenerator -from contextlib import asynccontextmanager - -import numpy as np -import torch -import soundfile as sf -from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect -from fastapi.responses import Response, StreamingResponse -from pydantic import BaseModel - -# Global model instance -model = None - - -class TTSRequest(BaseModel): - text: str - # Supported: auto, chinese, english, french, german, italian, japanese, korean, portuguese, russian, spanish - language: str = "english" - # Reference audio for voice cloning (base64 encoded WAV) - reference_audio: Optional[str] = None - reference_text: Optional[str] = None - - -class TTSResponse(BaseModel): - # Base64 encoded WAV audio - audio: str - sample_rate: int - duration_seconds: float - - -def get_model_name(): - """Get model name from environment or use default.""" - return os.environ.get("QWEN3_TTS_MODEL", "Qwen/Qwen3-TTS-12Hz-0.6B-Base") - - -def get_device(): - """Get device to use for inference.""" - device = os.environ.get("TTS_DEVICE", "auto") - if device == "auto": - # MPS has limitations with large output channels, prefer CPU on macOS - import platform - if platform.system() == "Darwin": - return "cpu" - elif torch.cuda.is_available(): - return "cuda" - else: - return "cpu" - return device - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Load model on startup.""" - global model - model_name = get_model_name() - print(f"Loading Qwen3-TTS model: {model_name}") - start = time.time() - - from qwen_tts import Qwen3TTSModel - - # Check if flash attention is available - try: - import flash_attn - attn_impl = "flash_attention_2" - print("Using Flash Attention 2") - except ImportError: - attn_impl = "eager" - print("Flash Attention not available, using eager attention") - - device = get_device() - print(f"Using device: {device}") - - # Use float32 for CPU (bfloat16 can be slow on CPU) - dtype = torch.float32 if device == "cpu" else torch.bfloat16 - - model = Qwen3TTSModel.from_pretrained( - model_name, - torch_dtype=dtype, - attn_implementation=attn_impl, - device_map=device, - ) - - print(f"Model loaded in {time.time() - start:.2f}s") - yield - # Cleanup - model = None - - -app = FastAPI( - title="Qwen3-TTS Server", - description="HTTP API for Qwen3-TTS text-to-speech", - lifespan=lifespan, -) - - -@app.get("/health") -async def health(): - """Health check endpoint.""" - return {"status": "ok", "model_loaded": model is not None} - - -@app.post("/tts", response_model=TTSResponse) -async def generate_tts(request: TTSRequest): - """Generate speech from text.""" - if model is None: - raise HTTPException(status_code=503, detail="Model not loaded") - - try: - start = time.time() - - # Decode reference audio if provided - ref_audio = None - if request.reference_audio: - audio_bytes = base64.b64decode(request.reference_audio) - audio_data, audio_sr = sf.read(io.BytesIO(audio_bytes)) - # qwen-tts expects tuple of (audio, sample_rate) for numpy input - ref_audio = (audio_data, audio_sr) - - # Voice cloning requires reference audio - if ref_audio is None: - raise HTTPException( - status_code=400, - detail="reference_audio is required for the Base model. Please provide a base64-encoded WAV file." - ) - - # Use x_vector_only_mode if no reference text provided (simpler voice extraction) - use_x_vector_only = request.reference_text is None or request.reference_text.strip() == "" - - wavs, sample_rate = model.generate_voice_clone( - text=request.text, - language=request.language, - ref_audio=ref_audio, - ref_text=request.reference_text if not use_x_vector_only else None, - x_vector_only_mode=use_x_vector_only, - max_new_tokens=2048, - temperature=0.9, - top_k=50, - repetition_penalty=1.05, - ) - - # Get first waveform - waveform = wavs[0] if isinstance(wavs, list) else wavs - - # Convert to numpy if tensor - if torch.is_tensor(waveform): - waveform = waveform.cpu().numpy() - - # Ensure 1D array - if waveform.ndim > 1: - waveform = waveform.squeeze() - - # Encode as WAV - buffer = io.BytesIO() - sf.write(buffer, waveform, sample_rate, format="WAV") - audio_bytes = buffer.getvalue() - - duration = len(waveform) / sample_rate - elapsed = time.time() - start - print(f"Generated {duration:.2f}s audio in {elapsed:.2f}s (RTF: {elapsed/duration:.2f})") - - return TTSResponse( - audio=base64.b64encode(audio_bytes).decode("utf-8"), - sample_rate=sample_rate, - duration_seconds=duration, - ) - - except Exception as e: - import traceback - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/tts/raw") -async def generate_tts_raw(request: TTSRequest): - """Generate speech and return raw WAV bytes.""" - if model is None: - raise HTTPException(status_code=503, detail="Model not loaded") - - try: - # Decode reference audio if provided - ref_audio = None - if request.reference_audio: - audio_bytes = base64.b64decode(request.reference_audio) - audio_data, audio_sr = sf.read(io.BytesIO(audio_bytes)) - ref_audio = (audio_data, audio_sr) - - # Voice cloning requires reference audio - if ref_audio is None: - raise HTTPException( - status_code=400, - detail="reference_audio is required for the Base model." - ) - - use_x_vector_only = request.reference_text is None or request.reference_text.strip() == "" - - wavs, sample_rate = model.generate_voice_clone( - text=request.text, - language=request.language, - ref_audio=ref_audio, - ref_text=request.reference_text if not use_x_vector_only else None, - x_vector_only_mode=use_x_vector_only, - max_new_tokens=2048, - temperature=0.9, - top_k=50, - repetition_penalty=1.05, - ) - - waveform = wavs[0] if isinstance(wavs, list) else wavs - - if torch.is_tensor(waveform): - waveform = waveform.cpu().numpy() - - if waveform.ndim > 1: - waveform = waveform.squeeze() - - # Return raw WAV - buffer = io.BytesIO() - sf.write(buffer, waveform, sample_rate, format="WAV") - - return Response( - content=buffer.getvalue(), - media_type="audio/wav", - ) - - except Exception as e: - import traceback - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) - - -@app.websocket("/tts/stream") -async def stream_tts(websocket: WebSocket): - """ - WebSocket endpoint for streaming TTS. - - Protocol: - - Client sends JSON: {"text": "...", "language": "english", "reference_audio": "base64...", "reference_text": "..."} - - Server sends binary PCM16 chunks as they're generated - - Server sends JSON {"type": "audio_end", "sample_rate": 24000} when done - - Server sends JSON {"type": "error", "message": "..."} on error - """ - await websocket.accept() - - if model is None: - await websocket.send_json({"type": "error", "message": "Model not loaded"}) - await websocket.close() - return - - try: - # Wait for request - data = await websocket.receive_json() - text = data.get("text", "") - language = data.get("language", "english") - ref_audio_b64 = data.get("reference_audio") - ref_text = data.get("reference_text") - - if not text.strip(): - await websocket.send_json({"type": "error", "message": "No text provided"}) - await websocket.close() - return - - # Decode reference audio - ref_audio = None - if ref_audio_b64: - audio_bytes = base64.b64decode(ref_audio_b64) - audio_data, audio_sr = sf.read(io.BytesIO(audio_bytes)) - ref_audio = (audio_data, audio_sr) - - if ref_audio is None: - await websocket.send_json({"type": "error", "message": "reference_audio is required"}) - await websocket.close() - return - - use_x_vector_only = ref_text is None or ref_text.strip() == "" - - print(f"Streaming TTS for {len(text)} chars...") - start = time.time() - - # Use streaming mode (non_streaming_mode=False is default) - # This returns a generator that yields audio chunks - generator = model.generate_voice_clone( - text=text, - language=language, - ref_audio=ref_audio, - ref_text=ref_text if not use_x_vector_only else None, - x_vector_only_mode=use_x_vector_only, - max_new_tokens=2048, - temperature=0.9, - top_k=50, - repetition_penalty=1.05, - non_streaming_mode=False, # Enable streaming - ) - - total_samples = 0 - sample_rate = 24000 - - # Check if generator is actually a generator or just the result - if hasattr(generator, '__iter__') and not isinstance(generator, tuple): - for chunk_data in generator: - # chunk_data might be (wav_chunk, sr) or just wav_chunk - if isinstance(chunk_data, tuple): - wav_chunk, sample_rate = chunk_data - else: - wav_chunk = chunk_data - - if torch.is_tensor(wav_chunk): - wav_chunk = wav_chunk.cpu().numpy() - - if wav_chunk.ndim > 1: - wav_chunk = wav_chunk.squeeze() - - # Convert to PCM16 - pcm16 = (wav_chunk * 32767).astype(np.int16) - total_samples += len(pcm16) - - # Send binary audio chunk - await websocket.send_bytes(pcm16.tobytes()) - - # Yield to allow other tasks - await asyncio.sleep(0) - else: - # Non-streaming fallback - model returned full result - wavs, sample_rate = generator if isinstance(generator, tuple) else (generator, 24000) - waveform = wavs[0] if isinstance(wavs, list) else wavs - - if torch.is_tensor(waveform): - waveform = waveform.cpu().numpy() - - if waveform.ndim > 1: - waveform = waveform.squeeze() - - # Send in chunks for better streaming behavior - chunk_size = sample_rate // 4 # 250ms chunks - for i in range(0, len(waveform), chunk_size): - chunk = waveform[i:i + chunk_size] - pcm16 = (chunk * 32767).astype(np.int16) - total_samples += len(pcm16) - await websocket.send_bytes(pcm16.tobytes()) - await asyncio.sleep(0) - - duration = total_samples / sample_rate - elapsed = time.time() - start - print(f"Streamed {duration:.2f}s audio in {elapsed:.2f}s (RTF: {elapsed/duration:.2f})") - - # Send completion message - await websocket.send_json({ - "type": "audio_end", - "sample_rate": sample_rate, - "duration_seconds": duration, - }) - - except WebSocketDisconnect: - print("Client disconnected during streaming") - except Exception as e: - import traceback - traceback.print_exc() - try: - await websocket.send_json({"type": "error", "message": str(e)}) - except: - pass - finally: - try: - await websocket.close() - except: - pass - - -if __name__ == "__main__": - import uvicorn - port = int(os.environ.get("PORT", "8100")) - uvicorn.run( - app, - host="0.0.0.0", - port=port, - # Increase keep-alive timeout to avoid connection resets - timeout_keep_alive=120, - ) |
