summaryrefslogtreecommitdiff
path: root/makima/makima-vllm/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'makima/makima-vllm/server.py')
-rw-r--r--makima/makima-vllm/server.py390
1 files changed, 0 insertions, 390 deletions
diff --git a/makima/makima-vllm/server.py b/makima/makima-vllm/server.py
deleted file mode 100644
index 2d9ea40..0000000
--- a/makima/makima-vllm/server.py
+++ /dev/null
@@ -1,390 +0,0 @@
-#!/usr/bin/env python3
-"""
-Qwen3-TTS FastAPI Server
-
-Simple HTTP wrapper around Qwen3-TTS for use by makima.
-Supports streaming audio output for real-time playback.
-"""
-
-import io
-import os
-import base64
-import time
-import asyncio
-from typing import Optional, AsyncGenerator
-from contextlib import asynccontextmanager
-
-import numpy as np
-import torch
-import soundfile as sf
-from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
-from fastapi.responses import Response, StreamingResponse
-from pydantic import BaseModel
-
-# Global model instance
-model = None
-
-
-class TTSRequest(BaseModel):
- text: str
- # Supported: auto, chinese, english, french, german, italian, japanese, korean, portuguese, russian, spanish
- language: str = "english"
- # Reference audio for voice cloning (base64 encoded WAV)
- reference_audio: Optional[str] = None
- reference_text: Optional[str] = None
-
-
-class TTSResponse(BaseModel):
- # Base64 encoded WAV audio
- audio: str
- sample_rate: int
- duration_seconds: float
-
-
-def get_model_name():
- """Get model name from environment or use default."""
- return os.environ.get("QWEN3_TTS_MODEL", "Qwen/Qwen3-TTS-12Hz-0.6B-Base")
-
-
-def get_device():
- """Get device to use for inference."""
- device = os.environ.get("TTS_DEVICE", "auto")
- if device == "auto":
- # MPS has limitations with large output channels, prefer CPU on macOS
- import platform
- if platform.system() == "Darwin":
- return "cpu"
- elif torch.cuda.is_available():
- return "cuda"
- else:
- return "cpu"
- return device
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- """Load model on startup."""
- global model
- model_name = get_model_name()
- print(f"Loading Qwen3-TTS model: {model_name}")
- start = time.time()
-
- from qwen_tts import Qwen3TTSModel
-
- # Check if flash attention is available
- try:
- import flash_attn
- attn_impl = "flash_attention_2"
- print("Using Flash Attention 2")
- except ImportError:
- attn_impl = "eager"
- print("Flash Attention not available, using eager attention")
-
- device = get_device()
- print(f"Using device: {device}")
-
- # Use float32 for CPU (bfloat16 can be slow on CPU)
- dtype = torch.float32 if device == "cpu" else torch.bfloat16
-
- model = Qwen3TTSModel.from_pretrained(
- model_name,
- torch_dtype=dtype,
- attn_implementation=attn_impl,
- device_map=device,
- )
-
- print(f"Model loaded in {time.time() - start:.2f}s")
- yield
- # Cleanup
- model = None
-
-
-app = FastAPI(
- title="Qwen3-TTS Server",
- description="HTTP API for Qwen3-TTS text-to-speech",
- lifespan=lifespan,
-)
-
-
-@app.get("/health")
-async def health():
- """Health check endpoint."""
- return {"status": "ok", "model_loaded": model is not None}
-
-
-@app.post("/tts", response_model=TTSResponse)
-async def generate_tts(request: TTSRequest):
- """Generate speech from text."""
- if model is None:
- raise HTTPException(status_code=503, detail="Model not loaded")
-
- try:
- start = time.time()
-
- # Decode reference audio if provided
- ref_audio = None
- if request.reference_audio:
- audio_bytes = base64.b64decode(request.reference_audio)
- audio_data, audio_sr = sf.read(io.BytesIO(audio_bytes))
- # qwen-tts expects tuple of (audio, sample_rate) for numpy input
- ref_audio = (audio_data, audio_sr)
-
- # Voice cloning requires reference audio
- if ref_audio is None:
- raise HTTPException(
- status_code=400,
- detail="reference_audio is required for the Base model. Please provide a base64-encoded WAV file."
- )
-
- # Use x_vector_only_mode if no reference text provided (simpler voice extraction)
- use_x_vector_only = request.reference_text is None or request.reference_text.strip() == ""
-
- wavs, sample_rate = model.generate_voice_clone(
- text=request.text,
- language=request.language,
- ref_audio=ref_audio,
- ref_text=request.reference_text if not use_x_vector_only else None,
- x_vector_only_mode=use_x_vector_only,
- max_new_tokens=2048,
- temperature=0.9,
- top_k=50,
- repetition_penalty=1.05,
- )
-
- # Get first waveform
- waveform = wavs[0] if isinstance(wavs, list) else wavs
-
- # Convert to numpy if tensor
- if torch.is_tensor(waveform):
- waveform = waveform.cpu().numpy()
-
- # Ensure 1D array
- if waveform.ndim > 1:
- waveform = waveform.squeeze()
-
- # Encode as WAV
- buffer = io.BytesIO()
- sf.write(buffer, waveform, sample_rate, format="WAV")
- audio_bytes = buffer.getvalue()
-
- duration = len(waveform) / sample_rate
- elapsed = time.time() - start
- print(f"Generated {duration:.2f}s audio in {elapsed:.2f}s (RTF: {elapsed/duration:.2f})")
-
- return TTSResponse(
- audio=base64.b64encode(audio_bytes).decode("utf-8"),
- sample_rate=sample_rate,
- duration_seconds=duration,
- )
-
- except Exception as e:
- import traceback
- traceback.print_exc()
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.post("/tts/raw")
-async def generate_tts_raw(request: TTSRequest):
- """Generate speech and return raw WAV bytes."""
- if model is None:
- raise HTTPException(status_code=503, detail="Model not loaded")
-
- try:
- # Decode reference audio if provided
- ref_audio = None
- if request.reference_audio:
- audio_bytes = base64.b64decode(request.reference_audio)
- audio_data, audio_sr = sf.read(io.BytesIO(audio_bytes))
- ref_audio = (audio_data, audio_sr)
-
- # Voice cloning requires reference audio
- if ref_audio is None:
- raise HTTPException(
- status_code=400,
- detail="reference_audio is required for the Base model."
- )
-
- use_x_vector_only = request.reference_text is None or request.reference_text.strip() == ""
-
- wavs, sample_rate = model.generate_voice_clone(
- text=request.text,
- language=request.language,
- ref_audio=ref_audio,
- ref_text=request.reference_text if not use_x_vector_only else None,
- x_vector_only_mode=use_x_vector_only,
- max_new_tokens=2048,
- temperature=0.9,
- top_k=50,
- repetition_penalty=1.05,
- )
-
- waveform = wavs[0] if isinstance(wavs, list) else wavs
-
- if torch.is_tensor(waveform):
- waveform = waveform.cpu().numpy()
-
- if waveform.ndim > 1:
- waveform = waveform.squeeze()
-
- # Return raw WAV
- buffer = io.BytesIO()
- sf.write(buffer, waveform, sample_rate, format="WAV")
-
- return Response(
- content=buffer.getvalue(),
- media_type="audio/wav",
- )
-
- except Exception as e:
- import traceback
- traceback.print_exc()
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.websocket("/tts/stream")
-async def stream_tts(websocket: WebSocket):
- """
- WebSocket endpoint for streaming TTS.
-
- Protocol:
- - Client sends JSON: {"text": "...", "language": "english", "reference_audio": "base64...", "reference_text": "..."}
- - Server sends binary PCM16 chunks as they're generated
- - Server sends JSON {"type": "audio_end", "sample_rate": 24000} when done
- - Server sends JSON {"type": "error", "message": "..."} on error
- """
- await websocket.accept()
-
- if model is None:
- await websocket.send_json({"type": "error", "message": "Model not loaded"})
- await websocket.close()
- return
-
- try:
- # Wait for request
- data = await websocket.receive_json()
- text = data.get("text", "")
- language = data.get("language", "english")
- ref_audio_b64 = data.get("reference_audio")
- ref_text = data.get("reference_text")
-
- if not text.strip():
- await websocket.send_json({"type": "error", "message": "No text provided"})
- await websocket.close()
- return
-
- # Decode reference audio
- ref_audio = None
- if ref_audio_b64:
- audio_bytes = base64.b64decode(ref_audio_b64)
- audio_data, audio_sr = sf.read(io.BytesIO(audio_bytes))
- ref_audio = (audio_data, audio_sr)
-
- if ref_audio is None:
- await websocket.send_json({"type": "error", "message": "reference_audio is required"})
- await websocket.close()
- return
-
- use_x_vector_only = ref_text is None or ref_text.strip() == ""
-
- print(f"Streaming TTS for {len(text)} chars...")
- start = time.time()
-
- # Use streaming mode (non_streaming_mode=False is default)
- # This returns a generator that yields audio chunks
- generator = model.generate_voice_clone(
- text=text,
- language=language,
- ref_audio=ref_audio,
- ref_text=ref_text if not use_x_vector_only else None,
- x_vector_only_mode=use_x_vector_only,
- max_new_tokens=2048,
- temperature=0.9,
- top_k=50,
- repetition_penalty=1.05,
- non_streaming_mode=False, # Enable streaming
- )
-
- total_samples = 0
- sample_rate = 24000
-
- # Check if generator is actually a generator or just the result
- if hasattr(generator, '__iter__') and not isinstance(generator, tuple):
- for chunk_data in generator:
- # chunk_data might be (wav_chunk, sr) or just wav_chunk
- if isinstance(chunk_data, tuple):
- wav_chunk, sample_rate = chunk_data
- else:
- wav_chunk = chunk_data
-
- if torch.is_tensor(wav_chunk):
- wav_chunk = wav_chunk.cpu().numpy()
-
- if wav_chunk.ndim > 1:
- wav_chunk = wav_chunk.squeeze()
-
- # Convert to PCM16
- pcm16 = (wav_chunk * 32767).astype(np.int16)
- total_samples += len(pcm16)
-
- # Send binary audio chunk
- await websocket.send_bytes(pcm16.tobytes())
-
- # Yield to allow other tasks
- await asyncio.sleep(0)
- else:
- # Non-streaming fallback - model returned full result
- wavs, sample_rate = generator if isinstance(generator, tuple) else (generator, 24000)
- waveform = wavs[0] if isinstance(wavs, list) else wavs
-
- if torch.is_tensor(waveform):
- waveform = waveform.cpu().numpy()
-
- if waveform.ndim > 1:
- waveform = waveform.squeeze()
-
- # Send in chunks for better streaming behavior
- chunk_size = sample_rate // 4 # 250ms chunks
- for i in range(0, len(waveform), chunk_size):
- chunk = waveform[i:i + chunk_size]
- pcm16 = (chunk * 32767).astype(np.int16)
- total_samples += len(pcm16)
- await websocket.send_bytes(pcm16.tobytes())
- await asyncio.sleep(0)
-
- duration = total_samples / sample_rate
- elapsed = time.time() - start
- print(f"Streamed {duration:.2f}s audio in {elapsed:.2f}s (RTF: {elapsed/duration:.2f})")
-
- # Send completion message
- await websocket.send_json({
- "type": "audio_end",
- "sample_rate": sample_rate,
- "duration_seconds": duration,
- })
-
- except WebSocketDisconnect:
- print("Client disconnected during streaming")
- except Exception as e:
- import traceback
- traceback.print_exc()
- try:
- await websocket.send_json({"type": "error", "message": str(e)})
- except:
- pass
- finally:
- try:
- await websocket.close()
- except:
- pass
-
-
-if __name__ == "__main__":
- import uvicorn
- port = int(os.environ.get("PORT", "8100"))
- uvicorn.run(
- app,
- host="0.0.0.0",
- port=port,
- # Increase keep-alive timeout to avoid connection resets
- timeout_keep_alive=120,
- )