rate_limit: Move functions called by external code to RateLimitedObject.

This commit is contained in:
Mateusz Mandera
2020-03-04 14:05:25 +01:00
committed by Tim Abbott
parent 2b51b3c6c5
commit 85df6201f6
11 changed files with 102 additions and 109 deletions

View File

@@ -28,7 +28,7 @@ from zerver.lib.exceptions import JsonableError, ErrorCode, \
from zerver.lib.types import ViewFuncT
from zerver.lib.validator import to_non_negative_int
from zerver.lib.rate_limiter import rate_limit_request_by_entity, RateLimitedUser
from zerver.lib.rate_limiter import RateLimitedUser
from zerver.lib.request import REQ, has_request_variables
from functools import wraps
@@ -744,8 +744,7 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non
if the user has been rate limited, otherwise returns and modifies request to contain
the rate limit information"""
entity = RateLimitedUser(user, domain=domain)
rate_limit_request_by_entity(request, entity)
RateLimitedUser(user, domain=domain).rate_limit_request(request)
def rate_limit(domain: str='api_by_user') -> Callable[[ViewFuncT], ViewFuncT]:
"""Rate-limits a view. Takes an optional 'domain' param if you wish to

View File

@@ -19,7 +19,7 @@ from zerver.lib.email_validation import email_allowed_for_realm, \
validate_email_not_already_in_realm
from zerver.lib.name_restrictions import is_reserved_subdomain, is_disposable_domain
from zerver.lib.rate_limiter import RateLimited, get_rate_limit_result_from_request, \
RateLimitedObject, rate_limit_entity
RateLimitedObject
from zerver.lib.request import JsonableError
from zerver.lib.send_email import send_email, FromAddress
from zerver.lib.subdomains import get_subdomain, is_root_domain_available
@@ -314,7 +314,7 @@ class RateLimitedPasswordResetByEmail(RateLimitedObject):
return settings.RATE_LIMITING_RULES['password_reset_form_by_email']
def rate_limit_password_reset_form_by_email(email: str) -> None:
ratelimited, _ = rate_limit_entity(RateLimitedPasswordResetByEmail(email))
ratelimited, _ = RateLimitedPasswordResetByEmail(email).rate_limit()
if ratelimited:
raise RateLimited

View File

@@ -20,7 +20,7 @@ from zerver.lib.queue import queue_json_publish
from zerver.lib.utils import generate_random_token
from zerver.lib.upload import upload_message_file
from zerver.lib.send_email import FromAddress
from zerver.lib.rate_limiter import RateLimitedObject, rate_limit_entity
from zerver.lib.rate_limiter import RateLimitedObject
from zerver.lib.exceptions import RateLimited
from zerver.models import Stream, Recipient, MissedMessageEmailAddress, \
get_display_recipient, \
@@ -453,8 +453,7 @@ class RateLimitedRealmMirror(RateLimitedObject):
return self.realm.string_id
def rate_limit_mirror_by_realm(recipient_realm: Realm) -> None:
entity = RateLimitedRealmMirror(recipient_realm)
ratelimited = rate_limit_entity(entity)[0]
ratelimited = RateLimitedRealmMirror(recipient_realm).rate_limit()[0]
if ratelimited:
raise RateLimited()

View File

