diff --git a/zerver/lib/sessions.py b/zerver/lib/sessions.py index a6aaff6ad0..1c47973c17 100644 --- a/zerver/lib/sessions.py +++ b/zerver/lib/sessions.py @@ -1,13 +1,15 @@ import logging +from datetime import timedelta from django.conf import settings from django.contrib.auth import SESSION_KEY, get_user_model from django.contrib.sessions.models import Session from django.utils.timezone import now as timezone_now from importlib import import_module -from typing import List, Mapping, Optional +from typing import Any, List, Mapping, Optional from zerver.models import Realm, UserProfile, get_user_profile_by_id +from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime session_engine = import_module(settings.SESSION_ENGINE) @@ -53,3 +55,26 @@ def delete_all_deactivated_user_sessions() -> None: if not user_profile.is_active or user_profile.realm.deactivated: logging.info("Deactivating session for deactivated user %s" % (user_profile.id,)) delete_session(session) + +def set_expirable_session_var(session: Session, var_name: str, var_value: Any, expiry_seconds: int) -> None: + expire_at = datetime_to_timestamp(timezone_now() + timedelta(seconds=expiry_seconds)) + session[var_name] = {'value': var_value, 'expire_at': expire_at} + +def get_expirable_session_var(session: Session, var_name: str, default_value: Any=None, + delete: bool=False) -> Any: + if var_name not in session: + return default_value + + try: + value, expire_at = (session[var_name]['value'], session[var_name]['expire_at']) + except (KeyError, TypeError) as e: + logging.warning("get_expirable_session_var: Variable {}: {}".format(var_name, e)) + return default_value + + if timestamp_to_datetime(expire_at) < timezone_now(): + del session[var_name] + return default_value + + if delete: + del session[var_name] + return value diff --git a/zerver/tests/test_sessions.py b/zerver/tests/test_sessions.py index 6f16483178..87a1b5e667 100644 --- a/zerver/tests/test_sessions.py +++ b/zerver/tests/test_sessions.py @@ -1,3 +1,5 @@ +from datetime import timedelta +from django.utils.timezone import now as timezone_now from typing import Any, Callable from zerver.lib.sessions import ( @@ -7,6 +9,8 @@ from zerver.lib.sessions import ( delete_realm_user_sessions, delete_all_user_sessions, delete_all_deactivated_user_sessions, + get_expirable_session_var, + set_expirable_session_var, ) from zerver.models import ( @@ -15,6 +19,7 @@ from zerver.models import ( from zerver.lib.test_classes import ZulipTestCase +import mock class TestSessions(ZulipTestCase): @@ -93,3 +98,35 @@ class TestSessions(ZulipTestCase): delete_all_deactivated_user_sessions() result = self.client_get("/") self.assertEqual('/login/', result.url) + +class TestExpirableSessionVars(ZulipTestCase): + def setUp(self) -> None: + self.session = self.client.session + super().setUp() + + def test_set_and_get_basic(self) -> None: + start_time = timezone_now() + with mock.patch('zerver.lib.sessions.timezone_now', return_value=start_time): + set_expirable_session_var(self.session, 'test_set_and_get_basic', 'some_value', expiry_seconds=10) + value = get_expirable_session_var(self.session, 'test_set_and_get_basic') + self.assertEqual(value, 'some_value') + with mock.patch('zerver.lib.sessions.timezone_now', return_value=start_time + timedelta(seconds=11)): + value = get_expirable_session_var(self.session, 'test_set_and_get_basic') + self.assertEqual(value, None) + + def test_set_and_get_with_delete(self) -> None: + set_expirable_session_var(self.session, 'test_set_and_get_with_delete', 'some_value', expiry_seconds=10) + value = get_expirable_session_var(self.session, 'test_set_and_get_with_delete', delete=True) + self.assertEqual(value, 'some_value') + self.assertEqual(get_expirable_session_var(self.session, 'test_set_and_get_with_delete'), None) + + def test_get_var_not_set(self) -> None: + value = get_expirable_session_var(self.session, 'test_get_var_not_set', default_value='default') + self.assertEqual(value, 'default') + + def test_get_var_is_not_expirable(self) -> None: + self.session["test_get_var_is_not_expirable"] = 0 + with mock.patch('zerver.lib.sessions.logging.warning') as mock_warn: + value = get_expirable_session_var(self.session, 'test_get_var_is_not_expirable', default_value='default') + self.assertEqual(value, 'default') + mock_warn.assert_called_once()