Remove pending action duplicates and make policy check/task propogation more efficient

This commit is contained in:
sadnub
2021-01-16 17:44:27 -05:00
parent 2655964113
commit 61eeb60c19
9 changed files with 99 additions and 100 deletions

View File

@@ -382,32 +382,18 @@ class Agent(BaseAuditModel):
return patch_policy
# clear is used to delete managed policy checks from agent
# parent_checks specifies a list of checks to delete from agent with matching parent_check field
def generate_checks_from_policies(self, clear=False):
def generate_checks_from_policies(self):
from automation.models import Policy
# Clear agent checks managed by policy
if clear:
self.agentchecks.filter(managed_by_policy=True).delete()
# Clear agent checks that have overriden_by_policy set
self.agentchecks.update(overriden_by_policy=False)
# Generate checks based on policies
Policy.generate_policy_checks(self)
# clear is used to delete managed policy tasks from agent
# parent_tasks specifies a list of tasks to delete from agent with matching parent_task field
def generate_tasks_from_policies(self, clear=False):
from autotasks.tasks import delete_win_task_schedule
def generate_tasks_from_policies(self):
from automation.models import Policy
# Clear agent tasks managed by policy
if clear:
for task in self.autotasks.filter(managed_by_policy=True):
delete_win_task_schedule.delay(task.pk)
# Generate tasks based on policies
Policy.generate_policy_tasks(self)
@@ -625,6 +611,13 @@ class Agent(BaseAuditModel):
elif action.details["action"] == "taskdelete":
delete_win_task_schedule.delay(task_id, pending_action=action.id)
# for clearing duplicate pending actions on agent
def remove_matching_pending_task_actions(self, task_id):
# remove any other pending actions on agent with same task_id
for action in self.pendingactions.exclude(status="completed"):
if action.details["task_id"] == task_id:
action.delete()
class AgentOutage(models.Model):
agent = models.ForeignKey(

View File

@@ -114,8 +114,8 @@ def edit_agent(request):
# check if site changed and initiate generating correct policies
if old_site != request.data["site"]:
agent.generate_checks_from_policies(clear=True)
agent.generate_tasks_from_policies(clear=True)
agent.generate_checks_from_policies()
agent.generate_tasks_from_policies()
return Response("ok")

View File

@@ -1,6 +1,5 @@
from django.db import models
from agents.models import Agent
from clients.models import Site, Client
from core.models import CoreSettings
from logs.models import BaseAuditModel
@@ -58,6 +57,11 @@ class Policy(BaseAuditModel):
@staticmethod
def cascade_policy_tasks(agent):
from autotasks.tasks import delete_win_task_schedule
from autotasks.models import AutomatedTask
from logs.models import PendingAction
# List of all tasks to be applied
tasks = list()
added_task_pks = list()
@@ -107,6 +111,33 @@ class Policy(BaseAuditModel):
tasks.append(task)
added_task_pks.append(task.pk)
# remove policy tasks from agent not included in policy
for task in agent.autotasks.filter(
parent_task__in=[
taskpk
for taskpk in agent_tasks_parent_pks
if taskpk not in added_task_pks
]
):
delete_win_task_schedule.delay(task.pk)
# handle matching tasks that haven't synced to agent yet or pending deletion due to agent being offline
for action in agent.pendingactions.exclude(status="completed"):
task = AutomatedTask.objects.get(pk=action.details["task_id"])
if (
task.parent_task in agent_tasks_parent_pks
and task.parent_task in added_task_pks
):
agent.remove_matching_pending_task_actions(task.id)
PendingAction(
agent=agent,
action_type="taskaction",
details={"action": "taskcreate", "task_id": task.id},
).save()
task.sync_status = "notsynced"
task.save(update_fields=["sync_status"])
return [task for task in tasks if task.pk not in agent_tasks_parent_pks]
@staticmethod
@@ -280,6 +311,15 @@ class Policy(BaseAuditModel):
+ eventlog_checks
)
# remove policy checks from agent that fell out of policy scope
agent.agentchecks.filter(
parent_check__in=[
checkpk
for checkpk in agent_checks_parent_pks
if checkpk not in [check.pk for check in final_list]
]
).delete()
return [
check for check in final_list if check.pk not in agent_checks_parent_pks
]

View File

