move callback url info to the backend

This commit is contained in:
wh1te909
2024-11-04 21:58:37 +00:00
parent 4a5bfee616
commit 46c5128418
7 changed files with 72 additions and 27 deletions

View File

@@ -1,10 +1,11 @@
import subprocess import subprocess
import pyotp import pyotp
from django.conf import settings
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from accounts.models import User from accounts.models import User
from tacticalrmm.helpers import get_webdomain from tacticalrmm.util_settings import get_webdomain
class Command(BaseCommand): class Command(BaseCommand):
@@ -26,7 +27,7 @@ class Command(BaseCommand):
user.save(update_fields=["totp_key"]) user.save(update_fields=["totp_key"])
url = pyotp.totp.TOTP(code).provisioning_uri( url = pyotp.totp.TOTP(code).provisioning_uri(
username, issuer_name=get_webdomain() username, issuer_name=get_webdomain(settings.CORS_ORIGIN_WHITELIST[0])
) )
subprocess.run(f'qr "{url}"', shell=True) subprocess.run(f'qr "{url}"', shell=True)
self.stdout.write( self.stdout.write(

View File

@@ -1,11 +1,12 @@
import pyotp import pyotp
from django.conf import settings
from rest_framework.serializers import ( from rest_framework.serializers import (
ModelSerializer, ModelSerializer,
ReadOnlyField, ReadOnlyField,
SerializerMethodField, SerializerMethodField,
) )
from tacticalrmm.helpers import get_webdomain from tacticalrmm.util_settings import get_webdomain
from .models import APIKey, Role, User from .models import APIKey, Role, User
@@ -63,7 +64,7 @@ class TOTPSetupSerializer(ModelSerializer):
def get_qr_url(self, obj): def get_qr_url(self, obj):
return pyotp.totp.TOTP(obj.totp_key).provisioning_uri( return pyotp.totp.TOTP(obj.totp_key).provisioning_uri(
obj.username, issuer_name=get_webdomain() obj.username, issuer_name=get_webdomain(settings.CORS_ORIGIN_WHITELIST[0])
) )

View File

@@ -3,7 +3,7 @@ from urllib.parse import urlparse
from django.conf import settings from django.conf import settings
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from tacticalrmm.helpers import get_root_domain, get_webdomain from tacticalrmm.util_settings import get_backend_url, get_root_domain, get_webdomain
from tacticalrmm.utils import get_certs from tacticalrmm.utils import get_certs
@@ -29,8 +29,16 @@ class Command(BaseCommand):
self.stdout.write(settings.NATS_SERVER_VER) self.stdout.write(settings.NATS_SERVER_VER)
case "frontend": case "frontend":
self.stdout.write(settings.CORS_ORIGIN_WHITELIST[0]) self.stdout.write(settings.CORS_ORIGIN_WHITELIST[0])
case "backend_url":
self.stdout.write(
get_backend_url(
settings.ALLOWED_HOSTS[0],
settings.TRMM_PROTO,
settings.TRMM_BACKEND_PORT,
)
)
case "webdomain": case "webdomain":
self.stdout.write(get_webdomain()) self.stdout.write(get_webdomain(settings.CORS_ORIGIN_WHITELIST[0]))
case "djangoadmin": case "djangoadmin":
url = f"https://{settings.ALLOWED_HOSTS[0]}/{settings.ADMIN_URL}" url = f"https://{settings.ALLOWED_HOSTS[0]}/{settings.ADMIN_URL}"
self.stdout.write(url) self.stdout.write(url)

View File

@@ -7,6 +7,7 @@ For details, see: https://license.tacticalrmm.com/ee
import re import re
from allauth.socialaccount.models import SocialAccount, SocialApp from allauth.socialaccount.models import SocialAccount, SocialApp
from django.conf import settings
from django.contrib.auth import logout from django.contrib.auth import logout
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
@@ -16,11 +17,16 @@ from rest_framework import status
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer, ReadOnlyField from rest_framework.serializers import (
ModelSerializer,
ReadOnlyField,
SerializerMethodField,
)
from rest_framework.views import APIView from rest_framework.views import APIView
from accounts.permissions import AccountsPerms from accounts.permissions import AccountsPerms
from logs.models import AuditLog from logs.models import AuditLog
from tacticalrmm.util_settings import get_backend_url
from tacticalrmm.utils import get_core_settings from tacticalrmm.utils import get_core_settings
from .permissions import SSOLoginPerms from .permissions import SSOLoginPerms
@@ -29,6 +35,15 @@ from .permissions import SSOLoginPerms
class SocialAppSerializer(ModelSerializer): class SocialAppSerializer(ModelSerializer):
server_url = ReadOnlyField(source="settings.server_url") server_url = ReadOnlyField(source="settings.server_url")
role = ReadOnlyField(source="settings.role") role = ReadOnlyField(source="settings.role")
callback_url = SerializerMethodField()
javascript_origin_url = SerializerMethodField()
def get_callback_url(self, obj):
backend_url = self.context["backend_url"]
return f"{backend_url}/accounts/oidc/{obj.provider_id}/login/callback/"
def get_javascript_origin_url(self, obj):
return self.context["frontend_url"]
class Meta: class Meta:
model = SocialApp model = SocialApp
@@ -42,6 +57,8 @@ class SocialAppSerializer(ModelSerializer):
"server_url", "server_url",
"settings", "settings",
"role", "role",
"callback_url",
"javascript_origin_url",
] ]
@@ -49,8 +66,16 @@ class GetAddSSOProvider(APIView):
permission_classes = [IsAuthenticated, AccountsPerms] permission_classes = [IsAuthenticated, AccountsPerms]
def get(self, request): def get(self, request):
ctx = {
"backend_url": get_backend_url(
settings.ALLOWED_HOSTS[0],
settings.TRMM_PROTO,
settings.TRMM_BACKEND_PORT,
),
"frontend_url": settings.CORS_ORIGIN_WHITELIST[0],
}
providers = SocialApp.objects.all() providers = SocialApp.objects.all()
return Response(SocialAppSerializer(providers, many=True).data) return Response(SocialAppSerializer(providers, many=True, context=ctx).data)
class InputSerializer(ModelSerializer): class InputSerializer(ModelSerializer):
server_url = ReadOnlyField() server_url = ReadOnlyField()

