1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{AccessToken, Clock, RefreshToken, RefreshTokenState, Session};
11use mas_storage::oauth2::OAuth2RefreshTokenRepository;
12use rand::RngCore;
13use sqlx::PgConnection;
14use ulid::Ulid;
15use uuid::Uuid;
16
17use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
18
19pub struct PgOAuth2RefreshTokenRepository<'c> {
22 conn: &'c mut PgConnection,
23}
24
25impl<'c> PgOAuth2RefreshTokenRepository<'c> {
26 pub fn new(conn: &'c mut PgConnection) -> Self {
29 Self { conn }
30 }
31}
32
33struct OAuth2RefreshTokenLookup {
34 oauth2_refresh_token_id: Uuid,
35 refresh_token: String,
36 created_at: DateTime<Utc>,
37 consumed_at: Option<DateTime<Utc>>,
38 revoked_at: Option<DateTime<Utc>>,
39 oauth2_access_token_id: Option<Uuid>,
40 oauth2_session_id: Uuid,
41 next_oauth2_refresh_token_id: Option<Uuid>,
42}
43
44impl TryFrom<OAuth2RefreshTokenLookup> for RefreshToken {
45 type Error = DatabaseInconsistencyError;
46
47 fn try_from(value: OAuth2RefreshTokenLookup) -> Result<Self, Self::Error> {
48 let id = value.oauth2_refresh_token_id.into();
49 let state = match (
50 value.revoked_at,
51 value.consumed_at,
52 value.next_oauth2_refresh_token_id,
53 ) {
54 (None, None, None) => RefreshTokenState::Valid,
55 (Some(revoked_at), None, None) => RefreshTokenState::Revoked { revoked_at },
56 (None, Some(consumed_at), None) => RefreshTokenState::Consumed {
57 consumed_at,
58 next_refresh_token_id: None,
59 },
60 (None, Some(consumed_at), Some(id)) => RefreshTokenState::Consumed {
61 consumed_at,
62 next_refresh_token_id: Some(Ulid::from(id)),
63 },
64 _ => {
65 return Err(DatabaseInconsistencyError::on("oauth2_refresh_tokens")
66 .column("next_oauth2_refresh_token_id")
67 .row(id));
68 }
69 };
70
71 Ok(RefreshToken {
72 id,
73 state,
74 session_id: value.oauth2_session_id.into(),
75 refresh_token: value.refresh_token,
76 created_at: value.created_at,
77 access_token_id: value.oauth2_access_token_id.map(Ulid::from),
78 })
79 }
80}
81
82#[async_trait]
83impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
84 type Error = DatabaseError;
85
86 #[tracing::instrument(
87 name = "db.oauth2_refresh_token.lookup",
88 skip_all,
89 fields(
90 db.query.text,
91 refresh_token.id = %id,
92 ),
93 err,
94 )]
95 async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
96 let res = sqlx::query_as!(
97 OAuth2RefreshTokenLookup,
98 r#"
99 SELECT oauth2_refresh_token_id
100 , refresh_token
101 , created_at
102 , consumed_at
103 , revoked_at
104 , oauth2_access_token_id
105 , oauth2_session_id
106 , next_oauth2_refresh_token_id
107 FROM oauth2_refresh_tokens
108
109 WHERE oauth2_refresh_token_id = $1
110 "#,
111 Uuid::from(id),
112 )
113 .traced()
114 .fetch_optional(&mut *self.conn)
115 .await?;
116
117 let Some(res) = res else { return Ok(None) };
118
119 Ok(Some(res.try_into()?))
120 }
121
122 #[tracing::instrument(
123 name = "db.oauth2_refresh_token.find_by_token",
124 skip_all,
125 fields(
126 db.query.text,
127 ),
128 err,
129 )]
130 async fn find_by_token(
131 &mut self,
132 refresh_token: &str,
133 ) -> Result<Option<RefreshToken>, Self::Error> {
134 let res = sqlx::query_as!(
135 OAuth2RefreshTokenLookup,
136 r#"
137 SELECT oauth2_refresh_token_id
138 , refresh_token
139 , created_at
140 , consumed_at
141 , revoked_at
142 , oauth2_access_token_id
143 , oauth2_session_id
144 , next_oauth2_refresh_token_id
145 FROM oauth2_refresh_tokens
146
147 WHERE refresh_token = $1
148 "#,
149 refresh_token,
150 )
151 .traced()
152 .fetch_optional(&mut *self.conn)
153 .await?;
154
155 let Some(res) = res else { return Ok(None) };
156
157 Ok(Some(res.try_into()?))
158 }
159
160 #[tracing::instrument(
161 name = "db.oauth2_refresh_token.add",
162 skip_all,
163 fields(
164 db.query.text,
165 %session.id,
166 client.id = %session.client_id,
167 refresh_token.id,
168 ),
169 err,
170 )]
171 async fn add(
172 &mut self,
173 rng: &mut (dyn RngCore + Send),
174 clock: &dyn Clock,
175 session: &Session,
176 access_token: &AccessToken,
177 refresh_token: String,
178 ) -> Result<RefreshToken, Self::Error> {
179 let created_at = clock.now();
180 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
181 tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
182
183 sqlx::query!(
184 r#"
185 INSERT INTO oauth2_refresh_tokens
186 (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
187 refresh_token, created_at)
188 VALUES
189 ($1, $2, $3, $4, $5)
190 "#,
191 Uuid::from(id),
192 Uuid::from(session.id),
193 Uuid::from(access_token.id),
194 refresh_token,
195 created_at,
196 )
197 .traced()
198 .execute(&mut *self.conn)
199 .await?;
200
201 Ok(RefreshToken {
202 id,
203 state: RefreshTokenState::default(),
204 session_id: session.id,
205 refresh_token,
206 access_token_id: Some(access_token.id),
207 created_at,
208 })
209 }
210
211 #[tracing::instrument(
212 name = "db.oauth2_refresh_token.consume",
213 skip_all,
214 fields(
215 db.query.text,
216 %refresh_token.id,
217 session.id = %refresh_token.session_id,
218 ),
219 err,
220 )]
221 async fn consume(
222 &mut self,
223 clock: &dyn Clock,
224 refresh_token: RefreshToken,
225 replaced_by: &RefreshToken,
226 ) -> Result<RefreshToken, Self::Error> {
227 let consumed_at = clock.now();
228 let res = sqlx::query!(
229 r#"
230 UPDATE oauth2_refresh_tokens
231 SET consumed_at = $2,
232 next_oauth2_refresh_token_id = $3
233 WHERE oauth2_refresh_token_id = $1
234 "#,
235 Uuid::from(refresh_token.id),
236 consumed_at,
237 Uuid::from(replaced_by.id),
238 )
239 .traced()
240 .execute(&mut *self.conn)
241 .await?;
242
243 DatabaseError::ensure_affected_rows(&res, 1)?;
244
245 refresh_token
246 .consume(consumed_at, replaced_by)
247 .map_err(DatabaseError::to_invalid_operation)
248 }
249
250 #[tracing::instrument(
251 name = "db.oauth2_refresh_token.revoke",
252 skip_all,
253 fields(
254 db.query.text,
255 %refresh_token.id,
256 session.id = %refresh_token.session_id,
257 ),
258 err,
259 )]
260 async fn revoke(
261 &mut self,
262 clock: &dyn Clock,
263 refresh_token: RefreshToken,
264 ) -> Result<RefreshToken, Self::Error> {
265 let revoked_at = clock.now();
266 let res = sqlx::query!(
267 r#"
268 UPDATE oauth2_refresh_tokens
269 SET revoked_at = $2
270 WHERE oauth2_refresh_token_id = $1
271 "#,
272 Uuid::from(refresh_token.id),
273 revoked_at,
274 )
275 .traced()
276 .execute(&mut *self.conn)
277 .await?;
278
279 DatabaseError::ensure_affected_rows(&res, 1)?;
280
281 refresh_token
282 .revoke(revoked_at)
283 .map_err(DatabaseError::to_invalid_operation)
284 }
285
286 #[tracing::instrument(
287 name = "db.oauth2_refresh_token.cleanup_revoked",
288 skip_all,
289 fields(
290 db.query.text,
291 ),
292 err,
293 )]
294 async fn cleanup_revoked(
295 &mut self,
296 since: Option<DateTime<Utc>>,
297 until: DateTime<Utc>,
298 limit: usize,
299 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
300 let res = sqlx::query!(
301 r#"
302 WITH
303 to_delete AS (
304 SELECT oauth2_refresh_token_id
305 FROM oauth2_refresh_tokens
306 WHERE revoked_at IS NOT NULL
307 AND ($1::timestamptz IS NULL OR revoked_at >= $1::timestamptz)
308 AND revoked_at < $2::timestamptz
309 ORDER BY revoked_at ASC
310 LIMIT $3
311 FOR UPDATE
312 ),
313
314 deleted AS (
315 DELETE FROM oauth2_refresh_tokens
316 USING to_delete
317 WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id
318 RETURNING oauth2_refresh_tokens.revoked_at
319 )
320
321 SELECT
322 COUNT(*) as "count!",
323 MAX(revoked_at) as last_revoked_at
324 FROM deleted
325 "#,
326 since,
327 until,
328 i64::try_from(limit).unwrap_or(i64::MAX),
329 )
330 .traced()
331 .fetch_one(&mut *self.conn)
332 .await?;
333
334 Ok((
335 res.count.try_into().unwrap_or(usize::MAX),
336 res.last_revoked_at,
337 ))
338 }
339
340 #[tracing::instrument(
341 name = "db.oauth2_refresh_token.cleanup_consumed",
342 skip_all,
343 fields(
344 db.query.text,
345 ),
346 err,
347 )]
348 async fn cleanup_consumed(
349 &mut self,
350 since: Option<DateTime<Utc>>,
351 until: DateTime<Utc>,
352 limit: usize,
353 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
354 let res = sqlx::query!(
360 r#"
361 WITH
362 to_delete AS (
363 SELECT rts_to_del.oauth2_refresh_token_id
364 FROM oauth2_refresh_tokens rts_to_del
365 LEFT JOIN oauth2_refresh_tokens next_rts
366 ON rts_to_del.next_oauth2_refresh_token_id = next_rts.oauth2_refresh_token_id
367 WHERE rts_to_del.consumed_at IS NOT NULL
368 AND (rts_to_del.next_oauth2_refresh_token_id IS NULL OR next_rts.consumed_at IS NOT NULL)
369 AND ($1::timestamptz IS NULL OR rts_to_del.consumed_at >= $1::timestamptz)
370 AND rts_to_del.consumed_at < $2::timestamptz
371 ORDER BY rts_to_del.consumed_at ASC
372 LIMIT $3
373 ),
374
375 deleted AS (
376 DELETE FROM oauth2_refresh_tokens
377 USING to_delete
378 WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id
379 RETURNING oauth2_refresh_tokens.consumed_at
380 )
381
382 SELECT
383 COUNT(*) as "count!",
384 MAX(consumed_at) as last_consumed_at
385 FROM deleted
386 "#,
387 since,
388 until,
389 i64::try_from(limit).unwrap_or(i64::MAX),
390 )
391 .traced()
392 .fetch_one(&mut *self.conn)
393 .await?;
394
395 Ok((
396 res.count.try_into().unwrap_or(usize::MAX),
397 res.last_consumed_at,
398 ))
399 }
400}