Django 1.8 compatibility: extracting the user from a session

django commit 596564e80808 stores the user id in the session as a
string, which broke our code that extracts the user id and compares
it to the id of a UserProfile object.

(imported from commit 99defd7fea96553550fa19e0b2f3e91a1baac123)
This commit is contained in:
Reid Barton
2015-08-19 11:53:55 -07:00
parent 5ea3bf85de
commit 9db521a931
5 changed files with 29 additions and 14 deletions

View File

@@ -53,6 +53,7 @@ from zerver.lib.push_notifications import num_push_devices_for_user, \
send_apple_push_notification, send_android_push_notification send_apple_push_notification, send_android_push_notification
from zerver.lib.notifications import clear_followup_emails_queue from zerver.lib.notifications import clear_followup_emails_queue
from zerver.lib.narrow import check_supported_events_narrow_filter from zerver.lib.narrow import check_supported_events_narrow_filter
from zerver.lib.session_user import get_session_user
import DNS import DNS
import ujson import ujson
@@ -166,21 +167,21 @@ def do_create_user(email, password, realm, full_name, short_name,
def user_sessions(user_profile): def user_sessions(user_profile):
return [s for s in Session.objects.all() return [s for s in Session.objects.all()
if s.get_decoded().get('_auth_user_id') == user_profile.id] if get_session_user(s) == user_profile.id]
def delete_session(session): def delete_session(session):
return session_engine.SessionStore(session.session_key).delete() return session_engine.SessionStore(session.session_key).delete()
def delete_user_sessions(user_profile): def delete_user_sessions(user_profile):
for session in Session.objects.all(): for session in Session.objects.all():
if session.get_decoded().get('_auth_user_id') == user_profile.id: if get_session_user(session) == user_profile.id:
delete_session(session) delete_session(session)
def delete_realm_user_sessions(realm): def delete_realm_user_sessions(realm):
realm_user_ids = [user_profile.id for user_profile in realm_user_ids = [user_profile.id for user_profile in
UserProfile.objects.filter(realm=realm)] UserProfile.objects.filter(realm=realm)]
for session in Session.objects.filter(expire_date__gte=datetime.datetime.now()): for session in Session.objects.filter(expire_date__gte=datetime.datetime.now()):
if session.get_decoded().get('_auth_user_id') in realm_user_ids: if get_session_user(session) in realm_user_ids:
delete_session(session) delete_session(session)
def delete_all_user_sessions(): def delete_all_user_sessions():

View File

@@ -0,0 +1,13 @@
from __future__ import absolute_import
from django.contrib.auth import SESSION_KEY, get_user_model
def get_session_dict_user(session_dict):
# Compare django.contrib.auth._get_user_session_key
try:
return get_user_model()._meta.pk.to_python(session_dict[SESSION_KEY])
except KeyError:
return None
def get_session_user(session):
return get_session_dict_user(session.get_decoded())

View File

@@ -20,6 +20,7 @@ from zerver.lib.event_queue import get_client_descriptor
from zerver.middleware import record_request_start_data, record_request_stop_data, \ from zerver.middleware import record_request_start_data, record_request_stop_data, \
record_request_restart_data, write_log_line, format_timedelta record_request_restart_data, write_log_line, format_timedelta
from zerver.lib.redis_utils import get_redis_client from zerver.lib.redis_utils import get_redis_client
from zerver.lib.session_user import get_session_user
logger = logging.getLogger('zulip.socket') logger = logging.getLogger('zulip.socket')
@@ -34,10 +35,8 @@ def get_user_profile(session_id):
except djSession.DoesNotExist: except djSession.DoesNotExist:
return None return None
session_store = djsession_engine.SessionStore(djsession.session_key)
try: try:
return UserProfile.objects.get(pk=session_store['_auth_user_id']) return UserProfile.objects.get(pk=get_session_user(djsession))
except (UserProfile.DoesNotExist, KeyError): except (UserProfile.DoesNotExist, KeyError):
return None return None

View File

@@ -19,6 +19,7 @@ from zerver.lib.digest import send_digest_email
from zerver.lib.notifications import enqueue_welcome_emails, one_click_unsubscribe_link from zerver.lib.notifications import enqueue_welcome_emails, one_click_unsubscribe_link
from zerver.lib.test_helpers import AuthedTestCase, find_key_by_email, queries_captured from zerver.lib.test_helpers import AuthedTestCase, find_key_by_email, queries_captured
from zerver.lib.test_runner import slow from zerver.lib.test_runner import slow
from zerver.lib.session_user import get_session_dict_user
import re import re
import ujson import ujson
@@ -90,11 +91,11 @@ class LoginTest(AuthedTestCase):
def test_login(self): def test_login(self):
self.login("hamlet@zulip.com") self.login("hamlet@zulip.com")
user_profile = get_user_profile_by_email('hamlet@zulip.com') user_profile = get_user_profile_by_email('hamlet@zulip.com')
self.assertEqual(self.client.session['_auth_user_id'], user_profile.id) self.assertEqual(get_session_dict_user(self.client.session), user_profile.id)
def test_login_bad_password(self): def test_login_bad_password(self):
self.login("hamlet@zulip.com", "wrongpassword") self.login("hamlet@zulip.com", "wrongpassword")
self.assertIsNone(self.client.session.get('_auth_user_id', None)) self.assertIsNone(get_session_dict_user(self.client.session))
def test_login_nonexist_user(self): def test_login_nonexist_user(self):
result = self.login("xxx@zulip.com", "xxx") result = self.login("xxx@zulip.com", "xxx")
@@ -112,7 +113,7 @@ class LoginTest(AuthedTestCase):
# Ensure the number of queries we make is not O(streams) # Ensure the number of queries we make is not O(streams)
self.assert_length(queries, 67) self.assert_length(queries, 67)
user_profile = get_user_profile_by_email('test@zulip.com') user_profile = get_user_profile_by_email('test@zulip.com')
self.assertEqual(self.client.session['_auth_user_id'], user_profile.id) self.assertEqual(get_session_dict_user(self.client.session), user_profile.id)
def test_register_deactivated(self): def test_register_deactivated(self):
""" """
@@ -143,7 +144,7 @@ class LoginTest(AuthedTestCase):
def test_logout(self): def test_logout(self):
self.login("hamlet@zulip.com") self.login("hamlet@zulip.com")
self.client.post('/accounts/logout/') self.client.post('/accounts/logout/')
self.assertIsNone(self.client.session.get('_auth_user_id', None)) self.assertIsNone(get_session_dict_user(self.client.session))
def test_non_ascii_login(self): def test_non_ascii_login(self):
""" """
@@ -155,14 +156,14 @@ class LoginTest(AuthedTestCase):
# Registering succeeds. # Registering succeeds.
self.register("test", password) self.register("test", password)
user_profile = get_user_profile_by_email(email) user_profile = get_user_profile_by_email(email)
self.assertEqual(self.client.session['_auth_user_id'], user_profile.id) self.assertEqual(get_session_dict_user(self.client.session), user_profile.id)
self.client.post('/accounts/logout/') self.client.post('/accounts/logout/')
self.assertIsNone(self.client.session.get('_auth_user_id', None)) self.assertIsNone(get_session_dict_user(self.client.session))
# Logging in succeeds. # Logging in succeeds.
self.client.post('/accounts/logout/') self.client.post('/accounts/logout/')
self.login(email, password) self.login(email, password)
self.assertEqual(self.client.session['_auth_user_id'], user_profile.id) self.assertEqual(get_session_dict_user(self.client.session), user_profile.id)
def test_register_first_user_with_invites(self): def test_register_first_user_with_invites(self):
""" """

View File

@@ -24,6 +24,7 @@ from zerver.lib.actions import \
from zerver.lib.alert_words import alert_words_in_realm, user_alert_words, \ from zerver.lib.alert_words import alert_words_in_realm, user_alert_words, \
add_user_alert_words, remove_user_alert_words add_user_alert_words, remove_user_alert_words
from zerver.lib.notifications import handle_missedmessage_emails from zerver.lib.notifications import handle_missedmessage_emails
from zerver.lib.session_user import get_session_dict_user
from zerver.middleware import is_slow_query from zerver.middleware import is_slow_query
from zerver.worker import queue_processors from zerver.worker import queue_processors
@@ -951,7 +952,7 @@ class ChangeSettingsTest(AuthedTestCase):
self.client.post('/accounts/logout/') self.client.post('/accounts/logout/')
self.login("hamlet@zulip.com", "foobar1") self.login("hamlet@zulip.com", "foobar1")
user_profile = get_user_profile_by_email('hamlet@zulip.com') user_profile = get_user_profile_by_email('hamlet@zulip.com')
self.assertEqual(self.client.session['_auth_user_id'], user_profile.id) self.assertEqual(get_session_dict_user(self.client.session), user_profile.id)
def test_notify_settings(self): def test_notify_settings(self):
# This is basically a don't-explode test. # This is basically a don't-explode test.