Source code for timApp.tests.server.test_user_sessions

from timApp.auth.accesstype import AccessType
from timApp.auth.session.model import UserSession
from timApp.auth.session.util import verify_session_for
from timApp.tests.server.timroutetest import TimRouteTest
from timApp.timdb.sqa import db


[docs]class UserSessionsTest(TimRouteTest):
[docs] def forget_session(self): """Forget the current session. Simulates user e.g. closing an incognito window.""" with self.client.session_transaction() as s: s.clear()
[docs] def latest_session(self) -> UserSession: """Get latest session of Test User 1.""" return ( UserSession.query.filter_by(user_id=self.test_user_1.id) .order_by(UserSession.logged_in_at.desc()) .first() )
[docs] def assert_sesion_expired_state( self, session_ids: list[str], state: list[bool], msg: str ) -> None: """Assert the state of Test User 1's sessions.""" self.assertEqual( [ UserSession.query.filter_by( user_id=self.test_user_1.id, session_id=sess, ) .one() .expired for sess in session_ids ], state, msg, )
[docs] def test_session_basic(self): """Test that sessions are tracked when SESSIONS_ENABLE is True.""" self.logout() with self.temp_config( { "SESSIONS_ENABLE": True, "SESSIONS_MAX_CONCURRENT_SESSIONS_PER_USER": None, } ): UserSession.query.delete() db.session.commit() self.login_test1(manual=True) sessions: list[UserSession] = UserSession.query.all() self.assertEqual(len(sessions), 1) self.assertEqual(sessions[0].user.name, self.test_user_1.name) self.assertEqual(sessions[0].expired, False) self.get( "/user/sessions/current", expect_content={ "sessionId": sessions[0].session_id, "valid": True, }, ) self.logout() sessions: list[UserSession] = UserSession.query.all() self.assertEqual(len(sessions), 1) self.assertEqual(sessions[0].user.name, self.test_user_1.name) self.assertEqual(sessions[0].expired, True)
[docs] def test_session_access_block(self): """Test that expired sessions cannot access documents.""" self.login_test2() d = self.create_doc() self.test_user_1.grant_access(d, AccessType.view) db.session.commit() self.logout() with self.temp_config( { "SESSIONS_ENABLE": True, "SESSIONS_MAX_CONCURRENT_SESSIONS_PER_USER": 1, } ): UserSession.query.delete() db.session.commit() self.login_test1(manual=True) self.get(f"/view/{d.id}", expect_status=200) self.forget_session() self.login_test1(manual=True) self.get(f"/view/{d.id}", expect_status=490)
[docs] def test_session_validity(self): """Test cases where the session can invalidate itself.""" self.logout() with self.temp_config( { "SESSIONS_ENABLE": True, "SESSIONS_MAX_CONCURRENT_SESSIONS_PER_USER": 1, } ): UserSession.query.delete() db.session.commit() self.login_test1(manual=True) lsess = self.latest_session() self.get( "/user/sessions/current", expect_content={ "sessionId": lsess.session_id, "valid": True, }, ) self.logout() lsess = self.latest_session() self.assertEqual( lsess.expired, True, "Session should be expired after logging out", ) self.get("/user/sessions/current", expect_status=403) self.login_test1(manual=True) lsess = self.latest_session() prev_id = lsess.session_id self.assertEqual(lsess.expired, False) self.get( "/user/sessions/current", expect_content={ "sessionId": lsess.session_id, "valid": True, }, ) self.forget_session() self.login_test1(manual=True) self.assertEqual( lsess.expired, False, "The previous session should still be valid" ) lsess = self.latest_session() self.assertEqual( lsess.expired, True, "The new session should be automatically expired (max concurrent sessions reached)", ) self.get( "/user/sessions/current", expect_content={ "sessionId": lsess.session_id, "valid": False, }, ) verify_session_for(self.test_user_1.name, lsess.session_id) db.session.commit() # Session is valid after verifying it self.get( "/user/sessions/current", expect_content={ "sessionId": lsess.session_id, "valid": True, }, ) prev_session = UserSession.query.filter_by( user_id=self.test_user_1.id, session_id=prev_id ).first() self.assertEqual( prev_session.expired, True, "The previously unexpired session should now be expired", )
[docs] def test_session_verify_remote(self): """Test that a session can be verified remotely.""" self.logout() with self.temp_config( { "SESSIONS_ENABLE": True, "SESSIONS_MAX_CONCURRENT_SESSIONS_PER_USER": 1, "DIST_RIGHTS_SEND_SECRET": "yyy", "DIST_RIGHTS_RECEIVE_SECRET": "yyy", } ): UserSession.query.delete() db.session.commit() session_ids = [] self.login_test1(manual=True) session_ids.append(self.latest_session().session_id) self.forget_session() self.login_test1(manual=True) session_ids.append(self.latest_session().session_id) self.forget_session() self.login_test1(manual=True) session_ids.append(self.latest_session().session_id) self.forget_session() self.post( "/user/sessions/verify", data={ "session_id": session_ids[0], "username": self.test_user_1.name, }, expect_status=403, ) self.post( "/user/sessions/verify", data={ "session_id": session_ids[0], "username": self.test_user_1.name, "secret": "xxx", }, expect_status=400, ) self.post( "/user/sessions/verify", data={ "session_id": "invalid session id", "username": self.test_user_1.name, "secret": "yyy", }, expect_status=200, ) self.post( "/user/sessions/verify", data={ "session_id": session_ids[0], "username": self.test_user_1.name, "secret": "yyy", }, ) self.assert_sesion_expired_state( session_ids, [False, True, True], "Specific session should be valid, others expired", ) # Mark all sessions as not expired to test verification of the latest session UserSession.query.filter_by(user_id=self.test_user_1.id).update( {"expired_at": None} ) db.session.commit() self.post( "/user/sessions/verify", data={ "username": self.test_user_1.name, "secret": "yyy", }, ) self.assert_sesion_expired_state( session_ids, [True, True, False], "Latest session should be valid, others expired", )
[docs] def test_session_invalidate(self) -> None: self.logout() with self.temp_config( { "SESSIONS_ENABLE": True, "SESSIONS_MAX_CONCURRENT_SESSIONS_PER_USER": 1, "DIST_RIGHTS_SEND_SECRET": "yyy", "DIST_RIGHTS_RECEIVE_SECRET": "yyy", } ): UserSession.query.delete() db.session.commit() session_ids = [] self.login_test1(manual=True) session_ids.append(self.latest_session().session_id) self.forget_session() self.login_test1(manual=True) session_ids.append(self.latest_session().session_id) self.forget_session() self.login_test1(manual=True) session_ids.append(self.latest_session().session_id) self.forget_session() self.assert_sesion_expired_state( session_ids, [False, True, True], "Oldest session should be valid, others expired", ) self.post( "/user/sessions/invalidate", data={ "username": self.test_user_1.name, "session_id": session_ids[0], }, expect_status=403, ) self.post( "/user/sessions/invalidate", data={ "username": self.test_user_1.name, "session_id": session_ids[1], "secret": "xxx", }, expect_status=400, ) self.post( "/user/sessions/invalidate", data={ "username": self.test_user_1.name, "session_id": "invalid session id", "secret": "yyy", }, expect_status=200, ) self.assert_sesion_expired_state( session_ids, [False, True, True], "Expiration state should not have changed after wrong invalidation attempts", ) self.post( "/user/sessions/invalidate", data={ "username": self.test_user_1.name, "session_id": session_ids[0], "secret": "yyy", }, ) self.assert_sesion_expired_state( session_ids, [True, True, True], "Oldest session should be invalidated", ) for session_id in session_ids: sess = UserSession.query.filter_by(session_id=session_id).first() sess.expired_at = None db.session.commit() self.assert_sesion_expired_state( session_ids, [False, False, False], "All sessions should be valid after manual validation", ) self.post( "/user/sessions/invalidate", data={ "username": self.test_user_1.name, "secret": "yyy", }, ) self.assert_sesion_expired_state( session_ids, [True, True, True], "All sessions should be expired after invalidation", )