@@ -6,17 +6,7 @@ from tacticalrmm.celery import app
@app.task
def generate_agent_checks_from_policies_task(
###
# copies the policy checks to all affected agents
#
# clear: clears all policy checks first
# create_tasks: also create tasks after checks are generated
###
policypk,
clear=False,
create_tasks=False,
):
def generate_agent_checks_from_policies_task(policypk, create_tasks=False):
policy = Policy.objects.get(pk=policypk)
@@ -30,32 +20,28 @@ def generate_agent_checks_from_policies_task(
agents = policy.related_agents()
for agent in agents:
agent.generate_checks_from_policies(clear=clear)
agent.generate_checks_from_policies()
if create_tasks:
agent.generate_tasks_from_policies(
clear=clear,
)
agent.generate_tasks_from_policies()
@app.task
def generate_agent_checks_by_location_task(
location, mon_type, clear=False, create_tasks=False
):
def generate_agent_checks_by_location_task(location, mon_type, create_tasks=False):
for agent in Agent.objects.filter(**location).filter(monitoring_type=mon_type):
agent.generate_checks_from_policies(clear=clear)
agent.generate_checks_from_policies()
if create_tasks:
agent.generate_tasks_from_policies(clear=clear)
agent.generate_tasks_from_policies()
@app.task
def generate_all_agent_checks_task(mon_type, clear=False, create_tasks=False):
def generate_all_agent_checks_task(mon_type, create_tasks=False):
for agent in Agent.objects.filter(monitoring_type=mon_type):
agent.generate_checks_from_policies(clear=clear)
agent.generate_checks_from_policies()
if create_tasks:
agent.generate_tasks_from_policies(clear=clear)
agent.generate_tasks_from_policies()
@app.task
@@ -93,7 +79,7 @@ def update_policy_check_fields_task(checkpk):
@app.task
def generate_agent_tasks_from_policies_task(policypk, clear=False):
def generate_agent_tasks_from_policies_task(policypk):
policy = Policy.objects.get(pk=policypk)
@@ -107,14 +93,14 @@ def generate_agent_tasks_from_policies_task(policypk, clear=False):
agents = policy.related_agents()
for agent in agents:
agent.generate_tasks_from_policies(clear=clear)
agent.generate_tasks_from_policies()
@app.task
def generate_agent_tasks_by_location_task(location, mon_type, clear=False):
def generate_agent_tasks_by_location_task(location, mon_type):
for agent in Agent.objects.filter(**location).filter(monitoring_type=mon_type):
agent.generate_tasks_from_policies(clear=clear)
agent.generate_tasks_from_policies()
@app.task

View File

@@ -122,7 +122,7 @@ class TestPolicyViews(TacticalTestCase):
resp = self.client.put(url, data, format="json")
self.assertEqual(resp.status_code, 200)
mock_checks_task.assert_called_with(
policypk=policy.pk, clear=True, create_tasks=True
policypk=policy.pk, create_tasks=True
)
self.check_not_authenticated("put", url)
@@ -140,8 +140,8 @@ class TestPolicyViews(TacticalTestCase):
resp = self.client.delete(url, format="json")
self.assertEqual(resp.status_code, 200)
mock_checks_task.assert_called_with(policypk=policy.pk, clear=True)
mock_tasks_task.assert_called_with(policypk=policy.pk, clear=True)
mock_checks_task.assert_called_with(policypk=policy.pk)
mock_tasks_task.assert_called_with(policypk=policy.pk)
self.check_not_authenticated("delete", url)
@@ -298,7 +298,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site__client_id": client.id},
mon_type="server",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -311,7 +310,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site__client_id": client.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -324,7 +322,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site_id": site.id},
mon_type="server",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -337,7 +334,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site_id": site.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -347,7 +343,7 @@ class TestPolicyViews(TacticalTestCase):
self.assertEqual(resp.status_code, 200)
# called because the relation changed
mock_checks_task.assert_called_with(clear=True)
mock_checks_task.assert_called()
mock_checks_task.reset_mock()
# Adding the same relations shouldn't trigger mocks
@@ -396,7 +392,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site__client_id": client.id},
mon_type="server",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -409,7 +404,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site__client_id": client.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -422,7 +416,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site_id": site.id},
mon_type="server",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -435,7 +428,6 @@ class TestPolicyViews(TacticalTestCase):
mock_checks_location_task.assert_called_with(
location={"site_id": site.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
mock_checks_location_task.reset_mock()
@@ -444,7 +436,7 @@ class TestPolicyViews(TacticalTestCase):
resp = self.client.post(url, agent_payload, format="json")
self.assertEqual(resp.status_code, 200)
# called because the relation changed
mock_checks_task.assert_called_with(clear=True)
mock_checks_task.assert_called()
mock_checks_task.reset_mock()
# adding the same relations shouldn't trigger mocks
@@ -753,7 +745,7 @@ class TestPolicyTasks(TacticalTestCase):
agent = baker.make_recipe("agents.agent", site=site, policy=policy)
# test policy assigned to agent
generate_agent_checks_from_policies_task(policy.id, clear=True)
generate_agent_checks_from_policies_task(policy.id)
# make sure all checks were created. should be 7
agent_checks = Agent.objects.get(pk=agent.id).agentchecks.all()
@@ -832,7 +824,6 @@ class TestPolicyTasks(TacticalTestCase):
generate_agent_checks_by_location_task(
{"site_id": sites[0].id},
"server",
clear=True,
create_tasks=True,
)
@@ -846,7 +837,6 @@ class TestPolicyTasks(TacticalTestCase):
generate_agent_checks_by_location_task(
{"site__client_id": clients[0].id},
"workstation",
clear=True,
create_tasks=True,
)
# workstation_agent should now have policy checks and the other agents should not
@@ -875,7 +865,7 @@ class TestPolicyTasks(TacticalTestCase):
core.workstation_policy = policy
core.save()
generate_all_agent_checks_task("server", clear=True, create_tasks=True)
generate_all_agent_checks_task("server", create_tasks=True)
# all servers should have 7 checks
for agent in server_agents:
@@ -884,7 +874,7 @@ class TestPolicyTasks(TacticalTestCase):
for agent in workstation_agents:
self.assertEqual(Agent.objects.get(pk=agent.id).agentchecks.count(), 0)
generate_all_agent_checks_task("workstation", clear=True, create_tasks=True)
generate_all_agent_checks_task("workstation", create_tasks=True)
# all agents should have 7 checks now
for agent in server_agents:
@@ -961,7 +951,7 @@ class TestPolicyTasks(TacticalTestCase):
site = baker.make("clients.Site")
agent = baker.make_recipe("agents.server_agent", site=site, policy=policy)
generate_agent_tasks_from_policies_task(policy.id, clear=True)
generate_agent_tasks_from_policies_task(policy.id)
agent_tasks = Agent.objects.get(pk=agent.id).autotasks.all()
@@ -1000,9 +990,7 @@ class TestPolicyTasks(TacticalTestCase):
agent1 = baker.make_recipe("agents.agent", site=sites[1])
agent2 = baker.make_recipe("agents.agent", site=sites[3])
generate_agent_tasks_by_location_task(
{"site_id": sites[0].id}, "server", clear=True
)
generate_agent_tasks_by_location_task({"site_id": sites[0].id}, "server")
# all servers in site1 and site2 should have 3 tasks
self.assertEqual(
@@ -1013,8 +1001,7 @@ class TestPolicyTasks(TacticalTestCase):
self.assertEqual(Agent.objects.get(pk=agent2.id).autotasks.count(), 0)
generate_agent_tasks_by_location_task(
{"site__client_id": clients[0].id}, "workstation", clear=True
)
{"site__client_id": clients[0].id}, "workstation")
# all workstations in Default1 should have 3 tasks
self.assertEqual(

View File

@@ -83,7 +83,6 @@ class GetUpdateDeletePolicy(APIView):
if saved_policy.active != old_active or saved_policy.enforced != old_enforced:
generate_agent_checks_from_policies_task.delay(
policypk=policy.pk,
clear=(not saved_policy.active or not saved_policy.enforced),
create_tasks=(saved_policy.active != old_active),
)
@@ -93,8 +92,8 @@ class GetUpdateDeletePolicy(APIView):
policy = get_object_or_404(Policy, pk=pk)
# delete all managed policy checks off of agents
generate_agent_checks_from_policies_task.delay(policypk=policy.pk, clear=True)
generate_agent_tasks_from_policies_task.delay(policypk=policy.pk, clear=True)
generate_agent_checks_from_policies_task.delay(policypk=policy.pk)
generate_agent_tasks_from_policies_task.delay(policypk=policy.pk)
policy.delete()
return Response("ok")
@@ -218,7 +217,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site__client_id": client.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
@@ -236,7 +234,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site_id": site.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
@@ -258,7 +255,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site__client_id": client.id},
mon_type="server",
clear=True,
create_tasks=True,
)
@@ -276,7 +272,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site_id": site.id},
mon_type="server",
clear=True,
create_tasks=True,
)
@@ -296,7 +291,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site__client_id": client.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
@@ -311,7 +305,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site_id": site.id},
mon_type="workstation",
clear=True,
create_tasks=True,
)
@@ -329,7 +322,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site__client_id": client.id},
mon_type="server",
clear=True,
create_tasks=True,
)
@@ -343,7 +335,6 @@ class GetRelated(APIView):
generate_agent_checks_by_location_task.delay(
location={"site_id": site.pk},
mon_type="server",
clear=True,
create_tasks=True,
)
@@ -358,14 +349,14 @@ class GetRelated(APIView):
if not agent.policy or agent.policy and agent.policy.pk != policy.pk:
agent.policy = policy
agent.save()
agent.generate_checks_from_policies(clear=True)
agent.generate_tasks_from_policies(clear=True)
agent.generate_checks_from_policies()
agent.generate_tasks_from_policies()
else:
if agent.policy:
agent.policy = None
agent.save()
agent.generate_checks_from_policies(clear=True)
agent.generate_tasks_from_policies(clear=True)
agent.generate_checks_from_policies()
agent.generate_tasks_from_policies()
return Response("ok")

