mas_storage_pg/oauth2/
refresh_token.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{AccessToken, Clock, RefreshToken, RefreshTokenState, Session};
11use mas_storage::oauth2::OAuth2RefreshTokenRepository;
12use rand::RngCore;
13use sqlx::PgConnection;
14use ulid::Ulid;
15use uuid::Uuid;
16
17use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
18
19/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL
20/// connection
21pub struct PgOAuth2RefreshTokenRepository<'c> {
22    conn: &'c mut PgConnection,
23}
24
25impl<'c> PgOAuth2RefreshTokenRepository<'c> {
26    /// Create a new [`PgOAuth2RefreshTokenRepository`] from an active
27    /// PostgreSQL connection
28    pub fn new(conn: &'c mut PgConnection) -> Self {
29        Self { conn }
30    }
31}
32
33struct OAuth2RefreshTokenLookup {
34    oauth2_refresh_token_id: Uuid,
35    refresh_token: String,
36    created_at: DateTime<Utc>,
37    consumed_at: Option<DateTime<Utc>>,
38    revoked_at: Option<DateTime<Utc>>,
39    oauth2_access_token_id: Option<Uuid>,
40    oauth2_session_id: Uuid,
41    next_oauth2_refresh_token_id: Option<Uuid>,
42}
43
44impl TryFrom<OAuth2RefreshTokenLookup> for RefreshToken {
45    type Error = DatabaseInconsistencyError;
46
47    fn try_from(value: OAuth2RefreshTokenLookup) -> Result<Self, Self::Error> {
48        let id = value.oauth2_refresh_token_id.into();
49        let state = match (
50            value.revoked_at,
51            value.consumed_at,
52            value.next_oauth2_refresh_token_id,
53        ) {
54            (None, None, None) => RefreshTokenState::Valid,
55            (Some(revoked_at), None, None) => RefreshTokenState::Revoked { revoked_at },
56            (None, Some(consumed_at), None) => RefreshTokenState::Consumed {
57                consumed_at,
58                next_refresh_token_id: None,
59            },
60            (None, Some(consumed_at), Some(id)) => RefreshTokenState::Consumed {
61                consumed_at,
62                next_refresh_token_id: Some(Ulid::from(id)),
63            },
64            _ => {
65                return Err(DatabaseInconsistencyError::on("oauth2_refresh_tokens")
66                    .column("next_oauth2_refresh_token_id")
67                    .row(id));
68            }
69        };
70
71        Ok(RefreshToken {
72            id,
73            state,
74            session_id: value.oauth2_session_id.into(),
75            refresh_token: value.refresh_token,
76            created_at: value.created_at,
77            access_token_id: value.oauth2_access_token_id.map(Ulid::from),
78        })
79    }
80}
81
82#[async_trait]
83impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
84    type Error = DatabaseError;
85
86    #[tracing::instrument(
87        name = "db.oauth2_refresh_token.lookup",
88        skip_all,
89        fields(
90            db.query.text,
91            refresh_token.id = %id,
92        ),
93        err,
94    )]
95    async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
96        let res = sqlx::query_as!(
97            OAuth2RefreshTokenLookup,
98            r#"
99                SELECT oauth2_refresh_token_id
100                     , refresh_token
101                     , created_at
102                     , consumed_at
103                     , revoked_at
104                     , oauth2_access_token_id
105                     , oauth2_session_id
106                     , next_oauth2_refresh_token_id
107                FROM oauth2_refresh_tokens
108
109                WHERE oauth2_refresh_token_id = $1
110            "#,
111            Uuid::from(id),
112        )
113        .traced()
114        .fetch_optional(&mut *self.conn)
115        .await?;
116
117        let Some(res) = res else { return Ok(None) };
118
119        Ok(Some(res.try_into()?))
120    }
121
122    #[tracing::instrument(
123        name = "db.oauth2_refresh_token.find_by_token",
124        skip_all,
125        fields(
126            db.query.text,
127        ),
128        err,
129    )]
130    async fn find_by_token(
131        &mut self,
132        refresh_token: &str,
133    ) -> Result<Option<RefreshToken>, Self::Error> {
134        let res = sqlx::query_as!(
135            OAuth2RefreshTokenLookup,
136            r#"
137                SELECT oauth2_refresh_token_id
138                     , refresh_token
139                     , created_at
140                     , consumed_at
141                     , revoked_at
142                     , oauth2_access_token_id
143                     , oauth2_session_id
144                     , next_oauth2_refresh_token_id
145                FROM oauth2_refresh_tokens
146
147                WHERE refresh_token = $1
148            "#,
149            refresh_token,
150        )
151        .traced()
152        .fetch_optional(&mut *self.conn)
153        .await?;
154
155        let Some(res) = res else { return Ok(None) };
156
157        Ok(Some(res.try_into()?))
158    }
159
160    #[tracing::instrument(
161        name = "db.oauth2_refresh_token.add",
162        skip_all,
163        fields(
164            db.query.text,
165            %session.id,
166            client.id = %session.client_id,
167            refresh_token.id,
168        ),
169        err,
170    )]
171    async fn add(
172        &mut self,
173        rng: &mut (dyn RngCore + Send),
174        clock: &dyn Clock,
175        session: &Session,
176        access_token: &AccessToken,
177        refresh_token: String,
178    ) -> Result<RefreshToken, Self::Error> {
179        let created_at = clock.now();
180        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
181        tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
182
183        sqlx::query!(
184            r#"
185                INSERT INTO oauth2_refresh_tokens
186                    (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
187                     refresh_token, created_at)
188                VALUES
189                    ($1, $2, $3, $4, $5)
190            "#,
191            Uuid::from(id),
192            Uuid::from(session.id),
193            Uuid::from(access_token.id),
194            refresh_token,
195            created_at,
196        )
197        .traced()
198        .execute(&mut *self.conn)
199        .await?;
200
201        Ok(RefreshToken {
202            id,
203            state: RefreshTokenState::default(),
204            session_id: session.id,
205            refresh_token,
206            access_token_id: Some(access_token.id),
207            created_at,
208        })
209    }
210
211    #[tracing::instrument(
212        name = "db.oauth2_refresh_token.consume",
213        skip_all,
214        fields(
215            db.query.text,
216            %refresh_token.id,
217            session.id = %refresh_token.session_id,
218        ),
219        err,
220    )]
221    async fn consume(
222        &mut self,
223        clock: &dyn Clock,
224        refresh_token: RefreshToken,
225        replaced_by: &RefreshToken,
226    ) -> Result<RefreshToken, Self::Error> {
227        let consumed_at = clock.now();
228        let res = sqlx::query!(
229            r#"
230                UPDATE oauth2_refresh_tokens
231                SET consumed_at = $2,
232                    next_oauth2_refresh_token_id = $3
233                WHERE oauth2_refresh_token_id = $1
234            "#,
235            Uuid::from(refresh_token.id),
236            consumed_at,
237            Uuid::from(replaced_by.id),
238        )
239        .traced()
240        .execute(&mut *self.conn)
241        .await?;
242
243        DatabaseError::ensure_affected_rows(&res, 1)?;
244
245        refresh_token
246            .consume(consumed_at, replaced_by)
247            .map_err(DatabaseError::to_invalid_operation)
248    }
249
250    #[tracing::instrument(
251        name = "db.oauth2_refresh_token.revoke",
252        skip_all,
253        fields(
254            db.query.text,
255            %refresh_token.id,
256            session.id = %refresh_token.session_id,
257        ),
258        err,
259    )]
260    async fn revoke(
261        &mut self,
262        clock: &dyn Clock,
263        refresh_token: RefreshToken,
264    ) -> Result<RefreshToken, Self::Error> {
265        let revoked_at = clock.now();
266        let res = sqlx::query!(
267            r#"
268                UPDATE oauth2_refresh_tokens
269                SET revoked_at = $2
270                WHERE oauth2_refresh_token_id = $1
271            "#,
272            Uuid::from(refresh_token.id),
273            revoked_at,
274        )
275        .traced()
276        .execute(&mut *self.conn)
277        .await?;
278
279        DatabaseError::ensure_affected_rows(&res, 1)?;
280
281        refresh_token
282            .revoke(revoked_at)
283            .map_err(DatabaseError::to_invalid_operation)
284    }
285
286    #[tracing::instrument(
287        name = "db.oauth2_refresh_token.cleanup_revoked",
288        skip_all,
289        fields(
290            db.query.text,
291        ),
292        err,
293    )]
294    async fn cleanup_revoked(
295        &mut self,
296        since: Option<DateTime<Utc>>,
297        until: DateTime<Utc>,
298        limit: usize,
299    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
300        let res = sqlx::query!(
301            r#"
302                WITH
303                    to_delete AS (
304                        SELECT oauth2_refresh_token_id
305                        FROM oauth2_refresh_tokens
306                        WHERE revoked_at IS NOT NULL
307                          AND ($1::timestamptz IS NULL OR revoked_at >= $1::timestamptz)
308                          AND revoked_at < $2::timestamptz
309                        ORDER BY revoked_at ASC
310                        LIMIT $3
311                        FOR UPDATE
312                    ),
313
314                    deleted AS (
315                        DELETE FROM oauth2_refresh_tokens
316                        USING to_delete
317                        WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id
318                        RETURNING oauth2_refresh_tokens.revoked_at
319                    )
320
321                SELECT
322                    COUNT(*) as "count!",
323                    MAX(revoked_at) as last_revoked_at
324                FROM deleted
325            "#,
326            since,
327            until,
328            i64::try_from(limit).unwrap_or(i64::MAX),
329        )
330        .traced()
331        .fetch_one(&mut *self.conn)
332        .await?;
333
334        Ok((
335            res.count.try_into().unwrap_or(usize::MAX),
336            res.last_revoked_at,
337        ))
338    }
339
340    #[tracing::instrument(
341        name = "db.oauth2_refresh_token.cleanup_consumed",
342        skip_all,
343        fields(
344            db.query.text,
345        ),
346        err,
347    )]
348    async fn cleanup_consumed(
349        &mut self,
350        since: Option<DateTime<Utc>>,
351        until: DateTime<Utc>,
352        limit: usize,
353    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
354        // We only consider a token as consumed if also the next token has its
355        // `consumed_at` set. This makes the query a bit expensive to compute,
356        // but is optimised to two index scans and a nested join using the
357        // `oauth2_refresh_token_not_consumed_idx` and
358        // `oauth2_refresh_token_consumed_at_idx` indexes.
359        let res = sqlx::query!(
360            r#"
361                WITH
362                    to_delete AS (
363                        SELECT rts_to_del.oauth2_refresh_token_id
364                        FROM oauth2_refresh_tokens rts_to_del
365                        LEFT JOIN oauth2_refresh_tokens next_rts
366                          ON rts_to_del.next_oauth2_refresh_token_id = next_rts.oauth2_refresh_token_id
367                        WHERE rts_to_del.consumed_at IS NOT NULL
368                          AND (rts_to_del.next_oauth2_refresh_token_id IS NULL OR next_rts.consumed_at IS NOT NULL)
369                          AND ($1::timestamptz IS NULL OR rts_to_del.consumed_at >= $1::timestamptz)
370                          AND rts_to_del.consumed_at < $2::timestamptz
371                        ORDER BY rts_to_del.consumed_at ASC
372                        LIMIT $3
373                    ),
374
375                    deleted AS (
376                        DELETE FROM oauth2_refresh_tokens
377                        USING to_delete
378                        WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id
379                        RETURNING oauth2_refresh_tokens.consumed_at
380                    )
381
382                SELECT
383                    COUNT(*) as "count!",
384                    MAX(consumed_at) as last_consumed_at
385                FROM deleted
386            "#,
387            since,
388            until,
389            i64::try_from(limit).unwrap_or(i64::MAX),
390        )
391        .traced()
392        .fetch_one(&mut *self.conn)
393        .await?;
394
395        Ok((
396            res.count.try_into().unwrap_or(usize::MAX),
397            res.last_consumed_at,
398        ))
399    }
400}