diff --git a/.sqlx/query-8646787e97a2e9c59fa9a1f2f510a48574aa64161b0ccc37be8bb58deaff5783.json b/.sqlx/query-8646787e97a2e9c59fa9a1f2f510a48574aa64161b0ccc37be8bb58deaff5783.json new file mode 100644 index 0000000..1ad93a8 --- /dev/null +++ b/.sqlx/query-8646787e97a2e9c59fa9a1f2f510a48574aa64161b0ccc37be8bb58deaff5783.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE durable.task\n SET state = 'ready',\n running_on = NULL\n WHERE id = $1\n AND running_on = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "8646787e97a2e9c59fa9a1f2f510a48574aa64161b0ccc37be8bb58deaff5783" +} diff --git a/.sqlx/query-cf31973a833c40d4d43aa01102e38178092c2ba065b7c7f367ce1cd2a90d2dce.json b/.sqlx/query-cf31973a833c40d4d43aa01102e38178092c2ba065b7c7f367ce1cd2a90d2dce.json new file mode 100644 index 0000000..a2bfd3a --- /dev/null +++ b/.sqlx/query-cf31973a833c40d4d43aa01102e38178092c2ba065b7c7f367ce1cd2a90d2dce.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE durable.task\n SET state = 'ready',\n running_on = NULL\n WHERE id = ANY($1::bigint[])\n AND running_on = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "cf31973a833c40d4d43aa01102e38178092c2ba065b7c7f367ce1cd2a90d2dce" +} diff --git a/crates/durable-runtime/src/worker.rs b/crates/durable-runtime/src/worker.rs index 56e2915..1006f9f 100644 --- a/crates/durable-runtime/src/worker.rs +++ b/crates/durable-runtime/src/worker.rs @@ -13,7 +13,7 @@ use rand::Rng; use serde_json::value::RawValue; use sqlx::postgres::PgNotification; use sqlx::types::Json; -use tokio::sync::{broadcast, Mutex, Notify, Semaphore}; +use tokio::sync::{broadcast, mpsc, Mutex, Notify, Semaphore}; use tokio::task::JoinSet; use tokio::time::Instant; use tracing::Instrument; @@ -344,7 +344,8 @@ impl Worker { // // How we do things here is that we order workers by id, each worker looks at // the one just in front of it and schedules a liveness check for just after - // that worker would expire. The exception to this is the + // that worker would expire. The worker with the oldest ID is then responsible + // for checking the one with the newest ID. let _guard = ShutdownGuard::new(&shared.shutdown); let mut shutdown = std::pin::pin!(shared.shutdown.wait()); @@ -559,8 +560,9 @@ impl Worker { let shutdown = self.shared.shutdown.clone(); let _guard = ShutdownGuard::new(&shutdown); let mut shutdown = std::pin::pin!(shutdown.wait()); + let (tx, mut rx) = tokio::sync::mpsc::channel::(1024); - self.spawn_new_tasks().await?; + self.spawn_new_tasks(&tx).await?; self.load_leader_id().await?; 'outer: loop { @@ -568,20 +570,50 @@ impl Worker { biased; _ = shutdown.as_mut() => break 'outer, - _ = self.tasks.join_next(), if !self.tasks.is_empty() => None, - event = self.event_source.next() => Some(event?), + _ = self.tasks.join_next(), if !self.tasks.is_empty() => LoopEvent::TaskComplete, + id = rx.recv() => LoopEvent::TaskFailed(id.expect("failed task channel closed unexpectedly")), + event = self.event_source.next() => LoopEvent::Event(event?), }; // Clean up any tasks that have completed already. while self.tasks.try_join_next().is_some() {} let event = match event { - Some(event) => event, - None => { + LoopEvent::Event(event) => event, + LoopEvent::TaskComplete => { if self.blocked { - self.spawn_new_tasks().await?; + self.spawn_new_tasks(&tx).await?; + } + + continue; + } + LoopEvent::TaskFailed(id) => { + let mut failed = vec![id]; + + let mut count = 0; + while let Ok(id) = rx.try_recv() { + failed.push(id); + + count += 1; + if count > 1024 { + break; + } } + sqlx::query!( + " + UPDATE durable.task + SET state = 'ready', + running_on = NULL + WHERE id = ANY($1::bigint[]) + AND running_on = $2 + ", + &failed, + self.worker_id + ) + .execute(&self.shared.pool) + .await?; + continue; } }; @@ -595,7 +627,7 @@ impl Worker { running_on: Some(id), .. }) if id != self.worker_id => (), - Event::Task(_) => self.spawn_new_tasks().await?, + Event::Task(_) => self.spawn_new_tasks(&tx).await?, Event::TaskSuspend(_) => { self.shared.suspend.notify_waiters(); } @@ -613,7 +645,7 @@ impl Worker { // We don't know what we missed so do everything. Event::Lagged => { - self.spawn_new_tasks().await?; + self.spawn_new_tasks(&tx).await?; self.load_leader_id().await?; self.shared.suspend.notify_waiters(); } @@ -647,7 +679,7 @@ impl Worker { /// Spawn all new tasks that are scheduled on this server and also those /// that aren't scheduled on any server. - async fn spawn_new_tasks(&mut self) -> anyhow::Result<()> { + async fn spawn_new_tasks(&mut self, failure: &mpsc::Sender) -> anyhow::Result<()> { let max_tasks = self.shared.config.max_tasks; let allowed = max_tasks.saturating_sub(self.tasks.len()); if allowed == 0 { @@ -709,6 +741,7 @@ impl Worker { let shared = self.shared.clone(); let engine = self.engine.clone(); let worker_id = self.worker_id; + let failures = failure.clone(); tracing::trace!( target: "durable_runtime::worker::spawn_new_tasks", @@ -723,6 +756,10 @@ impl Worker { .await { tracing::error!(task_id, "worker task exited with an error: {e}"); + + // An error here means we are already shutting down and normal shutdown recovery + // should take care of any remaining tasks. + let _ = failures.send(task_id).await; } }; @@ -750,13 +787,52 @@ impl Worker { ) -> anyhow::Result<()> { let task_id = task.id; - let status = - match AssertUnwindSafe(Self::run_task_impl(shared.clone(), engine, task, worker_id)) - .catch_unwind() - .await - { + // We are using the loop here to do some early breaks. + #[allow(clippy::never_loop)] + let status = loop { + let future = Self::run_task_impl(shared.clone(), engine, task, worker_id); + break match AssertUnwindSafe(future).catch_unwind().await { Ok(Ok(status)) => status, Ok(Err(error)) => { + match find_sqlx_error(&error) { + // These errors are external to the runtime and should usually be resolvable + // if the workflow is retried somewhere else at a later point in time. + // + // This should also help to reduce the number of workflow aborts that are + // "not the fault" of the workflow itself. These will (eventually) instead + // get turned into worker errors, which can be handled at a higher level. + Some( + sqlx::Error::PoolTimedOut + | sqlx::Error::WorkerCrashed + | sqlx::Error::Io(_), + ) => { + // Attempt to reset the task state so it can be picked up again. + // + // If this fails then the task failure gets reported to the main event + // loop which can ensure it gets retried. + sqlx::query!( + " + UPDATE durable.task + SET state = 'ready', + running_on = NULL + WHERE id = $1 + AND running_on = $2 + ", + task_id, + worker_id + ) + .execute(&shared.pool) + .await?; + + break TaskStatus::Suspend; + } + Some(sqlx::Error::PoolClosed) => { + // Nothing we can do, since we can't make database queries. + break TaskStatus::Suspend; + } + _ => (), + } + let message = format!("{error:?}\n"); let result = sqlx::query!( @@ -803,6 +879,7 @@ impl Worker { TaskStatus::ExitFailure } }; + }; match status { TaskStatus::NotScheduledOnWorker => { @@ -1087,3 +1164,13 @@ impl EventSource for PgEventSource { } } } + +enum LoopEvent { + Event(Event), + TaskComplete, + TaskFailed(i64), +} + +fn find_sqlx_error(error: &anyhow::Error) -> Option<&sqlx::Error> { + error.chain().filter_map(|e| e.downcast_ref()).next() +}