summaryrefslogblamecommitdiff
path: root/makima/src/server/messages.rs
blob: cecb622af3bf376f5d0f3460eda473a51f392f4f (plain) (tree)


























                                                                  





                                                                        









































                                                             






                                                             






















                                                                             
































































































































































                                                                                   
//! WebSocket and API message types for the makima server.

use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

/// Audio encoding format for WebSocket streaming.
#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema)]
#[serde(rename_all = "lowercase")]
pub enum AudioEncoding {
    /// 32-bit floating point PCM samples
    Pcm32f,
    /// 16-bit signed integer PCM samples
    Pcm16,
    /// Raw bytes (will be interpreted as PCM16)
    Raw,
}

/// Initial handshake message from client specifying audio format.
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct StartMessage {
    /// Audio sample rate in Hz (e.g., 16000, 44100, 48000)
    pub sample_rate: u32,
    /// Number of audio channels (1 for mono, 2 for stereo)
    pub channels: u16,
    /// Audio encoding format
    pub encoding: AudioEncoding,
    /// Optional contract ID to save transcript to (requires auth_token)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub contract_id: Option<String>,
    /// Optional auth token (JWT) for authenticated sessions
    #[serde(skip_serializing_if = "Option::is_none")]
    pub auth_token: Option<String>,
}

/// Stop message to terminate the session.
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct StopMessage {
    /// Optional reason for stopping
    pub reason: Option<String>,
}

/// Wrapper for all WebSocket messages from client to server.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum ClientMessage {
    Start(StartMessage),
    Stop(StopMessage),
}

/// Transcription result message sent from server to client.
#[derive(Debug, Clone, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TranscriptMessage {
    /// Speaker identifier (e.g., "Speaker 0", "Speaker 1")
    pub speaker: String,
    /// Segment start time in seconds
    pub start: f32,
    /// Segment end time in seconds
    pub end: f32,
    /// Transcribed text
    pub text: String,
    /// Whether this is a final or interim result
    pub is_final: bool,
}

/// Wrapper for all WebSocket messages from server to client.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum ServerMessage {
    /// Session is ready for audio streaming
    Ready { session_id: String },
    /// Transcription result
    Transcript(TranscriptMessage),
    /// Transcript has been saved to a file
    TranscriptSaved {
        /// The ID of the file where the transcript was saved
        file_id: String,
        /// The ID of the contract the file belongs to
        contract_id: String,
    },
    /// Error occurred during processing
    Error { code: String, message: String },
    /// Session has been stopped
    Stopped { reason: String },
}

/// Error response for HTTP API endpoints.
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ApiError {
    /// Error code for programmatic handling
    pub code: String,
    /// Human-readable error message
    pub message: String,
}

impl ApiError {
    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
        Self {
            code: code.into(),
            message: message.into(),
        }
    }
}

// =============================================================================
// TTS (Text-to-Speech) Message Types
// =============================================================================

/// TTS audio encoding format for WebSocket streaming.
#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
pub enum TtsAudioEncoding {
    /// 16-bit signed integer PCM samples
    #[default]
    Pcm16,
    /// 32-bit floating point PCM samples
    Pcm32f,
}

/// TTS synthesis priority level.
#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
pub enum TtsPriority {
    /// Low priority - may be queued
    Low,
    /// Normal priority (default)
    #[default]
    Normal,
    /// High priority - processed immediately
    High,
}

/// TTS session start message from client.
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsStartMessage {
    /// Audio sample rate in Hz (default: 24000)
    #[serde(default = "default_tts_sample_rate")]
    pub sample_rate: u32,
    /// Audio encoding format
    #[serde(default)]
    pub encoding: TtsAudioEncoding,
    /// Voice identifier (default: "makima")
    #[serde(default = "default_tts_voice")]
    pub voice: String,
    /// Language for synthesis (default: "English")
    #[serde(default = "default_tts_language")]
    pub language: String,
}

fn default_tts_sample_rate() -> u32 {
    24000
}

fn default_tts_voice() -> String {
    "makima".to_string()
}

fn default_tts_language() -> String {
    "English".to_string()
}

/// TTS speak request message from client.
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsSpeakMessage {
    /// Text to synthesize (max 1000 characters)
    pub text: String,
    /// Synthesis priority
    #[serde(default)]
    pub priority: TtsPriority,
}

/// TTS stop request message from client.
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsStopMessage {
    /// Optional reason for stopping
    pub reason: Option<String>,
}

/// Wrapper for all TTS WebSocket messages from client to server.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TtsClientMessage {
    /// Start a new TTS session
    Start(TtsStartMessage),
    /// Request speech synthesis
    Speak(TtsSpeakMessage),
    /// Stop the current session
    Stop(TtsStopMessage),
}

/// TTS session ready message sent from server to client.
#[derive(Debug, Clone, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsReadyMessage {
    /// Unique session identifier
    pub session_id: String,
    /// Confirmed sample rate
    pub sample_rate: u32,
    /// Confirmed encoding format
    pub encoding: TtsAudioEncoding,
    /// Confirmed voice
    pub voice: String,
}

/// TTS audio chunk message sent from server to client.
#[derive(Debug, Clone, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsAudioChunkMessage {
    /// Base64-encoded audio data
    pub data: String,
    /// Whether this is the final chunk
    pub is_final: bool,
    /// Timestamp in seconds from start of audio
    pub timestamp: f64,
}

/// TTS synthesis complete message sent from server to client.
#[derive(Debug, Clone, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsCompleteMessage {
    /// Total synthesis duration in milliseconds
    pub duration_ms: u64,
    /// Total number of chunks sent
    pub total_chunks: u32,
    /// Length of input text
    pub text_length: u32,
}

/// TTS error message sent from server to client.
#[derive(Debug, Clone, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsErrorMessage {
    /// Error code for programmatic handling
    pub code: String,
    /// Human-readable error message
    pub message: String,
}

/// TTS session stopped message sent from server to client.
#[derive(Debug, Clone, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct TtsStoppedMessage {
    /// Reason for stopping
    pub reason: String,
}

/// Wrapper for all TTS WebSocket messages from server to client.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TtsServerMessage {
    /// Session is ready for synthesis requests
    Ready(TtsReadyMessage),
    /// Audio chunk (streamed during synthesis)
    AudioChunk(TtsAudioChunkMessage),
    /// Synthesis completed
    Complete(TtsCompleteMessage),
    /// Error occurred
    Error(TtsErrorMessage),
    /// Session has been stopped
    Stopped(TtsStoppedMessage),
}