1use async_trait::async_trait;
9use chrono::{DateTime, Duration, Utc};
10use mas_data_model::{AccessToken, AccessTokenState, Clock, Session};
11use mas_storage::oauth2::OAuth2AccessTokenRepository;
12use rand::RngCore;
13use sqlx::PgConnection;
14use ulid::Ulid;
15use uuid::Uuid;
16
17use crate::{DatabaseError, tracing::ExecuteExt};
18
19pub struct PgOAuth2AccessTokenRepository<'c> {
22 conn: &'c mut PgConnection,
23}
24
25impl<'c> PgOAuth2AccessTokenRepository<'c> {
26 pub fn new(conn: &'c mut PgConnection) -> Self {
29 Self { conn }
30 }
31}
32
33struct OAuth2AccessTokenLookup {
34 oauth2_access_token_id: Uuid,
35 oauth2_session_id: Uuid,
36 access_token: String,
37 created_at: DateTime<Utc>,
38 expires_at: Option<DateTime<Utc>>,
39 revoked_at: Option<DateTime<Utc>>,
40 first_used_at: Option<DateTime<Utc>>,
41}
42
43impl From<OAuth2AccessTokenLookup> for AccessToken {
44 fn from(value: OAuth2AccessTokenLookup) -> Self {
45 let state = match value.revoked_at {
46 None => AccessTokenState::Valid,
47 Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
48 };
49
50 Self {
51 id: value.oauth2_access_token_id.into(),
52 state,
53 session_id: value.oauth2_session_id.into(),
54 access_token: value.access_token,
55 created_at: value.created_at,
56 expires_at: value.expires_at,
57 first_used_at: value.first_used_at,
58 }
59 }
60}
61
62#[async_trait]
63impl OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'_> {
64 type Error = DatabaseError;
65
66 async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
67 let res = sqlx::query_as!(
68 OAuth2AccessTokenLookup,
69 r#"
70 SELECT oauth2_access_token_id
71 , access_token
72 , created_at
73 , expires_at
74 , revoked_at
75 , oauth2_session_id
76 , first_used_at
77
78 FROM oauth2_access_tokens
79
80 WHERE oauth2_access_token_id = $1
81 "#,
82 Uuid::from(id),
83 )
84 .fetch_optional(&mut *self.conn)
85 .await?;
86
87 let Some(res) = res else { return Ok(None) };
88
89 Ok(Some(res.into()))
90 }
91
92 #[tracing::instrument(
93 name = "db.oauth2_access_token.find_by_token",
94 skip_all,
95 fields(
96 db.query.text,
97 ),
98 err,
99 )]
100 async fn find_by_token(
101 &mut self,
102 access_token: &str,
103 ) -> Result<Option<AccessToken>, Self::Error> {
104 let res = sqlx::query_as!(
105 OAuth2AccessTokenLookup,
106 r#"
107 SELECT oauth2_access_token_id
108 , access_token
109 , created_at
110 , expires_at
111 , revoked_at
112 , oauth2_session_id
113 , first_used_at
114
115 FROM oauth2_access_tokens
116
117 WHERE access_token = $1
118 "#,
119 access_token,
120 )
121 .fetch_optional(&mut *self.conn)
122 .await?;
123
124 let Some(res) = res else { return Ok(None) };
125
126 Ok(Some(res.into()))
127 }
128
129 #[tracing::instrument(
130 name = "db.oauth2_access_token.add",
131 skip_all,
132 fields(
133 db.query.text,
134 %session.id,
135 client.id = %session.client_id,
136 access_token.id,
137 ),
138 err,
139 )]
140 async fn add(
141 &mut self,
142 rng: &mut (dyn RngCore + Send),
143 clock: &dyn Clock,
144 session: &Session,
145 access_token: String,
146 expires_after: Option<Duration>,
147 ) -> Result<AccessToken, Self::Error> {
148 let created_at = clock.now();
149 let expires_at = expires_after.map(|d| created_at + d);
150 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
151
152 tracing::Span::current().record("access_token.id", tracing::field::display(id));
153
154 sqlx::query!(
155 r#"
156 INSERT INTO oauth2_access_tokens
157 (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
158 VALUES
159 ($1, $2, $3, $4, $5)
160 "#,
161 Uuid::from(id),
162 Uuid::from(session.id),
163 &access_token,
164 created_at,
165 expires_at,
166 )
167 .traced()
168 .execute(&mut *self.conn)
169 .await?;
170
171 Ok(AccessToken {
172 id,
173 state: AccessTokenState::default(),
174 access_token,
175 session_id: session.id,
176 created_at,
177 expires_at,
178 first_used_at: None,
179 })
180 }
181
182 #[tracing::instrument(
183 name = "db.oauth2_access_token.revoke",
184 skip_all,
185 fields(
186 db.query.text,
187 session.id = %access_token.session_id,
188 %access_token.id,
189 ),
190 err,
191 )]
192 async fn revoke(
193 &mut self,
194 clock: &dyn Clock,
195 access_token: AccessToken,
196 ) -> Result<AccessToken, Self::Error> {
197 let revoked_at = clock.now();
198 let res = sqlx::query!(
199 r#"
200 UPDATE oauth2_access_tokens
201 SET revoked_at = $2
202 WHERE oauth2_access_token_id = $1
203 "#,
204 Uuid::from(access_token.id),
205 revoked_at,
206 )
207 .traced()
208 .execute(&mut *self.conn)
209 .await?;
210
211 DatabaseError::ensure_affected_rows(&res, 1)?;
212
213 access_token
214 .revoke(revoked_at)
215 .map_err(DatabaseError::to_invalid_operation)
216 }
217
218 #[tracing::instrument(
219 name = "db.oauth2_access_token.mark_used",
220 skip_all,
221 fields(
222 db.query.text,
223 session.id = %access_token.session_id,
224 %access_token.id,
225 ),
226 err,
227 )]
228 async fn mark_used(
229 &mut self,
230 clock: &dyn Clock,
231 mut access_token: AccessToken,
232 ) -> Result<AccessToken, Self::Error> {
233 let now = clock.now();
234 let res = sqlx::query!(
235 r#"
236 UPDATE oauth2_access_tokens
237 SET first_used_at = $2
238 WHERE oauth2_access_token_id = $1
239 "#,
240 Uuid::from(access_token.id),
241 now,
242 )
243 .execute(&mut *self.conn)
244 .await?;
245
246 DatabaseError::ensure_affected_rows(&res, 1)?;
247
248 access_token.first_used_at = Some(now);
249
250 Ok(access_token)
251 }
252
253 #[tracing::instrument(
254 name = "db.oauth2_access_token.cleanup_revoked",
255 skip_all,
256 fields(
257 db.query.text,
258 ),
259 err,
260 )]
261 async fn cleanup_revoked(
262 &mut self,
263 since: Option<DateTime<Utc>>,
264 until: DateTime<Utc>,
265 limit: usize,
266 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
267 let res = sqlx::query!(
268 r#"
269 WITH
270 to_delete AS (
271 SELECT oauth2_access_token_id
272 FROM oauth2_access_tokens
273 WHERE revoked_at IS NOT NULL
274 AND ($1::timestamptz IS NULL OR revoked_at >= $1::timestamptz)
275 AND revoked_at < $2::timestamptz
276 ORDER BY revoked_at ASC
277 LIMIT $3
278 FOR UPDATE
279 ),
280
281 deleted AS (
282 DELETE FROM oauth2_access_tokens
283 USING to_delete
284 WHERE oauth2_access_tokens.oauth2_access_token_id = to_delete.oauth2_access_token_id
285 RETURNING oauth2_access_tokens.revoked_at
286 )
287
288 SELECT
289 COUNT(*) as "count!",
290 MAX(revoked_at) as last_revoked_at
291 FROM deleted
292 "#,
293 since,
294 until,
295 i64::try_from(limit).unwrap_or(i64::MAX),
296 )
297 .traced()
298 .fetch_one(&mut *self.conn)
299 .await?;
300
301 Ok((
302 res.count.try_into().unwrap_or(usize::MAX),
303 res.last_revoked_at,
304 ))
305 }
306
307 #[tracing::instrument(
308 name = "db.oauth2_access_token.cleanup_expired",
309 skip_all,
310 fields(
311 db.query.text,
312 ),
313 err,
314 )]
315 async fn cleanup_expired(
316 &mut self,
317 since: Option<DateTime<Utc>>,
318 until: DateTime<Utc>,
319 limit: usize,
320 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
321 let res = sqlx::query!(
322 r#"
323 WITH
324 to_delete AS (
325 SELECT oauth2_access_token_id
326 FROM oauth2_access_tokens
327 WHERE expires_at IS NOT NULL
328 AND ($1::timestamptz IS NULL OR expires_at >= $1::timestamptz)
329 AND expires_at < $2::timestamptz
330 ORDER BY expires_at ASC
331 LIMIT $3
332 FOR UPDATE
333 ),
334
335 deleted AS (
336 DELETE FROM oauth2_access_tokens
337 USING to_delete
338 WHERE oauth2_access_tokens.oauth2_access_token_id = to_delete.oauth2_access_token_id
339 RETURNING oauth2_access_tokens.expires_at
340 )
341
342 SELECT
343 COUNT(*) as "count!",
344 MAX(expires_at) as last_expires_at
345 FROM deleted
346 "#,
347 since,
348 until,
349 i64::try_from(limit).unwrap_or(i64::MAX),
350 )
351 .traced()
352 .fetch_one(&mut *self.conn)
353 .await?;
354
355 Ok((
356 res.count.try_into().unwrap_or(usize::MAX),
357 res.last_expires_at,
358 ))
359 }
360}