View File

@@ -6,7 +6,6 @@ import datetime as dt
from django.db import models
from django.contrib.postgres.fields import ArrayField
from django.db.models.fields import DateTimeField
from automation.models import Policy
from logs.models import BaseAuditModel
from tacticalrmm.utils import bitdays_to_string
@@ -43,7 +42,7 @@ class AutomatedTask(BaseAuditModel):
blank=True,
)
policy = models.ForeignKey(
Policy,
"automation.Policy",
related_name="autotasks",
null=True,
blank=True,

View File

@@ -76,9 +76,14 @@ def create_win_task_schedule(pk, pending_action=False):
return "error"
r = asyncio.run(task.agent.nats_cmd(nats_data, timeout=10))
print(r)
if r != "ok":
# don't create pending action if this task was initiated by a pending action
if not pending_action:
# complete any other pending actions on agent with same task_id
task.agent.remove_matching_pending_task_actions(task.id)
PendingAction(
agent=task.agent,
action_type="taskaction",
@@ -144,6 +149,7 @@ def enable_or_disable_win_task(pk, action, pending_action=False):
task.sync_status = "synced"
task.save(update_fields=["sync_status"])
return "ok"
@@ -156,10 +162,14 @@ def delete_win_task_schedule(pk, pending_action=False):
"schedtaskpayload": {"name": task.win_task_name},
}
r = asyncio.run(task.agent.nats_cmd(nats_data, timeout=10))
if r != "ok":
print(r)
if r != "ok" and "The system cannot find the file specified" not in r:
# don't create pending action if this task was initiated by a pending action
if not pending_action:
# complete any other pending actions on agent with same task_id
task.agent.remove_matching_pending_task_actions(task.id)
PendingAction(
agent=task.agent,
action_type="taskaction",
@@ -168,7 +178,7 @@ def delete_win_task_schedule(pk, pending_action=False):
task.sync_status = "pendingdeletion"
task.save(update_fields=["sync_status"])
return
return "timeout"
# complete pending action since it was successful
if pending_action:
@@ -177,10 +187,7 @@ def delete_win_task_schedule(pk, pending_action=False):
pendingaction.save(update_fields=["status"])
# complete any other pending actions on agent with same task_id
for action in task.agent.pendingactions.all():
if action.details["task_id"] == task.id:
action.status = "completed"
action.save()
task.agent.remove_matching_pending_task_actions(task.id)
task.delete()
return "ok"

View File

@@ -51,14 +51,10 @@ def edit_settings(request):
# check if default policies changed
if old_server_policy != new_settings.server_policy:
generate_all_agent_checks_task.delay(
mon_type="server", clear=True, create_tasks=True
)
generate_all_agent_checks_task.delay(mon_type="server", create_tasks=True)
if old_workstation_policy != new_settings.workstation_policy:
generate_all_agent_checks_task.delay(
mon_type="workstation", clear=True, create_tasks=True
)
generate_all_agent_checks_task.delay(mon_type="workstation", create_tasks=True)
return Response("ok")