View File

@@ -6,10 +6,8 @@ import secrets
import string import string
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from urllib.parse import urlparse
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
import tldextract
from cryptography import x509 from cryptography import x509
from django.conf import settings from django.conf import settings
from django.utils import timezone as djangotime from django.utils import timezone as djangotime
@@ -104,16 +102,6 @@ def date_is_in_past(*, datetime_obj: "datetime", agent_tz: str) -> bool:
return djangotime.now() > utc_time return djangotime.now() > utc_time
def get_webdomain() -> str:
return urlparse(settings.CORS_ORIGIN_WHITELIST[0]).netloc
def get_root_domain(subdomain) -> str:
no_fetch_extract = tldextract.TLDExtract(suffix_list_urls=())
extracted = no_fetch_extract(subdomain)
return f"{extracted.domain}.{extracted.suffix}"
def rand_range(min: int, max: int) -> float: def rand_range(min: int, max: int) -> float:
""" """
Input is milliseconds. Input is milliseconds.

View File

@@ -3,7 +3,8 @@ import sys
from contextlib import suppress from contextlib import suppress
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from tacticalrmm.helpers import get_root_domain, get_webdomain
from tacticalrmm.util_settings import get_backend_url, get_root_domain, get_webdomain
BASE_DIR = Path(__file__).resolve().parent.parent BASE_DIR = Path(__file__).resolve().parent.parent
@@ -117,12 +118,12 @@ SWAGGER_ENABLED = False
REDIS_HOST = "127.0.0.1" REDIS_HOST = "127.0.0.1"
TRMM_LOG_LEVEL = "ERROR" TRMM_LOG_LEVEL = "ERROR"
TRMM_LOG_TO = "file" TRMM_LOG_TO = "file"
TRMM_PROTO = "https"
TRMM_BACKEND_PORT = None
if not DOCKER_BUILD: if not DOCKER_BUILD:
ALLOWED_HOSTS = [] ALLOWED_HOSTS = []
CORS_ORIGIN_WHITELIST = [] CORS_ORIGIN_WHITELIST = []
TRMM_PROTO = "https"
TRMM_BACKEND_PORT = None
with suppress(ImportError): with suppress(ImportError):
from ee.sso.sso_settings import * # noqa from ee.sso.sso_settings import * # noqa
@@ -154,16 +155,14 @@ if "GHACTIONS" in os.environ:
if not DOCKER_BUILD: if not DOCKER_BUILD:
TRMM_ROOT_DOMAIN = get_root_domain(ALLOWED_HOSTS[0]) TRMM_ROOT_DOMAIN = get_root_domain(ALLOWED_HOSTS[0])
frontend_domain = get_webdomain().split(":")[0] frontend_domain = get_webdomain(CORS_ORIGIN_WHITELIST[0]).split(":")[0]
ALLOWED_HOSTS.append(frontend_domain) ALLOWED_HOSTS.append(frontend_domain)
if DEBUG: if DEBUG:
ALLOWED_HOSTS.append("*") ALLOWED_HOSTS.append("*")
backend_url = f"{TRMM_PROTO}://{ALLOWED_HOSTS[0]}" backend_url = get_backend_url(ALLOWED_HOSTS[0], TRMM_PROTO, TRMM_BACKEND_PORT)
if TRMM_BACKEND_PORT:
backend_url = f"{backend_url}:{TRMM_BACKEND_PORT}"
SESSION_COOKIE_DOMAIN = TRMM_ROOT_DOMAIN SESSION_COOKIE_DOMAIN = TRMM_ROOT_DOMAIN
CSRF_COOKIE_DOMAIN = TRMM_ROOT_DOMAIN CSRF_COOKIE_DOMAIN = TRMM_ROOT_DOMAIN

View File

@@ -0,0 +1,23 @@
# this file must not import anything from django settings to avoid circular import issues
from urllib.parse import urlparse
import tldextract
def get_webdomain(url: str) -> str:
return urlparse(url).netloc
def get_root_domain(subdomain) -> str:
no_fetch_extract = tldextract.TLDExtract(suffix_list_urls=())
extracted = no_fetch_extract(subdomain)
return f"{extracted.domain}.{extracted.suffix}"
def get_backend_url(subdomain, proto, port) -> str:
url = f"{proto}://{subdomain}"
if port:
url = f"{url}:{port}"
return url