From 9b53f6c6b01da85ef73bd5960b32ec319df0b947 Mon Sep 17 00:00:00 2001 From: soryu Date: Wed, 28 Jan 2026 03:50:45 +0000 Subject: Replace TTS endpoint with Rust-native Qwen3-TTS (#41) * chore: fix unused import warnings in qwen3-tts module - Remove unused import 'IndexOp' in model.rs - Remove unused import 'DType' in speech_tokenizer.rs - Add #[allow(dead_code)] to codebook_dim field in RvqCodebook Co-Authored-By: Claude Opus 4.5 * feat: add voice loading and selection for TTS cloning Add voice reference audio loading so the TTS speak handler can perform voice cloning using reference WAV files from the voices/ directory. - Add voice.rs module: loads manifest.json and reference.wav for a given voice_id, decodes via symphonia, resamples to 24kHz for the TTS engine - Update speak.rs: resolve voice_id from the speak request (default "makima"), load reference audio, pass it to engine.generate() - Add voices/makima/README.md with instructions for obtaining reference audio (extraction from YouTube, recording, ffmpeg conversion) - Graceful fallback: if reference audio is missing, TTS proceeds without voice cloning using the model's default voice Co-Authored-By: Claude Opus 4.5 * [WIP] Heartbeat checkpoint - 2026-01-28 03:49:13 UTC --------- Co-authored-by: Claude Opus 4.5 --- makima/src/server/handlers/mod.rs | 1 + makima/src/server/handlers/speak.rs | 77 ++++++++++- makima/src/server/handlers/voice.rs | 252 ++++++++++++++++++++++++++++++++++++ 3 files changed, 323 insertions(+), 7 deletions(-) create mode 100644 makima/src/server/handlers/voice.rs (limited to 'makima/src/server/handlers') diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs index 8207399..8af2a37 100644 --- a/makima/src/server/handlers/mod.rs +++ b/makima/src/server/handlers/mod.rs @@ -19,6 +19,7 @@ pub mod mesh_ws; pub mod repository_history; pub mod speak; pub mod templates; +pub mod voice; pub mod transcript_analysis; pub mod users; pub mod versions; diff --git a/makima/src/server/handlers/speak.rs b/makima/src/server/handlers/speak.rs index 75e7780..3ed2620 100644 --- a/makima/src/server/handlers/speak.rs +++ b/makima/src/server/handlers/speak.rs @@ -15,6 +15,9 @@ //! See `makima/src/tts/` for the TTS engine implementation. //! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification. +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + use axum::{ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, response::Response, @@ -32,9 +35,9 @@ enum ClientMessage { /// Request speech synthesis for the given text. Speak { text: String, - /// Optional voice ID (e.g., "makima"). Not yet used — reserved for future voice selection. + /// Optional voice ID (e.g., "makima"). Used to load reference audio for voice cloning. + /// Defaults to "makima" if not specified. #[serde(default)] - #[allow(dead_code)] voice: Option, }, /// Cancel any in-progress synthesis. @@ -76,6 +79,10 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { let (mut sender, mut receiver) = socket.split(); + // Cancellation flag shared between the message loop and inference. + // Each new Speak request resets it to false; Cancel sets it to true. + let cancel_flag: Arc = Arc::new(AtomicBool::new(false)); + // Process incoming messages while let Some(msg) = receiver.next().await { let msg = match msg { @@ -102,13 +109,41 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { }; match client_msg { - ClientMessage::Speak { text, .. } => { + ClientMessage::Speak { text, voice } => { + let voice_id = voice + .as_deref() + .unwrap_or(super::voice::DEFAULT_VOICE_ID); + tracing::info!( session_id = %session_id, text_len = text.len(), + voice_id = %voice_id, "TTS speak request" ); + // Load voice reference audio for cloning + let voice_ref = match super::voice::load_reference_audio(voice_id) { + Ok(v) => { + tracing::debug!( + session_id = %session_id, + voice_id = %voice_id, + voice_name = %v.manifest.name, + samples = v.samples.len(), + "Voice reference loaded" + ); + Some(v) + } + Err(e) => { + tracing::warn!( + session_id = %session_id, + voice_id = %voice_id, + error = %e, + "Failed to load voice reference, proceeding without cloning" + ); + None + } + }; + // Get or lazily load the TTS engine let engine = match state.get_tts_engine().await { Ok(e) => e, @@ -138,9 +173,21 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { continue; } - // Run TTS inference (no voice reference for now — uses default) - match engine.generate(&text, None, None).await { + // Reset the cancel flag for this new generation request + cancel_flag.store(false, Ordering::Relaxed); + + // Run TTS inference with optional voice reference for cloning + // and the cancel flag so it can be stopped early. + let (ref_audio, ref_rate) = match &voice_ref { + Some(v) => (Some(v.samples.as_slice()), Some(v.sample_rate)), + None => (None, None), + }; + let flag = cancel_flag.clone(); + match engine.generate(&text, ref_audio, ref_rate, Some(flag)).await { Ok(chunks) => { + // Check if generation was cancelled + let was_cancelled = cancel_flag.load(Ordering::Relaxed); + for chunk in &chunks { // Send binary PCM audio data let pcm_bytes = chunk.to_pcm16_bytes(); @@ -157,12 +204,13 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { } } - // Signal end of audio + // Signal end of audio (include cancelled status) let end_msg = serde_json::json!({ "type": "audio_end", "sample_rate": engine.sample_rate(), "format": "pcm_s16le", "channels": 1, + "cancelled": was_cancelled, }); let _ = sender .send(Message::Text(end_msg.to_string().into())) @@ -185,16 +233,18 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { } ClientMessage::Cancel => { tracing::info!(session_id = %session_id, "TTS cancel requested"); - // TODO: support cancellation of in-progress inference + cancel_flag.store(true, Ordering::Relaxed); } ClientMessage::Stop => { tracing::info!(session_id = %session_id, "TTS stop requested, closing"); + cancel_flag.store(true, Ordering::Relaxed); break; } } } Message::Close(_) => { tracing::info!(session_id = %session_id, "TTS WebSocket closed by client"); + cancel_flag.store(true, Ordering::Relaxed); break; } _ => { @@ -271,4 +321,17 @@ mod tests { let msg: ClientMessage = serde_json::from_str(json).unwrap(); assert!(matches!(msg, ClientMessage::Stop)); } + + #[test] + fn test_client_message_parse_speak_with_voice() { + let json = r#"{"type": "speak", "text": "Hello", "voice": "makima"}"#; + let msg: ClientMessage = serde_json::from_str(json).unwrap(); + match msg { + ClientMessage::Speak { text, voice } => { + assert_eq!(text, "Hello"); + assert_eq!(voice.as_deref(), Some("makima")); + } + _ => panic!("Expected Speak message"), + } + } } diff --git a/makima/src/server/handlers/voice.rs b/makima/src/server/handlers/voice.rs new file mode 100644 index 0000000..91b650d --- /dev/null +++ b/makima/src/server/handlers/voice.rs @@ -0,0 +1,252 @@ +//! Voice loading utilities for TTS voice cloning. +//! +//! Loads voice manifests and reference audio from the `voices/` directory. +//! Each voice is a directory containing: +//! - `manifest.json` — voice metadata (name, sample rate, backend, etc.) +//! - `reference.wav` — reference audio clip for voice cloning (5-15s, 24kHz mono) + +use serde::Deserialize; +use std::path::{Path, PathBuf}; + +use crate::tts::{resample_to_24k, SAMPLE_RATE}; + +/// Default voice ID used when no voice is specified. +pub const DEFAULT_VOICE_ID: &str = "makima"; + +/// Voice manifest loaded from `voices/{voice_id}/manifest.json`. +#[derive(Debug, Clone, Deserialize)] +pub struct VoiceManifest { + pub name: String, + pub id: String, + #[serde(default)] + pub description: Option, + #[serde(default = "default_language")] + pub language: String, + #[serde(default)] + pub accent: Option, + #[serde(default = "default_sample_rate")] + pub sample_rate: u32, + #[serde(default)] + pub format: Option, + #[serde(default)] + pub model_backend: Option, + #[serde(default = "default_reference_audio")] + pub reference_audio: String, + #[serde(default)] + pub notes: Option, +} + +fn default_language() -> String { + "en".to_string() +} + +fn default_sample_rate() -> u32 { + 24_000 +} + +fn default_reference_audio() -> String { + "reference.wav".to_string() +} + +/// Loaded voice reference: manifest + decoded PCM samples at 24kHz. +#[derive(Debug, Clone)] +pub struct VoiceReference { + pub manifest: VoiceManifest, + /// PCM f32 samples resampled to 24kHz mono. + pub samples: Vec, + /// Always 24000 after resampling. + pub sample_rate: u32, +} + +/// Resolve the base directory for voice data. +/// +/// Looks for the `voices/` directory relative to the current working directory, +/// or falls back to the executable's directory. +fn voices_base_dir() -> PathBuf { + // Try current working directory first + let cwd = std::env::current_dir().unwrap_or_default(); + let cwd_voices = cwd.join("voices"); + if cwd_voices.is_dir() { + return cwd_voices; + } + + // Try relative to executable + if let Ok(exe) = std::env::current_exe() { + if let Some(exe_dir) = exe.parent() { + let exe_voices = exe_dir.join("voices"); + if exe_voices.is_dir() { + return exe_voices; + } + // Try one level up (common in target/debug layout) + if let Some(parent) = exe_dir.parent() { + let parent_voices = parent.join("voices"); + if parent_voices.is_dir() { + return parent_voices; + } + // Two levels up (target/debug -> project root) + if let Some(grandparent) = parent.parent() { + let gp_voices = grandparent.join("voices"); + if gp_voices.is_dir() { + return gp_voices; + } + } + } + } + } + + // Default: assume cwd/voices + cwd_voices +} + +/// Load a voice manifest from `voices/{voice_id}/manifest.json`. +pub fn load_manifest(voice_id: &str) -> Result { + let base = voices_base_dir(); + let manifest_path = base.join(voice_id).join("manifest.json"); + + if !manifest_path.exists() { + return Err(VoiceLoadError::NotFound(voice_id.to_string())); + } + + let data = std::fs::read_to_string(&manifest_path).map_err(|e| { + VoiceLoadError::Io(format!( + "failed to read manifest at {}: {e}", + manifest_path.display() + )) + })?; + + let manifest: VoiceManifest = serde_json::from_str(&data).map_err(|e| { + VoiceLoadError::InvalidManifest(format!("failed to parse manifest: {e}")) + })?; + + Ok(manifest) +} + +/// Load a voice's reference audio as f32 PCM samples resampled to 24kHz. +/// +/// Uses symphonia (via `crate::audio`) to decode the WAV file, then +/// resamples to 24kHz using `tts::resample_to_24k`. +pub fn load_reference_audio(voice_id: &str) -> Result { + let manifest = load_manifest(voice_id)?; + + let base = voices_base_dir(); + let audio_path = base.join(voice_id).join(&manifest.reference_audio); + + if !audio_path.exists() { + return Err(VoiceLoadError::MissingAudio(format!( + "reference audio not found at {}. See voices/{}/README.md for instructions.", + audio_path.display(), + voice_id, + ))); + } + + load_reference_audio_from_path(&audio_path, manifest) +} + +/// Load reference audio from a specific file path with a pre-loaded manifest. +fn load_reference_audio_from_path( + audio_path: &Path, + manifest: VoiceManifest, +) -> Result { + // Use symphonia-based decoder from crate::audio to decode the WAV + let pcm = crate::audio::to_16k_mono_from_path(audio_path).map_err(|e| { + VoiceLoadError::AudioDecode(format!("failed to decode {}: {e}", audio_path.display())) + })?; + + // The audio module decodes to 16kHz mono; we need 24kHz for TTS. + // Resample from 16kHz to 24kHz. + let samples = if pcm.sample_rate == SAMPLE_RATE { + pcm.samples + } else { + resample_to_24k(&pcm.samples, pcm.sample_rate) + }; + + tracing::info!( + voice_id = %manifest.id, + voice_name = %manifest.name, + samples_len = samples.len(), + duration_secs = samples.len() as f32 / SAMPLE_RATE as f32, + "Loaded voice reference audio" + ); + + Ok(VoiceReference { + manifest, + samples, + sample_rate: SAMPLE_RATE, + }) +} + +/// Errors that can occur when loading a voice. +#[derive(Debug)] +pub enum VoiceLoadError { + /// Voice directory not found. + NotFound(String), + /// IO error reading files. + Io(String), + /// Manifest JSON is invalid. + InvalidManifest(String), + /// Reference audio file is missing. + MissingAudio(String), + /// Failed to decode audio. + AudioDecode(String), +} + +impl std::fmt::Display for VoiceLoadError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VoiceLoadError::NotFound(id) => { + write!(f, "voice '{id}' not found (no voices/{id}/manifest.json)") + } + VoiceLoadError::Io(msg) => write!(f, "voice IO error: {msg}"), + VoiceLoadError::InvalidManifest(msg) => write!(f, "invalid voice manifest: {msg}"), + VoiceLoadError::MissingAudio(msg) => write!(f, "missing reference audio: {msg}"), + VoiceLoadError::AudioDecode(msg) => write!(f, "audio decode error: {msg}"), + } + } +} + +impl std::error::Error for VoiceLoadError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_voice_id() { + assert_eq!(DEFAULT_VOICE_ID, "makima"); + } + + #[test] + fn test_manifest_deserialize() { + let json = r#"{ + "name": "Test Voice", + "id": "test", + "sample_rate": 24000, + "reference_audio": "reference.wav" + }"#; + let manifest: VoiceManifest = serde_json::from_str(json).unwrap(); + assert_eq!(manifest.name, "Test Voice"); + assert_eq!(manifest.id, "test"); + assert_eq!(manifest.sample_rate, 24000); + assert_eq!(manifest.reference_audio, "reference.wav"); + assert_eq!(manifest.language, "en"); + } + + #[test] + fn test_manifest_deserialize_defaults() { + let json = r#"{"name": "Minimal", "id": "min"}"#; + let manifest: VoiceManifest = serde_json::from_str(json).unwrap(); + assert_eq!(manifest.language, "en"); + assert_eq!(manifest.sample_rate, 24000); + assert_eq!(manifest.reference_audio, "reference.wav"); + } + + #[test] + fn test_load_nonexistent_voice() { + let result = load_manifest("nonexistent_voice_xyz"); + assert!(result.is_err()); + match result.unwrap_err() { + VoiceLoadError::NotFound(id) => assert_eq!(id, "nonexistent_voice_xyz"), + other => panic!("Expected NotFound, got: {other}"), + } + } +} -- cgit v1.2.3 From f6a40e2304585f140ed5766b25fe71a6958f4425 Mon Sep 17 00:00:00 2001 From: soryu Date: Thu, 29 Jan 2026 01:14:17 +0000 Subject: Fix makima supervisor pr CLI command --- makima/src/bin/makima.rs | 4 +-- makima/src/daemon/api/supervisor.rs | 9 ++---- makima/src/daemon/cli/supervisor.rs | 8 ++---- makima/src/daemon/task/manager.rs | 40 +++++++++++--------------- makima/src/daemon/ws/protocol.rs | 4 ++- makima/src/server/handlers/mesh_supervisor.rs | 41 +++++++++++++-------------- makima/src/server/state.rs | 4 ++- 7 files changed, 49 insertions(+), 61 deletions(-) (limited to 'makima/src/server/handlers') diff --git a/makima/src/bin/makima.rs b/makima/src/bin/makima.rs index 44fa590..8e83565 100644 --- a/makima/src/bin/makima.rs +++ b/makima/src/bin/makima.rs @@ -439,10 +439,10 @@ async fn run_supervisor( } SupervisorCommand::Pr(args) => { let client = ApiClient::new(args.common.api_url, args.common.api_key)?; - eprintln!("Creating PR for task {}...", args.task_id); + eprintln!("Creating PR for branch {}...", args.branch); let body = args.body.as_deref().unwrap_or(""); let result = client - .supervisor_pr(args.task_id, &args.title, body, &args.base) + .supervisor_pr(&args.branch, &args.title, body) .await?; println!("{}", serde_json::to_string(&result.0)?); } diff --git a/makima/src/daemon/api/supervisor.rs b/makima/src/daemon/api/supervisor.rs index 6b99de0..c841b21 100644 --- a/makima/src/daemon/api/supervisor.rs +++ b/makima/src/daemon/api/supervisor.rs @@ -54,10 +54,9 @@ pub struct MergeRequest { #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct CreatePrRequest { - pub task_id: Uuid, + pub branch: String, pub title: String, pub body: String, - pub base_branch: String, } #[derive(Serialize)] @@ -165,16 +164,14 @@ impl ApiClient { /// Create a pull request. pub async fn supervisor_pr( &self, - task_id: Uuid, + branch: &str, title: &str, body: &str, - base_branch: &str, ) -> Result { let req = CreatePrRequest { - task_id, + branch: branch.to_string(), title: title.to_string(), body: body.to_string(), - base_branch: base_branch.to_string(), }; self.post("/api/v1/mesh/supervisor/pr", &req).await } diff --git a/makima/src/daemon/cli/supervisor.rs b/makima/src/daemon/cli/supervisor.rs index 09f61db..9ad7aef 100644 --- a/makima/src/daemon/cli/supervisor.rs +++ b/makima/src/daemon/cli/supervisor.rs @@ -128,9 +128,9 @@ pub struct PrArgs { #[command(flatten)] pub common: SupervisorArgs, - /// Task ID to create PR for + /// Branch name to create PR from (e.g., "makima/feature-name") #[arg(index = 1)] - pub task_id: Uuid, + pub branch: String, /// PR title #[arg(long)] @@ -139,10 +139,6 @@ pub struct PrArgs { /// PR body/description #[arg(long)] pub body: Option, - - /// Base branch (default: main) - #[arg(long, default_value = "main")] - pub base: String, } /// Arguments for diff command. diff --git a/makima/src/daemon/task/manager.rs b/makima/src/daemon/task/manager.rs index f0da860..8c5f8d7 100644 --- a/makima/src/daemon/task/manager.rs +++ b/makima/src/daemon/task/manager.rs @@ -669,7 +669,7 @@ makima supervisor wait "$TASK_ID" makima supervisor merge "$TASK_ID" --to "makima/user-authentication" # Step 3: All tasks complete - create PR from makima branch -makima supervisor pr "makima/user-authentication" --title "Add user authentication" --base main +makima supervisor pr "makima/user-authentication" --title "Add user authentication" ``` ## Available Tools (via makima supervisor) @@ -701,7 +701,7 @@ makima supervisor branch [--from ] makima supervisor merge [--to ] [--squash] # Create a pull request -makima supervisor pr --title "Title" [--body "Body"] [--base main] +makima supervisor pr --title "Title" [--body "Body"] # View a task's diff makima supervisor diff @@ -838,7 +838,7 @@ Common deliverable IDs by phase: 3. **wait blocks until complete** - you MUST call this to know when a task finishes 4. **Never fire-and-forget** - always wait for each task before moving on 5. **Merge to your makima branch** - use `merge --to "makima/{name}"` to collect completed work -6. **Create PR when done** - use `pr "makima/{name}" --title "..." --base main` +6. **Create PR when done** - use `pr "makima/{name}" --title "..."` 7. **Ask when unsure** - use `ask` to get user feedback on decisions ## Standard Workflow @@ -849,7 +849,7 @@ Common deliverable IDs by phase: - `wait` - Block until complete - `merge --to "makima/{name}"` - Merge to branch 3. `ask "Ready to create PR?"` - Get user approval -4. `pr "makima/{name}" --title "..." --base main` - Create PR +4. `pr "makima/{name}" --title "..."` - Create PR ## Important Reminders @@ -875,7 +875,7 @@ When you receive an `[ACTION REQUIRED]` message from the system: After all tasks are "done" and merged, you MUST take the following actions: **If in execute phase:** -1. Create PR immediately: `makima supervisor pr "makima/{name}" --title "..." --base main` +1. Create PR immediately: `makima supervisor pr "makima/{name}" --title "..."` 2. After PR created: - Simple contract: Mark complete with `makima supervisor complete` - Specification contract: Advance to review with `makima supervisor advance-phase review` @@ -2016,14 +2016,16 @@ impl TaskManager { title, body, base_branch, + branch, } => { tracing::info!( task_id = %task_id, title = %title, base_branch = %base_branch, + branch = %branch, "Creating pull request" ); - self.handle_create_pr(task_id, title, body, base_branch).await?; + self.handle_create_pr(task_id, title, body, base_branch, branch).await?; } DaemonCommand::GetTaskDiff { task_id, @@ -3135,6 +3137,7 @@ impl TaskManager { title: String, body: Option, base_branch: String, + branch: String, ) -> Result<(), DaemonError> { // Get worktree path - this works even for completed tasks by scanning worktrees directory let worktree_path = match self.get_task_worktree_path(task_id).await { @@ -3153,30 +3156,19 @@ impl TaskManager { } }; - // Get base_branch from in-memory tasks if available (for fallback) - let task_base_branch = { - let tasks = self.tasks.read().await; - tasks.get(&task_id).and_then(|t| t.base_branch.clone()) - }; - - // Use task's base_branch if the provided one is the default "main" and task has a detected one - let effective_base_branch = if base_branch == "main" { - task_base_branch.unwrap_or(base_branch) - } else { - base_branch - }; - tracing::info!( task_id = %task_id, - effective_base_branch = %effective_base_branch, + base_branch = %base_branch, + branch = %branch, worktree_path = %worktree_path.display(), - "Creating PR with effective base branch" + "Creating PR" ); - // Push the current branch first + // Push the branch to origin + let push_refspec = format!("HEAD:refs/heads/{}", branch); let push_result = tokio::process::Command::new("git") .current_dir(&worktree_path) - .args(["push", "-u", "origin", "HEAD"]) + .args(["push", "-u", "origin", &push_refspec]) .output() .await; @@ -3195,7 +3187,7 @@ impl TaskManager { // Create PR using gh CLI let mut pr_cmd = tokio::process::Command::new("gh"); pr_cmd.current_dir(&worktree_path); - pr_cmd.args(["pr", "create", "--title", &title, "--base", &effective_base_branch]); + pr_cmd.args(["pr", "create", "--title", &title, "--base", &base_branch, "--head", &branch]); if let Some(ref body_text) = body { pr_cmd.args(["--body", body_text]); diff --git a/makima/src/daemon/ws/protocol.rs b/makima/src/daemon/ws/protocol.rs index bd13975..e971798 100644 --- a/makima/src/daemon/ws/protocol.rs +++ b/makima/src/daemon/ws/protocol.rs @@ -693,9 +693,11 @@ pub enum DaemonCommand { task_id: Uuid, title: String, body: Option, - /// Base branch for the PR (default: main). + /// Base branch for the PR. #[serde(rename = "baseBranch")] base_branch: String, + /// Source branch name to push and create PR from. + branch: String, }, /// Get the diff for a task's changes. diff --git a/makima/src/server/handlers/mesh_supervisor.rs b/makima/src/server/handlers/mesh_supervisor.rs index 016367f..a0a3a96 100644 --- a/makima/src/server/handlers/mesh_supervisor.rs +++ b/makima/src/server/handlers/mesh_supervisor.rs @@ -1267,15 +1267,9 @@ pub struct MergeTaskResponse { #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct CreatePRRequest { - pub task_id: Uuid, + pub branch: String, pub title: String, pub body: Option, - #[serde(default = "default_base_branch")] - pub base_branch: String, -} - -fn default_base_branch() -> String { - "main".to_string() } /// Response for PR creation. @@ -1513,48 +1507,53 @@ pub async fn create_pr( headers: HeaderMap, Json(request): Json, ) -> impl IntoResponse { - let (_supervisor_id, owner_id) = match verify_supervisor_auth(&state, &headers, None).await { + let (supervisor_id, _owner_id) = match verify_supervisor_auth(&state, &headers, None).await { Ok(ids) => ids, Err(e) => return e.into_response(), }; let pool = state.db_pool.as_ref().unwrap(); - // Get the target task - let task = match repository::get_task_for_owner(pool, request.task_id, owner_id).await { + // Get the supervisor's own task to find daemon and base_branch + let task = match repository::get_task(pool, supervisor_id).await { Ok(Some(t)) => t, Ok(None) => { return ( StatusCode::NOT_FOUND, - Json(ApiError::new("NOT_FOUND", "Task not found")), + Json(ApiError::new("NOT_FOUND", "Supervisor task not found")), ).into_response(); } Err(e) => { - tracing::error!(error = %e, "Failed to get task"); + tracing::error!(error = %e, "Failed to get supervisor task"); return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(ApiError::new("DB_ERROR", "Failed to get task")), + Json(ApiError::new("DB_ERROR", "Failed to get supervisor task")), ).into_response(); } }; - // Get daemon running the task + // Get daemon running the supervisor let Some(daemon_id) = task.daemon_id else { return ( StatusCode::SERVICE_UNAVAILABLE, - Json(ApiError::new("NO_DAEMON", "Task has no assigned daemon")), + Json(ApiError::new("NO_DAEMON", "Supervisor has no assigned daemon")), ).into_response(); }; + // Use base_branch from the task's repository config, falling back to "main" + let base_branch = task.base_branch.unwrap_or_else(|| "main".to_string()); + // Subscribe to PR results BEFORE sending the command let mut rx = state.pr_results.subscribe(); - // Send CreatePR command to daemon + // Send CreatePR command to daemon using the supervisor's task ID + // (the branch is in the supervisor's worktree) let cmd = DaemonCommand::CreatePR { - task_id: request.task_id, + task_id: supervisor_id, title: request.title.clone(), body: request.body.clone(), - base_branch: request.base_branch.clone(), + base_branch, + branch: request.branch.clone(), }; if let Err(e) = state.send_daemon_command(daemon_id, cmd).await { @@ -1571,7 +1570,7 @@ pub async fn create_pr( loop { match rx.recv().await { Ok(notification) => { - if notification.task_id == request.task_id { + if notification.task_id == supervisor_id { return Some(notification); } // Not our task, keep waiting @@ -1594,7 +1593,7 @@ pub async fn create_pr( ( status, Json(CreatePRResponse { - task_id: request.task_id, + task_id: supervisor_id, success: notification.success, message: notification.message, pr_url: notification.pr_url, @@ -1607,7 +1606,7 @@ pub async fn create_pr( ( StatusCode::GATEWAY_TIMEOUT, Json(CreatePRResponse { - task_id: request.task_id, + task_id: supervisor_id, success: false, message: "PR creation timed out waiting for daemon response".to_string(), pr_url: None, diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index bf8f6f2..041b101 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -461,9 +461,11 @@ pub enum DaemonCommand { task_id: Uuid, title: String, body: Option, - /// Base branch for the PR (default: main) + /// Base branch for the PR #[serde(rename = "baseBranch")] base_branch: String, + /// Source branch name to push and create PR from + branch: String, }, /// Get the diff for a task's changes -- cgit v1.2.3