mas_tasks/
new_queue.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{collections::HashMap, sync::Arc};
8
9use async_trait::async_trait;
10use chrono::{DateTime, Duration, Utc};
11use cron::Schedule;
12use mas_context::LogContext;
13use mas_data_model::Clock;
14use mas_storage::{
15    RepositoryAccess, RepositoryError,
16    queue::{InsertableJob, Job, JobMetadata, Worker},
17};
18use mas_storage_pg::{DatabaseError, PgRepository};
19use opentelemetry::{
20    KeyValue,
21    metrics::{Counter, Histogram, UpDownCounter},
22};
23use rand::{Rng, RngCore, distributions::Uniform};
24use serde::de::DeserializeOwned;
25use sqlx::{
26    Acquire, Either,
27    postgres::{PgAdvisoryLock, PgListener},
28};
29use thiserror::Error;
30use tokio::{task::JoinSet, time::Instant};
31use tokio_util::sync::CancellationToken;
32use tracing::{Instrument as _, Span};
33use tracing_opentelemetry::OpenTelemetrySpanExt as _;
34use ulid::Ulid;
35
36use crate::{METER, State};
37
38type JobPayload = serde_json::Value;
39
40#[derive(Clone)]
41pub struct JobContext {
42    pub id: Ulid,
43    pub metadata: JobMetadata,
44    pub queue_name: String,
45    pub attempt: usize,
46    pub start: Instant,
47    pub cancellation_token: CancellationToken,
48}
49
50impl JobContext {
51    pub fn span(&self) -> Span {
52        let span = tracing::info_span!(
53            parent: Span::none(),
54            "job.run",
55            job.id = %self.id,
56            job.queue.name = self.queue_name,
57            job.attempt = self.attempt,
58        );
59
60        span.add_link(self.metadata.span_context());
61
62        span
63    }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
67pub enum JobErrorDecision {
68    Retry,
69
70    #[default]
71    Fail,
72}
73
74impl std::fmt::Display for JobErrorDecision {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            Self::Retry => f.write_str("retry"),
78            Self::Fail => f.write_str("fail"),
79        }
80    }
81}
82
83#[derive(Debug, Error)]
84#[error("Job failed to run, will {decision}")]
85pub struct JobError {
86    decision: JobErrorDecision,
87    #[source]
88    error: anyhow::Error,
89}
90
91impl JobError {
92    pub fn retry<T: Into<anyhow::Error>>(error: T) -> Self {
93        Self {
94            decision: JobErrorDecision::Retry,
95            error: error.into(),
96        }
97    }
98
99    pub fn fail<T: Into<anyhow::Error>>(error: T) -> Self {
100        Self {
101            decision: JobErrorDecision::Fail,
102            error: error.into(),
103        }
104    }
105}
106
107pub trait FromJob {
108    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
109    where
110        Self: Sized;
111}
112
113impl<T> FromJob for T
114where
115    T: DeserializeOwned,
116{
117    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
118        serde_json::from_value(payload).map_err(Into::into)
119    }
120}
121
122#[async_trait]
123pub trait RunnableJob: Send + 'static {
124    async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>;
125
126    /// Allows the job to set a timeout for its execution. Jobs should then look
127    /// at the cancellation token passed in the [`JobContext`] to handle
128    /// graceful shutdowns.
129    fn timeout(&self) -> Option<std::time::Duration> {
130        None
131    }
132}
133
134fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
135    Box::new(job)
136}
137
138#[derive(Debug, Error)]
139pub enum QueueRunnerError {
140    #[error("Failed to setup listener")]
141    SetupListener(#[source] sqlx::Error),
142
143    #[error("Failed to start transaction")]
144    StartTransaction(#[source] sqlx::Error),
145
146    #[error("Failed to commit transaction")]
147    CommitTransaction(#[source] sqlx::Error),
148
149    #[error("Failed to acquire leader lock")]
150    LeaderLock(#[source] sqlx::Error),
151
152    #[error(transparent)]
153    Repository(#[from] RepositoryError),
154
155    #[error(transparent)]
156    Database(#[from] DatabaseError),
157
158    #[error("Invalid schedule expression")]
159    InvalidSchedule(#[from] cron::error::Error),
160
161    #[error("Worker is not the leader")]
162    NotLeader,
163}
164
165// When the worker waits for a notification, we still want to wake it up every
166// second. Because we don't want all the workers to wake up at the same time, we
167// add a random jitter to the sleep duration, so they effectively sleep between
168// 0.9 and 1.1 seconds.
169const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
170const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
171
172// How many jobs can we run concurrently
173const MAX_CONCURRENT_JOBS: usize = 10;
174
175// How many jobs can we fetch at once
176const MAX_JOBS_TO_FETCH: usize = 5;
177
178// How many attempts a job should be retried
179const MAX_ATTEMPTS: usize = 10;
180
181/// Returns the delay to wait before retrying a job
182///
183/// Uses an exponential backoff: 5s, 10s, 20s, 40s, 1m20s, 2m40s, 5m20s, 10m50s,
184/// 21m40s, 43m20s
185fn retry_delay(attempt: usize) -> Duration {
186    let attempt = u32::try_from(attempt).unwrap_or(u32::MAX);
187    Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
188}
189
190type JobResult = (std::time::Duration, Result<(), JobError>);
191type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
192
193/// This is a fake job we use to consume jobs from deprecated queues
194struct DeprecatedJob;
195
196#[async_trait]
197impl RunnableJob for DeprecatedJob {
198    async fn run(&self, _state: &State, context: JobContext) -> Result<(), JobError> {
199        tracing::warn!(
200            job.id = %context.id,
201            job.queue.name = context.queue_name,
202            "Consumed a job from a deprecated queue, which can happen after version upgrades. This did nothing other than removing the job from the queue."
203        );
204
205        Ok(())
206    }
207}
208
209struct ScheduleDefinition {
210    schedule_name: &'static str,
211    expression: Schedule,
212    queue_name: &'static str,
213    payload: serde_json::Value,
214}
215
216pub struct QueueWorker {
217    listener: PgListener,
218    registration: Worker,
219    am_i_leader: bool,
220    last_heartbeat: DateTime<Utc>,
221    cancellation_token: CancellationToken,
222    #[expect(dead_code, reason = "This is used on Drop")]
223    cancellation_guard: tokio_util::sync::DropGuard,
224    state: State,
225    schedules: Vec<ScheduleDefinition>,
226    tracker: JobTracker,
227    wakeup_reason: Counter<u64>,
228    tick_time: Histogram<u64>,
229}
230
231impl QueueWorker {
232    #[tracing::instrument(
233        name = "worker.init",
234        skip_all,
235        fields(worker.id)
236    )]
237    pub(crate) async fn new(
238        state: State,
239        cancellation_token: CancellationToken,
240    ) -> Result<Self, QueueRunnerError> {
241        let mut rng = state.rng();
242        let clock = state.clock();
243
244        let mut listener = PgListener::connect_with(&state.pool())
245            .await
246            .map_err(QueueRunnerError::SetupListener)?;
247
248        // We get notifications of leader stepping down on this channel
249        listener
250            .listen("queue_leader_stepdown")
251            .await
252            .map_err(QueueRunnerError::SetupListener)?;
253
254        // We get notifications when a job is available on this channel
255        listener
256            .listen("queue_available")
257            .await
258            .map_err(QueueRunnerError::SetupListener)?;
259
260        let txn = listener
261            .begin()
262            .await
263            .map_err(QueueRunnerError::StartTransaction)?;
264        let mut repo = PgRepository::from_conn(txn);
265
266        let registration = repo.queue_worker().register(&mut rng, clock).await?;
267        tracing::Span::current().record("worker.id", tracing::field::display(registration.id));
268        repo.into_inner()
269            .commit()
270            .await
271            .map_err(QueueRunnerError::CommitTransaction)?;
272
273        tracing::info!(worker.id = %registration.id, "Registered worker");
274        let now = clock.now();
275
276        let wakeup_reason = METER
277            .u64_counter("job.worker.wakeups")
278            .with_description("Counts how many time the worker has been woken up, for which reason")
279            .build();
280
281        // Pre-create the reasons on the counter
282        wakeup_reason.add(0, &[KeyValue::new("reason", "sleep")]);
283        wakeup_reason.add(0, &[KeyValue::new("reason", "task")]);
284        wakeup_reason.add(0, &[KeyValue::new("reason", "notification")]);
285
286        let tick_time = METER
287            .u64_histogram("job.worker.tick_duration")
288            .with_description(
289                "How much time the worker took to tick, including performing leader duties",
290            )
291            .build();
292
293        // We put a cancellation drop guard in the structure, so that when it gets
294        // dropped, we're sure to cancel the token
295        let cancellation_guard = cancellation_token.clone().drop_guard();
296
297        Ok(Self {
298            listener,
299            registration,
300            am_i_leader: false,
301            last_heartbeat: now,
302            cancellation_token,
303            cancellation_guard,
304            state,
305            schedules: Vec::new(),
306            tracker: JobTracker::new(),
307            wakeup_reason,
308            tick_time,
309        })
310    }
311
312    pub(crate) fn register_handler<T: RunnableJob + InsertableJob + FromJob>(
313        &mut self,
314    ) -> &mut Self {
315        // There is a potential panic here, which is fine as it's going to be caught
316        // within the job task
317        let factory = |payload: JobPayload| {
318            box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
319        };
320
321        self.tracker
322            .factories
323            .insert(T::QUEUE_NAME, Arc::new(factory));
324        self
325    }
326
327    /// Register a queue name as deprecated, which will consume leftover jobs
328    pub(crate) fn register_deprecated_queue(&mut self, queue_name: &'static str) -> &mut Self {
329        let factory = |_payload: JobPayload| box_runnable_job(DeprecatedJob);
330        self.tracker.factories.insert(queue_name, Arc::new(factory));
331        self
332    }
333
334    pub(crate) fn add_schedule<T: InsertableJob>(
335        &mut self,
336        schedule_name: &'static str,
337        expression: Schedule,
338        job: T,
339    ) -> &mut Self {
340        let payload = serde_json::to_value(job).expect("failed to serialize job payload");
341
342        self.schedules.push(ScheduleDefinition {
343            schedule_name,
344            expression,
345            queue_name: T::QUEUE_NAME,
346            payload,
347        });
348
349        self
350    }
351
352    pub(crate) async fn run(mut self) {
353        if let Err(e) = self.run_inner().await {
354            tracing::error!(
355                error = &e as &dyn std::error::Error,
356                "Failed to run new queue"
357            );
358        }
359    }
360
361    async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
362        self.setup_schedules().await?;
363
364        while !self.cancellation_token.is_cancelled() {
365            LogContext::new("worker-run-loop")
366                .run(|| self.run_loop())
367                .await?;
368        }
369
370        self.shutdown().await?;
371
372        Ok(())
373    }
374
375    #[tracing::instrument(name = "worker.setup_schedules", skip_all)]
376    pub(crate) async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
377        let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
378
379        // Start a transaction on the existing PgListener connection
380        let txn = self
381            .listener
382            .begin()
383            .await
384            .map_err(QueueRunnerError::StartTransaction)?;
385
386        let mut repo = PgRepository::from_conn(txn);
387
388        // Setup the entries in the queue_schedules table
389        repo.queue_schedule().setup(&schedules).await?;
390
391        repo.into_inner()
392            .commit()
393            .await
394            .map_err(QueueRunnerError::CommitTransaction)?;
395
396        Ok(())
397    }
398
399    #[tracing::instrument(name = "worker.run_loop", skip_all)]
400    async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
401        self.wait_until_wakeup().await?;
402
403        if self.cancellation_token.is_cancelled() {
404            return Ok(());
405        }
406
407        let start = Instant::now();
408        self.tick().await?;
409
410        if self.am_i_leader {
411            self.perform_leader_duties().await?;
412        }
413
414        let elapsed = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
415        self.tick_time.record(elapsed, &[]);
416
417        Ok(())
418    }
419
420    #[tracing::instrument(name = "worker.shutdown", skip_all)]
421    async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
422        tracing::info!("Shutting down worker");
423
424        let clock = self.state.clock();
425        let mut rng = self.state.rng();
426
427        // Start a transaction on the existing PgListener connection
428        let txn = self
429            .listener
430            .begin()
431            .await
432            .map_err(QueueRunnerError::StartTransaction)?;
433
434        let mut repo = PgRepository::from_conn(txn);
435
436        // Log about any job still running
437        match self.tracker.running_jobs() {
438            0 => {}
439            1 => tracing::warn!("There is one job still running, waiting for it to finish"),
440            n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"),
441        }
442
443        // TODO: we may want to introduce a timeout here, and abort the tasks if they
444        // take too long. It's fine for now, as we don't have long-running
445        // tasks, most of them are idempotent, and the only effect might be that
446        // the worker would 'dirtily' shutdown, meaning that its tasks would be
447        // considered, later retried by another worker
448
449        // Wait for all the jobs to finish
450        self.tracker
451            .process_jobs(&mut rng, clock, &mut repo, true)
452            .await?;
453
454        // Tell the other workers we're shutting down
455        // This also releases the leader election lease
456        repo.queue_worker()
457            .shutdown(clock, &self.registration)
458            .await?;
459
460        repo.into_inner()
461            .commit()
462            .await
463            .map_err(QueueRunnerError::CommitTransaction)?;
464
465        Ok(())
466    }
467
468    #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all)]
469    async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
470        let mut rng = self.state.rng();
471
472        // This is to make sure we wake up every second to do the maintenance tasks
473        // We add a little bit of random jitter to the duration, so that we don't get
474        // fully synced workers waking up at the same time after each notification
475        let sleep_duration = rng.sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
476        let wakeup_sleep = tokio::time::sleep(sleep_duration);
477
478        tokio::select! {
479            () = self.cancellation_token.cancelled() => {
480                tracing::debug!("Woke up from cancellation");
481            },
482
483            () = wakeup_sleep => {
484                tracing::debug!("Woke up from sleep");
485                self.wakeup_reason.add(1, &[KeyValue::new("reason", "sleep")]);
486            },
487
488            () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => {
489                tracing::debug!("Joined job task");
490                self.wakeup_reason.add(1, &[KeyValue::new("reason", "task")]);
491            },
492
493            notification = self.listener.recv() => {
494                self.wakeup_reason.add(1, &[KeyValue::new("reason", "notification")]);
495                match notification {
496                    Ok(notification) => {
497                        tracing::debug!(
498                            notification.channel = notification.channel(),
499                            notification.payload = notification.payload(),
500                            "Woke up from notification"
501                        );
502                    },
503                    Err(e) => {
504                        tracing::error!(error = &e as &dyn std::error::Error, "Failed to receive notification");
505                    },
506                }
507            },
508        }
509
510        Ok(())
511    }
512
513    #[tracing::instrument(
514        name = "worker.tick",
515        skip_all,
516        fields(worker.id = %self.registration.id),
517    )]
518    async fn tick(&mut self) -> Result<(), QueueRunnerError> {
519        tracing::debug!("Tick");
520        let clock = self.state.clock();
521        let mut rng = self.state.rng();
522        let now = clock.now();
523
524        // Start a transaction on the existing PgListener connection
525        let txn = self
526            .listener
527            .begin()
528            .await
529            .map_err(QueueRunnerError::StartTransaction)?;
530        let mut repo = PgRepository::from_conn(txn);
531
532        // We send a heartbeat every minute, to avoid writing to the database too often
533        // on a logged table
534        if now - self.last_heartbeat >= chrono::Duration::minutes(1) {
535            tracing::info!("Sending heartbeat");
536            repo.queue_worker()
537                .heartbeat(clock, &self.registration)
538                .await?;
539            self.last_heartbeat = now;
540        }
541
542        // Remove any dead worker leader leases
543        repo.queue_worker()
544            .remove_leader_lease_if_expired(clock)
545            .await?;
546
547        // Try to become (or stay) the leader
548        let leader = repo
549            .queue_worker()
550            .try_get_leader_lease(clock, &self.registration)
551            .await?;
552
553        // Process any job task which finished
554        self.tracker
555            .process_jobs(&mut rng, clock, &mut repo, false)
556            .await?;
557
558        // Compute how many jobs we should fetch at most
559        let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
560            .saturating_sub(self.tracker.running_jobs())
561            .max(MAX_JOBS_TO_FETCH);
562
563        if max_jobs_to_fetch == 0 {
564            tracing::warn!("Internal job queue is full, not fetching any new jobs");
565        } else {
566            // Grab a few jobs in the queue
567            let queues = self.tracker.queues();
568            let jobs = repo
569                .queue_job()
570                .reserve(clock, &self.registration, &queues, max_jobs_to_fetch)
571                .await?;
572
573            for Job {
574                id,
575                queue_name,
576                payload,
577                metadata,
578                attempt,
579            } in jobs
580            {
581                let cancellation_token = self.cancellation_token.child_token();
582                let start = Instant::now();
583                let context = JobContext {
584                    id,
585                    metadata,
586                    queue_name,
587                    attempt,
588                    start,
589                    cancellation_token,
590                };
591
592                self.tracker.spawn_job(self.state.clone(), context, payload);
593            }
594        }
595
596        // After this point, we are locking the leader table, so it's important that we
597        // commit as soon as possible to not block the other workers for too long
598        repo.into_inner()
599            .commit()
600            .await
601            .map_err(QueueRunnerError::CommitTransaction)?;
602
603        // Save the new leader state to log any change
604        if leader != self.am_i_leader {
605            // If we flipped state, log it
606            self.am_i_leader = leader;
607            if self.am_i_leader {
608                tracing::info!("I'm the leader now");
609            } else {
610                tracing::warn!("I am no longer the leader");
611            }
612        }
613
614        Ok(())
615    }
616
617    #[tracing::instrument(name = "worker.perform_leader_duties", skip_all)]
618    async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
619        // This should have been checked by the caller, but better safe than sorry
620        if !self.am_i_leader {
621            return Err(QueueRunnerError::NotLeader);
622        }
623
624        let clock = self.state.clock();
625        let mut rng = self.state.rng();
626
627        // Start a transaction on the existing PgListener connection
628        let txn = self
629            .listener
630            .begin()
631            .await
632            .map_err(QueueRunnerError::StartTransaction)?;
633
634        // The thing with the leader election is that it locks the table during the
635        // election, preventing other workers from going through the loop.
636        //
637        // Ideally, we would do the leader duties in the same transaction so that we
638        // make sure only one worker is doing the leader duties, but that
639        // would mean we would lock all the workers for the duration of the
640        // duties, which is not ideal.
641        //
642        // So we do the duties in a separate transaction, in which we take an advisory
643        // lock, so that in the very rare case where two workers think they are the
644        // leader, we still don't have two workers doing the duties at the same time.
645        let lock = PgAdvisoryLock::new("leader-duties");
646
647        let locked = lock
648            .try_acquire(txn)
649            .await
650            .map_err(QueueRunnerError::LeaderLock)?;
651
652        let locked = match locked {
653            Either::Left(locked) => locked,
654            Either::Right(txn) => {
655                tracing::error!("Another worker has the leader lock, aborting");
656                txn.rollback()
657                    .await
658                    .map_err(QueueRunnerError::CommitTransaction)?;
659                return Ok(());
660            }
661        };
662
663        let mut repo = PgRepository::from_conn(locked);
664
665        // Look at the state of schedules in the database
666        let schedules_status = repo.queue_schedule().list().await?;
667
668        let now = clock.now();
669        for schedule in &self.schedules {
670            // Find the schedule status from the database
671            let Some(status) = schedules_status
672                .iter()
673                .find(|s| s.schedule_name == schedule.schedule_name)
674            else {
675                tracing::error!(
676                    "Schedule {} was not found in the database",
677                    schedule.schedule_name
678                );
679                continue;
680            };
681
682            // Figure out if we should schedule a new job
683            if let Some(next_time) = status.last_scheduled_at {
684                if next_time > now {
685                    // We already have a job scheduled in the future, skip
686                    continue;
687                }
688
689                if status.last_scheduled_job_completed == Some(false) {
690                    // The last scheduled job has not completed yet, skip
691                    continue;
692                }
693            }
694
695            let next_tick = schedule.expression.after(&now).next().unwrap();
696
697            tracing::info!(
698                "Scheduling job for {}, next run at {}",
699                schedule.schedule_name,
700                next_tick
701            );
702
703            repo.queue_job()
704                .schedule_later(
705                    &mut rng,
706                    clock,
707                    schedule.queue_name,
708                    schedule.payload.clone(),
709                    serde_json::json!({}),
710                    next_tick,
711                    Some(schedule.schedule_name),
712                )
713                .await?;
714        }
715
716        // We also check if the worker is dead, and if so, we shutdown all the dead
717        // workers that haven't checked in the last two minutes
718        repo.queue_worker()
719            .shutdown_dead_workers(clock, Duration::minutes(2))
720            .await?;
721
722        // TODO: mark tasks those workers had as lost
723
724        // Mark all the scheduled jobs as available
725        let scheduled = repo.queue_job().schedule_available_jobs(clock).await?;
726        match scheduled {
727            0 => {}
728            1 => tracing::info!("One scheduled job marked as available"),
729            n => tracing::info!("{n} scheduled jobs marked as available"),
730        }
731
732        // Release the leader lock
733        let txn = repo
734            .into_inner()
735            .release_now()
736            .await
737            .map_err(QueueRunnerError::LeaderLock)?;
738
739        txn.commit()
740            .await
741            .map_err(QueueRunnerError::CommitTransaction)?;
742
743        Ok(())
744    }
745
746    /// Process all the pending jobs in the queue.
747    /// This should only be called in tests!
748    ///
749    /// # Errors
750    ///
751    /// This function can fail if the database connection fails.
752    pub async fn process_all_jobs_in_tests(&mut self) -> Result<(), QueueRunnerError> {
753        // In case we haven't setup the schedules yet
754        self.setup_schedules().await?;
755
756        // I swear, I'm the leader!
757        self.am_i_leader = true;
758
759        // First, perform the leader duties. This will make sure that we schedule
760        // recurring jobs.
761        self.perform_leader_duties().await?;
762
763        let clock = self.state.clock();
764        let mut rng = self.state.rng();
765
766        // Grab the connection from the PgListener
767        let txn = self
768            .listener
769            .begin()
770            .await
771            .map_err(QueueRunnerError::StartTransaction)?;
772        let mut repo = PgRepository::from_conn(txn);
773
774        // Spawn all the jobs in the database
775        let queues = self.tracker.queues();
776        let jobs = repo
777            .queue_job()
778            // I really hope that we don't spawn more than 10k jobs in tests
779            .reserve(clock, &self.registration, &queues, 10_000)
780            .await?;
781
782        for Job {
783            id,
784            queue_name,
785            payload,
786            metadata,
787            attempt,
788        } in jobs
789        {
790            let cancellation_token = self.cancellation_token.child_token();
791            let start = Instant::now();
792            let context = JobContext {
793                id,
794                metadata,
795                queue_name,
796                attempt,
797                start,
798                cancellation_token,
799            };
800
801            self.tracker.spawn_job(self.state.clone(), context, payload);
802        }
803
804        self.tracker
805            .process_jobs(&mut rng, clock, &mut repo, true)
806            .await?;
807
808        repo.into_inner()
809            .commit()
810            .await
811            .map_err(QueueRunnerError::CommitTransaction)?;
812
813        Ok(())
814    }
815}
816
817/// Tracks running jobs
818///
819/// This is a separate structure to be able to borrow it mutably at the same
820/// time as the connection to the database is borrowed
821struct JobTracker {
822    /// Stores a mapping from the job queue name to the job factory
823    factories: HashMap<&'static str, JobFactory>,
824
825    /// A join set of all the currently running jobs
826    running_jobs: JoinSet<JobResult>,
827
828    /// Stores a mapping from the Tokio task ID to the job context
829    job_contexts: HashMap<tokio::task::Id, JobContext>,
830
831    /// Stores the last `join_next_with_id` result for processing, in case we
832    /// got woken up in `collect_next_job`
833    last_join_result: Option<Result<(tokio::task::Id, JobResult), tokio::task::JoinError>>,
834
835    /// An histogram which records the time it takes to process a job
836    job_processing_time: Histogram<u64>,
837
838    /// A counter which records the number of jobs currently in flight
839    in_flight_jobs: UpDownCounter<i64>,
840}
841
842impl JobTracker {
843    fn new() -> Self {
844        let job_processing_time = METER
845            .u64_histogram("job.process.duration")
846            .with_description("The time it takes to process a job in milliseconds")
847            .with_unit("ms")
848            .build();
849
850        let in_flight_jobs = METER
851            .i64_up_down_counter("job.active_tasks")
852            .with_description("The number of jobs currently in flight")
853            .with_unit("{job}")
854            .build();
855
856        Self {
857            factories: HashMap::new(),
858            running_jobs: JoinSet::new(),
859            job_contexts: HashMap::new(),
860            last_join_result: None,
861            job_processing_time,
862            in_flight_jobs,
863        }
864    }
865
866    /// Returns the queue names that are currently being tracked
867    fn queues(&self) -> Vec<&'static str> {
868        self.factories.keys().copied().collect()
869    }
870
871    /// Spawn a job on the job tracker
872    fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
873        let factory = self.factories.get(context.queue_name.as_str()).cloned();
874        let task = {
875            let log_context = LogContext::new(format!("job-{}", context.queue_name));
876            let context = context.clone();
877            let span = context.span();
878            log_context
879                .run(async move || {
880                    // We should never crash, but in case we do, we do that in the task and
881                    // don't crash the worker
882                    let job = factory.expect("unknown job factory")(payload);
883
884                    let timeout = job.timeout();
885                    // If there is a timeout set on the job, spawn a task which will cancel the
886                    // CancellationToken once the timeout is reached
887                    if let Some(timeout) = timeout {
888                        let context = context.clone();
889
890                        // It's fine to spawn this task without tracking it, as it is quite
891                        // lightweight and has no reason to crash.
892                        tokio::spawn(
893                            context
894                                .cancellation_token
895                                .clone()
896                                // This makes sure the task gets cancelled as soon as the job
897                                // finishes
898                                .run_until_cancelled_owned(async move {
899                                    tokio::time::sleep(timeout).await;
900                                    tracing::warn!(
901                                        job.id = %context.id,
902                                        job.queue.name = %context.queue_name,
903                                        "Job reached timeout, asking for cancellation"
904                                    );
905                                    context.cancellation_token.cancel();
906                                }),
907                        );
908                    }
909
910                    tracing::info!(
911                        job.id = %context.id,
912                        job.queue.name = %context.queue_name,
913                        job.attempt = %context.attempt,
914                        job.timeout = timeout.map(tracing::field::debug),
915                        "Running job"
916                    );
917                    let result = job.run(&state, context.clone()).await;
918
919                    // Cancel the cancellation token to stop any timeout task
920                    // that may be running
921                    context.cancellation_token.cancel();
922
923                    let Some(context_stats) =
924                        LogContext::maybe_with(mas_context::LogContext::stats)
925                    else {
926                        // This should never happen, but if it does it's fine: we're recovering fine
927                        // from panics in those tasks
928                        panic!("Missing log context, this should never happen");
929                    };
930
931                    // We log the result here so that it's attached to the right span & log context
932                    match &result {
933                        Ok(()) => {
934                            tracing::info!(
935                                job.id = %context.id,
936                                job.queue.name = %context.queue_name,
937                                job.attempt = %context.attempt,
938                                "Job completed [{context_stats}]"
939                            );
940                        }
941
942                        Err(JobError {
943                            decision: JobErrorDecision::Fail,
944                            error,
945                        }) => {
946                            tracing::error!(
947                                error = &**error as &dyn std::error::Error,
948                                job.id = %context.id,
949                                job.queue.name = %context.queue_name,
950                                job.attempt = %context.attempt,
951                                "Job failed, not retrying [{context_stats}]"
952                            );
953                        }
954
955                        Err(JobError {
956                            decision: JobErrorDecision::Retry,
957                            error,
958                        }) if context.attempt < MAX_ATTEMPTS => {
959                            let delay = retry_delay(context.attempt);
960                            tracing::warn!(
961                                error = &**error as &dyn std::error::Error,
962                                job.id = %context.id,
963                                job.queue.name = %context.queue_name,
964                                job.attempt = %context.attempt,
965                                "Job failed, will retry in {}s [{context_stats}]",
966                                delay.num_seconds()
967                            );
968                        }
969
970                        Err(JobError {
971                            decision: JobErrorDecision::Retry,
972                            error,
973                        }) => {
974                            tracing::error!(
975                                error = &**error as &dyn std::error::Error,
976                                job.id = %context.id,
977                                job.queue.name = %context.queue_name,
978                                job.attempt = %context.attempt,
979                                "Job failed too many times, abandonning [{context_stats}]"
980                            );
981                        }
982                    }
983
984                    (context_stats.elapsed, result)
985                })
986                .instrument(span)
987        };
988
989        self.in_flight_jobs.add(
990            1,
991            &[KeyValue::new("job.queue.name", context.queue_name.clone())],
992        );
993
994        let handle = self.running_jobs.spawn(task);
995        self.job_contexts.insert(handle.id(), context);
996    }
997
998    /// Returns `true` if there are currently running jobs
999    fn has_jobs(&self) -> bool {
1000        !self.running_jobs.is_empty()
1001    }
1002
1003    /// Returns the number of currently running jobs
1004    ///
1005    /// This also includes the job result which may be stored for processing
1006    fn running_jobs(&self) -> usize {
1007        self.running_jobs.len() + usize::from(self.last_join_result.is_some())
1008    }
1009
1010    async fn collect_next_job(&mut self) {
1011        // Double-check that we don't have a job result stored
1012        if self.last_join_result.is_some() {
1013            tracing::error!(
1014                "Job tracker already had a job result stored, this should never happen!"
1015            );
1016            return;
1017        }
1018
1019        self.last_join_result = self.running_jobs.join_next_with_id().await;
1020    }
1021
1022    /// Process all the jobs which are currently running
1023    ///
1024    /// If `blocking` is `true`, this function will block until all the jobs
1025    /// are finished. Otherwise, it will return as soon as it processed the
1026    /// already finished jobs.
1027    async fn process_jobs<E: std::error::Error + Send + Sync + 'static>(
1028        &mut self,
1029        rng: &mut (dyn RngCore + Send),
1030        clock: &dyn Clock,
1031        repo: &mut dyn RepositoryAccess<Error = E>,
1032        blocking: bool,
1033    ) -> Result<(), E> {
1034        if self.last_join_result.is_none() {
1035            if blocking {
1036                self.last_join_result = self.running_jobs.join_next_with_id().await;
1037            } else {
1038                self.last_join_result = self.running_jobs.try_join_next_with_id();
1039            }
1040        }
1041
1042        while let Some(result) = self.last_join_result.take() {
1043            match result {
1044                // The job succeeded. The logging and time measurement is already done in the task
1045                Ok((id, (elapsed, Ok(())))) => {
1046                    let context = self
1047                        .job_contexts
1048                        .remove(&id)
1049                        .expect("Job context not found");
1050
1051                    self.in_flight_jobs.add(
1052                        -1,
1053                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1054                    );
1055
1056                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
1057                    self.job_processing_time.record(
1058                        elapsed_ms,
1059                        &[
1060                            KeyValue::new("job.queue.name", context.queue_name),
1061                            KeyValue::new("job.result", "success"),
1062                        ],
1063                    );
1064
1065                    repo.queue_job()
1066                        .mark_as_completed(clock, context.id)
1067                        .await?;
1068                }
1069
1070                // The job failed. The logging and time measurement is already done in the task
1071                Ok((id, (elapsed, Err(e)))) => {
1072                    let context = self
1073                        .job_contexts
1074                        .remove(&id)
1075                        .expect("Job context not found");
1076
1077                    self.in_flight_jobs.add(
1078                        -1,
1079                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1080                    );
1081
1082                    let reason = format!("{:?}", e.error);
1083                    repo.queue_job()
1084                        .mark_as_failed(clock, context.id, &reason)
1085                        .await?;
1086
1087                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
1088                    match e.decision {
1089                        JobErrorDecision::Fail => {
1090                            self.job_processing_time.record(
1091                                elapsed_ms,
1092                                &[
1093                                    KeyValue::new("job.queue.name", context.queue_name),
1094                                    KeyValue::new("job.result", "failed"),
1095                                    KeyValue::new("job.decision", "fail"),
1096                                ],
1097                            );
1098                        }
1099
1100                        JobErrorDecision::Retry if context.attempt < MAX_ATTEMPTS => {
1101                            self.job_processing_time.record(
1102                                elapsed_ms,
1103                                &[
1104                                    KeyValue::new("job.queue.name", context.queue_name),
1105                                    KeyValue::new("job.result", "failed"),
1106                                    KeyValue::new("job.decision", "retry"),
1107                                ],
1108                            );
1109
1110                            let delay = retry_delay(context.attempt);
1111                            repo.queue_job()
1112                                .retry(&mut *rng, clock, context.id, delay)
1113                                .await?;
1114                        }
1115
1116                        JobErrorDecision::Retry => {
1117                            self.job_processing_time.record(
1118                                elapsed_ms,
1119                                &[
1120                                    KeyValue::new("job.queue.name", context.queue_name),
1121                                    KeyValue::new("job.result", "failed"),
1122                                    KeyValue::new("job.decision", "abandon"),
1123                                ],
1124                            );
1125                        }
1126                    }
1127                }
1128
1129                // The job crashed (or was aborted)
1130                Err(e) => {
1131                    let id = e.id();
1132                    let context = self
1133                        .job_contexts
1134                        .remove(&id)
1135                        .expect("Job context not found");
1136
1137                    self.in_flight_jobs.add(
1138                        -1,
1139                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1140                    );
1141
1142                    // This measurement is not accurate as it includes the time processing the jobs,
1143                    // but it's fine, it's only for panicked tasks
1144                    let elapsed = context
1145                        .start
1146                        .elapsed()
1147                        .as_millis()
1148                        .try_into()
1149                        .unwrap_or(u64::MAX);
1150
1151                    let reason = e.to_string();
1152                    repo.queue_job()
1153                        .mark_as_failed(clock, context.id, &reason)
1154                        .await?;
1155
1156                    if context.attempt < MAX_ATTEMPTS {
1157                        let delay = retry_delay(context.attempt);
1158                        tracing::error!(
1159                            error = &e as &dyn std::error::Error,
1160                            job.id = %context.id,
1161                            job.queue.name = %context.queue_name,
1162                            job.attempt = %context.attempt,
1163                            job.elapsed = format!("{elapsed}ms"),
1164                            "Job crashed, will retry in {}s",
1165                            delay.num_seconds()
1166                        );
1167
1168                        self.job_processing_time.record(
1169                            elapsed,
1170                            &[
1171                                KeyValue::new("job.queue.name", context.queue_name),
1172                                KeyValue::new("job.result", "crashed"),
1173                                KeyValue::new("job.decision", "retry"),
1174                            ],
1175                        );
1176
1177                        repo.queue_job()
1178                            .retry(&mut *rng, clock, context.id, delay)
1179                            .await?;
1180                    } else {
1181                        tracing::error!(
1182                            error = &e as &dyn std::error::Error,
1183                            job.id = %context.id,
1184                            job.queue.name = %context.queue_name,
1185                            job.attempt = %context.attempt,
1186                            job.elapsed = format!("{elapsed}ms"),
1187                            "Job crashed too many times, abandonning"
1188                        );
1189
1190                        self.job_processing_time.record(
1191                            elapsed,
1192                            &[
1193                                KeyValue::new("job.queue.name", context.queue_name),
1194                                KeyValue::new("job.result", "crashed"),
1195                                KeyValue::new("job.decision", "abandon"),
1196                            ],
1197                        );
1198                    }
1199                }
1200            }
1201
1202            if blocking {
1203                self.last_join_result = self.running_jobs.join_next_with_id().await;
1204            } else {
1205                self.last_join_result = self.running_jobs.try_join_next_with_id();
1206            }
1207        }
1208
1209        Ok(())
1210    }
1211}