1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 Clock, User,
12 personal::session::{PersonalSession, PersonalSessionOwner, SessionState},
13};
14use mas_storage::{
15 Page, Pagination,
16 pagination::Node,
17 personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
18};
19use oauth2_types::scope::Scope;
20use rand::RngCore;
21use sea_query::{
22 Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
23 extension::postgres::PgExpr as _,
24};
25use sea_query_binder::SqlxBinder as _;
26use sqlx::PgConnection;
27use ulid::Ulid;
28use uuid::Uuid;
29
30use crate::{
31 DatabaseError,
32 errors::DatabaseInconsistencyError,
33 filter::{Filter, StatementExt as _},
34 iden::PersonalSessions,
35 pagination::QueryBuilderExt as _,
36 tracing::ExecuteExt as _,
37};
38
39pub struct PgPersonalSessionRepository<'c> {
42 conn: &'c mut PgConnection,
43}
44
45impl<'c> PgPersonalSessionRepository<'c> {
46 pub fn new(conn: &'c mut PgConnection) -> Self {
49 Self { conn }
50 }
51}
52
53#[derive(sqlx::FromRow)]
54#[enum_def]
55struct PersonalSessionLookup {
56 personal_session_id: Uuid,
57 owner_user_id: Option<Uuid>,
58 owner_oauth2_client_id: Option<Uuid>,
59 actor_user_id: Uuid,
60 human_name: String,
61 scope_list: Vec<String>,
62 created_at: DateTime<Utc>,
63 revoked_at: Option<DateTime<Utc>>,
64 last_active_at: Option<DateTime<Utc>>,
65 last_active_ip: Option<IpAddr>,
66}
67
68impl Node<Ulid> for PersonalSessionLookup {
69 fn cursor(&self) -> Ulid {
70 self.personal_session_id.into()
71 }
72}
73
74impl TryFrom<PersonalSessionLookup> for PersonalSession {
75 type Error = DatabaseInconsistencyError;
76
77 fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
78 let id = Ulid::from(value.personal_session_id);
79 let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
80 let scope = scope.map_err(|e| {
81 DatabaseInconsistencyError::on("personal_sessions")
82 .column("scope")
83 .row(id)
84 .source(e)
85 })?;
86
87 let state = match value.revoked_at {
88 None => SessionState::Valid,
89 Some(revoked_at) => SessionState::Revoked { revoked_at },
90 };
91
92 let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
93 (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
94 (None, Some(owner_oauth2_client_id)) => {
95 PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
96 }
97 _ => {
98 return Err(DatabaseInconsistencyError::on("personal_sessions")
100 .column("owner_user_id, owner_oauth2_client_id")
101 .row(id));
102 }
103 };
104
105 Ok(PersonalSession {
106 id,
107 state,
108 owner,
109 actor_user_id: Ulid::from(value.actor_user_id),
110 human_name: value.human_name,
111 scope,
112 created_at: value.created_at,
113 last_active_at: value.last_active_at,
114 last_active_ip: value.last_active_ip,
115 })
116 }
117}
118
119#[async_trait]
120impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
121 type Error = DatabaseError;
122
123 #[tracing::instrument(
124 name = "db.personal_session.lookup",
125 skip_all,
126 fields(
127 db.query.text,
128 session.id = %id,
129 ),
130 err,
131 )]
132 async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
133 let res = sqlx::query_as!(
134 PersonalSessionLookup,
135 r#"
136 SELECT personal_session_id
137 , owner_user_id
138 , owner_oauth2_client_id
139 , actor_user_id
140 , scope_list
141 , created_at
142 , revoked_at
143 , human_name
144 , last_active_at
145 , last_active_ip as "last_active_ip: IpAddr"
146 FROM personal_sessions
147
148 WHERE personal_session_id = $1
149 "#,
150 Uuid::from(id),
151 )
152 .traced()
153 .fetch_optional(&mut *self.conn)
154 .await?;
155
156 let Some(session) = res else { return Ok(None) };
157
158 Ok(Some(session.try_into()?))
159 }
160
161 #[tracing::instrument(
162 name = "db.personal_session.add",
163 skip_all,
164 fields(
165 db.query.text,
166 session.id,
167 session.scope = %scope,
168 ),
169 err,
170 )]
171 async fn add(
172 &mut self,
173 rng: &mut (dyn RngCore + Send),
174 clock: &dyn Clock,
175 owner: PersonalSessionOwner,
176 actor_user: &User,
177 human_name: String,
178 scope: Scope,
179 ) -> Result<PersonalSession, Self::Error> {
180 let created_at = clock.now();
181 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
182 tracing::Span::current().record("session.id", tracing::field::display(id));
183
184 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
185
186 let (owner_user_id, owner_oauth2_client_id) = match owner {
187 PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
188 PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
189 };
190
191 sqlx::query!(
192 r#"
193 INSERT INTO personal_sessions
194 ( personal_session_id
195 , owner_user_id
196 , owner_oauth2_client_id
197 , actor_user_id
198 , human_name
199 , scope_list
200 , created_at
201 )
202 VALUES ($1, $2, $3, $4, $5, $6, $7)
203 "#,
204 Uuid::from(id),
205 owner_user_id,
206 owner_oauth2_client_id,
207 Uuid::from(actor_user.id),
208 &human_name,
209 &scope_list,
210 created_at,
211 )
212 .traced()
213 .execute(&mut *self.conn)
214 .await?;
215
216 Ok(PersonalSession {
217 id,
218 state: SessionState::Valid,
219 owner,
220 actor_user_id: actor_user.id,
221 human_name,
222 scope,
223 created_at,
224 last_active_at: None,
225 last_active_ip: None,
226 })
227 }
228
229 #[tracing::instrument(
230 name = "db.personal_session.revoke",
231 skip_all,
232 fields(
233 db.query.text,
234 %session.id,
235 %session.scope,
236 ),
237 err,
238 )]
239 async fn revoke(
240 &mut self,
241 clock: &dyn Clock,
242 session: PersonalSession,
243 ) -> Result<PersonalSession, Self::Error> {
244 let finished_at = clock.now();
245 let res = sqlx::query!(
246 r#"
247 UPDATE personal_sessions
248 SET revoked_at = $2
249 WHERE personal_session_id = $1
250 "#,
251 Uuid::from(session.id),
252 finished_at,
253 )
254 .traced()
255 .execute(&mut *self.conn)
256 .await?;
257
258 DatabaseError::ensure_affected_rows(&res, 1)?;
259
260 session
261 .finish(finished_at)
262 .map_err(DatabaseError::to_invalid_operation)
263 }
264
265 #[tracing::instrument(
266 name = "db.personal_session.list",
267 skip_all,
268 fields(
269 db.query.text,
270 ),
271 err,
272 )]
273 async fn list(
274 &mut self,
275 filter: PersonalSessionFilter<'_>,
276 pagination: Pagination,
277 ) -> Result<Page<PersonalSession>, Self::Error> {
278 let (sql, arguments) = Query::select()
279 .expr_as(
280 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
281 PersonalSessionLookupIden::PersonalSessionId,
282 )
283 .expr_as(
284 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
285 PersonalSessionLookupIden::OwnerUserId,
286 )
287 .expr_as(
288 Expr::col((
289 PersonalSessions::Table,
290 PersonalSessions::OwnerOAuth2ClientId,
291 )),
292 PersonalSessionLookupIden::OwnerOauth2ClientId,
293 )
294 .expr_as(
295 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
296 PersonalSessionLookupIden::ActorUserId,
297 )
298 .expr_as(
299 Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
300 PersonalSessionLookupIden::HumanName,
301 )
302 .expr_as(
303 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
304 PersonalSessionLookupIden::ScopeList,
305 )
306 .expr_as(
307 Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
308 PersonalSessionLookupIden::CreatedAt,
309 )
310 .expr_as(
311 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
312 PersonalSessionLookupIden::RevokedAt,
313 )
314 .expr_as(
315 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
316 PersonalSessionLookupIden::LastActiveAt,
317 )
318 .expr_as(
319 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
320 PersonalSessionLookupIden::LastActiveIp,
321 )
322 .from(PersonalSessions::Table)
323 .apply_filter(filter)
324 .generate_pagination(
325 (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
326 pagination,
327 )
328 .build_sqlx(PostgresQueryBuilder);
329
330 let edges: Vec<PersonalSessionLookup> = sqlx::query_as_with(&sql, arguments)
331 .traced()
332 .fetch_all(&mut *self.conn)
333 .await?;
334
335 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
336
337 Ok(page)
338 }
339
340 #[tracing::instrument(
341 name = "db.personal_session.count",
342 skip_all,
343 fields(
344 db.query.text,
345 ),
346 err,
347 )]
348 async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
349 let (sql, arguments) = Query::select()
350 .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
351 .from(PersonalSessions::Table)
352 .apply_filter(filter)
353 .build_sqlx(PostgresQueryBuilder);
354
355 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
356 .traced()
357 .fetch_one(&mut *self.conn)
358 .await?;
359
360 count
361 .try_into()
362 .map_err(DatabaseError::to_invalid_operation)
363 }
364}
365
366impl Filter for PersonalSessionFilter<'_> {
367 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
368 sea_query::Condition::all()
369 .add_option(self.owner_user().map(|user| {
370 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
371 .eq(Uuid::from(user.id))
372 }))
373 .add_option(self.owner_oauth2_client().map(|client| {
374 Expr::col((
375 PersonalSessions::Table,
376 PersonalSessions::OwnerOAuth2ClientId,
377 ))
378 .eq(Uuid::from(client.id))
379 }))
380 .add_option(self.actor_user().map(|user| {
381 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
382 .eq(Uuid::from(user.id))
383 }))
384 .add_option(self.device().map(|device| -> SimpleExpr {
385 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
386 Condition::any()
387 .add(
388 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
389 PersonalSessions::Table,
390 PersonalSessions::ScopeList,
391 )))),
392 )
393 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
394 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
395 )))
396 .into()
397 } else {
398 Expr::val(false).into()
400 }
401 }))
402 .add_option(self.state().map(|state| match state {
403 PersonalSessionState::Active => {
404 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
405 }
406 PersonalSessionState::Revoked => {
407 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
408 }
409 }))
410 .add_option(self.scope().map(|scope| {
411 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
412 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
413 }))
414 .add_option(self.last_active_before().map(|last_active_before| {
415 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
416 .lt(last_active_before)
417 }))
418 .add_option(self.last_active_after().map(|last_active_after| {
419 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
420 .gt(last_active_after)
421 }))
422 }
423}