billing: Move checks from process_initial_upgrade into separate function.

This commit is contained in:
Rishi Gupta
2018-08-06 00:47:15 -04:00
parent 5719633992
commit 9f2b8a4a11
2 changed files with 24 additions and 19 deletions

View File

@@ -9,7 +9,6 @@ from django.conf import settings
from django.db import transaction from django.db import transaction
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from django.core.signing import Signer from django.core.signing import Signer
from django.core import signing
import stripe import stripe
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
@@ -179,23 +178,11 @@ def do_subscribe_customer_to_plan(stripe_customer: stripe.Customer, stripe_plan_
requires_billing_update=True, requires_billing_update=True,
extra_data=ujson.dumps({'quantity': current_seat_count})) extra_data=ujson.dumps({'quantity': current_seat_count}))
def process_initial_upgrade(user: UserProfile, plan: str, signed_seat_count: str, def process_initial_upgrade(user: UserProfile, plan: Plan, seat_count: int, stripe_token: str) -> None:
salt: str, stripe_token: str) -> None:
if plan not in [Plan.CLOUD_ANNUAL, Plan.CLOUD_MONTHLY]:
billing_logger.warning("Tampered plan during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered plan', BillingError.CONTACT_SUPPORT)
try:
seat_count = int(unsign_string(signed_seat_count, salt))
except signing.BadSignature:
billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT)
stripe_customer = do_create_customer_with_payment_source(user, stripe_token) stripe_customer = do_create_customer_with_payment_source(user, stripe_token)
do_subscribe_customer_to_plan( do_subscribe_customer_to_plan(
stripe_customer=stripe_customer, stripe_customer=stripe_customer,
stripe_plan_id=Plan.objects.get(nickname=plan).stripe_plan_id, stripe_plan_id=plan.stripe_plan_id,
seat_count=seat_count, seat_count=seat_count,
# TODO: billing address details are passed to us in the request; # TODO: billing address details are passed to us in the request;
# use that to calculate taxes. # use that to calculate taxes.

View File

@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional, Union, cast from typing import Any, Dict, Optional, Tuple, Union, cast
import logging import logging
from django.core import signing
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import validate_email, URLValidator from django.core.validators import validate_email, URLValidator
from django.db import IntegrityError from django.db import IntegrityError
@@ -27,7 +28,7 @@ from zerver.views.push_notifications import validate_token
from zilencer.lib.stripe import STRIPE_PUBLISHABLE_KEY, \ from zilencer.lib.stripe import STRIPE_PUBLISHABLE_KEY, \
get_stripe_customer, get_upcoming_invoice, get_seat_count, \ get_stripe_customer, get_upcoming_invoice, get_seat_count, \
extract_current_subscription, process_initial_upgrade, sign_string, \ extract_current_subscription, process_initial_upgrade, sign_string, \
BillingError unsign_string, BillingError
from zilencer.models import RemotePushDeviceToken, RemoteZulipServer, \ from zilencer.models import RemotePushDeviceToken, RemoteZulipServer, \
Customer, Plan Customer, Plan
@@ -158,6 +159,22 @@ def remote_server_notify_push(request: HttpRequest, entity: Union[UserProfile, R
return json_success() return json_success()
def unsign_and_check_upgrade_parameters(user: UserProfile, plan_nickname: str,
signed_seat_count: str, salt: str) -> Tuple[Plan, int]:
if plan_nickname not in [Plan.CLOUD_ANNUAL, Plan.CLOUD_MONTHLY]:
billing_logger.warning("Tampered plan during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered plan', BillingError.CONTACT_SUPPORT)
plan = Plan.objects.get(nickname=plan_nickname)
try:
seat_count = int(unsign_string(signed_seat_count, salt))
except signing.BadSignature:
billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT)
return plan, seat_count
@zulip_login_required @zulip_login_required
def initial_upgrade(request: HttpRequest) -> HttpResponse: def initial_upgrade(request: HttpRequest) -> HttpResponse:
if not settings.DEVELOPMENT: if not settings.DEVELOPMENT:
@@ -172,8 +189,9 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
if request.method == 'POST': if request.method == 'POST':
try: try:
process_initial_upgrade(user, request.POST['plan'], request.POST['signed_seat_count'], plan, seat_count = unsign_and_check_upgrade_parameters(
request.POST['salt'], request.POST['stripeToken']) user, request.POST['plan'], request.POST['signed_seat_count'], request.POST['salt'])
process_initial_upgrade(user, plan, seat_count, request.POST['stripeToken'])
except BillingError as e: except BillingError as e:
error_message = e.message error_message = e.message
error_description = e.description error_description = e.description