@@ -34,6 +34,80 @@ class RateLimitedObject(ABC):
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype)
for keytype in ['list', 'zset', 'block']]
def rate_limit(self) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)
ratelimited, time = is_ratelimited(self)
if ratelimited:
statsd.incr("ratelimiter.limited.%s.%s" % (type(self), str(self)))
else:
try:
incr_ratelimit(self)
except RateLimiterLockingException:
logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % (
type(self).__name__, str(self)))
# rate-limit users who are hitting the API so hard we can't update our stats.
ratelimited = True
return ratelimited, time
def rate_limit_request(self, request: HttpRequest) -> None:
ratelimited, time = self.rate_limit()
entity_type = type(self).__name__
if not hasattr(request, '_ratelimit'):
request._ratelimit = {}
request._ratelimit[entity_type] = RateLimitResult(
entity=self,
secs_to_freedom=time,
over_limit=ratelimited
)
# Abort this request if the user is over their rate limits
if ratelimited:
# Pass information about what kind of entity got limited in the exception:
raise RateLimited(entity_type)
calls_remaining, time_reset = self.api_calls_left()
request._ratelimit[entity_type].remaining = calls_remaining
request._ratelimit[entity_type].secs_to_freedom = time_reset
def block_access(self, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
_, _, blocking_key = self.get_keys()
with client.pipeline() as pipe:
pipe.set(blocking_key, 1)
pipe.expire(blocking_key, seconds)
pipe.execute()
def unblock_access(self) -> None:
_, _, blocking_key = self.get_keys()
client.delete(blocking_key)
def clear_history(self) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
'''
for key in self.get_keys():
client.delete(key)
def max_api_calls(self) -> int:
"Returns the API rate limit for the highest limit"
return self.rules()[-1][1]
def max_api_window(self) -> int:
"Returns the API time window for the highest limit"
return self.rules()[-1][0]
def api_calls_left(self) -> Tuple[int, float]:
"""Returns how many API calls in this range this client has, as well as when
the rate-limit will be reset to 0"""
max_window = self.max_api_window()
max_calls = self.max_api_calls()
return _get_api_calls_left(self, max_window, max_calls)
@abstractmethod
def key_fragment(self) -> str:
pass
@@ -71,14 +145,6 @@ def bounce_redis_key_prefix_for_testing(test_name: str) -> None:
global KEY_PREFIX
KEY_PREFIX = test_name + ':' + str(os.getpid()) + ':'
def max_api_calls(entity: RateLimitedObject) -> int:
"Returns the API rate limit for the highest limit"
return entity.rules()[-1][1]
def max_api_window(entity: RateLimitedObject) -> int:
"Returns the API time window for the highest limit"
return entity.rules()[-1][0]
def add_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='api_by_user') -> None:
"Add a rate-limiting rule to the ratelimiter"
global rules
@@ -95,26 +161,6 @@ def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='ap
global rules
rules[domain] = [x for x in rules[domain] if x[0] != range_seconds and x[1] != num_requests]
def block_access(entity: RateLimitedObject, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
_, _, blocking_key = entity.get_keys()
with client.pipeline() as pipe:
pipe.set(blocking_key, 1)
pipe.expire(blocking_key, seconds)
pipe.execute()
def unblock_access(entity: RateLimitedObject) -> None:
_, _, blocking_key = entity.get_keys()
client.delete(blocking_key)
def clear_history(entity: RateLimitedObject) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
'''
for key in entity.get_keys():
client.delete(key)
def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]:
list_key, set_key, _ = entity.get_keys()
# Count the number of values in our sorted set
@@ -142,13 +188,6 @@ def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls
return calls_left, time_reset
def api_calls_left(entity: RateLimitedObject) -> Tuple[int, float]:
"""Returns how many API calls in this range this client has, as well as when
the rate-limit will be reset to 0"""
max_window = max_api_window(entity)
max_calls = max_api_calls(entity)
return _get_api_calls_left(entity, max_window, max_calls)
def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]:
"Returns a tuple of (rate_limited, time_till_free)"
list_key, set_key, blocking_key = entity.get_keys()
@@ -219,7 +258,7 @@ def incr_ratelimit(entity: RateLimitedObject) -> None:
pipe.watch(list_key)
# Get the last elem that we'll trim (so we can remove it from our sorted set)
last_val = pipe.lindex(list_key, max_api_calls(entity) - 1)
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
# Restart buffered execution
pipe.multi()
@@ -228,7 +267,7 @@ def incr_ratelimit(entity: RateLimitedObject) -> None:
pipe.lpush(list_key, now)
# Trim our list to the oldest rule we have
pipe.ltrim(list_key, 0, max_api_calls(entity) - 1)
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
# Add our new value to the sorted set that we keep
# We need to put the score and val both as timestamp,
@@ -240,7 +279,7 @@ def incr_ratelimit(entity: RateLimitedObject) -> None:
pipe.zrem(set_key, last_val)
# Set the TTL for our keys as well
api_window = max_api_window(entity)
api_window = entity.max_api_window()
pipe.expire(list_key, api_window)
pipe.expire(set_key, api_window)
@@ -255,45 +294,6 @@ def incr_ratelimit(entity: RateLimitedObject) -> None:
continue
def rate_limit_entity(entity: RateLimitedObject) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)
ratelimited, time = is_ratelimited(entity)
if ratelimited:
statsd.incr("ratelimiter.limited.%s.%s" % (type(entity), str(entity)))
else:
try:
incr_ratelimit(entity)
except RateLimiterLockingException:
logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % (
type(entity).__name__, str(entity)))
# rate-limit users who are hitting the API so hard we can't update our stats.
ratelimited = True
return ratelimited, time
def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject) -> None:
ratelimited, time = rate_limit_entity(entity)
entity_type = type(entity).__name__
if not hasattr(request, '_ratelimit'):
request._ratelimit = {}
request._ratelimit[entity_type] = RateLimitResult(
entity=entity,
secs_to_freedom=time,
over_limit=ratelimited
)
# Abort this request if the user is over their rate limits
if ratelimited:
# Pass information about what kind of entity got limited in the exception:
raise RateLimited(entity_type)
calls_remaining, time_reset = api_calls_left(entity)
request._ratelimit[entity_type].remaining = calls_remaining
request._ratelimit[entity_type].secs_to_freedom = time_reset
class RateLimitResult:
def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool,
remaining: Optional[int]=None) -> None:

