summaryrefslogblamecommitdiff
path: root/src/main.rs
blob: a2bf354d0f0358599d0ed3056d9748efade52229 (plain) (tree)











































































































































































                                                                                                                      
use std::{convert::Infallible, env, net::SocketAddr, time::Duration};

use axum::{
    extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State},
    http::Method,
    response::IntoResponse,
    routing::{get, post},
    Json, Router,
};
use axum::response::sse::{Event, KeepAlive, Sse};
use serde::{Deserialize, Serialize};
use tower_http::{cors::{Any, CorsLayer}, trace::TraceLayer};
use tracing::info;
use tracing_subscriber::EnvFilter;
use tokio_stream::{wrappers::IntervalStream, StreamExt as _};
use futures::StreamExt as _; // for websocket split/next
use futures::stream::Stream; // for SSE return type

#[derive(Clone, Default)]
struct AppState {}

#[derive(Serialize)]
struct HealthResponse {
    status: &'static str,
}

#[derive(Serialize)]
struct HelloResponse {
    message: &'static str,
}

#[derive(Deserialize, Serialize)]
struct EchoPayload {
    #[serde(flatten)]
    rest: serde_json::Value,
}

#[tokio::main]
async fn main() {
    // Logging setup
    let filter = EnvFilter::try_from_default_env()
        .unwrap_or_else(|_| EnvFilter::new("info,tower_http=info,axum::rejection=trace"));
    tracing_subscriber::fmt()
        .with_env_filter(filter)
        .with_target(false)
        .compact()
        .init();

    // Shared app state (extend as needed)
    let state = AppState::default();

    // CORS to allow local frontend dev at 5173 and others
    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
        .allow_headers(Any);

    // Router
    let app = Router::new()
        .route("/", get(root))
        .route("/health", get(health))
        .route("/api/hello", get(hello))
        .route("/api/echo", post(echo))
        .route("/api/stream/transcript", get(stream_transcript_sse))
        .route("/ws", get(ws_handler))
        .with_state(state)
        .layer(cors)
        .layer(TraceLayer::new_for_http());

    // Bind address
    let port = env::var("PORT").ok().and_then(|p| p.parse().ok()).unwrap_or(8080);
    let addr = SocketAddr::from(([0, 0, 0, 0], port));
    info!(%addr, "starting soryu backend");
    let listener = tokio::net::TcpListener::bind(addr).await.expect("bind port");
    axum::serve(listener, app).await.expect("server error");
}

async fn root() -> impl IntoResponse {
    Json(serde_json::json!({
        "service": "soryu-backend",
        "endpoints": [
            "/health",
            "/api/hello",
            "/api/echo",
            "/api/stream/transcript",
            "/ws"
        ],
    }))
}

async fn health() -> impl IntoResponse {
    Json(HealthResponse { status: "ok" })
}

async fn hello(State(_state): State<AppState>) -> impl IntoResponse {
    Json(HelloResponse { message: "Hello from Soryu backend" })
}

async fn echo(Json(body): Json<serde_json::Value>) -> impl IntoResponse {
    Json(EchoPayload { rest: body })
}

// ---
// Streaming transcript (SSE) skeleton
// Provides a low-latency server-sent events stream that emits example
// transcript chunks every ~250ms.
async fn stream_transcript_sse(State(_state): State<AppState>) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    const SAMPLE_LINES: &[&str] = &[
        "speaker_a: hey there, can you hear me?",
        "speaker_b: loud and clear — let's begin.",
        "speaker_a: streaming transcript looks smooth so far.",
        "speaker_b: agreed, latency feels low.",
        "speaker_a: wrapping up the demo now.",
    ];

    let mut idx = 0usize;
    let interval = tokio::time::interval(Duration::from_millis(250));
    let stream = IntervalStream::new(interval).map(move |_| {
        let line = SAMPLE_LINES[idx % SAMPLE_LINES.len()];
        idx += 1;
        let data = serde_json::json!({
            "type": "transcript_chunk",
            "text": line,
            "ts_ms": chrono::Utc::now().timestamp_millis(),
        });
        Ok(Event::default().json_data(&data).unwrap())
    });

    Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10)))
}

// ---
// WebSocket transcript stream skeleton
// Sends example transcript messages on connect; ignores inbound messages.
async fn ws_handler(ws: WebSocketUpgrade, State(_state): State<AppState>) -> impl IntoResponse {
    ws.on_upgrade(handle_socket)
}

async fn handle_socket(mut socket: WebSocket) {
    const SAMPLE_LINES: &[&str] = &[
        "speaker_a: hey there, can you hear me?",
        "speaker_b: loud and clear — let's begin.",
        "speaker_a: streaming transcript looks smooth so far.",
        "speaker_b: agreed, latency feels low.",
        "speaker_a: wrapping up the demo now.",
    ];

    // Spawn a task to drain inbound messages (optional for skeleton)
    let mut recv_socket = socket.split().1;
    tokio::spawn(async move {
        while let Some(Ok(_msg)) = recv_socket.next().await {
            // Ignore inbound for skeleton; handle pings/acks here if needed
        }
    });

    // Send a small transcript stream
    for line in SAMPLE_LINES {
        let payload = serde_json::json!({
            "type": "transcript_chunk",
            "text": line,
            "ts_ms": chrono::Utc::now().timestamp_millis(),
        })
        .to_string();
        if socket.send(Message::Text(payload)).await.is_err() {
            return;
        }
        tokio::time::sleep(Duration::from_millis(250)).await;
    }

    let done = serde_json::json!({"type":"done"}).to_string();
    let _ = socket.send(Message::Text(done)).await;
}