diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index ad1aae16..18b31712 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -105,7 +105,7 @@ services: image: postgres:13-alpine restart: always environment: - POSTGRES_DB: tacticalrmm + POSTGRES_DB: ${POSTGRES_DB} POSTGRES_USER: ${POSTGRES_USER} POSTGRES_PASSWORD: ${POSTGRES_PASS} volumes: @@ -145,6 +145,7 @@ services: TRMM_PASS: ${TRMM_PASS} HTTP_PROTOCOL: ${HTTP_PROTOCOL} APP_PORT: ${APP_PORT} + POSTGRES_DB: ${POSTGRES_DB} depends_on: - postgres-dev - meshcentral-dev diff --git a/.devcontainer/entrypoint.sh b/.devcontainer/entrypoint.sh index 3c59e208..588f66ef 100644 --- a/.devcontainer/entrypoint.sh +++ b/.devcontainer/entrypoint.sh @@ -102,6 +102,7 @@ EOF echo "${localvars}" > ${WORKSPACE_DIR}/api/tacticalrmm/tacticalrmm/local_settings.py # run migrations and init scripts + "${VIRTUAL_ENV}"/bin/python manage.py pre_update_tasks "${VIRTUAL_ENV}"/bin/python manage.py migrate --no-input "${VIRTUAL_ENV}"/bin/python manage.py collectstatic --no-input "${VIRTUAL_ENV}"/bin/python manage.py initial_db_setup diff --git a/.devcontainer/requirements.txt b/.devcontainer/requirements.txt index c851ea11..c1cbafae 100644 --- a/.devcontainer/requirements.txt +++ b/.devcontainer/requirements.txt @@ -1,8 +1,8 @@ # To ensure app dependencies are ported from your virtual environment/host machine into your container, run 'pip freeze > requirements.txt' in the terminal to overwrite this file asgiref==3.5.0 -celery==5.2.3 +celery==5.2.6 channels==3.0.4 -channels_redis==3.3.1 +channels_redis==3.4.0 daphne==3.0.2 Django==4.0.3 django-cors-headers==3.11.0 @@ -11,29 +11,29 @@ django-rest-knox==4.2.0 djangorestframework==3.13.1 future==0.18.2 msgpack==1.0.3 -nats-py==2.0.0 +nats-py==2.1.0 packaging==21.3 -psycopg2-binary==2.9.3 +psycopg-binary==3.0.11 pycryptodome==3.14.1 pyotp==2.6.0 -pytz==2021.3 +pytz==2022.1 qrcode==7.3.1 -redis==4.1.3 +redis==4.2.2 requests==2.27.1 -twilio==7.6.0 -urllib3==1.26.8 +twilio==7.8.1 +urllib3==1.26.9 validators==0.18.2 -websockets==10.1 -drf_spectacular==0.21.2 +websockets==10.2 +drf_spectacular==0.22.0 # dev -black==22.1.0 -Werkzeug==2.0.2 +black==22.3.0 django-extensions==3.1.5 -Pygments==2.11.2 isort==5.10.1 -mypy==0.931 -types-pytz==2021.3.4 -model-bakery==1.4.0 +mypy==0.942 +types-pytz==2021.3.6 +model-bakery==1.5.0 coverage==6.3.2 django-silk==4.3.0 +django-stubs==1.10.1 +djangorestframework-stubs==1.5.0 \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 8acb8e75..1a0fb602 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,8 +8,18 @@ "python.analysis.diagnosticSeverityOverrides": { "reportUnusedImport": "error", "reportDuplicateImport": "error", + "reportGeneralTypeIssues": "none" }, "python.analysis.typeCheckingMode": "basic", + "mypy.runUsingActiveInterpreter": true, + "python.linting.enabled": true, + "python.linting.mypyEnabled": true, + "python.linting.mypyArgs": [ + "--ignore-missing-imports", + "--follow-imports=silent", + "--show-column-numbers", + "--strict" + ], "python.formatting.provider": "black", "editor.formatOnSave": true, "vetur.format.defaultFormatter.js": "prettier", @@ -64,5 +74,13 @@ "usePlaceholders": true, "completeUnimported": true, "staticcheck": true, - } + }, + "mypy.targets": [ + "api/tacticalrmm" + ], + "python.linting.ignorePatterns": [ + "**/site-packages/**/*.py", + ".vscode/*.py", + "**env/**" + ] } \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json deleted file mode 100644 index 10e4059d..00000000 --- a/.vscode/tasks.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - // See https://go.microsoft.com/fwlink/?LinkId=733558 - // for the documentation about the tasks.json format - "version": "2.0.0", - "tasks": [ - { - "label": "docker debug", - "type": "shell", - "command": "docker-compose", - "args": [ - "-p", - "trmm", - "-f", - ".devcontainer/docker-compose.yml", - "-f", - ".devcontainer/docker-compose.debug.yml", - "up", - "-d", - "--build" - ] - } - ] -} \ No newline at end of file diff --git a/api/tacticalrmm/agents/migrations/0047_alter_agent_plat_alter_agent_site.py b/api/tacticalrmm/agents/migrations/0047_alter_agent_plat_alter_agent_site.py new file mode 100644 index 00000000..cbfd0e2c --- /dev/null +++ b/api/tacticalrmm/agents/migrations/0047_alter_agent_plat_alter_agent_site.py @@ -0,0 +1,26 @@ +# Generated by Django 4.0.3 on 2022-04-07 17:28 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('clients', '0020_auto_20211226_0547'), + ('agents', '0046_alter_agenthistory_command'), + ] + + operations = [ + migrations.AlterField( + model_name='agent', + name='plat', + field=models.CharField(default='windows', max_length=255), + ), + migrations.AlterField( + model_name='agent', + name='site', + field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.RESTRICT, related_name='agents', to='clients.site'), + preserve_default=False, + ), + ] diff --git a/api/tacticalrmm/agents/models.py b/api/tacticalrmm/agents/models.py index 7f007ed6..91d57fb8 100644 --- a/api/tacticalrmm/agents/models.py +++ b/api/tacticalrmm/agents/models.py @@ -2,7 +2,7 @@ import asyncio import re from collections import Counter from distutils.version import LooseVersion -from typing import Any, Optional, List, Dict, TYPE_CHECKING +from typing import Any, Optional, List, Dict, Union, Sequence, cast, TYPE_CHECKING from django.core.cache import cache import msgpack @@ -25,6 +25,11 @@ if TYPE_CHECKING: from alerts.models import AlertTemplate, Alert from autotasks.models import AutomatedTask from checks.models import Check + from clients.models import Client + from winupdate.models import WinUpdatePolicy + +# type helpers +Disk = Union[Dict[str, Any], str] class Agent(BaseAuditModel): @@ -32,7 +37,7 @@ class Agent(BaseAuditModel): version = models.CharField(default="0.1.0", max_length=255) operating_system = models.CharField(null=True, blank=True, max_length=255) - plat = models.CharField(max_length=255, null=True, blank=True) + plat = models.CharField(max_length=255, default="windows") goarch = models.CharField(max_length=255, null=True, blank=True) plat_release = models.CharField(max_length=255, null=True, blank=True) hostname = models.CharField(max_length=255) @@ -75,9 +80,7 @@ class Agent(BaseAuditModel): site = models.ForeignKey( "clients.Site", related_name="agents", - null=True, - blank=True, - on_delete=models.SET_NULL, + on_delete=models.RESTRICT, ) policy = models.ForeignKey( "automation.Policy", @@ -87,15 +90,15 @@ class Agent(BaseAuditModel): on_delete=models.SET_NULL, ) - def __str__(self): + def __str__(self) -> str: return self.hostname @property - def client(self): + def client(self) -> "Client": return self.site.client @property - def timezone(self): + def timezone(self) -> str: # return the default timezone unless the timezone is explicity set per agent if self.time_zone: return self.time_zone @@ -103,11 +106,11 @@ class Agent(BaseAuditModel): return get_core_settings().default_time_zone @property - def is_posix(self): + def is_posix(self) -> bool: return self.plat == "linux" or self.plat == "darwin" @property - def arch(self): + def arch(self) -> Optional[str]: if self.is_posix: return self.goarch @@ -119,7 +122,7 @@ class Agent(BaseAuditModel): return None @property - def winagent_dl(self): + def winagent_dl(self) -> Optional[str]: if self.arch == "64": return settings.DL_64 elif self.arch == "32": @@ -127,7 +130,7 @@ class Agent(BaseAuditModel): return None @property - def win_inno_exe(self): + def win_inno_exe(self) -> Optional[str]: if self.arch == "64": return f"winagent-v{settings.LATEST_AGENT_VER}.exe" elif self.arch == "32": @@ -135,7 +138,7 @@ class Agent(BaseAuditModel): return None @property - def status(self): + def status(self) -> str: offline = djangotime.now() - djangotime.timedelta(minutes=self.offline_time) overdue = djangotime.now() - djangotime.timedelta(minutes=self.overdue_time) @@ -150,7 +153,7 @@ class Agent(BaseAuditModel): return "offline" @property - def checks(self): + def checks(self) -> Dict[str, Any]: from checks.models import CheckResult total, passing, failing, warning, info = 0, 0, 0, 0, 0 @@ -185,10 +188,10 @@ class Agent(BaseAuditModel): return ret @property - def cpu_model(self): + def cpu_model(self) -> List[str]: if self.is_posix: try: - return self.wmi_detail["cpus"] + return cast(List[str], self.wmi_detail["cpus"]) except: return ["unknown cpu model"] @@ -202,12 +205,12 @@ class Agent(BaseAuditModel): return ["unknown cpu model"] @property - def graphics(self): + def graphics(self) -> str: if self.is_posix: try: if not self.wmi_detail["gpus"]: return "No graphics cards" - return self.wmi_detail["gpus"] + return cast(str, self.wmi_detail["gpus"]) except: return "Error getting graphics cards" @@ -231,7 +234,7 @@ class Agent(BaseAuditModel): return "Graphics info requires agent v1.4.14" @property - def local_ips(self): + def local_ips(self) -> str: if self.is_posix: try: return ", ".join(self.wmi_detail["local_ips"]) @@ -258,15 +261,15 @@ class Agent(BaseAuditModel): ret.append(ip) if len(ret) == 1: - return ret[0] + return cast(str, ret[0]) else: return ", ".join(ret) if ret else "error getting local ips" @property - def make_model(self): + def make_model(self) -> str: if self.is_posix: try: - return self.wmi_detail["make_model"] + return cast(str, self.wmi_detail["make_model"]) except: return "error getting make/model" @@ -292,17 +295,17 @@ class Agent(BaseAuditModel): try: comp_sys_prod = self.wmi_detail["comp_sys_prod"][0] - return [x["Version"] for x in comp_sys_prod if "Version" in x][0] + return cast(str, [x["Version"] for x in comp_sys_prod if "Version" in x][0]) except: pass return "unknown make/model" @property - def physical_disks(self): + def physical_disks(self) -> Sequence[Disk]: if self.is_posix: try: - return self.wmi_detail["disks"] + return cast(List[Disk], self.wmi_detail["disks"]) except: return ["unknown disk"] @@ -327,7 +330,7 @@ class Agent(BaseAuditModel): except: return ["unknown disk"] - def is_supported_script(self, platforms: list) -> bool: + def is_supported_script(self, platforms: List[str]) -> bool: return self.plat.lower() in platforms if platforms else True def get_checks_with_policies( @@ -357,21 +360,25 @@ class Agent(BaseAuditModel): else: return self.add_task_results(tasks) - def get_agent_policies(self) -> "Dict[str, Policy]": + def get_agent_policies(self) -> "Dict[str, Optional[Policy]]": site_policy = getattr(self.site, f"{self.monitoring_type}_policy", None) client_policy = getattr(self.client, f"{self.monitoring_type}_policy", None) default_policy = getattr( get_core_settings(), f"{self.monitoring_type}_policy", None ) - # prefetch excluded objects on polices + # prefetch excluded objects on polices only if policy is not None models.prefetch_related_objects( - [self.policy, site_policy, client_policy, default_policy], + [ + policy + for policy in [self.policy, site_policy, client_policy, default_policy] + if policy + ], "excluded_agents", "excluded_sites", "excluded_clients", "policychecks__script", - "policytasks", + "autotasks", ) return { @@ -471,57 +478,51 @@ class Agent(BaseAuditModel): return "ok" # auto approves updates - def approve_updates(self): + def approve_updates(self) -> None: patch_policy = self.get_patch_policy() - updates = list() + severity_list = list() if patch_policy.critical == "approve": - updates |= self.winupdates.filter( # type: ignore - severity="Critical", installed=False - ).exclude(action="approve") + severity_list.append("Critical") if patch_policy.important == "approve": - updates |= self.winupdates.filter( # type: ignore - severity="Important", installed=False - ).exclude(action="approve") + severity_list.append("Important") if patch_policy.moderate == "approve": - updates |= self.winupdates.filter( # type: ignore - severity="Moderate", installed=False - ).exclude(action="approve") + severity_list.append("Moderate") if patch_policy.low == "approve": - updates |= self.winupdates.filter(severity="Low", installed=False).exclude( # type: ignore - action="approve" - ) + severity_list.append("Low") if patch_policy.other == "approve": - updates |= self.winupdates.filter(severity="", installed=False).exclude( # type: ignore - action="approve" - ) + severity_list.append("") - for update in updates: - update.action = "approve" - update.save(update_fields=["action"]) + self.winupdates.filter(severity__in=severity_list, installed=False).exclude( + action="approve" + ).update(action="approve") # returns agent policy merged with a client or site specific policy - def get_patch_policy(self): + def get_patch_policy(self) -> "WinUpdatePolicy": # check if site has a patch policy and if so use it patch_policy = None - agent_policy = self.winupdatepolicy.first() # type: ignore + + agent_policy = self.winupdatepolicy.first() + + if not agent_policy: + raise WinUpdatePolicy.DoesNotExist policies = self.get_agent_policies() - processed_policies = list() + processed_policies: "List[int]" = list() for _, policy in policies.items(): if ( policy and policy.active and policy.pk not in processed_policies - and policy.winupdatepolicy.exists() # type: ignore + and policy.winupdatepolicy.exists() ): - patch_policy = policy.winupdatepolicy.first() # type: ignore + patch_policy = policy.winupdatepolicy.first() # if policy still doesn't exist return the agent patch policy if not patch_policy: @@ -567,13 +568,13 @@ class Agent(BaseAuditModel): # sets alert template assigned in the following order: policy, site, client, global # sets None if nothing is found - def set_alert_template(self): + def set_alert_template(self) -> "Optional[AlertTemplate]": core = get_core_settings() policies = self.get_agent_policies() # loop through all policies applied to agent and return an alert_template if found - processed_policies = list() + processed_policies: List[int] = list() for key, policy in policies.items(): # default alert_template will override a default policy with alert template applied if ( @@ -663,7 +664,7 @@ class Agent(BaseAuditModel): def get_checks_from_policies(self) -> "List[Check]": from automation.models import Policy - cached_checks = cache.get(f"{self.site.name}_checks") + cached_checks = cache.get(f"site_{self.site.id}_checks") if cached_checks and isinstance(cached_checks, list): return cached_checks @@ -673,27 +674,29 @@ class Agent(BaseAuditModel): # get agent checks based on policies checks = Policy.get_policy_checks(self) - cache.set(f"{self.site.name}_checks", checks, 300) + cache.set(f"site_{self.site.id}_checks", checks, 300) return checks def get_tasks_from_policies(self) -> "List[AutomatedTask]": from automation.models import Policy - cached_tasks = cache.get(f"{self.site.name}_tasks") + cached_tasks = cache.get(f"site_{self.site.id}_tasks") if cached_tasks and isinstance(cached_tasks, list): return cached_tasks else: # get agent tasks based on policies tasks = Policy.get_policy_tasks(self) - cache.set(f"{self.site.name}_tasks", tasks, 300) + cache.set(f"site_{self.site.id}_tasks", tasks, 300) return tasks def _do_nats_debug(self, agent, message): DebugLog.error(agent=agent, log_type="agent_issues", message=message) - async def nats_cmd(self, data: dict, timeout: int = 30, wait: bool = True): + async def nats_cmd( + self, data: Dict[Any, Any], timeout: int = 30, wait: bool = True + ) -> Any: options = { "servers": f"tls://{settings.ALLOWED_HOSTS[0]}:4222", "user": "tacticalrmm", @@ -731,13 +734,13 @@ class Agent(BaseAuditModel): await nc.close() @staticmethod - def serialize(agent): + def serialize(class_name: "Agent") -> Dict[str, Any]: # serializes the agent and returns json from .serializers import AgentAuditSerializer - return AgentAuditSerializer(agent).data + return AgentAuditSerializer(class_name).data - def delete_superseded_updates(self): + def delete_superseded_updates(self) -> None: try: pks = [] # list of pks to delete kbs = list(self.winupdates.values_list("kb", flat=True)) # type: ignore @@ -745,12 +748,12 @@ class Agent(BaseAuditModel): dupes = [k for k, v in d.items() if v > 1] for dupe in dupes: - titles = self.winupdates.filter(kb=dupe).values_list("title", flat=True) # type: ignore + titles = self.winupdates.filter(kb=dupe).values_list("title", flat=True) # extract the version from the title and sort from oldest to newest # skip if no version info is available therefore nothing to parse try: vers = [ - re.search(r"\(Version(.*?)\)", i).group(1).strip() # type: ignore + re.search(r"\(Version(.*?)\)", i).group(1).strip() for i in titles ] sorted_vers = sorted(vers, key=LooseVersion) @@ -758,16 +761,18 @@ class Agent(BaseAuditModel): continue # append all but the latest version to our list of pks to delete for ver in sorted_vers[:-1]: - q = self.winupdates.filter(kb=dupe).filter(title__contains=ver) # type: ignore + q = self.winupdates.filter(kb=dupe).filter(title__contains=ver) pks.append(q.first().pk) pks = list(set(pks)) - self.winupdates.filter(pk__in=pks).delete() # type: ignore + self.winupdates.filter(pk__in=pks).delete() except: pass - def should_create_alert(self, alert_template=None): - return ( + def should_create_alert( + self, alert_template: "Optional[AlertTemplate]" = None + ) -> bool: + return bool( self.overdue_dashboard_alert or self.overdue_email_alert or self.overdue_text_alert @@ -781,7 +786,7 @@ class Agent(BaseAuditModel): ) ) - def send_outage_email(self): + def send_outage_email(self) -> None: CORE = get_core_settings() CORE.send_mail( @@ -795,7 +800,7 @@ class Agent(BaseAuditModel): alert_template=self.alert_template, ) - def send_recovery_email(self): + def send_recovery_email(self) -> None: CORE = get_core_settings() CORE.send_mail( @@ -809,7 +814,7 @@ class Agent(BaseAuditModel): alert_template=self.alert_template, ) - def send_outage_sms(self): + def send_outage_sms(self) -> None: CORE = get_core_settings() CORE.send_sms( @@ -817,7 +822,7 @@ class Agent(BaseAuditModel): alert_template=self.alert_template, ) - def send_recovery_sms(self): + def send_recovery_sms(self) -> None: CORE = get_core_settings() CORE.send_sms( @@ -844,7 +849,7 @@ class Note(models.Model): note = models.TextField(null=True, blank=True) entry_time = models.DateTimeField(auto_now_add=True) - def __str__(self): + def __str__(self) -> str: return self.agent.hostname @@ -872,26 +877,26 @@ class AgentCustomField(models.Model): default=list, ) - def __str__(self): + def __str__(self) -> str: return self.field.name @property - def value(self): + def value(self) -> Union[List[Any], bool, str]: if self.field.type == "multiple": - return self.multiple_value + return cast(List[str], self.multiple_value) elif self.field.type == "checkbox": return self.bool_value else: - return self.string_value + return cast(str, self.string_value) - def save_to_field(self, value): + def save_to_field(self, value: Union[List[Any], bool, str]) -> None: if self.field.type in [ "text", "number", "single", "datetime", ]: - self.string_value = value + self.string_value = cast(str, value) self.save() elif self.field.type == "multiple": self.multiple_value = value.split(",") @@ -937,5 +942,5 @@ class AgentHistory(models.Model): ) script_results = models.JSONField(null=True, blank=True) - def __str__(self): + def __str__(self) -> str: return f"{self.agent.hostname} - {self.type}" diff --git a/api/tacticalrmm/agents/tests.py b/api/tacticalrmm/agents/tests.py index 18c17b0a..d52d31ee 100644 --- a/api/tacticalrmm/agents/tests.py +++ b/api/tacticalrmm/agents/tests.py @@ -2,8 +2,8 @@ import json import os from itertools import cycle from unittest.mock import patch - import pytz +from typing import TYPE_CHECKING from django.conf import settings from django.test import modify_settings from django.utils import timezone as djangotime @@ -24,6 +24,9 @@ from .serializers import ( ) from .tasks import auto_self_agent_update_task +if TYPE_CHECKING: + from clients.models import Client, Site + base_url = "/agents" @@ -33,19 +36,19 @@ base_url = "/agents" } ) class TestAgentsList(TacticalTestCase): - def setUp(self): + def setUp(self) -> None: self.authenticate() self.setup_coresettings() - def test_get_agents(self): + def test_get_agents(self) -> None: url = f"{base_url}/" # 36 total agents - company1 = baker.make("clients.Client") - company2 = baker.make("clients.Client") - site1 = baker.make("clients.Site", client=company1) - site2 = baker.make("clients.Site", client=company1) - site3 = baker.make("clients.Site", client=company2) + company1: "Client" = baker.make("clients.Client") + company2: "Client" = baker.make("clients.Client") + site1: "Site" = baker.make("clients.Site", client=company1) + site2: "Site" = baker.make("clients.Site", client=company1) + site3: "Site" = baker.make("clients.Site", client=company2) baker.make_recipe( "agents.online_agent", site=site1, monitoring_type="server", _quantity=15 @@ -129,7 +132,7 @@ class TestAgentViews(TacticalTestCase): url = f"{base_url}/{self.agent.agent_id}/" data = { - "site": site.id, + "site": site.pk, "monitoring_type": "workstation", "description": "asjdk234andasd", "offline_time": 4, @@ -160,7 +163,7 @@ class TestAgentViews(TacticalTestCase): agent = Agent.objects.get(pk=self.agent.pk) data = AgentSerializer(agent).data - self.assertEqual(data["site"], site.id) + self.assertEqual(data["site"], site.pk) policy = WinUpdatePolicy.objects.get(agent=self.agent) data = WinUpdatePolicySerializer(policy).data @@ -169,9 +172,9 @@ class TestAgentViews(TacticalTestCase): # test adding custom fields field = baker.make("core.CustomField", model="agent", type="number") data = { - "site": site.id, + "site": site.pk, "description": "asjdk234andasd", - "custom_fields": [{"field": field.id, "string_value": "123"}], + "custom_fields": [{"field": field.pk, "string_value": "123"}], } r = self.client.put(url, data, format="json") @@ -182,9 +185,9 @@ class TestAgentViews(TacticalTestCase): # test edit custom field data = { - "site": site.id, + "site": site.pk, "description": "asjdk234andasd", - "custom_fields": [{"field": field.id, "string_value": "456"}], + "custom_fields": [{"field": field.pk, "string_value": "456"}], } r = self.client.put(url, data, format="json") @@ -488,8 +491,8 @@ class TestAgentViews(TacticalTestCase): site = baker.make("clients.Site") data = { - "client": site.client.id, - "site": site.id, + "client": site.client.pk, + "site": site.pk, "arch": "64", "expires": 23, "installMethod": "manual", @@ -660,7 +663,7 @@ class TestAgentViews(TacticalTestCase): "output": "collector", "args": ["hello", "world"], "timeout": 22, - "custom_field": custom_field.id, + "custom_field": custom_field.pk, "save_all_output": True, } @@ -691,7 +694,7 @@ class TestAgentViews(TacticalTestCase): "output": "collector", "args": ["hello", "world"], "timeout": 22, - "custom_field": custom_field.id, + "custom_field": custom_field.pk, "save_all_output": False, } @@ -724,7 +727,7 @@ class TestAgentViews(TacticalTestCase): "output": "collector", "args": ["hello", "world"], "timeout": 22, - "custom_field": custom_field.id, + "custom_field": custom_field.pk, "save_all_output": False, } @@ -814,7 +817,7 @@ class TestAgentViews(TacticalTestCase): # setup agent = baker.make_recipe("agents.agent") note = baker.make("agents.Note", agent=agent) - url = f"{base_url}/notes/{note.id}/" + url = f"{base_url}/notes/{note.pk}/" # test not found r = self.client.get(f"{base_url}/notes/500/") @@ -829,7 +832,7 @@ class TestAgentViews(TacticalTestCase): # setup agent = baker.make_recipe("agents.agent") note = baker.make("agents.Note", agent=agent) - url = f"{base_url}/notes/{note.id}/" + url = f"{base_url}/notes/{note.pk}/" # test not found r = self.client.put(f"{base_url}/notes/500/") @@ -839,7 +842,7 @@ class TestAgentViews(TacticalTestCase): r = self.client.put(url, data) self.assertEqual(r.status_code, 200) - new_note = Note.objects.get(pk=note.id) + new_note = Note.objects.get(pk=note.pk) self.assertEqual(new_note.note, data["note"]) self.check_not_authenticated("put", url) @@ -848,7 +851,7 @@ class TestAgentViews(TacticalTestCase): # setup agent = baker.make_recipe("agents.agent") note = baker.make("agents.Note", agent=agent) - url = f"{base_url}/notes/{note.id}/" + url = f"{base_url}/notes/{note.pk}/" # test not found r = self.client.delete(f"{base_url}/notes/500/") @@ -857,7 +860,7 @@ class TestAgentViews(TacticalTestCase): r = self.client.delete(url) self.assertEqual(r.status_code, 200) - self.assertFalse(Note.objects.filter(pk=note.id).exists()) + self.assertFalse(Note.objects.filter(pk=note.pk).exists()) self.check_not_authenticated("delete", url) @@ -1085,9 +1088,9 @@ class TestAgentPermissions(TacticalTestCase): site = baker.make("clients.Site") client = baker.make("clients.Client") - site_data = {"id": site.id, "type": "Site", "action": True} + site_data = {"id": site.pk, "type": "Site", "action": True} - client_data = {"id": client.id, "type": "Client", "action": True} + client_data = {"id": client.pk, "type": "Client", "action": True} url = f"{base_url}/maintenance/bulk/" @@ -1229,8 +1232,8 @@ class TestAgentPermissions(TacticalTestCase): user.role.can_view_clients.set([client]) data = { - "client": client.id, - "site": client_site.id, + "client": client.pk, + "site": client_site.pk, "version": settings.LATEST_AGENT_VER, "arch": "64", } @@ -1238,8 +1241,8 @@ class TestAgentPermissions(TacticalTestCase): self.check_authorized("post", url, data) data = { - "client": site.client.id, - "site": site.id, + "client": site.client.pk, + "site": site.pk, "version": settings.LATEST_AGENT_VER, "arch": "64", } @@ -1250,8 +1253,8 @@ class TestAgentPermissions(TacticalTestCase): user.role.can_view_clients.clear() user.role.can_view_sites.set([site]) data = { - "client": site.client.id, - "site": site.id, + "client": site.client.pk, + "site": site.pk, "version": settings.LATEST_AGENT_VER, "arch": "64", } @@ -1259,8 +1262,8 @@ class TestAgentPermissions(TacticalTestCase): self.check_authorized("post", url, data) data = { - "client": client.id, - "site": client_site.id, + "client": client.pk, + "site": client_site.pk, "version": settings.LATEST_AGENT_VER, "arch": "64", } @@ -1281,17 +1284,17 @@ class TestAgentPermissions(TacticalTestCase): {"url": f"{base_url}/notes/", "method": "get", "role": "can_list_notes"}, {"url": f"{base_url}/notes/", "method": "post", "role": "can_manage_notes"}, { - "url": f"{base_url}/notes/{notes[0].id}/", + "url": f"{base_url}/notes/{notes[0].pk}/", "method": "get", "role": "can_list_notes", }, { - "url": f"{base_url}/notes/{notes[0].id}/", + "url": f"{base_url}/notes/{notes[0].pk}/", "method": "put", "role": "can_manage_notes", }, { - "url": f"{base_url}/notes/{notes[0].id}/", + "url": f"{base_url}/notes/{notes[0].pk}/", "method": "delete", "role": "can_manage_notes", }, @@ -1335,19 +1338,19 @@ class TestAgentPermissions(TacticalTestCase): # test post get, put, and delete and make sure unauthorized is returned with unauthorized agent and works for authorized self.check_authorized("post", f"{base_url}/notes/", authorized_data) self.check_not_authorized("post", f"{base_url}/notes/", unauthorized_data) - self.check_authorized("get", f"{base_url}/notes/{notes[2].id}/") + self.check_authorized("get", f"{base_url}/notes/{notes[2].pk}/") self.check_not_authorized( - "get", f"{base_url}/notes/{unauthorized_notes[2].id}/" + "get", f"{base_url}/notes/{unauthorized_notes[2].pk}/" ) self.check_authorized( - "put", f"{base_url}/notes/{notes[3].id}/", authorized_data + "put", f"{base_url}/notes/{notes[3].pk}/", authorized_data ) self.check_not_authorized( - "put", f"{base_url}/notes/{unauthorized_notes[3].id}/", unauthorized_data + "put", f"{base_url}/notes/{unauthorized_notes[3].pk}/", unauthorized_data ) - self.check_authorized("delete", f"{base_url}/notes/{notes[3].id}/") + self.check_authorized("delete", f"{base_url}/notes/{notes[3].pk}/") self.check_not_authorized( - "delete", f"{base_url}/notes/{unauthorized_notes[3].id}/" + "delete", f"{base_url}/notes/{unauthorized_notes[3].pk}/" ) def test_get_agent_history_permissions(self): diff --git a/api/tacticalrmm/alerts/migrations/0011_alter_alert_agent.py b/api/tacticalrmm/alerts/migrations/0011_alter_alert_agent.py new file mode 100644 index 00000000..36873600 --- /dev/null +++ b/api/tacticalrmm/alerts/migrations/0011_alter_alert_agent.py @@ -0,0 +1,21 @@ +# Generated by Django 4.0.3 on 2022-04-07 17:28 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('agents', '0047_alter_agent_plat_alter_agent_site'), + ('alerts', '0010_auto_20210917_1954'), + ] + + operations = [ + migrations.AlterField( + model_name='alert', + name='agent', + field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.CASCADE, related_name='agent', to='agents.agent'), + preserve_default=False, + ), + ] diff --git a/api/tacticalrmm/alerts/models.py b/api/tacticalrmm/alerts/models.py index f19630be..de1385d9 100644 --- a/api/tacticalrmm/alerts/models.py +++ b/api/tacticalrmm/alerts/models.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Union, Optional, List, cast +from typing import TYPE_CHECKING, Union, Optional, Dict, Any, List, cast from django.contrib.postgres.fields import ArrayField from django.db import models @@ -15,6 +15,7 @@ if TYPE_CHECKING: from agents.models import Agent from autotasks.models import AutomatedTask, TaskResult from checks.models import Check, CheckResult + from clients.models import Client, Site SEVERITY_CHOICES = [ @@ -38,8 +39,6 @@ class Alert(models.Model): "agents.Agent", related_name="agent", on_delete=models.CASCADE, - null=True, - blank=True, ) assigned_check = models.ForeignKey( "checks.Check", @@ -87,18 +86,18 @@ class Alert(models.Model): return self.message @property - def assigned_agent(self): + def assigned_agent(self) -> "Agent": return self.agent @property - def site(self): + def site(self) -> "Site": return self.agent.site @property - def client(self): + def client(self) -> "Client": return self.agent.client - def resolve(self): + def resolve(self) -> None: self.resolved = True self.resolved_on = djangotime.now() self.snoozed = False @@ -113,24 +112,30 @@ class Alert(models.Model): if skip_create: return None - return cls.objects.create( - agent=agent, - alert_type="availability", - severity="error", - message=f"{agent.hostname} in {agent.client.name}\\{agent.site.name} is overdue.", - hidden=True, + return cast( + Alert, + cls.objects.create( + agent=agent, + alert_type="availability", + severity="error", + message=f"{agent.hostname} in {agent.client.name}\\{agent.site.name} is overdue.", + hidden=True, + ), ) else: try: - return cls.objects.get( - agent=agent, alert_type="availability", resolved=False + return cast( + Alert, + cls.objects.get( + agent=agent, alert_type="availability", resolved=False + ), ) except cls.MultipleObjectsReturned: alerts = cls.objects.filter( agent=agent, alert_type="availability", resolved=False ) - last_alert = cast(cls, alerts.last()) + last_alert = cast(Alert, alerts.last()) # cycle through other alerts and resolve for alert in alerts: @@ -159,22 +164,29 @@ class Alert(models.Model): if skip_create: return None - return cls.objects.create( - assigned_check=check, - agent=agent if check.policy else check.agent, - alert_type="check", - severity=check.alert_severity - if check.check_type not in ["memory", "cpuload", "diskspace", "script"] - else alert_severity, - message=f"{agent.hostname if agent else check.agent.hostname} has a {check.check_type} check: {check.readable_desc} that failed.", - hidden=True, + return cast( + Alert, + cls.objects.create( + assigned_check=check, + agent=agent if check.policy else check.agent, + alert_type="check", + severity=check.alert_severity + if check.check_type + not in ["memory", "cpuload", "diskspace", "script"] + else alert_severity, + message=f"{agent.hostname if agent else check.agent.hostname} has a {check.check_type} check: {check.readable_desc} that failed.", + hidden=True, + ), ) else: try: - return cls.objects.get( - assigned_check=check, - agent=agent if check.policy else check.agent, - resolved=False, + return cast( + Alert, + cls.objects.get( + assigned_check=check, + agent=agent if check.policy else check.agent, + resolved=False, + ), ) except cls.MultipleObjectsReturned: alerts = cls.objects.filter( @@ -182,7 +194,7 @@ class Alert(models.Model): agent=agent if check.policy else check.agent, resolved=False, ) - last_alert = cast(cls, alerts.last()) + last_alert = cast(Alert, alerts.last()) # cycle through other alerts and resolve for alert in alerts: @@ -702,10 +714,10 @@ class AlertTemplate(BaseAuditModel): "agents.Agent", related_name="alert_exclusions", blank=True ) - def __str__(self): + def __str__(self) -> str: return self.name - def is_agent_excluded(self, agent): + def is_agent_excluded(self, agent: "Agent") -> bool: return ( agent in self.excluded_agents.all() or agent.site in self.excluded_sites.all() @@ -717,7 +729,7 @@ class AlertTemplate(BaseAuditModel): ) @staticmethod - def serialize(alert_template): + def serialize(alert_template: AlertTemplate) -> Dict[str, Any]: # serializes the agent and returns json from .serializers import AlertTemplateAuditSerializer diff --git a/api/tacticalrmm/automation/models.py b/api/tacticalrmm/automation/models.py index b985404a..5778f7cb 100644 --- a/api/tacticalrmm/automation/models.py +++ b/api/tacticalrmm/automation/models.py @@ -3,7 +3,11 @@ from clients.models import Client, Site from django.db import models from logs.models import BaseAuditModel -from typing import Optional +from typing import Optional, Dict, Any, List, TYPE_CHECKING + +if TYPE_CHECKING: + from checks.models import Check + from autotasks.models import AutomatedTask class Policy(BaseAuditModel): @@ -28,11 +32,13 @@ class Policy(BaseAuditModel): "agents.Agent", related_name="policy_exclusions", blank=True ) - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: from alerts.tasks import cache_agents_alert_template # get old policy if exists - old_policy = type(self).objects.get(pk=self.pk) if self.pk else None + old_policy: Optional[Policy] = ( + type(self).objects.get(pk=self.pk) if self.pk else None + ) super(Policy, self).save(old_model=old_policy, *args, **kwargs) # check if alert template was changes and cache on agents @@ -40,23 +46,18 @@ class Policy(BaseAuditModel): if old_policy.alert_template != self.alert_template: cache_agents_alert_template.delay() - def __str__(self): + def __str__(self) -> str: return self.name @property - def is_default_server_policy(self): - return self.default_server_policy.exists() # type: ignore + def is_default_server_policy(self) -> bool: + return self.default_server_policy.exists() @property - def is_default_workstation_policy(self): - return self.default_workstation_policy.exists() # type: ignore - - def is_agent_excluded(self, agent): - # will prefetch the many to many relations in a single query versus 3. results are cached on the object - models.prefetch_related_objects( - [self], "excluded_agents", "excluded_sites", "excluded_clients" - ) + def is_default_workstation_policy(self) -> bool: + return self.default_workstation_policy.exists() + def is_agent_excluded(self, agent: "Agent") -> bool: return ( agent in self.excluded_agents.all() or agent.site in self.excluded_sites.all() @@ -180,14 +181,14 @@ class Policy(BaseAuditModel): ) @staticmethod - def serialize(policy): + def serialize(policy: "Policy") -> Dict[str, Any]: # serializes the policy and returns json from .serializers import PolicyAuditSerializer return PolicyAuditSerializer(policy).data @staticmethod - def get_policy_tasks(agent): + def get_policy_tasks(agent: "Agent") -> "List[AutomatedTask]": # List of all tasks to be applied tasks = list() @@ -206,7 +207,7 @@ class Policy(BaseAuditModel): return tasks @staticmethod - def get_policy_checks(agent): + def get_policy_checks(agent: "Agent") -> "List[Check]": # Get checks added to agent directly agent_checks = list(agent.agentchecks.all()) @@ -231,22 +232,22 @@ class Policy(BaseAuditModel): policy_checks.append(check) # Sorted Checks already added - added_diskspace_checks = list() - added_ping_checks = list() - added_winsvc_checks = list() - added_script_checks = list() - added_eventlog_checks = list() - added_cpuload_checks = list() - added_memory_checks = list() + added_diskspace_checks: List[str] = list() + added_ping_checks: List[str] = list() + added_winsvc_checks: List[str] = list() + added_script_checks: List[int] = list() + added_eventlog_checks: List[List[str]] = list() + added_cpuload_checks: List[int] = list() + added_memory_checks: List[int] = list() # Lists all agent and policy checks that will be returned - diskspace_checks = list() - ping_checks = list() - winsvc_checks = list() - script_checks = list() - eventlog_checks = list() - cpuload_checks = list() - memory_checks = list() + diskspace_checks: "List[Check]" = list() + ping_checks: "List[Check]" = list() + winsvc_checks: "List[Check]" = list() + script_checks: "List[Check]" = list() + eventlog_checks: "List[Check]" = list() + cpuload_checks: "List[Check]" = list() + memory_checks: "List[Check]" = list() overridden_checks = list() @@ -275,7 +276,7 @@ class Policy(BaseAuditModel): elif check.check_type == "cpuload" and agent.plat == "windows": # Check if cpuload list is empty if not added_cpuload_checks: - added_cpuload_checks.append(check) + added_cpuload_checks.append(check.pk) # Dont create the check if it is an agent check if not check.agent: cpuload_checks.append(check) @@ -285,7 +286,7 @@ class Policy(BaseAuditModel): elif check.check_type == "memory" and agent.plat == "windows": # Check if memory check list is empty if not added_memory_checks: - added_memory_checks.append(check) + added_memory_checks.append(check.pk) # Dont create the check if it is an agent check if not check.agent: memory_checks.append(check) diff --git a/api/tacticalrmm/core/management/commands/post_update_tasks.py b/api/tacticalrmm/core/management/commands/post_update_tasks.py index f984a614..e3655fde 100644 --- a/api/tacticalrmm/core/management/commands/post_update_tasks.py +++ b/api/tacticalrmm/core/management/commands/post_update_tasks.py @@ -13,7 +13,7 @@ from tacticalrmm.constants import AGENT_DEFER class Command(BaseCommand): help = "Collection of tasks to run after updating the rmm, after migrations" - def handle(self, *args, **kwargs): + def handle(self, *args, **kwargs) -> None: self.stdout.write("Running post update tasks") # load community scripts into the db diff --git a/api/tacticalrmm/core/management/commands/pre_update_tasks.py b/api/tacticalrmm/core/management/commands/pre_update_tasks.py index 89f61eab..113a27b1 100644 --- a/api/tacticalrmm/core/management/commands/pre_update_tasks.py +++ b/api/tacticalrmm/core/management/commands/pre_update_tasks.py @@ -1,4 +1,5 @@ from django.core.management.base import BaseCommand +from alerts.models import Alert class Command(BaseCommand): @@ -6,4 +7,4 @@ class Command(BaseCommand): def handle(self, *args, **kwargs): # adding this now for future updates - pass + Alert.objects.filter(agent=None).delete() diff --git a/api/tacticalrmm/core/models.py b/api/tacticalrmm/core/models.py index 80c9bfa8..70a048b0 100644 --- a/api/tacticalrmm/core/models.py +++ b/api/tacticalrmm/core/models.py @@ -1,6 +1,7 @@ import smtplib from email.message import EmailMessage +from typing import Optional, Union, List, cast, TYPE_CHECKING import pytz import requests from django.conf import settings @@ -11,19 +12,20 @@ from logs.models import LOG_LEVEL_CHOICES, BaseAuditModel, DebugLog from twilio.base.exceptions import TwilioRestException from twilio.rest import Client as TwClient +if TYPE_CHECKING: + from alerts.models import AlertTemplate + TZ_CHOICES = [(_, _) for _ in pytz.all_timezones] class CoreSettings(BaseAuditModel): email_alert_recipients = ArrayField( models.EmailField(null=True, blank=True), - null=True, blank=True, default=list, ) sms_alert_recipients = ArrayField( models.CharField(max_length=255, null=True, blank=True), - null=True, blank=True, default=list, ) @@ -31,18 +33,16 @@ class CoreSettings(BaseAuditModel): twilio_account_sid = models.CharField(max_length=255, null=True, blank=True) twilio_auth_token = models.CharField(max_length=255, null=True, blank=True) smtp_from_email = models.CharField( - max_length=255, null=True, blank=True, default="from@example.com" - ) - smtp_host = models.CharField( - max_length=255, null=True, blank=True, default="smtp.gmail.com" + max_length=255, blank=True, default="from@example.com" ) + smtp_host = models.CharField(max_length=255, blank=True, default="smtp.gmail.com") smtp_host_user = models.CharField( - max_length=255, null=True, blank=True, default="admin@example.com" + max_length=255, blank=True, default="admin@example.com" ) smtp_host_password = models.CharField( - max_length=255, null=True, blank=True, default="changeme" + max_length=255, blank=True, default="changeme" ) - smtp_port = models.PositiveIntegerField(default=587, null=True, blank=True) + smtp_port = models.PositiveIntegerField(default=587, blank=True) smtp_requires_auth = models.BooleanField(default=True) default_time_zone = models.CharField( max_length=255, choices=TZ_CHOICES, default="America/Los_Angeles" @@ -89,7 +89,7 @@ class CoreSettings(BaseAuditModel): max_length=30, blank=True, default="MMM-DD-YYYY - HH:mm" ) - def save(self, *args, **kwargs): + def save(self, *args, **kwargs) -> None: from alerts.tasks import cache_agents_alert_template if not self.pk and CoreSettings.objects.exists(): @@ -110,11 +110,11 @@ class CoreSettings(BaseAuditModel): if old_settings and old_settings.alert_template != self.alert_template: cache_agents_alert_template.delay() - def __str__(self): + def __str__(self) -> str: return "Global Site Settings" @property - def sms_is_configured(self): + def sms_is_configured(self) -> bool: return all( [ self.twilio_auth_token, @@ -124,7 +124,7 @@ class CoreSettings(BaseAuditModel): ) @property - def email_is_configured(self): + def email_is_configured(self) -> bool: # smtp with username/password authentication if ( self.smtp_requires_auth @@ -146,12 +146,18 @@ class CoreSettings(BaseAuditModel): return False - def send_mail(self, subject, body, alert_template=None, test=False): + def send_mail( + self, + subject: str, + body: str, + alert_template: "Optional[AlertTemplate]" = None, + test: bool = False, + ) -> Union[bool, str]: if test and not self.email_is_configured: return "There needs to be at least one email recipient configured" # return since email must be configured to continue elif not self.email_is_configured: - return False + return "SMTP messaging not configured." # override email from if alert_template is passed and is set if alert_template and alert_template.email_from: @@ -162,10 +168,9 @@ class CoreSettings(BaseAuditModel): # override email recipients if alert_template is passed and is set if alert_template and alert_template.email_recipients: email_recipients = ", ".join(alert_template.email_recipients) + elif self.email_alert_recipients: + email_recipients = ", ".join(cast(List[str], self.email_alert_recipients)) else: - email_recipients = ", ".join(self.email_alert_recipients) - - if not email_recipients: return "There needs to be at least one email recipient configured" try: @@ -179,7 +184,10 @@ class CoreSettings(BaseAuditModel): if self.smtp_requires_auth: server.ehlo() server.starttls() - server.login(self.smtp_host_user, self.smtp_host_password) + server.login( + self.smtp_host_user, + self.smtp_host_password, + ) server.send_message(msg) server.quit() else: @@ -191,20 +199,24 @@ class CoreSettings(BaseAuditModel): DebugLog.error(message=f"Sending email failed with error: {e}") if test: return str(e) - else: + finally: return True - def send_sms(self, body, alert_template=None, test=False): + def send_sms( + self, + body: str, + alert_template: "Optional[AlertTemplate]" = None, + test: bool = False, + ) -> Union[str, bool]: if not self.sms_is_configured: return "Sms alerting is not setup correctly." # override email recipients if alert_template is passed and is set if alert_template and alert_template.text_recipients: text_recipients = alert_template.text_recipients + elif self.sms_alert_recipients: + text_recipients = cast(List[str], self.sms_alert_recipients) else: - text_recipients = self.sms_alert_recipients - - if not text_recipients: return "No sms recipients found" tw_client = TwClient(self.twilio_account_sid, self.twilio_auth_token) @@ -249,7 +261,7 @@ class CustomField(BaseAuditModel): blank=True, default=list, ) - name = models.TextField(null=True, blank=True) + name = models.CharField(max_length=30, blank=True) required = models.BooleanField(blank=True, default=False) default_value_string = models.TextField(null=True, blank=True) default_value_bool = models.BooleanField(default=False) @@ -264,7 +276,7 @@ class CustomField(BaseAuditModel): class Meta: unique_together = (("model", "name"),) - def __str__(self): + def __str__(self) -> str: return self.name @staticmethod diff --git a/api/tacticalrmm/logs/models.py b/api/tacticalrmm/logs/models.py index f9439bd6..3b72d7e3 100644 --- a/api/tacticalrmm/logs/models.py +++ b/api/tacticalrmm/logs/models.py @@ -2,12 +2,17 @@ from abc import abstractmethod from django.db import models from core.utils import get_core_settings - +from typing import Optional, Dict, Any, Union, cast, Tuple, TYPE_CHECKING from tacticalrmm.middleware import get_debug_info, get_username from tacticalrmm.models import PermissionQuerySet +if TYPE_CHECKING: + from agents.models import Agent + from clients.models import Client, Site + from core.models import URLAction -def get_debug_level(): + +def get_debug_level() -> str: return get_core_settings().agent_debug_level @@ -75,10 +80,10 @@ class AuditLog(models.Model): message = models.CharField(max_length=255, null=True, blank=True) debug_info = models.JSONField(null=True, blank=True) - def __str__(self): + def __str__(self) -> str: return f"{self.username} {self.action} {self.object_type}" - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: if not self.pk and self.message: # truncate message field if longer than 255 characters @@ -89,7 +94,9 @@ class AuditLog(models.Model): return super(AuditLog, self).save(*args, **kwargs) @staticmethod - def audit_mesh_session(username, agent, debug_info={}): + def audit_mesh_session( + username: str, agent: "Agent", debug_info: Dict[Any, Any] = {} + ) -> None: AuditLog.objects.create( username=username, agent=agent.hostname, @@ -101,7 +108,13 @@ class AuditLog(models.Model): ) @staticmethod - def audit_raw_command(username, agent, cmd, shell, debug_info={}): + def audit_raw_command( + username: str, + agent: "Agent", + cmd: str, + shell: str, + debug_info: Dict[Any, Any] = {}, + ) -> None: AuditLog.objects.create( username=username, agent=agent.hostname, @@ -115,8 +128,13 @@ class AuditLog(models.Model): @staticmethod def audit_object_changed( - username, object_type, before, after, name="", debug_info={} - ): + username: str, + object_type: str, + before: Dict[Any, Any], + after: Dict[Any, Any], + name: str = "", + debug_info: Dict[Any, Any] = {}, + ) -> None: AuditLog.objects.create( username=username, object_type=object_type, @@ -130,7 +148,13 @@ class AuditLog(models.Model): ) @staticmethod - def audit_object_add(username, object_type, after, name="", debug_info={}): + def audit_object_add( + username: str, + object_type: str, + after: Dict[Any, Any], + name: str = "", + debug_info: Dict[Any, Any] = {}, + ) -> None: AuditLog.objects.create( username=username, object_type=object_type, @@ -143,7 +167,13 @@ class AuditLog(models.Model): ) @staticmethod - def audit_object_delete(username, object_type, before, name="", debug_info={}): + def audit_object_delete( + username: str, + object_type: str, + before: Dict[Any, Any], + name: str = "", + debug_info: Dict[Any, Any] = {}, + ) -> None: AuditLog.objects.create( username=username, object_type=object_type, @@ -156,7 +186,9 @@ class AuditLog(models.Model): ) @staticmethod - def audit_script_run(username, agent, script, debug_info={}): + def audit_script_run( + username: str, agent: "Agent", script: str, debug_info: Dict[Any, Any] = {} + ) -> None: AuditLog.objects.create( agent=agent.hostname, agent_id=agent.agent_id, @@ -168,7 +200,7 @@ class AuditLog(models.Model): ) @staticmethod - def audit_user_failed_login(username, debug_info={}): + def audit_user_failed_login(username: str, debug_info: Dict[Any, Any] = {}) -> None: AuditLog.objects.create( username=username, object_type="user", @@ -178,7 +210,9 @@ class AuditLog(models.Model): ) @staticmethod - def audit_user_failed_twofactor(username, debug_info={}): + def audit_user_failed_twofactor( + username: str, debug_info: Dict[Any, Any] = {} + ) -> None: AuditLog.objects.create( username=username, object_type="user", @@ -188,7 +222,9 @@ class AuditLog(models.Model): ) @staticmethod - def audit_user_login_successful(username, debug_info={}): + def audit_user_login_successful( + username: str, debug_info: Dict[Any, Any] = {} + ) -> None: AuditLog.objects.create( username=username, object_type="user", @@ -198,14 +234,19 @@ class AuditLog(models.Model): ) @staticmethod - def audit_url_action(username, urlaction, instance, debug_info={}): + def audit_url_action( + username: str, + urlaction: "URLAction", + instance: "Union[Agent, Client, Site]", + debug_info: Dict[Any, Any] = {}, + ) -> None: - name = instance.hostname if hasattr(instance, "hostname") else instance.name + name = instance.hostname if isinstance(instance, Agent) else instance.name classname = type(instance).__name__ AuditLog.objects.create( username=username, - agent=instance.hostname if classname == "Agent" else None, - agent_id=instance.agent_id if classname == "Agent" else None, + agent=name if isinstance(instance, Agent) else None, + agent_id=instance.agent_id if isinstance(instance, Agent) else None, object_type=classname.lower(), action="url_action", message=f"{username} ran url action: {urlaction.pattern} on {classname}: {name}", @@ -213,7 +254,12 @@ class AuditLog(models.Model): ) @staticmethod - def audit_bulk_action(username, action, affected, debug_info={}): + def audit_bulk_action( + username: str, + action: str, + affected: Dict[str, Any], + debug_info: Dict[Any, Any] = {}, + ) -> None: from agents.models import Agent from clients.models import Client, Site from scripts.models import Script @@ -290,31 +336,46 @@ class DebugLog(models.Model): @classmethod def info( cls, - message, - agent=None, - log_type="system_issues", - ): + message: str, + agent: "Optional[Agent]" = None, + log_type: str = "system_issues", + ) -> None: if get_debug_level() in ["info"]: cls.objects.create( log_level="info", agent=agent, log_type=log_type, message=message ) @classmethod - def warning(cls, message, agent=None, log_type="system_issues"): + def warning( + cls, + message: str, + agent: "Optional[Agent]" = None, + log_type: str = "system_issues", + ) -> None: if get_debug_level() in ["info", "warning"]: cls.objects.create( log_level="warning", agent=agent, log_type=log_type, message=message ) @classmethod - def error(cls, message, agent=None, log_type="system_issues"): + def error( + cls, + message: str, + agent: "Optional[Agent]" = None, + log_type: str = "system_issues", + ) -> None: if get_debug_level() in ["info", "warning", "error"]: cls.objects.create( log_level="error", agent=agent, log_type=log_type, message=message ) @classmethod - def critical(cls, message, agent=None, log_type="system_issues"): + def critical( + cls, + message: str, + agent: "Optional[Agent]" = None, + log_type: str = "system_issues", + ) -> None: if get_debug_level() in ["info", "warning", "error", "critical"]: cls.objects.create( log_level="critical", agent=agent, log_type=log_type, message=message @@ -342,13 +403,13 @@ class PendingAction(models.Model): celery_id = models.CharField(null=True, blank=True, max_length=255) details = models.JSONField(null=True, blank=True) - def __str__(self): + def __str__(self) -> str: return f"{self.agent.hostname} - {self.action_type}" @property - def due(self): + def due(self) -> str: if self.action_type == "schedreboot": - return self.details["time"] + return cast(str, self.details["time"]) elif self.action_type == "agentupdate": return "Next update cycle" elif self.action_type == "chocoinstall": @@ -357,7 +418,7 @@ class PendingAction(models.Model): return "On next checkin" @property - def description(self): + def description(self) -> Optional[str]: if self.action_type == "schedreboot": return "Device pending reboot" @@ -374,6 +435,8 @@ class PendingAction(models.Model): "runpatchinstall", ]: return f"{self.action_type}" + else: + return None class BaseAuditModel(models.Model): @@ -388,16 +451,16 @@ class BaseAuditModel(models.Model): modified_time = models.DateTimeField(auto_now=True, null=True, blank=True) @abstractmethod - def serialize(class_name): + def serialize(class_name: models.Model) -> Dict[str, Any]: pass - def save(self, old_model=None, *args, **kwargs): + def save(self, old_model: Optional[models.Model] = None, *args, **kwargs) -> None: - if get_username(): + username = get_username() + if username: object_class = type(self) object_name = object_class.__name__.lower() - username = get_username() after_value = object_class.serialize(self) # populate created_by and modified_by fields on instance @@ -407,7 +470,7 @@ class BaseAuditModel(models.Model): self.modified_by = username # dont create entry for agent add since that is done in view - if not self.pk: + if not self.pk and username: AuditLog.audit_object_add( username, object_name, @@ -424,7 +487,7 @@ class BaseAuditModel(models.Model): object_class.objects.get(pk=self.pk) ) # only create an audit entry if the values have changed - if before_value != after_value: + if before_value != after_value and username: AuditLog.audit_object_changed( username, @@ -437,14 +500,14 @@ class BaseAuditModel(models.Model): super(BaseAuditModel, self).save(*args, **kwargs) - def delete(self, *args, **kwargs): + def delete(self, *args, **kwargs) -> Tuple[int, Dict[str, int]]: super(BaseAuditModel, self).delete(*args, **kwargs) - if get_username(): - + username = get_username() + if username: object_class = type(self) AuditLog.audit_object_delete( - get_username(), + username, object_class.__name__.lower(), object_class.serialize(self), self.__str__(), diff --git a/api/tacticalrmm/tacticalrmm/middleware.py b/api/tacticalrmm/tacticalrmm/middleware.py index 73fbadd9..723c8bb6 100644 --- a/api/tacticalrmm/tacticalrmm/middleware.py +++ b/api/tacticalrmm/tacticalrmm/middleware.py @@ -2,6 +2,7 @@ import threading from django.conf import settings from ipware import get_client_ip +from typing import Dict, Optional, Any from rest_framework.exceptions import AuthenticationFailed from tacticalrmm.constants import DEMO_NOT_ALLOWED, LINUX_NOT_IMPLEMENTED @@ -9,11 +10,11 @@ from tacticalrmm.constants import DEMO_NOT_ALLOWED, LINUX_NOT_IMPLEMENTED request_local = threading.local() -def get_username(): +def get_username() -> Optional[str]: return getattr(request_local, "username", None) -def get_debug_info(): +def get_debug_info() -> Dict[str, Any]: return getattr(request_local, "debug_info", {}) diff --git a/docker/containers/tactical/entrypoint.sh b/docker/containers/tactical/entrypoint.sh index e88066b9..13855444 100644 --- a/docker/containers/tactical/entrypoint.sh +++ b/docker/containers/tactical/entrypoint.sh @@ -143,6 +143,7 @@ EOF # run migrations and init scripts + python manage.py pre_update_tasks python manage.py migrate --no-input python manage.py collectstatic --no-input python manage.py initial_db_setup diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..4d316c7d --- /dev/null +++ b/mypy.ini @@ -0,0 +1,16 @@ +[mypy] +mypy_path = api/tacticalrmm +strict_optional = True +check_untyped_defs = True +show_traceback = True +allow_redefinition = True +incremental = True +files = **/*.py +exclude = (env | migrations) + +plugins = + mypy_django_plugin.main, + mypy_drf_plugin.main + +[mypy.plugins.django-stubs] +django_settings_module = tacticalrmm.settings \ No newline at end of file