View File

@@ -6,8 +6,7 @@ from django.conf import settings
from django.core.management.base import BaseCommand, CommandError, \
CommandParser
from zerver.lib.rate_limiter import RateLimitedUser, client, max_api_calls, \
max_api_window
from zerver.lib.rate_limiter import RateLimitedUser, client
from zerver.models import get_user_profile_by_id
@@ -29,7 +28,7 @@ class Command(BaseCommand):
user_id = int(key.split(':')[1])
user = get_user_profile_by_id(user_id)
entity = RateLimitedUser(user)
max_calls = max_api_calls(entity)
max_calls = entity.max_api_calls()
age = int(client.ttl(key))
if age < 0:
@@ -40,7 +39,7 @@ class Command(BaseCommand):
logging.error("Redis health check found key with more elements \
than max_api_calls! (trying to trim) %s %s" % (key, count))
if trim_func is not None:
client.expire(key, max_api_window(entity))
client.expire(key, entity.max_api_window())
trim_func(key, max_calls)
def handle(self, *args: Any, **options: Any) -> None:

View File

@@ -2,8 +2,7 @@ from argparse import ArgumentParser
from typing import Any
from zerver.lib.management import CommandError, ZulipBaseCommand
from zerver.lib.rate_limiter import RateLimitedUser, block_access, \
unblock_access
from zerver.lib.rate_limiter import RateLimitedUser
from zerver.models import UserProfile, get_user_profile_by_api_key
@@ -59,7 +58,6 @@ class Command(ZulipBaseCommand):
print("Applying operation to User ID: %s: %s" % (user.id, operation))
if operation == 'block':
block_access(RateLimitedUser(user, domain=options['domain']),
options['seconds'])
RateLimitedUser(user, domain=options['domain']).block_access(options['seconds'])
elif operation == 'unblock':
unblock_access(RateLimitedUser(user, domain=options['domain']))
RateLimitedUser(user, domain=options['domain']).unblock_access()

View File

@@ -26,7 +26,7 @@ from zerver.lib.db import reset_queries
from zerver.lib.exceptions import ErrorCode, JsonableError, RateLimited
from zerver.lib.html_to_text import get_content_description
from zerver.lib.queue import queue_json_publish
from zerver.lib.rate_limiter import RateLimitResult, max_api_calls
from zerver.lib.rate_limiter import RateLimitResult
from zerver.lib.response import json_error, json_response_from_error
from zerver.lib.subdomains import get_subdomain
from zerver.lib.utils import statsd
@@ -352,7 +352,7 @@ class RateLimitMiddleware(MiddlewareMixin):
def set_response_headers(self, response: HttpResponse,
rate_limit_results: List[RateLimitResult]) -> None:
# The limit on the action that was requested is the minimum of the limits that get applied:
limit = min([max_api_calls(result.entity) for result in rate_limit_results])
limit = min([result.entity.max_api_calls() for result in rate_limit_results])
response['X-RateLimit-Limit'] = str(limit)
# Same principle applies to remaining api calls:
if all(result.remaining for result in rate_limit_results):

View File

@@ -40,7 +40,7 @@ from zerver.lib.exceptions import RateLimited
from zerver.lib.mobile_auth_otp import otp_decrypt_api_key
from zerver.lib.validator import validate_login_email, \
check_bool, check_dict_only, check_list, check_string, Validator
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule, clear_history
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
from zerver.lib.request import JsonableError
from zerver.lib.storage import static_path
from zerver.lib.upload import resize_avatar, MEDIUM_AVATAR_SIZE
@@ -524,7 +524,7 @@ class RateLimitAuthenticationTests(ZulipTestCase):
attempt_authentication(username, wrong_password)
finally:
# Clean up to avoid affecting other tests.
clear_history(RateLimitedAuthenticationByUsername(username))
RateLimitedAuthenticationByUsername(username).clear_history()
remove_ratelimit_rule(10, 2, domain='authenticate_by_username')
def test_email_auth_backend_user_based_rate_limiting(self) -> None:

View File

@@ -7,7 +7,6 @@ from zerver.forms import email_is_not_mit_mailing_list
from zerver.lib.rate_limiter import (
add_ratelimit_rule,
clear_history,
remove_ratelimit_rule,
RateLimitedUser,
RateLimiterLockingException,
@@ -71,7 +70,7 @@ class RateLimitTests(ZulipTestCase):
def test_headers(self) -> None:
user = self.example_user('hamlet')
clear_history(RateLimitedUser(user))
RateLimitedUser(user).clear_history()
result = self.send_api_message(user, "some stuff")
self.assertTrue('X-RateLimit-Remaining' in result)
@@ -80,7 +79,7 @@ class RateLimitTests(ZulipTestCase):
def test_ratelimit_decrease(self) -> None:
user = self.example_user('hamlet')
clear_history(RateLimitedUser(user))
RateLimitedUser(user).clear_history()
result = self.send_api_message(user, "some stuff")
limit = int(result['X-RateLimit-Remaining'])
@@ -90,7 +89,7 @@ class RateLimitTests(ZulipTestCase):
def test_hit_ratelimits(self) -> None:
user = self.example_user('cordelia')
clear_history(RateLimitedUser(user))
RateLimitedUser(user).clear_history()
start_time = time.time()
for i in range(6):
@@ -116,7 +115,7 @@ class RateLimitTests(ZulipTestCase):
@mock.patch('zerver.lib.rate_limiter.logger.warning')
def test_hit_ratelimiterlockingexception(self, mock_warn: mock.MagicMock) -> None:
user = self.example_user('cordelia')
clear_history(RateLimitedUser(user))
RateLimitedUser(user).clear_history()
with mock.patch('zerver.lib.rate_limiter.incr_ratelimit',
side_effect=RateLimiterLockingException):

View File

@@ -13,7 +13,7 @@ from zerver.lib.actions import create_stream_if_needed
from zerver.lib.email_mirror import RateLimitedRealmMirror
from zerver.lib.email_mirror_helpers import encode_email_address
from zerver.lib.queue import MAX_REQUEST_RETRIES
from zerver.lib.rate_limiter import RateLimiterLockingException, clear_history
from zerver.lib.rate_limiter import RateLimiterLockingException
from zerver.lib.remote_server import PushNotificationBouncerRetryLaterError
from zerver.lib.send_email import FromAddress
from zerver.lib.test_helpers import simulated_queue_client
@@ -359,7 +359,7 @@ class WorkerTest(ZulipTestCase):
mock_warn: MagicMock) -> None:
fake_client = self.FakeClient()
realm = get_realm('zulip')
clear_history(RateLimitedRealmMirror(realm))
RateLimitedRealmMirror(realm).clear_history()
stream = get_stream('Denmark', realm)
stream_to_address = encode_email_address(stream)
data = [

View File

@@ -55,7 +55,7 @@ from zerver.lib.dev_ldap_directory import init_fakeldap
from zerver.lib.email_validation import email_allowed_for_realm, \
validate_email_not_already_in_realm
from zerver.lib.mobile_auth_otp import is_valid_otp
from zerver.lib.rate_limiter import clear_history, rate_limit_request_by_entity, RateLimitedObject
from zerver.lib.rate_limiter import RateLimitedObject
from zerver.lib.request import JsonableError
from zerver.lib.users import check_full_name, validate_user_custom_profile_field
from zerver.lib.redis_utils import get_redis_client, get_dict_from_redis, put_dict_in_redis
@@ -191,8 +191,7 @@ class RateLimitedAuthenticationByUsername(RateLimitedObject):
return rate_limiting_rules
def rate_limit_authentication_by_username(request: HttpRequest, username: str) -> None:
entity = RateLimitedAuthenticationByUsername(username)
rate_limit_request_by_entity(request, entity)
RateLimitedAuthenticationByUsername(username).rate_limit_request(request)
def auth_rate_limiting_already_applied(request: HttpRequest) -> bool:
return hasattr(request, '_ratelimit') and 'RateLimitedAuthenticationByUsername' in request._ratelimit
@@ -224,7 +223,7 @@ def rate_limit_auth(auth_func: AuthFuncT, *args: Any, **kwargs: Any) -> Optional
result = auth_func(*args, **kwargs)
if result is not None:
# Authentication succeeded, clear the rate-limiting record.
clear_history(RateLimitedAuthenticationByUsername(username))
RateLimitedAuthenticationByUsername(username).clear_history()
return result