finish up tests and some code rework
This commit is contained in:
147
api/tacticalrmm/ee/reporting/tests/test_base_template_views.py
Normal file
147
api/tacticalrmm/ee/reporting/tests/test_base_template_views.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import pytest
|
||||
from model_bakery import baker
|
||||
from rest_framework.test import APIClient
|
||||
from rest_framework import status
|
||||
|
||||
from ..models import ReportHTMLTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client():
|
||||
client = APIClient()
|
||||
user = baker.make("accounts.User")
|
||||
client.force_authenticate(user=user)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client():
|
||||
return APIClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_html_template():
|
||||
return baker.make("reporting.ReportHTMLTemplate")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_html_template_data():
|
||||
return {"name": "Test Template", "html": "<div>Test HTML</div>"}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetAddReportHTMLTemplate:
|
||||
def test_get_all_report_html_templates(
|
||||
self, authenticated_client, report_html_template
|
||||
):
|
||||
response = authenticated_client.get("/reporting/htmltemplates/")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0]["name"] == report_html_template.name
|
||||
|
||||
def test_post_new_report_html_template(
|
||||
self, authenticated_client, report_html_template_data
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/htmltemplates/", data=report_html_template_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert ReportHTMLTemplate.objects.filter(
|
||||
name=report_html_template_data["name"]
|
||||
).exists()
|
||||
|
||||
def test_post_invalid_data(self, authenticated_client):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/htmltemplates/", data={"name": ""}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_unauthenticated_get_html_templates_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.get("/reporting/htmltemplates/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthenticated_add_html_template_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/htmltemplates/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetEditDeleteReportHTMLTemplate:
|
||||
def test_get_specific_report_html_template(
|
||||
self, authenticated_client, report_html_template
|
||||
):
|
||||
response = authenticated_client.get(
|
||||
f"/reporting/htmltemplates/{report_html_template.id}/"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["name"] == report_html_template.name
|
||||
|
||||
def test_get_non_existent_template(self, authenticated_client):
|
||||
response = authenticated_client.get("/reporting/htmltemplates/999/")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_put_update_report_html_template(
|
||||
self, authenticated_client, report_html_template, report_html_template_data
|
||||
):
|
||||
response = authenticated_client.put(
|
||||
f"/reporting/htmltemplates/{report_html_template.id}/",
|
||||
data=report_html_template_data,
|
||||
)
|
||||
|
||||
report_html_template.refresh_from_db()
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert report_html_template.name == report_html_template_data["name"]
|
||||
|
||||
def test_put_invalid_data(self, authenticated_client, report_html_template):
|
||||
response = authenticated_client.put(
|
||||
f"/reporting/htmltemplates/{report_html_template.id}/", data={"name": ""}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_delete_report_html_template(
|
||||
self, authenticated_client, report_html_template
|
||||
):
|
||||
response = authenticated_client.delete(
|
||||
f"/reporting/htmltemplates/{report_html_template.id}/"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert not ReportHTMLTemplate.objects.filter(
|
||||
id=report_html_template.id
|
||||
).exists()
|
||||
|
||||
def test_delete_non_existent_template(self, authenticated_client):
|
||||
response = authenticated_client.delete("/reporting/htmltemplates/999/")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_unauthenticated_get_html_template_view(
|
||||
self, unauthenticated_client, report_html_template
|
||||
):
|
||||
response = unauthenticated_client.get(
|
||||
f"/reporting/htmltemplates/{report_html_template.id}/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthenticated_edit_html_template_view(
|
||||
self, unauthenticated_client, report_html_template
|
||||
):
|
||||
response = unauthenticated_client.put(
|
||||
f"/reporting/htmltemplates/{report_html_template.id}/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthenticated_delete_html_template_view(
|
||||
self, unauthenticated_client, report_html_template
|
||||
):
|
||||
response = unauthenticated_client.delete(
|
||||
f"/reporting/htmltemplates/{report_html_template.id}/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
513
api/tacticalrmm/ee/reporting/tests/test_data_queries.py
Normal file
513
api/tacticalrmm/ee/reporting/tests/test_data_queries.py
Normal file
@@ -0,0 +1,513 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest.mock import patch
|
||||
from model_bakery import baker
|
||||
from django.apps import apps
|
||||
from ..utils import (
|
||||
add_custom_fields,
|
||||
make_dataqueries_inline,
|
||||
build_queryset,
|
||||
resolve_model,
|
||||
ResolveModelException,
|
||||
InvalidDBOperationException,
|
||||
)
|
||||
from ..constants import REPORTING_MODELS
|
||||
from agents.models import Agent
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestMakeVariablesInline:
|
||||
def test_make_dataqueries_inline_valid_reference(self):
|
||||
data_query = baker.make(
|
||||
"reporting.ReportDataQuery", name="test_query", json_query={"test": "query"}
|
||||
)
|
||||
variables = yaml.dump({"data_sources": {"source1": "test_query"}})
|
||||
|
||||
result = make_dataqueries_inline(variables=variables)
|
||||
|
||||
assert yaml.safe_load(result) == {
|
||||
"data_sources": {"source1": {"test": "query"}}
|
||||
}
|
||||
|
||||
def test_make_dataqueries_inline_invalid_reference(self):
|
||||
variables = yaml.dump({"data_sources": {"source1": "nonexistent_query"}})
|
||||
|
||||
result = make_dataqueries_inline(variables=variables)
|
||||
|
||||
assert yaml.safe_load(result) == {
|
||||
"data_sources": {"source1": "nonexistent_query"}
|
||||
}
|
||||
|
||||
def test_make_dataqueries_inline_no_reference(self):
|
||||
variables = yaml.dump({"key": "value"})
|
||||
|
||||
result = make_dataqueries_inline(variables=variables)
|
||||
|
||||
assert yaml.safe_load(result) == {"key": "value"}
|
||||
|
||||
def test_make_dataqueries_inline_invalid_yaml(self):
|
||||
variables = "{some: invalid: yaml}"
|
||||
|
||||
result = make_dataqueries_inline(variables=variables)
|
||||
|
||||
assert yaml.safe_load(result) == {}
|
||||
|
||||
|
||||
class TestResolvingModels:
|
||||
def test_all_reporting_models_valid(self):
|
||||
for model_name, app_name in REPORTING_MODELS:
|
||||
try:
|
||||
model = apps.get_model(app_name, model_name)
|
||||
except LookupError:
|
||||
pytest.fail(f"Model: {model_name} does not exist in app: {app_name}")
|
||||
|
||||
def test_resolve_model_valid_model(self):
|
||||
data_source = {"model": "Agent"}
|
||||
|
||||
result = resolve_model(data_source=data_source)
|
||||
|
||||
# Assuming 'agents.Agent' is a valid model in your Django app.
|
||||
from agents.models import Agent
|
||||
|
||||
assert result["model"] == Agent
|
||||
|
||||
def test_resolve_model_invalid_model_name(self):
|
||||
data_source = {"model": "InvalidModel"}
|
||||
|
||||
with pytest.raises(
|
||||
ResolveModelException, match="Model lookup failed for InvalidModel"
|
||||
):
|
||||
resolve_model(data_source=data_source)
|
||||
|
||||
def test_resolve_model_no_model_key(self):
|
||||
data_source = {"key": "value"}
|
||||
|
||||
with pytest.raises(
|
||||
ResolveModelException, match="Model key must be present on data_source"
|
||||
):
|
||||
resolve_model(data_source=data_source)
|
||||
|
||||
|
||||
@patch("agents.models.Agent.objects.using", return_value=Agent.objects.using("default"))
|
||||
@pytest.mark.django_db()
|
||||
class TestBuildingQueryset:
|
||||
@pytest.fixture
|
||||
def setup_agents(self):
|
||||
agent1 = baker.make("agents.Agent", hostname="ZAgent1", plat="windows")
|
||||
agent2 = baker.make("agents.Agent", hostname="Agent2", plat="windows")
|
||||
return [agent1, agent2]
|
||||
|
||||
def test_build_queryset_with_valid_model(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
}
|
||||
result = build_queryset(data_source=data_source)
|
||||
assert result is not None, "Queryset should not be None for a valid model."
|
||||
|
||||
def test_build_queryset_invalid_operation(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"invalid_op": "value",
|
||||
}
|
||||
with pytest.raises(InvalidDBOperationException):
|
||||
build_queryset(data_source=data_source)
|
||||
|
||||
def test_build_queryset_without_model(self, mock, setup_agents):
|
||||
data_source = {}
|
||||
with pytest.raises(
|
||||
Exception
|
||||
): # This could be a more specific exception if you expect one.
|
||||
build_queryset(data_source=data_source)
|
||||
|
||||
def test_build_queryset_only_operation(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "only": ["hostname", "operating_system"]}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(result) == 2
|
||||
for agent_data in result:
|
||||
assert "hostname" in agent_data
|
||||
assert "operating_system" in agent_data
|
||||
assert "plat" not in agent_data
|
||||
|
||||
def test_build_queryset_id_is_appended_if_only_exists(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "only": ["hostname"]}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(result) == 2
|
||||
for agent_data in result:
|
||||
assert "id" in agent_data
|
||||
|
||||
def test_build_queryset_filter_operation(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"filter": {"hostname": setup_agents[0].hostname},
|
||||
}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_filtering_operation_with_multiple_fields(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"filter": {
|
||||
"hostname": setup_agents[0].hostname,
|
||||
"operating_system": setup_agents[0].operating_system,
|
||||
},
|
||||
}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_filtering_with_non_existing_condition(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "filter": {"hostname": "doesn't exist"}}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_build_queryset_exclude_operation(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"exclude": {"hostname": setup_agents[0].hostname},
|
||||
}
|
||||
|
||||
results = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["hostname"] != setup_agents[0].hostname
|
||||
|
||||
def test_build_query_get_operation(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "get": {"agent_id": setup_agents[0].agent_id}}
|
||||
|
||||
agent = build_queryset(data_source=data_source)
|
||||
assert agent["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_all_operation(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"all": True,
|
||||
}
|
||||
result = build_queryset(data_source=data_source)
|
||||
assert len(result) == 2
|
||||
|
||||
# test filter and only
|
||||
def test_build_queryset_filter_only_operation(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"filter": {"hostname": setup_agents[0].hostname},
|
||||
"only": ["agent_id", "hostname"],
|
||||
}
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
# should only return 1 result
|
||||
assert len(result) == 1
|
||||
assert result[0]["hostname"] == setup_agents[0].hostname
|
||||
assert "plat" not in result[0]
|
||||
|
||||
def test_build_queryset_limit_operation(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"limit": 1,
|
||||
}
|
||||
result = build_queryset(data_source=data_source)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_build_queryset_field_defer_operation(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "defer": ["wmi_detail", "services"]}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
assert "wmi_detail" not in result[0]
|
||||
assert "services" not in result[0]
|
||||
assert result[0]["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_first_operation(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "first": True}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert result["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_count_operation(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "count": True}
|
||||
|
||||
count = build_queryset(data_source=data_source)
|
||||
|
||||
assert count == 2
|
||||
|
||||
def test_build_queryset_order_by_operation(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "order_by": ["hostname"]}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["hostname"] == setup_agents[1].hostname
|
||||
assert result[1]["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_json_presentation(self, mock, setup_agents):
|
||||
import json
|
||||
|
||||
data_source = {"model": Agent, "json": True}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
# Deserializing the result to check the content.
|
||||
result_data = json.loads(result)
|
||||
assert result_data[0]["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_custom_fields(self, mock, setup_agents):
|
||||
default_value = "Default Value"
|
||||
|
||||
field1 = baker.make(
|
||||
"core.CustomField", name="custom_1", model="agent", type="text"
|
||||
)
|
||||
baker.make(
|
||||
"core.CustomField",
|
||||
name="custom_2",
|
||||
model="agent",
|
||||
type="text",
|
||||
default_value_string=default_value,
|
||||
)
|
||||
|
||||
baker.make(
|
||||
"agents.AgentCustomField",
|
||||
agent=setup_agents[0],
|
||||
field=field1,
|
||||
string_value="Agent1",
|
||||
)
|
||||
baker.make(
|
||||
"agents.AgentCustomField",
|
||||
agent=setup_agents[1],
|
||||
field=field1,
|
||||
string_value="Agent2",
|
||||
)
|
||||
|
||||
data_source = {"model": Agent, "custom_fields": ["custom_1", "custom_2"]}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
assert len(result) == 2
|
||||
|
||||
# check agent 1
|
||||
assert result[0]["custom_fields"]["custom_1"] == "Agent1"
|
||||
assert result[0]["custom_fields"]["custom_2"] == default_value
|
||||
|
||||
# check agent 2
|
||||
assert result[1]["custom_fields"]["custom_1"] == "Agent2"
|
||||
assert result[1]["custom_fields"]["custom_2"] == default_value
|
||||
|
||||
def test_build_queryset_filter_only_json_combination(self, mock, setup_agents):
|
||||
import json
|
||||
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"filter": {"agent_id": setup_agents[0].agent_id},
|
||||
"only": ["hostname", "agent_id"],
|
||||
"json": True,
|
||||
}
|
||||
|
||||
result_json = build_queryset(data_source=data_source)
|
||||
result = json.loads(result_json)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "operating_system" not in result[0]
|
||||
assert result[0]["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_get_only_json_combination(self, mock, setup_agents):
|
||||
import json
|
||||
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"get": {"agent_id": setup_agents[0].agent_id},
|
||||
"only": ["hostname", "agent_id"],
|
||||
"json": True,
|
||||
}
|
||||
|
||||
result_json = build_queryset(data_source=data_source)
|
||||
result = json.loads(result_json)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "operating_system" not in result
|
||||
assert result["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_filter_order_by_combination(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"filter": {"plat": "windows"},
|
||||
"order_by": ["hostname"],
|
||||
}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["hostname"] == setup_agents[1].hostname
|
||||
assert result[1]["hostname"] == setup_agents[0].hostname
|
||||
|
||||
def test_build_queryset_defer_used_over_only(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"only": ["hostname", "operating_system"],
|
||||
"defer": ["operating_system"],
|
||||
}
|
||||
|
||||
result = build_queryset(data_source=data_source)[0]
|
||||
|
||||
assert "hostname" in result
|
||||
assert "operating_system" not in result
|
||||
|
||||
def test_build_queryset_limit_ignored_with_first_or_get(self, mock, setup_agents):
|
||||
data_source = {"model": Agent, "limit": 1, "first": True}
|
||||
|
||||
result_first = build_queryset(data_source=data_source)
|
||||
assert isinstance(result_first, dict)
|
||||
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
"limit": 1,
|
||||
"get": {"agent_id": setup_agents[0].agent_id},
|
||||
}
|
||||
|
||||
result_get = build_queryset(data_source=data_source)
|
||||
assert isinstance(result_get, dict)
|
||||
|
||||
def test_build_queryset_result_type_with_get_or_first(self, mock, setup_agents):
|
||||
# Test with "get"
|
||||
data_source_get = {
|
||||
"model": Agent,
|
||||
"get": {"hostname": setup_agents[0].hostname},
|
||||
}
|
||||
result_get = build_queryset(data_source=data_source_get)
|
||||
|
||||
# Test with "first"
|
||||
data_source_first = {"model": Agent, "first": True}
|
||||
result_first = build_queryset(data_source=data_source_first)
|
||||
|
||||
assert not isinstance(result_get, list)
|
||||
assert not isinstance(result_first, list)
|
||||
|
||||
def test_build_queryset_result_type_without_get_or_first(self, mock, setup_agents):
|
||||
data_source = {
|
||||
"model": Agent,
|
||||
}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_build_queryset_result_in_json_format(self, mock, setup_agents):
|
||||
import json
|
||||
|
||||
data_source = {"model": Agent, "json": True}
|
||||
|
||||
result = build_queryset(data_source=data_source)
|
||||
|
||||
try:
|
||||
parsed_result = json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
assert False
|
||||
|
||||
assert isinstance(parsed_result, list)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestAddingCustomFields:
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,custom_field_model",
|
||||
[
|
||||
("agent", "agents.AgentCustomField"),
|
||||
("client", "clients.ClientCustomField"),
|
||||
("site", "clients.SiteCustomField"),
|
||||
],
|
||||
)
|
||||
def test_add_custom_fields_with_list_of_dicts(self, model_name, custom_field_model):
|
||||
custom_field = baker.make("core.CustomField", name="field1", model=model_name)
|
||||
default_value = "Default Value"
|
||||
custom_field2 = baker.make(
|
||||
"core.CustomField",
|
||||
name="field2",
|
||||
model=model_name,
|
||||
default_value_string=default_value,
|
||||
)
|
||||
|
||||
custom_model_instance1 = baker.make(
|
||||
custom_field_model, field=custom_field, string_value="Value"
|
||||
)
|
||||
custom_model_instance2 = baker.make(
|
||||
custom_field_model, field=custom_field, string_value="Value"
|
||||
)
|
||||
|
||||
data = [
|
||||
{"id": getattr(custom_model_instance1, f"{model_name}_id")},
|
||||
{"id": getattr(custom_model_instance2, f"{model_name}_id")},
|
||||
]
|
||||
fields_to_add = ["field1", "field2"]
|
||||
result = add_custom_fields(
|
||||
data=data, fields_to_add=fields_to_add, model_name=model_name
|
||||
)
|
||||
|
||||
# Assert logic here based on what you expect the result to be
|
||||
assert result[0]["custom_fields"]["field1"] == custom_model_instance1.value
|
||||
assert result[1]["custom_fields"]["field1"] == custom_model_instance2.value
|
||||
|
||||
assert result[0]["custom_fields"]["field2"] == default_value
|
||||
assert result[1]["custom_fields"]["field2"] == default_value
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,custom_field_model",
|
||||
[
|
||||
("agent", "agents.AgentCustomField"),
|
||||
("client", "clients.ClientCustomField"),
|
||||
("site", "clients.SiteCustomField"),
|
||||
],
|
||||
)
|
||||
def test_add_custom_fields_to_dictionary(self, model_name, custom_field_model):
|
||||
custom_field = baker.make("core.CustomField", name="field1", model=model_name)
|
||||
custom_model_instance = baker.make(
|
||||
custom_field_model, field=custom_field, string_value="default_value"
|
||||
)
|
||||
|
||||
data = {"id": getattr(custom_model_instance, f"{model_name}_id")}
|
||||
fields_to_add = ["field1"]
|
||||
result = add_custom_fields(
|
||||
data=data,
|
||||
fields_to_add=fields_to_add,
|
||||
model_name=model_name,
|
||||
dict_value=True,
|
||||
)
|
||||
|
||||
# Assert logic here based on what you expect the result to be
|
||||
assert result["custom_fields"]["field1"] == custom_model_instance.value
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"agent",
|
||||
"client",
|
||||
"site",
|
||||
],
|
||||
)
|
||||
def test_add_custom_fields_with_default_value(self, model_name):
|
||||
default_value = "default_value"
|
||||
custom_field = baker.make(
|
||||
"core.CustomField",
|
||||
name="field1",
|
||||
model=model_name,
|
||||
default_value_string=default_value,
|
||||
)
|
||||
|
||||
# Note: Not creating an instance of the custom_field_model here to ensure the default value is used
|
||||
|
||||
data = {"id": 999} # ID not associated with any custom field model instance
|
||||
fields_to_add = ["field1"]
|
||||
result = add_custom_fields(
|
||||
data=data,
|
||||
fields_to_add=fields_to_add,
|
||||
model_name=model_name,
|
||||
dict_value=True,
|
||||
)
|
||||
|
||||
# Assert that the default value is used
|
||||
assert result["custom_fields"]["field1"] == default_value
|
||||
184
api/tacticalrmm/ee/reporting/tests/test_dataquery_views.py
Normal file
184
api/tacticalrmm/ee/reporting/tests/test_dataquery_views.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import pytest
|
||||
import json
|
||||
from model_bakery import baker
|
||||
from rest_framework.test import APIClient
|
||||
from rest_framework import status
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
from ..models import ReportDataQuery
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client():
|
||||
client = APIClient()
|
||||
user = baker.make("accounts.User")
|
||||
client.force_authenticate(user=user)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client():
|
||||
return APIClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_data_query():
|
||||
return baker.make("reporting.ReportDataQuery")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_data_query_data():
|
||||
return {"name": "Test Data Query", "json_query": {"test": "value"}}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetAddReportDataQuery:
|
||||
def test_get_all_report_data_queries(self, authenticated_client, report_data_query):
|
||||
url = "/reporting/dataqueries/"
|
||||
response = authenticated_client.get(url)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0]["name"] == report_data_query.name
|
||||
|
||||
def test_post_new_report_data_query(
|
||||
self, authenticated_client, report_data_query_data
|
||||
):
|
||||
url = "/reporting/dataqueries/"
|
||||
response = authenticated_client.post(
|
||||
url, data=report_data_query_data, format="json"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert ReportDataQuery.objects.filter(
|
||||
name=report_data_query_data["name"]
|
||||
).exists()
|
||||
|
||||
def test_post_invalid_data(self, authenticated_client):
|
||||
url = "/reporting/dataqueries/"
|
||||
response = authenticated_client.post(url, data={"name": ""})
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_unauthenticated_get_data_queries_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.get("/reporting/dataqueries/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthenticated_add_data_query_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/dataqueries/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetEditDeleteReportDataQuery:
|
||||
def test_get_specific_report_data_query(
|
||||
self, authenticated_client, report_data_query
|
||||
):
|
||||
url = f"/reporting/dataqueries/{report_data_query.id}/"
|
||||
response = authenticated_client.get(url)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["name"] == report_data_query.name
|
||||
|
||||
def test_get_non_existent_data_query(self, authenticated_client):
|
||||
url = "/reporting/dataqueries/9999/"
|
||||
response = authenticated_client.get(url)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_put_update_report_data_query(
|
||||
self, authenticated_client, report_data_query, report_data_query_data
|
||||
):
|
||||
url = f"/reporting/dataqueries/{report_data_query.id}/"
|
||||
response = authenticated_client.put(
|
||||
url, data=report_data_query_data, format="json"
|
||||
)
|
||||
|
||||
report_data_query.refresh_from_db()
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert report_data_query.name == report_data_query_data["name"]
|
||||
|
||||
def test_put_invalid_data(self, authenticated_client, report_data_query):
|
||||
url = f"/reporting/dataqueries/{report_data_query.id}/"
|
||||
response = authenticated_client.put(url, data={"name": ""})
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_delete_report_data_query(self, authenticated_client, report_data_query):
|
||||
url = f"/reporting/dataqueries/{report_data_query.id}/"
|
||||
response = authenticated_client.delete(url)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert not ReportDataQuery.objects.filter(id=report_data_query.id).exists()
|
||||
|
||||
def test_delete_non_existent_data_query(self, authenticated_client):
|
||||
url = "/reporting/dataqueries/9999/"
|
||||
response = authenticated_client.delete(url)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_unauthenticated_get_data_queries_view(
|
||||
self, unauthenticated_client, report_data_query
|
||||
):
|
||||
response = unauthenticated_client.get(
|
||||
f"/reporting/dataqueries/{report_data_query.id}/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthenticated_edit_html_template_view(
|
||||
self, unauthenticated_client, report_data_query
|
||||
):
|
||||
response = unauthenticated_client.put(
|
||||
f"/reporting/dataqueries/{report_data_query.id}/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthenticated_delete_html_template_view(
|
||||
self, unauthenticated_client, report_data_query
|
||||
):
|
||||
response = unauthenticated_client.delete(
|
||||
f"/reporting/dataqueries/{report_data_query.id}/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestQuerySchema:
|
||||
def test_get_query_schema_in_debug_mode(self, settings, authenticated_client):
|
||||
# Set DEBUG mode
|
||||
settings.DEBUG = True
|
||||
|
||||
expected_data = {"sample": "json"}
|
||||
|
||||
# mock the file
|
||||
mopen = mock_open(read_data=json.dumps({"sample": "json"}))
|
||||
with patch("builtins.open", mopen):
|
||||
response = authenticated_client.get("/reporting/queryschema/")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == expected_data
|
||||
|
||||
def test_get_query_schema_in_production_mode(self, settings, authenticated_client):
|
||||
# Set production mode (DEBUG = False)
|
||||
settings.DEBUG = False
|
||||
|
||||
response = authenticated_client.get("/reporting/queryschema/")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
# Check that the X-Accel-Redirect header is set correctly
|
||||
assert (
|
||||
response["X-Accel-Redirect"]
|
||||
== "/static/reporting/schemas/query_schema.json"
|
||||
)
|
||||
|
||||
def test_get_query_schema_file_missing(self, settings, authenticated_client):
|
||||
# Set DEBUG mode
|
||||
settings.DEBUG = True
|
||||
|
||||
with patch("builtins.open", side_effect=FileNotFoundError):
|
||||
response = authenticated_client.get("/reporting/queryschema/")
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_unauthenticated_query_schema_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.delete(f"/reporting/queryschema/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -0,0 +1,249 @@
|
||||
import pytest
|
||||
import uuid
|
||||
import base64
|
||||
import json
|
||||
from model_bakery import baker
|
||||
from rest_framework.test import APIClient
|
||||
from unittest.mock import patch
|
||||
from rest_framework import status
|
||||
from ..models import ReportAsset, ReportTemplate, ReportHTMLTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client():
|
||||
client = APIClient()
|
||||
user = baker.make("accounts.User")
|
||||
client.force_authenticate(user=user)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client():
|
||||
return APIClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_template():
|
||||
return baker.make(
|
||||
"reporting.ReportTemplate",
|
||||
name="test_template",
|
||||
template_md="# Test MD",
|
||||
template_css="body { color: red; }",
|
||||
type="markdown",
|
||||
depends_on=["some_dependency"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_template_with_base_template():
|
||||
base_template = baker.make("reporting.ReportHTMLTemplate")
|
||||
return baker.make(
|
||||
"reporting.ReportTemplate",
|
||||
name="test_template",
|
||||
template_md="# Test MD",
|
||||
template_css="body { color: red; }",
|
||||
template_html=base_template,
|
||||
type="markdown",
|
||||
depends_on=["some_dependency"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestExportReportTemplate:
|
||||
@patch(
|
||||
"ee.reporting.views.base64_encode_assets", return_value="some_encoded_assets"
|
||||
)
|
||||
def test_export_report_template_with_base_template(
|
||||
self,
|
||||
mock_encode_assets,
|
||||
authenticated_client,
|
||||
report_template_with_base_template,
|
||||
):
|
||||
url = f"/reporting/templates/{report_template_with_base_template.id}/export/"
|
||||
response = authenticated_client.post(url)
|
||||
|
||||
expected_response = {
|
||||
"base_template": {
|
||||
"name": report_template_with_base_template.template_html.name,
|
||||
"html": report_template_with_base_template.template_html.html,
|
||||
},
|
||||
"template": {
|
||||
"name": report_template_with_base_template.name,
|
||||
"template_css": report_template_with_base_template.template_css,
|
||||
"template_md": report_template_with_base_template.template_md,
|
||||
"type": report_template_with_base_template.type,
|
||||
"depends_on": report_template_with_base_template.depends_on,
|
||||
"template_variables": "{}\n",
|
||||
},
|
||||
"assets": "some_encoded_assets",
|
||||
}
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == expected_response
|
||||
mock_encode_assets.assert_called()
|
||||
|
||||
@patch(
|
||||
"ee.reporting.views.base64_encode_assets", return_value="some_encoded_assets"
|
||||
)
|
||||
def test_export_report_template_without_base_template(
|
||||
self,
|
||||
mock_encode_assets,
|
||||
authenticated_client,
|
||||
report_template,
|
||||
):
|
||||
url = f"/reporting/templates/{report_template.id}/export/"
|
||||
response = authenticated_client.post(url)
|
||||
|
||||
expected_response = {
|
||||
"base_template": None,
|
||||
"template": {
|
||||
"name": report_template.name,
|
||||
"template_css": report_template.template_css,
|
||||
"template_md": report_template.template_md,
|
||||
"type": report_template.type,
|
||||
"depends_on": report_template.depends_on,
|
||||
"template_variables": "{}\n",
|
||||
},
|
||||
"assets": "some_encoded_assets",
|
||||
}
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == expected_response
|
||||
mock_encode_assets.assert_called()
|
||||
|
||||
def test_unauthenticated_export_report_template_view(
|
||||
self, unauthenticated_client, report_template
|
||||
):
|
||||
response = unauthenticated_client.post(
|
||||
f"/reporting/templates/{report_template.id}/export/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestImportReportTemplate:
|
||||
@pytest.fixture
|
||||
def valid_template_data(self):
|
||||
"""Returns a sample valid template data."""
|
||||
return {
|
||||
"template": {
|
||||
"name": "test_template",
|
||||
"template_md": "# Test MD",
|
||||
"type": "markdown",
|
||||
"depends_on": ["some_dependency"],
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def valid_base_template_data(self):
|
||||
"""Returns a sample valid base template data."""
|
||||
return {"name": "base_test_template", "html": "<div>Test</div>"}
|
||||
|
||||
@pytest.fixture
|
||||
def valid_assets_data(self):
|
||||
"""Returns a sample valid assets data."""
|
||||
return [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": "asset1.png",
|
||||
"file": base64.b64encode(b"mock_content1").decode("utf-8"),
|
||||
},
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": "asset2.png",
|
||||
"file": base64.b64encode(b"mock_content2").decode("utf-8"),
|
||||
},
|
||||
]
|
||||
|
||||
def test_basic_import(self, authenticated_client, valid_template_data):
|
||||
url = "/reporting/templates/import/"
|
||||
response = authenticated_client.post(
|
||||
url, data={"template": json.dumps(valid_template_data)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert ReportTemplate.objects.filter(name="test_template").exists()
|
||||
|
||||
def test_import_with_base_template(
|
||||
self, authenticated_client, valid_template_data, valid_base_template_data
|
||||
):
|
||||
valid_template_data["base_template"] = valid_base_template_data
|
||||
url = "/reporting/templates/import/"
|
||||
response = authenticated_client.post(
|
||||
url, data={"template": json.dumps(valid_template_data)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert ReportHTMLTemplate.objects.filter(name="base_test_template").exists()
|
||||
assert ReportTemplate.objects.filter(name="test_template").exists()
|
||||
|
||||
def test_import_with_assets(
|
||||
self, authenticated_client, valid_template_data, valid_assets_data
|
||||
):
|
||||
valid_template_data["assets"] = valid_assets_data
|
||||
url = "/reporting/templates/import/"
|
||||
response = authenticated_client.post(
|
||||
url, data={"template": json.dumps(valid_template_data)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert ReportAsset.objects.filter(id=valid_assets_data[0]["id"]).exists()
|
||||
assert ReportAsset.objects.filter(id=valid_assets_data[1]["id"]).exists()
|
||||
|
||||
@patch(
|
||||
"ee.reporting.views.ImportReportTemplate._generate_random_string",
|
||||
return_value="_randomized",
|
||||
)
|
||||
def test_name_conflict_on_repeated_calls(
|
||||
self, generate_random_string_mock, authenticated_client, valid_template_data
|
||||
):
|
||||
url = "/reporting/templates/import/"
|
||||
response = authenticated_client.post(
|
||||
url, data={"template": json.dumps(valid_template_data)}
|
||||
)
|
||||
assert ReportTemplate.objects.filter(name="test_template").exists()
|
||||
|
||||
response = authenticated_client.post(
|
||||
url, data={"template": json.dumps(valid_template_data)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert ReportTemplate.objects.filter(name="test_template_randomized").exists()
|
||||
|
||||
def test_invalid_data(self, authenticated_client, valid_template_data):
|
||||
valid_template_data["template"].pop("name")
|
||||
url = "/reporting/templates/import/"
|
||||
response = authenticated_client.post(
|
||||
url, data={"template": json.dumps(valid_template_data)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "name" in response.data
|
||||
|
||||
def test_import_with_assets_with_conflicting_paths(
|
||||
self, authenticated_client, valid_template_data, valid_assets_data
|
||||
):
|
||||
conflicting_asset = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": valid_assets_data[0]["name"],
|
||||
"file": base64.b64encode(b"mock_content1").decode("utf-8"),
|
||||
}
|
||||
valid_assets_data.append(conflicting_asset)
|
||||
valid_template_data["assets"] = valid_assets_data
|
||||
url = "/reporting/templates/import/"
|
||||
response = authenticated_client.post(
|
||||
url, data={"template": json.dumps(valid_template_data)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert ReportAsset.objects.filter(id=valid_assets_data[0]["id"]).exists()
|
||||
assert ReportAsset.objects.filter(id=valid_assets_data[1]["id"]).exists()
|
||||
assert ReportAsset.objects.filter(id=conflicting_asset["id"]).exists()
|
||||
|
||||
# check if the renaming logic is working
|
||||
asset = ReportAsset.objects.get(id=conflicting_asset["id"])
|
||||
assert asset.file.name != valid_assets_data[0]["name"]
|
||||
|
||||
def test_unauthenticated_import_report_template_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post(f"/reporting/templates/import/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
17
api/tacticalrmm/ee/reporting/tests/test_mgmt_commands.py
Normal file
17
api/tacticalrmm/ee/reporting/tests/test_mgmt_commands.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
import os
|
||||
|
||||
from django.core import management
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestSchemaGeneration:
|
||||
def test_generate_json_schema(self, settings):
|
||||
management.call_command("generate_json_schemas")
|
||||
|
||||
schema_path = (
|
||||
f"{settings.STATICFILES_DIRS[0]}reporting/schemas/query_schema.json"
|
||||
)
|
||||
assert os.path.exists(schema_path)
|
||||
|
||||
os.remove(schema_path)
|
||||
534
api/tacticalrmm/ee/reporting/tests/test_report_asset_views.py
Normal file
534
api/tacticalrmm/ee/reporting/tests/test_report_asset_views.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import pytest
|
||||
import uuid
|
||||
import os
|
||||
from rest_framework.test import APIClient
|
||||
from unittest.mock import patch, mock_open
|
||||
from model_bakery import baker
|
||||
from rest_framework import status
|
||||
from django.core.exceptions import SuspiciousFileOperation
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from ..models import ReportAsset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client():
|
||||
client = APIClient()
|
||||
user = baker.make("accounts.User")
|
||||
client.force_authenticate(user=user)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client():
|
||||
return APIClient()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetReportAssets:
|
||||
@patch("ee.reporting.views.report_assets_fs")
|
||||
def test_valid_path_with_dir_and_files(
|
||||
self, mock_report_assets_fs, authenticated_client
|
||||
):
|
||||
# Set up the mock behavior for report_assets_fs
|
||||
mock_report_assets_fs.listdir.return_value = (["folder1"], ["file1.txt"])
|
||||
mock_report_assets_fs.size.return_value = 100
|
||||
mock_report_assets_fs.url.return_value = "/mocked/url/to/resource"
|
||||
|
||||
path = "some/valid/path"
|
||||
url = f"/reporting/assets/?path={path}"
|
||||
expected_response_data = [
|
||||
{
|
||||
"name": "folder1",
|
||||
"path": os.path.join(path, "folder1"),
|
||||
"type": "folder",
|
||||
"size": None,
|
||||
"url": "/mocked/url/to/resource",
|
||||
},
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": os.path.join(path, "file1.txt"),
|
||||
"type": "file",
|
||||
"size": "100",
|
||||
"url": "/mocked/url/to/resource",
|
||||
},
|
||||
]
|
||||
|
||||
response = authenticated_client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.data == expected_response_data
|
||||
|
||||
@patch("ee.reporting.views.report_assets_fs")
|
||||
def test_no_path(self, mock_report_assets_fs, authenticated_client):
|
||||
# Set up the mock behavior for report_assets_fs
|
||||
mock_report_assets_fs.listdir.return_value = (["folder1"], ["file1.txt"])
|
||||
mock_report_assets_fs.size.return_value = 100
|
||||
mock_report_assets_fs.url.return_value = "/mocked/url/to/resource"
|
||||
|
||||
url = "/reporting/assets/"
|
||||
expected_response_data = [
|
||||
{
|
||||
"name": "folder1",
|
||||
"path": "folder1",
|
||||
"type": "folder",
|
||||
"size": None,
|
||||
"url": "/mocked/url/to/resource",
|
||||
},
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "file1.txt",
|
||||
"type": "file",
|
||||
"size": "100",
|
||||
"url": "/mocked/url/to/resource",
|
||||
},
|
||||
]
|
||||
|
||||
response = authenticated_client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.data == expected_response_data
|
||||
|
||||
def test_invalid_path(self, authenticated_client):
|
||||
url = "/reporting/assets/?path=some/invalid/path"
|
||||
|
||||
response = authenticated_client.get(url)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_unauthenticated_get_report_assets_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/assets/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetAllAssets:
|
||||
@patch("ee.reporting.views.os.walk")
|
||||
def test_general_functionality(self, mock_os_walk, authenticated_client):
|
||||
mock_os_walk.return_value = iter(
|
||||
[(".", ["subdir"], ["file1.txt"]), ("./subdir", [], ["subdirfile.txt"])]
|
||||
)
|
||||
|
||||
asset1 = baker.make("reporting.ReportAsset", file="file1.txt")
|
||||
asset2 = baker.make("reporting.ReportAsset", file="subdir/subdirfile.txt")
|
||||
|
||||
expected_data = [
|
||||
{
|
||||
"type": "folder",
|
||||
"name": "Report Assets",
|
||||
"path": ".",
|
||||
"children": [
|
||||
{
|
||||
"type": "folder",
|
||||
"name": "subdir",
|
||||
"path": "subdir",
|
||||
"children": [
|
||||
{
|
||||
"id": asset2.id,
|
||||
"type": "file",
|
||||
"name": "subdirfile.txt",
|
||||
"path": "subdir/subdirfile.txt",
|
||||
"icon": "description",
|
||||
}
|
||||
],
|
||||
"selectable": False,
|
||||
"icon": "folder",
|
||||
"iconColor": "yellow-9",
|
||||
},
|
||||
{
|
||||
"id": asset1.id,
|
||||
"type": "file",
|
||||
"name": "file1.txt",
|
||||
"path": "file1.txt",
|
||||
"icon": "description",
|
||||
},
|
||||
],
|
||||
"selectable": False,
|
||||
"icon": "folder",
|
||||
"iconColor": "yellow-9",
|
||||
}
|
||||
]
|
||||
|
||||
response = authenticated_client.get("/reporting/assets/all/")
|
||||
assert response.status_code == 200
|
||||
assert expected_data == response.data
|
||||
|
||||
@patch("ee.reporting.views.os.chdir", side_effect=FileNotFoundError)
|
||||
def test_invalid_report_assets_dir(self, mock_os_walk, authenticated_client):
|
||||
response = authenticated_client.get("/reporting/assets/all/")
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@patch("ee.reporting.views.os.walk")
|
||||
def test_only_folders(self, mock_os_walk, authenticated_client):
|
||||
mock_os_walk.return_value = iter(
|
||||
[(".", ["subdir"], ["file1.txt"]), ("./subdir", [], ["subdirfile.txt"])]
|
||||
)
|
||||
|
||||
asset1 = baker.make("reporting.ReportAsset", file="file1.txt")
|
||||
asset2 = baker.make("reporting.ReportAsset", file="subdir/subdirfile.txt")
|
||||
|
||||
response = authenticated_client.get("/reporting/assets/all/?onlyFolders=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
for node in response.data:
|
||||
assert node["type"] != "file"
|
||||
|
||||
def test_unauthenticated_get_report_assets_all_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/assets/all/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestRenameAsset:
|
||||
def test_rename_file(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.rename",
|
||||
return_value="path/to/newname.txt",
|
||||
) as mock_rename, patch(
|
||||
"ee.reporting.views.report_assets_fs.isfile", return_value=True
|
||||
) as mock_isfile, patch(
|
||||
"ee.reporting.views.report_assets_fs.exists", return_value=True
|
||||
) as mock_exists:
|
||||
asset = baker.make("reporting.ReportAsset", file="path/to/file.txt")
|
||||
|
||||
response = authenticated_client.put(
|
||||
"/reporting/assets/rename/",
|
||||
data={"path": "path/to/file.txt", "newName": "newname.txt"},
|
||||
)
|
||||
|
||||
mock_rename.assert_called_with(
|
||||
path="path/to/file.txt", new_name="newname.txt"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.data == "path/to/newname.txt"
|
||||
|
||||
asset.refresh_from_db()
|
||||
assert asset.file.name == "path/to/newname.txt"
|
||||
|
||||
def test_rename_folder(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.rename",
|
||||
return_value="path/to/newfolder",
|
||||
) as mock_rename, patch(
|
||||
"ee.reporting.views.report_assets_fs.isfile", return_value=False
|
||||
) as mock_isfile, patch(
|
||||
"ee.reporting.views.report_assets_fs.exists", return_value=True
|
||||
) as mock_exists:
|
||||
response = authenticated_client.put(
|
||||
"/reporting/assets/rename/",
|
||||
data={"path": "path/to/folder", "newName": "newfolder"},
|
||||
)
|
||||
|
||||
mock_rename.assert_called_with(path="path/to/folder", new_name="newfolder")
|
||||
assert response.status_code == 200
|
||||
assert response.data == "path/to/newfolder"
|
||||
|
||||
def test_rename_non_existent_file(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.rename",
|
||||
side_effect=OSError("File not found"),
|
||||
) as mock_rename, patch(
|
||||
"ee.reporting.views.report_assets_fs.exists", return_value=True
|
||||
) as mock_exists:
|
||||
response = authenticated_client.put(
|
||||
"/reporting/assets/rename/",
|
||||
data={
|
||||
"path": "non_existent_path/to/file.txt",
|
||||
"newName": "newname.txt",
|
||||
},
|
||||
)
|
||||
|
||||
mock_rename.assert_called_with(
|
||||
path="non_existent_path/to/file.txt", new_name="newname.txt"
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_suspicious_operation(self, authenticated_client):
|
||||
response = authenticated_client.put(
|
||||
"/reporting/assets/rename/",
|
||||
data={"path": "../outside/path", "newName": "newname.txt"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_unauthenticated_rename_report_asset_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/assets/rename/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestNewFolder:
|
||||
def test_create_folder_success(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.createfolder",
|
||||
return_value="new/path/to/folder",
|
||||
) as mock_createfolder:
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/newfolder/", data={"path": "new/path/to/folder"}
|
||||
)
|
||||
|
||||
mock_createfolder.assert_called_with(path="new/path/to/folder")
|
||||
assert response.status_code == 200
|
||||
assert response.data == "new/path/to/folder"
|
||||
|
||||
def test_create_folder_missing_path(self, authenticated_client):
|
||||
response = authenticated_client.post("/reporting/assets/newfolder/")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_create_folder_os_error(self, authenticated_client):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/newfolder/", data={"path": "invalid/path"}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_create_folder_suspicious_operation(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.createfolder",
|
||||
side_effect=SuspiciousFileOperation("Invalid path"),
|
||||
) as mock_createfolder:
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/newfolder/", data={"path": "../outside/path"}
|
||||
)
|
||||
|
||||
mock_createfolder.assert_called_with(path="../outside/path")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_unauthenticated_new_folder_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/assets/newfolder/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestDeleteAssets:
|
||||
def test_delete_directory_success(self, authenticated_client):
|
||||
from ..settings import settings
|
||||
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir", return_value=True
|
||||
), patch("ee.reporting.views.shutil.rmtree") as mock_rmtree:
|
||||
asset1 = baker.make("reporting.ReportAsset", file="path/to/dir/file1.txt")
|
||||
asset2 = baker.make("reporting.ReportAsset", file="path/to/dir/file2.txt")
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/delete/", data={"paths": ["path/to/dir"]}
|
||||
)
|
||||
|
||||
mock_rmtree.assert_called_with(
|
||||
f"{settings.REPORTING_ASSETS_BASE_PATH}/path/to/dir"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# make sure the assets within the deleted folder also got deleted
|
||||
assert not ReportAsset.objects.filter(id=asset1.id).exists()
|
||||
assert not ReportAsset.objects.filter(id=asset2.id).exists()
|
||||
|
||||
def test_delete_file_success(self, authenticated_client):
|
||||
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=False):
|
||||
asset = baker.make("reporting.ReportAsset", file="path/to/file.txt")
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/delete/", data={"paths": ["path/to/file.txt"]}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert not ReportAsset.objects.filter(id=asset.id).exists()
|
||||
|
||||
def test_delete_os_error(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir", return_value=True
|
||||
), patch(
|
||||
"ee.reporting.views.shutil.rmtree", side_effect=OSError("Unable to delete")
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/delete/", data={"paths": ["invalid/path"]}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_delete_suspicious_operation(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir", return_value=True
|
||||
), patch(
|
||||
"ee.reporting.views.shutil.rmtree",
|
||||
side_effect=SuspiciousFileOperation("Invalid path"),
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/delete/", data={"paths": ["../outside/path"]}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_delete_asset_not_in_db_but_on_fs(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir", return_value=False
|
||||
), patch("ee.reporting.views.report_assets_fs.delete") as mock_fs_delete:
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/delete/", data={"paths": ["path/to/file.txt"]}
|
||||
)
|
||||
|
||||
mock_fs_delete.assert_called_with("path/to/file.txt")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_unauthenticated_delete_assets_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/assets/delete/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestUploadAssets:
|
||||
def test_upload_success(self, authenticated_client):
|
||||
mock_file = SimpleUploadedFile("test123.txt", b"file_content")
|
||||
|
||||
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=True):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/upload/",
|
||||
data={"parentPath": "path", "test123.txt": mock_file},
|
||||
)
|
||||
|
||||
assert response.data["test123.txt"]["filename"] == "path/test123.txt"
|
||||
assert response.status_code == 200
|
||||
|
||||
assert ReportAsset.objects.filter(file="path/test123.txt").exists()
|
||||
|
||||
# cleanup file so future tests don't break
|
||||
asset = ReportAsset.objects.get(file="path/test123.txt")
|
||||
asset.file.delete()
|
||||
asset.delete()
|
||||
|
||||
def test_upload_invalid_directory(self, authenticated_client):
|
||||
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=False):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/upload/",
|
||||
data={"parentPath": "invalid_path"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_upload_suspicious_operation(self, authenticated_client):
|
||||
mock_file = SimpleUploadedFile("test2.txt", b"file_content")
|
||||
|
||||
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=True):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/assets/upload/",
|
||||
data={"parentPath": "../path", "test2.txt": mock_file},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_unauthenticated_uploads_assets_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/assets/upload/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestAssetDownload:
|
||||
def test_download_file_success(self, authenticated_client):
|
||||
m = mock_open()
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir", return_value=False
|
||||
), patch("builtins.open", m), patch(
|
||||
"ee.reporting.views.report_assets_fs.path", return_value="path/test.txt"
|
||||
):
|
||||
response = authenticated_client.get(
|
||||
"/reporting/assets/download/", {"path": "test.txt"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_download_directory_success(self, authenticated_client):
|
||||
m = mock_open()
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir", return_value=True
|
||||
), patch(
|
||||
"ee.reporting.views.shutil.make_archive", return_value="path/test.zip"
|
||||
), patch(
|
||||
"builtins.open", m
|
||||
), patch(
|
||||
"ee.reporting.views.report_assets_fs.path", return_value="path/test"
|
||||
), patch(
|
||||
"ee.reporting.views.os.remove", return_value=None
|
||||
):
|
||||
response = authenticated_client.get(
|
||||
"/reporting/assets/download/", {"path": "test"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
m.assert_called_once_with("path/test.zip", "rb")
|
||||
|
||||
def test_download_invalid_path(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir",
|
||||
side_effect=OSError("Path does not exist"),
|
||||
):
|
||||
response = authenticated_client.get(
|
||||
"/reporting/assets/download/", {"path": "invalid_path"}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_download_os_error(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir",
|
||||
side_effect=OSError("Download failed"),
|
||||
):
|
||||
response = authenticated_client.get(
|
||||
"/reporting/assets/download/", {"path": "test.txt"}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_download_suspicious_operation(self, authenticated_client):
|
||||
with patch(
|
||||
"ee.reporting.views.report_assets_fs.isdir",
|
||||
side_effect=SuspiciousFileOperation("Suspicious path"),
|
||||
):
|
||||
response = authenticated_client.get(
|
||||
"/reporting/assets/download/", {"path": "test.txt"}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_unauthenticated_uploads_assets_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/assets/download/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestAssetNginxRedirect:
|
||||
@pytest.fixture
|
||||
def asset(self):
|
||||
return baker.make("reporting.ReportAsset", file="test_asset.txt")
|
||||
|
||||
def test_valid_uuid_and_path(self, unauthenticated_client, asset):
|
||||
response = unauthenticated_client.get(
|
||||
f"/reporting/assets/test_asset.txt?id={asset.id}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response["X-Accel-Redirect"] == "/assets/test_asset.txt"
|
||||
|
||||
def test_valid_uuid_wrong_path(self, unauthenticated_client, asset):
|
||||
response = unauthenticated_client.get(
|
||||
f"/reporting/assets/wrong_path.txt?id={asset.id}"
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_valid_uuid_no_asset(self, unauthenticated_client):
|
||||
non_existent_uuid = uuid.uuid4()
|
||||
url = f"/reporting/assets/test_asset.txt?id={non_existent_uuid}"
|
||||
response = unauthenticated_client.get(url)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_invalid_uuid(self, unauthenticated_client):
|
||||
response = unauthenticated_client.get(
|
||||
"/reporting/assets/test_asset.txt?id=invalid_uuid"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "There was a error processing the request" in response.content.decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
def test_no_id(self, unauthenticated_client):
|
||||
response = unauthenticated_client.get("/reporting/assets/test_asset.txt")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "There was a error processing the request" in response.content.decode(
|
||||
"utf-8"
|
||||
)
|
||||
368
api/tacticalrmm/ee/reporting/tests/test_report_template_views.py
Normal file
368
api/tacticalrmm/ee/reporting/tests/test_report_template_views.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import pytest
|
||||
|
||||
from rest_framework.test import APIClient
|
||||
from unittest.mock import patch
|
||||
from model_bakery import baker
|
||||
from rest_framework import status
|
||||
from jinja2.exceptions import TemplateError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client():
|
||||
client = APIClient()
|
||||
user = baker.make("accounts.User")
|
||||
client.force_authenticate(user=user)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client():
|
||||
return APIClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_template():
|
||||
return baker.make("reporting.ReportTemplate")
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestReportTemplateViews:
|
||||
def test_get_all_report_templates_empty_db(self, authenticated_client):
|
||||
response = authenticated_client.get("/reporting/templates/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.data) == 0
|
||||
|
||||
def test_get_all_report_templates(self, authenticated_client):
|
||||
# Create some sample ReportTemplates using model_bakery
|
||||
baker.make("reporting.ReportTemplate", _quantity=5)
|
||||
response = authenticated_client.get("/reporting/templates/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.data) == 5
|
||||
|
||||
def test_get_report_templates_with_filter(self, authenticated_client):
|
||||
# Create templates with specific dependencies
|
||||
baker.make("reporting.ReportTemplate", depends_on=["agent"], _quantity=3)
|
||||
baker.make("reporting.ReportTemplate", depends_on=["client"], _quantity=2)
|
||||
|
||||
response = authenticated_client.get(
|
||||
"/reporting/templates/", {"dependsOn[]": ["agent"]}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.data) == 3
|
||||
|
||||
def test_post_report_template_valid_data(self, authenticated_client):
|
||||
valid_data = {"name": "Test Template", "template_md": "Template Text"}
|
||||
response = authenticated_client.post("/reporting/templates/", valid_data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["name"] == "Test Template"
|
||||
|
||||
def test_post_report_template_invalid_data(self, authenticated_client):
|
||||
invalid_data = {"name": "Test Template"}
|
||||
response = authenticated_client.post("/reporting/templates/", invalid_data)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_get_report_template(self, authenticated_client, report_template):
|
||||
url = f"/reporting/templates/{report_template.pk}/"
|
||||
response = authenticated_client.get(url)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["id"] == report_template.id
|
||||
|
||||
def test_edit_report_template(self, authenticated_client, report_template):
|
||||
url = f"/reporting/templates/{report_template.pk}/"
|
||||
updated_name = "Updated name"
|
||||
|
||||
response = authenticated_client.put(url, {"name": updated_name}, format="json")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
report_template.refresh_from_db()
|
||||
assert report_template.name == updated_name
|
||||
|
||||
def test_delete_report_template(self, authenticated_client, report_template):
|
||||
url = f"/reporting/templates/{report_template.pk}/"
|
||||
response = authenticated_client.delete(url)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
with pytest.raises(report_template.DoesNotExist):
|
||||
report_template.refresh_from_db()
|
||||
|
||||
# test unauthorized access
|
||||
def test_unauthorized_get_report_templates_view(self, unauthenticated_client):
|
||||
url = f"/reporting/templates/"
|
||||
response = unauthenticated_client.get(url)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthorized_add_report_template_view(self, unauthenticated_client):
|
||||
url = f"/reporting/templates/"
|
||||
response = unauthenticated_client.post(url)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthorized_get_report_template_view(
|
||||
self, unauthenticated_client, report_template
|
||||
):
|
||||
url = f"/reporting/templates/{report_template.pk}/"
|
||||
response = unauthenticated_client.get(url)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthorized_edit_report_template_view(
|
||||
self, unauthenticated_client, report_template
|
||||
):
|
||||
url = f"/reporting/templates/{report_template.pk}/"
|
||||
updated_name = "Updated name"
|
||||
|
||||
response = unauthenticated_client.put(
|
||||
url, {"name": updated_name}, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_unauthorized_delete_report_template_view(
|
||||
self, unauthenticated_client, report_template
|
||||
):
|
||||
url = f"/reporting/templates/{report_template.pk}/"
|
||||
response = unauthenticated_client.delete(url)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestReportTemplateGenerateView:
|
||||
def test_generate_html_report(self, authenticated_client, report_template):
|
||||
data = {"format": "html", "dependencies": {}}
|
||||
response = authenticated_client.post(
|
||||
f"/reporting/templates/{report_template.id}/run/", data=data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert report_template.template_md in response.data
|
||||
|
||||
def test_generate_pdf_report(self, authenticated_client, report_template):
|
||||
data = {"format": "pdf", "dependencies": {}}
|
||||
response = authenticated_client.post(
|
||||
f"/reporting/templates/{report_template.id}/run/", data=data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response["content-type"] == "application/pdf"
|
||||
|
||||
def test_generate_invalid_format_report(
|
||||
self, authenticated_client, report_template
|
||||
):
|
||||
data = {"format": "invalid_format", "dependencies": {}}
|
||||
response = authenticated_client.post(
|
||||
f"/reporting/templates/{report_template.id}/run/", data=data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_generate_report_template_error(self, authenticated_client):
|
||||
template = baker.make("reporting.ReportTemplate", template_md="{{invalid}")
|
||||
data = {"format": "html", "dependencies": {}}
|
||||
response = authenticated_client.post(
|
||||
f"/reporting/templates/{template.id}/run/", data=data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_generate_report_with_dependencies(
|
||||
self, authenticated_client, report_template
|
||||
):
|
||||
sample_html = "<html><body>Sample Report</body></html>"
|
||||
data = {"format": "html", "dependencies": {"client": 1}}
|
||||
|
||||
with patch(
|
||||
"ee.reporting.views.generate_html", return_value=(sample_html, None)
|
||||
) as mock_generate_html:
|
||||
url = f"/reporting/templates/{report_template.id}/run/"
|
||||
response = authenticated_client.post(url, data, format="json")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == sample_html
|
||||
|
||||
mock_generate_html.assert_called_with(
|
||||
template=report_template.template_md,
|
||||
template_type=report_template.type,
|
||||
css=report_template.template_css if report_template.template_css else "",
|
||||
html_template=report_template.template_html.id
|
||||
if report_template.template_html
|
||||
else None,
|
||||
variables=report_template.template_variables,
|
||||
dependencies={"client": 1},
|
||||
)
|
||||
|
||||
def test_unauthenticated_generate_report_view(
|
||||
self, unauthenticated_client, report_template
|
||||
):
|
||||
response = unauthenticated_client.post(
|
||||
f"/reporting/templates/{report_template.id}/run/"
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGenerateReportPreview:
|
||||
def test_generate_report_preview_html_format(self, authenticated_client):
|
||||
data = {
|
||||
"template_md": "some template md",
|
||||
"type": "some type",
|
||||
"template_css": "some css",
|
||||
"template_variables": {},
|
||||
"dependencies": {},
|
||||
"format": "html",
|
||||
"debug": False,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ee.reporting.views.generate_html", return_value=("<html></html>", {})
|
||||
) as mock_generate_html:
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/", data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == "<html></html>"
|
||||
mock_generate_html.assert_called()
|
||||
|
||||
@patch("ee.reporting.views.generate_html", return_value=("<html></html>", {}))
|
||||
@patch("ee.reporting.views.generate_pdf", return_value=b"some_pdf_bytes")
|
||||
def test_generate_report_preview_pdf_format(
|
||||
self, mock_generate_html, mock_generate_pdf, authenticated_client
|
||||
):
|
||||
data = {
|
||||
"template_md": "some template md",
|
||||
"type": "some type",
|
||||
"template_css": "some css",
|
||||
"template_variables": {},
|
||||
"dependencies": {},
|
||||
"format": "pdf",
|
||||
"debug": False,
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/", data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response["Content-Type"] == "application/pdf"
|
||||
mock_generate_html.assert_called()
|
||||
mock_generate_pdf.assert_called()
|
||||
|
||||
def test_generate_report_preview_debug(self, authenticated_client):
|
||||
data = {
|
||||
"template_md": "some template md",
|
||||
"type": "markdown",
|
||||
"template_css": "some css",
|
||||
"template_variables": {},
|
||||
"dependencies": {},
|
||||
"format": "html",
|
||||
"debug": True,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ee.reporting.views.generate_html",
|
||||
return_value=("<html></html>", {"agent": baker.prepare("agents.Agent")}),
|
||||
) as mock_generate_html:
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/", data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert "template" in response.data
|
||||
assert "variables" in response.data
|
||||
mock_generate_html.assert_called()
|
||||
|
||||
def test_generate_report_preview_invalid_data(self, authenticated_client):
|
||||
data = {
|
||||
"template_md": "some template md",
|
||||
# Missing 'type'
|
||||
"template_css": "some css",
|
||||
"template_variables": {},
|
||||
"dependencies": {},
|
||||
"format": "invalid_format",
|
||||
"debug": True,
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/", data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_generate_report_preview_template_error(self, authenticated_client):
|
||||
data = {
|
||||
"template_md": "some template md",
|
||||
"type": "some type",
|
||||
"template_css": "some css",
|
||||
"template_variables": {},
|
||||
"dependencies": {},
|
||||
"format": "html",
|
||||
"debug": True,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ee.reporting.views.generate_html",
|
||||
side_effect=TemplateError("Some template error"),
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/", data, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "Some template error" in response.data
|
||||
|
||||
def test_unauthenticated_generate_report_preview_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/templates/preview/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetAllowedValues:
|
||||
def test_valid_input(self, authenticated_client):
|
||||
data = {
|
||||
"variables": {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"roles": ["admin", "user"],
|
||||
},
|
||||
"count": 1,
|
||||
},
|
||||
"dependencies": {},
|
||||
}
|
||||
|
||||
expected_response_data = {
|
||||
"user": "Object",
|
||||
"user.name": "str",
|
||||
"user.roles": "Array (2 Results)",
|
||||
"user.roles[0]": "str",
|
||||
"count": "int",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ee.reporting.views.prep_variables_for_template",
|
||||
return_value=data["variables"],
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/analysis/", data, format="json"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.data == expected_response_data
|
||||
|
||||
def test_empty_variables(self, authenticated_client):
|
||||
data = {"variables": {}, "dependencies": {}}
|
||||
with patch(
|
||||
"ee.reporting.views.prep_variables_for_template",
|
||||
return_value=data["variables"],
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/analysis/", data, format="json"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.data == None
|
||||
|
||||
def test_invalid_input(self, authenticated_client):
|
||||
data = {"invalidKey": {}}
|
||||
|
||||
response = authenticated_client.post(
|
||||
"/reporting/templates/preview/analysis/", data, format="json"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_unauthenticated_variable_analysis_view(self, unauthenticated_client):
|
||||
response = unauthenticated_client.post("/reporting/templates/preview/analysis/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -257,4 +257,12 @@ def test_move_folder_with_name_conflict(tmp_path: Path) -> None:
|
||||
|
||||
|
||||
def test_move_directory_traversal(tmp_path: Path) -> None:
|
||||
assert False
|
||||
storage = ReportAssetStorage(location=tmp_path)
|
||||
|
||||
with pytest.raises(SuspiciousFileOperation):
|
||||
# relative
|
||||
storage.move(source="../../file", destination="../..")
|
||||
|
||||
with pytest.raises(SuspiciousFileOperation):
|
||||
# absolute
|
||||
storage.move(source="/etc", destination="/newpath")
|
||||
|
||||
107
api/tacticalrmm/ee/reporting/tests/test_template_generation.py
Normal file
107
api/tacticalrmm/ee/reporting/tests/test_template_generation.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
from model_bakery import baker
|
||||
|
||||
from ..utils import generate_html, db_template_loader
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGenerateHTML:
|
||||
@pytest.fixture
|
||||
def base_template(self):
|
||||
html = "<html>{% block content %}{% endblock %}</html>"
|
||||
return baker.make(
|
||||
"reporting.ReportHTMLTemplate", name="Base Template", html=html
|
||||
)
|
||||
|
||||
def test_markdown_conversion(self):
|
||||
template = "# This is a header"
|
||||
result, _ = generate_html(
|
||||
template=template, template_type="markdown", css="", variables=""
|
||||
)
|
||||
assert "<h1>This is a header</h1>" in result
|
||||
|
||||
def test_html_unchanged(self):
|
||||
template = "<h1>This is a header</h1>"
|
||||
result, _ = generate_html(
|
||||
template=template, template_type="html", css="", variables=""
|
||||
)
|
||||
assert "<h1>This is a header</h1>" in result
|
||||
|
||||
def test_html_template_exists(self, base_template):
|
||||
template = "{% block content %}<h1>This is a header</h1>{% endblock %}"
|
||||
result, _ = generate_html(
|
||||
template=template,
|
||||
template_type="html",
|
||||
html_template=base_template.id,
|
||||
css="",
|
||||
variables="",
|
||||
)
|
||||
|
||||
assert "<html><h1>This is a header</h1></html>" == result
|
||||
|
||||
def test_html_template_does_not_exist(self):
|
||||
template = "<h1>This is a header</h1>"
|
||||
# check it doesn't raise an error.
|
||||
generate_html(
|
||||
template=template,
|
||||
template_type="html",
|
||||
html_template=999,
|
||||
css="",
|
||||
variables="",
|
||||
)
|
||||
|
||||
def test_variables_processing(self):
|
||||
template = "Hello {{ name }}"
|
||||
variables = "name: John"
|
||||
result, _ = generate_html(
|
||||
template=template, template_type="html", css="", variables=variables
|
||||
)
|
||||
|
||||
assert "Hello John" in result
|
||||
|
||||
def test_css_incorporation(self):
|
||||
template = "<html><head><style>{{css}}</style></head></html>"
|
||||
css = ".my-class { color: red; }"
|
||||
result, _ = generate_html(
|
||||
template=template, template_type="html", css=css, variables=""
|
||||
)
|
||||
assert css in result
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestJinjaDBLoader:
|
||||
@pytest.fixture
|
||||
def report_base_template(self):
|
||||
return baker.make(
|
||||
"reporting.ReportHTMLTemplate", name="test_html_template", html="Test HTML"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def report_template(self):
|
||||
return baker.make(
|
||||
"reporting.ReportTemplate", name="test_md_template", template_md="Test MD"
|
||||
)
|
||||
|
||||
def test_load_base_template(self, report_base_template):
|
||||
result = db_template_loader(report_base_template.name)
|
||||
assert result == "Test HTML"
|
||||
|
||||
def test_fallback_to_md_template(self, report_template):
|
||||
result = db_template_loader(report_template.name)
|
||||
assert result == "Test MD"
|
||||
|
||||
def test_no_template_found(self):
|
||||
# Will return None
|
||||
result = db_template_loader("nonexistent_template")
|
||||
assert result is None
|
||||
|
||||
def test_html_template_priority(self):
|
||||
# Create both a ReportHTMLTemplate and a ReportTemplate with the same name
|
||||
template_name = "common_template"
|
||||
baker.make("reporting.ReportHTMLTemplate", name=template_name, html="Test HTML")
|
||||
baker.make(
|
||||
"reporting.ReportTemplate", name=template_name, template_md="Test MD"
|
||||
)
|
||||
|
||||
result = db_template_loader(template_name)
|
||||
assert result == "Test HTML" # HTML has priority
|
||||
318
api/tacticalrmm/ee/reporting/tests/test_template_variables.py
Normal file
318
api/tacticalrmm/ee/reporting/tests/test_template_variables.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from model_bakery import baker
|
||||
from ..utils import (
|
||||
process_dependencies,
|
||||
process_data_sources,
|
||||
process_chart_variables,
|
||||
prep_variables_for_template,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProcessingDependencies:
|
||||
# Fixtures for creating model instances
|
||||
@pytest.fixture
|
||||
def test_client(db):
|
||||
return baker.make("clients.Client", name="Test Client Name")
|
||||
|
||||
@pytest.fixture
|
||||
def test_site(db):
|
||||
return baker.make("clients.Site", name="Test Site Name")
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent(db):
|
||||
return baker.make(
|
||||
"agents.Agent", agent_id="Test Agent ID", hostname="Test Agent Name"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def test_global(db):
|
||||
return baker.make(
|
||||
"core.GlobalKVStore", name="some_global_value", value="Some Global Value"
|
||||
)
|
||||
|
||||
def test_replace_with_client_db_value(self, test_client):
|
||||
variables = """
|
||||
some_field:
|
||||
client_field: {{client.name}}
|
||||
"""
|
||||
dependencies = {"client": test_client.id}
|
||||
result = process_dependencies(variables=variables, dependencies=dependencies)
|
||||
assert result["some_field"]["client_field"] == test_client.name
|
||||
|
||||
def test_replace_with_site_db_value(self, test_site):
|
||||
variables = """
|
||||
some_field:
|
||||
site_field: {{site.name}}
|
||||
"""
|
||||
dependencies = {"site": test_site.id}
|
||||
result = process_dependencies(variables=variables, dependencies=dependencies)
|
||||
assert result["some_field"]["site_field"] == test_site.name
|
||||
|
||||
def test_replace_with_agent_db_value(self, test_agent):
|
||||
variables = """
|
||||
some_field:
|
||||
agent_field: {{agent.hostname}}
|
||||
"""
|
||||
dependencies = {"agent": test_agent.agent_id}
|
||||
result = process_dependencies(variables=variables, dependencies=dependencies)
|
||||
assert result["some_field"]["agent_field"] == test_agent.hostname
|
||||
|
||||
def test_replace_with_global_value(self, test_global):
|
||||
variables = """
|
||||
some_field:
|
||||
global_field: {{global.some_global_value}}
|
||||
"""
|
||||
result = process_dependencies(variables=variables, dependencies={})
|
||||
# Assuming you have a global value with key 'some_global_value' set to 'Some Global Value'
|
||||
assert result["some_field"]["global_field"] == test_global.value
|
||||
|
||||
def test_replace_with_non_db_dependencies(self):
|
||||
variables = """
|
||||
some_field:
|
||||
dependency_field: {{some_dependency}}
|
||||
"""
|
||||
dependencies = {"some_dependency": "Some Value"}
|
||||
result = process_dependencies(variables=variables, dependencies=dependencies)
|
||||
assert result["some_field"]["dependency_field"] == "Some Value"
|
||||
|
||||
def test_missing_non_db_dependencies(self):
|
||||
variables = """
|
||||
some_field:
|
||||
missing_dependency: "{{missing_dependency}}"
|
||||
"""
|
||||
dependencies = {} # Empty dependencies, simulating a missing dependency
|
||||
result = process_dependencies(variables=variables, dependencies=dependencies)
|
||||
assert result["some_field"]["missing_dependency"] == "{{missing_dependency}}"
|
||||
|
||||
def test_variables_dependencies_merge(self, test_agent):
|
||||
variables = """
|
||||
some_field:
|
||||
agent_field: {{agent.agent_id}}
|
||||
dependency_field: {{some_dependency}}
|
||||
"""
|
||||
dependencies = {"agent": test_agent.agent_id, "some_dependency": "Some Value"}
|
||||
result = process_dependencies(variables=variables, dependencies=dependencies)
|
||||
assert result["some_field"]["agent_field"] == "Test Agent ID"
|
||||
assert result["some_field"]["dependency_field"] == "Some Value"
|
||||
# Additionally, assert the merged structure has both processed variables and dependencies
|
||||
assert "agent" in result
|
||||
assert isinstance(result["agent"], type(test_agent))
|
||||
assert "some_dependency" in result
|
||||
assert result["some_dependency"] == "Some Value"
|
||||
|
||||
def test_multiple_replacements(self, test_agent, test_client, test_site):
|
||||
variables = """
|
||||
fields:
|
||||
agent_name: {{agent.hostname}}
|
||||
client_name: {{client.name}}
|
||||
site_name: {{site.name}}
|
||||
dependency_1: {{dep_1}}
|
||||
dependency_2: {{dep_2}}
|
||||
"""
|
||||
|
||||
dependencies = {
|
||||
"agent": test_agent.agent_id,
|
||||
"client": test_client.id,
|
||||
"site": test_site.id,
|
||||
"dep_1": "Dependency Value 1",
|
||||
"dep_2": "Dependency Value 2",
|
||||
}
|
||||
|
||||
result = process_dependencies(variables=variables, dependencies=dependencies)
|
||||
|
||||
assert result["fields"]["agent_name"] == test_agent.hostname
|
||||
assert result["fields"]["client_name"] == test_client.name
|
||||
assert result["fields"]["site_name"] == test_site.name
|
||||
assert result["fields"]["dependency_1"] == "Dependency Value 1"
|
||||
assert result["fields"]["dependency_2"] == "Dependency Value 2"
|
||||
|
||||
# Also verify that the non-replaced fields from dependencies are present in the result.
|
||||
assert "agent" in result
|
||||
assert "client" in result
|
||||
assert "site" in result
|
||||
assert result["dep_1"] == "Dependency Value 1"
|
||||
assert result["dep_2"] == "Dependency Value 2"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProcessDataSourceVariables:
|
||||
def test_process_data_sources_without_data_sources(self):
|
||||
variables = {"other_key": "some_value"}
|
||||
result = process_data_sources(variables=variables)
|
||||
assert result == variables
|
||||
|
||||
def test_process_data_sources_with_non_dict_data_sources(self):
|
||||
variables = {"data_sources": "some_string_value"}
|
||||
result = process_data_sources(variables=variables)
|
||||
assert result == variables
|
||||
|
||||
def test_process_data_sources_with_dict_data_sources(self):
|
||||
variables = {
|
||||
"data_sources": {
|
||||
"source1": {"model": "agent", "other_field": "value"},
|
||||
"source2": "some_string_value",
|
||||
}
|
||||
}
|
||||
|
||||
mock_queryset = {"data": "sample_data"}
|
||||
|
||||
# Mock build_queryset to return the mock_queryset
|
||||
with patch("ee.reporting.utils.build_queryset", return_value=mock_queryset):
|
||||
result = process_data_sources(variables=variables)
|
||||
|
||||
# Assert that the data_sources for "source1" is replaced with mock_queryset
|
||||
assert result["data_sources"]["source1"] == mock_queryset
|
||||
|
||||
# Assert that the "source2" data remains unchanged
|
||||
assert result["data_sources"]["source2"] == "some_string_value"
|
||||
|
||||
def test_process_data_sources_with_non_dict_data_sources(self):
|
||||
variables = {
|
||||
"data_sources": {
|
||||
"source1": {"model": "agent", "other_field": "value"},
|
||||
"source2": "some_string_value",
|
||||
}
|
||||
}
|
||||
|
||||
mock_queryset = 5
|
||||
|
||||
# Mock build_queryset to return the mock_queryset
|
||||
with patch("ee.reporting.utils.build_queryset", return_value=mock_queryset):
|
||||
result = process_data_sources(variables=variables)
|
||||
|
||||
# Assert that the data_sources for "source1" is replaced with mock_queryset
|
||||
assert result["data_sources"]["source1"] == mock_queryset
|
||||
|
||||
# Assert that the "source2" data remains unchanged
|
||||
assert result["data_sources"]["source2"] == "some_string_value"
|
||||
|
||||
|
||||
class TestProcessChartVariables:
|
||||
def test_process_chart_no_replace_data_frame(self):
|
||||
# Scenario where path doesn't exist in variables
|
||||
variables = {
|
||||
"charts": {
|
||||
"chart1": {
|
||||
"chartType": "type1",
|
||||
"outputType": "html",
|
||||
"options": {"data_frame": "path.to.nonexistent"},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert process_chart_variables(variables=variables) == variables
|
||||
|
||||
def test_process_chart_generate_chart_invocation(self):
|
||||
# Ensure generate_chart is invoked with expected parameters
|
||||
variables = {
|
||||
"charts": {
|
||||
"chart1": {
|
||||
"chartType": "type1",
|
||||
"outputType": "html",
|
||||
"options": {},
|
||||
"traces": "Some Traces",
|
||||
"layout": "Some Layout",
|
||||
}
|
||||
}
|
||||
}
|
||||
with patch("ee.reporting.utils.generate_chart") as mock_generate_chart:
|
||||
mock_generate_chart.return_value = "<html>Chart Here</html>"
|
||||
_ = process_chart_variables(variables=variables)
|
||||
|
||||
mock_generate_chart.assert_called_once_with(
|
||||
type="type1",
|
||||
format="html",
|
||||
options={},
|
||||
traces="Some Traces",
|
||||
layout="Some Layout",
|
||||
)
|
||||
|
||||
def test_process_chart_missing_keys(self):
|
||||
# Scenario where necessary keys are missing
|
||||
variables = {"charts": {"chart1": {}}}
|
||||
|
||||
result = process_chart_variables(variables=variables)
|
||||
assert result == variables # Expecting unchanged variables
|
||||
|
||||
def test_process_chart_no_charts(self):
|
||||
# Scenario with no charts key or charts not a dict
|
||||
variables1 = {}
|
||||
variables2 = {"charts": "Not a dict"}
|
||||
|
||||
assert process_chart_variables(variables=variables1) == variables1
|
||||
assert process_chart_variables(variables=variables2) == variables2
|
||||
|
||||
def test_process_chart_replaces_data_frame(self):
|
||||
# Sample input
|
||||
variables = {
|
||||
"charts": {
|
||||
"myChart": {
|
||||
"chartType": "bar",
|
||||
"outputType": "html",
|
||||
"options": {"data_frame": "data_sources.sample_data"},
|
||||
}
|
||||
},
|
||||
"data_sources": {"sample_data": [{"x": 1, "y": 2}, {"x": 2, "y": 3}]},
|
||||
}
|
||||
|
||||
with patch("ee.reporting.utils.generate_chart") as mock_generate_chart:
|
||||
mock_generate_chart.return_value = "<html>Chart Here</html>"
|
||||
|
||||
result = process_chart_variables(variables=variables)
|
||||
|
||||
# Check if the generate_chart function was called correctly
|
||||
mock_generate_chart.assert_called_once_with(
|
||||
type="bar",
|
||||
format="html",
|
||||
options={"data_frame": [{"x": 1, "y": 2}, {"x": 2, "y": 3}]},
|
||||
traces=None,
|
||||
layout=None,
|
||||
)
|
||||
|
||||
# Check if the returned data has the chart in place
|
||||
assert "<html>Chart Here</html>" in result["charts"]["myChart"]
|
||||
|
||||
|
||||
class TestPrepVariablesFunction:
|
||||
def test_prep_variables_base(self):
|
||||
result = prep_variables_for_template(variables="")
|
||||
assert isinstance(result, dict)
|
||||
assert not result
|
||||
|
||||
def test_prep_variables_with_dataqueries(self):
|
||||
with patch(
|
||||
"ee.reporting.utils.make_dataqueries_inline", return_value="test_yaml: true"
|
||||
):
|
||||
result = prep_variables_for_template(variables="dataquery: test")
|
||||
assert result == {"test_yaml": True}
|
||||
|
||||
def test_prep_variables_with_dependencies(self):
|
||||
with patch(
|
||||
"ee.reporting.utils.process_dependencies",
|
||||
return_value={"dependency_key": "dependency_value"},
|
||||
):
|
||||
result = prep_variables_for_template(
|
||||
variables="", dependencies={"some_dependency": "value"}
|
||||
)
|
||||
assert "dependency_key" in result
|
||||
assert result["dependency_key"] == "dependency_value"
|
||||
|
||||
def test_prep_variables_with_data_sources(self):
|
||||
with patch(
|
||||
"ee.reporting.utils.process_data_sources",
|
||||
return_value={"data_source_key": "data_value"},
|
||||
):
|
||||
result = prep_variables_for_template(variables="data_sources: some_data")
|
||||
assert "data_source_key" in result
|
||||
assert result["data_source_key"] == "data_value"
|
||||
|
||||
def test_prep_variables_with_charts(self):
|
||||
with patch(
|
||||
"ee.reporting.utils.process_chart_variables",
|
||||
return_value={"chart_key": "chart_value"},
|
||||
):
|
||||
result = prep_variables_for_template(variables="charts: some_chart")
|
||||
assert "chart_key" in result
|
||||
assert result["chart_key"] == "chart_value"
|
||||
108
api/tacticalrmm/ee/reporting/tests/test_util_functions.py
Normal file
108
api/tacticalrmm/ee/reporting/tests/test_util_functions.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
from model_bakery import baker
|
||||
from ..utils import (
|
||||
normalize_asset_url,
|
||||
base64_encode_assets,
|
||||
decode_base64_asset,
|
||||
)
|
||||
import base64
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestBase64EncodeDecodeAssets:
|
||||
def test_base64_encode_assets_multiple_valid_urls(self):
|
||||
asset1 = baker.make("reporting.ReportAsset", _create_files=True)
|
||||
asset2 = baker.make("reporting.ReportAsset", _create_files=True)
|
||||
template = f"Some content with link asset://{asset1.id} and another link asset://{asset2.id}"
|
||||
|
||||
result = base64_encode_assets(template)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# check asset 1
|
||||
assert result[0]["id"] == asset1.id
|
||||
# Checking if the file content is correctly encoded to base64
|
||||
encoded_content = base64.b64encode(asset1.file.file.read()).decode("utf-8")
|
||||
assert result[0]["file"] == encoded_content
|
||||
|
||||
# check asset 2
|
||||
assert result[1]["id"] == asset2.id
|
||||
# Checking if the file content is correctly encoded to base64
|
||||
encoded_content = base64.b64encode(asset2.file.file.read()).decode("utf-8")
|
||||
assert result[1]["file"] == encoded_content
|
||||
|
||||
def test_base64_encode_assets_some_invalid_urls(self):
|
||||
asset = baker.make("reporting.ReportAsset", _create_files=True)
|
||||
invalid_id = "11111111-1111-1111-1111-111111111111"
|
||||
template = f"Some content with link asset://{asset.id} and invalid link asset://{invalid_id}"
|
||||
|
||||
result = base64_encode_assets(template)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == asset.id
|
||||
|
||||
def test_base64_encode_assets_no_urls(self):
|
||||
template = "Some content with no assets"
|
||||
result = base64_encode_assets(template)
|
||||
assert result == []
|
||||
|
||||
def test_base64_encode_assets_duplicate_urls(self):
|
||||
asset = baker.make("reporting.ReportAsset", _create_files=True)
|
||||
template = f"Some content with link asset://{asset.id} and another link asset://{asset.id}"
|
||||
|
||||
result = base64_encode_assets(template)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == asset.id
|
||||
|
||||
def test_decode_base64_asset_valid_input(self):
|
||||
original_data = b"Hello, world!"
|
||||
encoded_data = base64.b64encode(original_data).decode("utf-8")
|
||||
|
||||
result = decode_base64_asset(encoded_data)
|
||||
|
||||
assert result == original_data
|
||||
|
||||
def test_decode_base64_asset_invalid_input(self):
|
||||
invalid_data = "Not a base64 encoded string."
|
||||
|
||||
with pytest.raises(base64.binascii.Error):
|
||||
decode_base64_asset(invalid_data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestNormalizeAssets:
|
||||
def test_normalize_asset_url_valid_html_type(self):
|
||||
asset = baker.make("reporting.ReportAsset", _create_files=True)
|
||||
text = f"Some content with link asset://{asset.id} and more content"
|
||||
|
||||
result = normalize_asset_url(text, "html")
|
||||
|
||||
assert f"{asset.file.url}?id={asset.id}" in result
|
||||
assert f"asset://{asset.id}" not in result
|
||||
|
||||
def test_normalize_asset_url_valid_pdf_type(self):
|
||||
asset = baker.make("reporting.ReportAsset", _create_files=True)
|
||||
text = f"Some content with link asset://{asset.id} and more content"
|
||||
|
||||
result = normalize_asset_url(text, "pdf")
|
||||
|
||||
assert f"file://{asset.file.path}" in result
|
||||
assert f"asset://{asset.id}" not in result
|
||||
|
||||
def test_normalize_asset_url_invalid_asset(self):
|
||||
invalid_id = "11111111-1111-1111-1111-111111111111" # UUID that's not in the DB
|
||||
text = f"Some content with link asset://{invalid_id} and more content"
|
||||
|
||||
result = normalize_asset_url(text, "html")
|
||||
|
||||
# Since the asset doesn't exist, the URL should remain unchanged
|
||||
assert f"asset://{invalid_id}" in result
|
||||
|
||||
def test_normalize_asset_url_no_asset(self):
|
||||
text = "Some content with no assets"
|
||||
|
||||
result = normalize_asset_url(text, "html")
|
||||
|
||||
# The text remains unchanged
|
||||
assert text == result
|
||||
@@ -18,7 +18,7 @@ from weasyprint.text.fonts import FontConfiguration
|
||||
|
||||
from .constants import REPORTING_MODELS
|
||||
from .markdown.config import Markdown
|
||||
from .models import ReportAsset, ReportHTMLTemplate, ReportTemplate
|
||||
from .models import ReportAsset, ReportHTMLTemplate, ReportTemplate, ReportDataQuery
|
||||
|
||||
|
||||
# regex for db data replacement
|
||||
@@ -30,6 +30,8 @@ RE_ASSET_URL = re.compile(
|
||||
r"(asset://([0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}))"
|
||||
)
|
||||
|
||||
RE_DEPENDENCY_VALUE = re.compile(r"(\{\{\s*(.*)\s*\}\})")
|
||||
|
||||
|
||||
# this will lookup the Jinja parent template in the DB
|
||||
# Example: {% extends "MASTER_TEMPLATE_NAME or REPORT_TEMPLATE_NAME" %}
|
||||
@@ -43,7 +45,7 @@ def db_template_loader(template_name: str) -> Optional[str]:
|
||||
try:
|
||||
template = ReportTemplate.objects.get(name=template_name)
|
||||
return template.template_md
|
||||
except ReportHTMLTemplate.DoesNotExist:
|
||||
except ReportTemplate.DoesNotExist:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -75,12 +77,15 @@ def generate_html(
|
||||
css: str = "",
|
||||
html_template: Optional[int] = None,
|
||||
variables: str = "",
|
||||
dependencies: Dict[str, int] = {},
|
||||
) -> Tuple[str, Optional[Dict[str, Any]]]:
|
||||
# validate the template before doing anything. This will throw a TemplateError exception
|
||||
dependencies: Optional[Dict[str, int]] = None,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
if dependencies is None:
|
||||
dependencies = {}
|
||||
|
||||
# validate the template
|
||||
env.parse(template)
|
||||
|
||||
# convert template from markdown to html if type is markdown
|
||||
# convert template
|
||||
template_string = (
|
||||
Markdown.convert(template) if template_type == "markdown" else template
|
||||
)
|
||||
@@ -89,7 +94,6 @@ def generate_html(
|
||||
if html_template:
|
||||
try:
|
||||
html_template_name = ReportHTMLTemplate.objects.get(pk=html_template).name
|
||||
|
||||
template_string = (
|
||||
f"""{{% extends "{html_template_name}" %}}\n{template_string}"""
|
||||
)
|
||||
@@ -98,28 +102,26 @@ def generate_html(
|
||||
|
||||
tm = env.from_string(template_string)
|
||||
|
||||
variables = prep_variables_for_template(
|
||||
variables_dict = prep_variables_for_template(
|
||||
variables=variables, dependencies=dependencies
|
||||
)
|
||||
if variables:
|
||||
return (tm.render(css=css, **variables), variables)
|
||||
else:
|
||||
return (tm.render(css=css), None)
|
||||
|
||||
return (tm.render(css=css, **variables_dict), variables_dict)
|
||||
|
||||
|
||||
def make_dataqueries_inline(*, variables: str) -> str:
|
||||
variables_obj = yaml.safe_load(variables) or {}
|
||||
if "data_sources" in variables_obj.keys() and isinstance(
|
||||
variables_obj["data_sources"], dict
|
||||
):
|
||||
for key, value in variables_obj["data_sources"].items():
|
||||
# data_source is referencing a saved data query
|
||||
try:
|
||||
variables_obj = yaml.safe_load(variables) or {}
|
||||
except (yaml.parser.ParserError, yaml.YAMLError):
|
||||
variables_obj = {}
|
||||
|
||||
data_sources = variables_obj.get("data_sources", {})
|
||||
if isinstance(data_sources, dict):
|
||||
for key, value in data_sources.items():
|
||||
if isinstance(value, str):
|
||||
ReportDataQuery = apps.get_model("reporting", "ReportDataQuery")
|
||||
try:
|
||||
variables_obj["data_sources"][key] = ReportDataQuery.objects.get(
|
||||
name=value
|
||||
).json_query
|
||||
query = ReportDataQuery.objects.get(name=value).json_query
|
||||
variables_obj["data_sources"][key] = query
|
||||
except ReportDataQuery.DoesNotExist:
|
||||
continue
|
||||
|
||||
@@ -129,101 +131,33 @@ def make_dataqueries_inline(*, variables: str) -> str:
|
||||
def prep_variables_for_template(
|
||||
*,
|
||||
variables: str,
|
||||
dependencies: Dict[str, Any] = {},
|
||||
dependencies: Optional[Dict[str, Any]] = None,
|
||||
limit_query_results: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if not dependencies:
|
||||
dependencies = {}
|
||||
|
||||
if not variables:
|
||||
variables = ""
|
||||
|
||||
# replace any data queries in data_sources with the yaml
|
||||
variables = make_dataqueries_inline(variables=variables)
|
||||
|
||||
# resolve dependencies that are agent, site, or client
|
||||
if "client" in dependencies.keys():
|
||||
Model = apps.get_model("clients", "Client")
|
||||
dependencies["client"] = Model.objects.get(id=dependencies["client"])
|
||||
elif "site" in dependencies.keys():
|
||||
Model = apps.get_model("clients", "Site")
|
||||
dependencies["site"] = Model.objects.get(id=dependencies["site"])
|
||||
elif "agent" in dependencies.keys():
|
||||
Model = apps.get_model("agents", "Agent")
|
||||
dependencies["agent"] = Model.objects.get(agent_id=dependencies["agent"])
|
||||
|
||||
# check for variables that need to be replaced with the database values ({{client.name}}, {{agent.hostname}}, etc)
|
||||
if variables and isinstance(variables, str):
|
||||
# returns {{ model.prop }}, prop, model
|
||||
for string, model, prop in re.findall(RE_DB_VALUE, variables):
|
||||
value: Any = ""
|
||||
# will be agent, site, client, or global
|
||||
if model == "global":
|
||||
value = get_db_value(string=f"{model}.{prop}")
|
||||
elif model in dependencies.keys():
|
||||
instance = dependencies[model]
|
||||
value = (
|
||||
get_db_value(string=prop, instance=instance) if instance else None
|
||||
)
|
||||
|
||||
if value:
|
||||
variables = variables.replace(string, str(value))
|
||||
|
||||
# check for any non-database dependencies and replace in variables
|
||||
if variables and isinstance(variables, str):
|
||||
RE_DEP_VALUE = re.compile(r"(\{\{\s*(.*)\s*\}\})")
|
||||
|
||||
for string, dep in re.findall(RE_DEP_VALUE, variables):
|
||||
if dep in dependencies.keys():
|
||||
variables = variables.replace(string, str(dependencies[dep]))
|
||||
|
||||
# load yaml variables if they exist
|
||||
variables = yaml.safe_load(variables) or {}
|
||||
# process report dependencies
|
||||
variables_dict = process_dependencies(
|
||||
variables=variables, dependencies=dependencies
|
||||
)
|
||||
|
||||
# replace the data_sources with the actual data from DB. This will be passed to the template
|
||||
# in the form of {{data_sources.data_source_name}}
|
||||
if "data_sources" in variables.keys() and isinstance(
|
||||
variables["data_sources"], dict
|
||||
):
|
||||
for key, value in variables["data_sources"].items():
|
||||
if isinstance(value, dict):
|
||||
data_source = value
|
||||
|
||||
_ = data_source.pop("meta") if "meta" in data_source.keys() else None
|
||||
|
||||
modified_datasource = resolve_model(data_source=data_source)
|
||||
queryset = build_queryset(
|
||||
data_source=modified_datasource, limit=limit_query_results
|
||||
)
|
||||
variables["data_sources"][key] = queryset
|
||||
variables_dict = process_data_sources(
|
||||
variables=variables_dict, limit_query_results=limit_query_results
|
||||
)
|
||||
|
||||
# generate and replace charts in the variables
|
||||
if "charts" in variables.keys() and isinstance(variables["charts"], dict):
|
||||
for key, chart in variables["charts"].items():
|
||||
# make sure chart options are present and a dict
|
||||
if "options" not in chart.keys() and not isinstance(chart["options"], dict):
|
||||
break
|
||||
variables_dict = process_chart_variables(variables=variables_dict)
|
||||
|
||||
options = chart["options"]
|
||||
# if data_frame is present and a str that means we need to replace it with a value from variables
|
||||
if "data_frame" in options.keys() and isinstance(
|
||||
options["data_frame"], str
|
||||
):
|
||||
# dot dotation to datasource if exists
|
||||
data_source = options["data_frame"].split(".")
|
||||
data = variables
|
||||
for path in data_source:
|
||||
if path in data.keys():
|
||||
data = data[path]
|
||||
else:
|
||||
break
|
||||
|
||||
if data:
|
||||
chart["options"]["data_frame"] = data
|
||||
|
||||
variables["charts"][key] = generate_chart(
|
||||
type=chart["chartType"],
|
||||
format=chart["outputType"],
|
||||
options=chart["options"],
|
||||
traces=chart["traces"] if "traces" in chart.keys() else None,
|
||||
layout=chart["layout"] if "layout" in chart.keys() else None,
|
||||
)
|
||||
|
||||
return {**variables, **dependencies}
|
||||
return variables_dict
|
||||
|
||||
|
||||
class ResolveModelException(Exception):
|
||||
@@ -250,31 +184,33 @@ def resolve_model(*, data_source: Dict[str, Any]) -> Dict[str, Any]:
|
||||
raise ResolveModelException("Model key must be present on data_source")
|
||||
|
||||
|
||||
ALLOWED_OPERATIONS = (
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AllowedOperations(Enum):
|
||||
# filtering
|
||||
"only",
|
||||
"defer",
|
||||
"filter",
|
||||
"exclude",
|
||||
"limit",
|
||||
"get",
|
||||
"first",
|
||||
"all",
|
||||
ONLY = "only"
|
||||
DEFER = "defer"
|
||||
FILTER = "filter"
|
||||
EXCLUDE = "exclude"
|
||||
LIMIT = "limit"
|
||||
GET = "get"
|
||||
FIRST = "first"
|
||||
ALL = "all"
|
||||
# custom fields
|
||||
"custom_fields",
|
||||
CUSTOM_FIELDS = "custom_fields"
|
||||
# presentation
|
||||
"json",
|
||||
JSON = "json"
|
||||
# relations
|
||||
"select_related",
|
||||
"prefetch_related",
|
||||
SELECT_RELATED = "select_related"
|
||||
PREFETCH_RELATED = "prefetch_related"
|
||||
# operations
|
||||
"aggregate",
|
||||
"annotate",
|
||||
"count",
|
||||
"values",
|
||||
AGGREGATE = "aggregate"
|
||||
ANNOTATE = "annotate"
|
||||
COUNT = "count"
|
||||
VALUES = "values"
|
||||
# ordering
|
||||
"order_by",
|
||||
)
|
||||
ORDER_BY = "order_by"
|
||||
|
||||
|
||||
class InvalidDBOperationException(Exception):
|
||||
@@ -284,22 +220,23 @@ class InvalidDBOperationException(Exception):
|
||||
def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None) -> Any:
|
||||
local_data_source = data_source
|
||||
Model = local_data_source.pop("model")
|
||||
limit = limit
|
||||
count = False
|
||||
get = False
|
||||
first = False
|
||||
all = False
|
||||
isJson = False
|
||||
columns = local_data_source["only"] if "only" in local_data_source.keys() else None
|
||||
defer = local_data_source.get("defer", None)
|
||||
columns = local_data_source.get("only", None)
|
||||
fields_to_add = []
|
||||
|
||||
# create a base reporting queryset
|
||||
queryset = Model.objects.using("reporting")
|
||||
model_name = Model.__name__.lower()
|
||||
for operation, values in local_data_source.items():
|
||||
if operation not in ALLOWED_OPERATIONS:
|
||||
# Usage in the build_queryset function:
|
||||
if operation not in [op.value for op in AllowedOperations]:
|
||||
raise InvalidDBOperationException(
|
||||
f"DB operation: {operation} not allowed. Supported operations: only, defer, filter, get, first, all, custom_fields, exclude, limit, select_related, prefetch_related, annotate, aggregate, order_by, count"
|
||||
f"DB operation: {operation} not allowed. Supported operations: {', '.join(op.value for op in AllowedOperations)}"
|
||||
)
|
||||
|
||||
if operation == "meta":
|
||||
@@ -308,10 +245,9 @@ def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None)
|
||||
from core.models import CustomField
|
||||
|
||||
if model_name in ["client", "site", "agent"]:
|
||||
fields = CustomField.objects.filter(model=model_name)
|
||||
fields_to_add = [
|
||||
field
|
||||
for field in values
|
||||
if CustomField.objects.filter(model=model_name, name=field).exists()
|
||||
field for field in values if fields.filter(name=field).exists()
|
||||
]
|
||||
|
||||
elif operation == "limit":
|
||||
@@ -319,6 +255,9 @@ def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None)
|
||||
elif operation == "count":
|
||||
count = True
|
||||
elif operation == "get":
|
||||
# need to add a filter for the get if present
|
||||
if isinstance(values, dict):
|
||||
queryset = queryset.filter(**values)
|
||||
get = True
|
||||
elif operation == "first":
|
||||
first = True
|
||||
@@ -343,9 +282,20 @@ def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None)
|
||||
queryset = queryset[:limit]
|
||||
|
||||
if columns:
|
||||
# remove columns from only if defer is also present
|
||||
if defer:
|
||||
columns = [column for column in columns if column not in defer]
|
||||
if "id" not in columns:
|
||||
columns.append("id")
|
||||
|
||||
queryset = queryset.values(*columns)
|
||||
elif defer:
|
||||
# Since values seems to ignore only and defer, we need to get all columns and remove the defered ones.
|
||||
# Then we can pass the rest of the columns in
|
||||
included_fields = [
|
||||
field.name for field in Model._meta.local_fields if field.name not in defer
|
||||
]
|
||||
queryset = queryset.values(*included_fields)
|
||||
else:
|
||||
queryset = queryset.values()
|
||||
|
||||
@@ -386,7 +336,7 @@ def add_custom_fields(
|
||||
fields_to_add: List[str],
|
||||
model_name: Literal["client", "site", "agent"],
|
||||
dict_value: bool = False,
|
||||
):
|
||||
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
from core.models import CustomField
|
||||
from agents.models import AgentCustomField
|
||||
from clients.models import ClientCustomField, SiteCustomField
|
||||
@@ -454,6 +404,8 @@ def add_custom_fields(
|
||||
else:
|
||||
row["custom_fields"][custom_field.name] = custom_field.default_value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def normalize_asset_url(text: str, type: Literal["pdf", "html"]) -> str:
|
||||
RE_ASSET_URL = re.compile(
|
||||
@@ -496,7 +448,9 @@ def base64_encode_assets(template: str) -> List[Dict[str, Any]]:
|
||||
"file": encoded_base64_str,
|
||||
}
|
||||
)
|
||||
added_ids.append(asset.id)
|
||||
added_ids.append(
|
||||
str(asset.id)
|
||||
) # need to convert uuid to str for easy comparison
|
||||
except ReportAsset.DoesNotExist:
|
||||
continue
|
||||
|
||||
@@ -509,6 +463,106 @@ def decode_base64_asset(asset: str) -> bytes:
|
||||
return base64.b64decode(asset.encode("utf-8"))
|
||||
|
||||
|
||||
def process_data_sources(
|
||||
*, variables: Dict[str, Any], limit_query_results: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
data_sources = variables.get("data_sources")
|
||||
|
||||
if isinstance(data_sources, dict):
|
||||
for key, value in data_sources.items():
|
||||
if isinstance(value, dict):
|
||||
modified_datasource = resolve_model(data_source=value)
|
||||
queryset = build_queryset(
|
||||
data_source=modified_datasource, limit=limit_query_results
|
||||
)
|
||||
data_sources[key] = queryset
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def process_dependencies(
|
||||
*, variables: str, dependencies: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
DEPENDENCY_MODELS = {
|
||||
"client": ("clients", "Client"),
|
||||
"site": ("clients", "Site"),
|
||||
"agent": ("agents", "Agent"),
|
||||
}
|
||||
|
||||
# Resolve dependencies that are agent, site, or client
|
||||
for dep, (app_label, model_name) in DEPENDENCY_MODELS.items():
|
||||
if dep in dependencies:
|
||||
Model = apps.get_model(app_label, model_name)
|
||||
# Assumes each model has a unique lookup mechanism
|
||||
lookup_param = "agent_id" if dep == "agent" else "id"
|
||||
dependencies[dep] = Model.objects.get(**{lookup_param: dependencies[dep]})
|
||||
|
||||
# Handle database value placeholders
|
||||
for string, model, prop in re.findall(RE_DB_VALUE, variables):
|
||||
value = get_value_for_model(model, prop, dependencies)
|
||||
if value:
|
||||
variables = variables.replace(string, str(value))
|
||||
|
||||
# Handle non-database dependencies
|
||||
for string, dep in re.findall(RE_DEPENDENCY_VALUE, variables):
|
||||
value = dependencies.get(dep)
|
||||
if value:
|
||||
variables = variables.replace(string, str(value))
|
||||
|
||||
# Load yaml variables if they exist
|
||||
variables = yaml.safe_load(variables) or {}
|
||||
|
||||
return {**variables, **dependencies}
|
||||
|
||||
|
||||
def get_value_for_model(model: str, prop: str, dependencies: Dict[str, Any]) -> Any:
|
||||
if model == "global":
|
||||
return get_db_value(string=f"{model}.{prop}")
|
||||
instance = dependencies.get(model)
|
||||
return get_db_value(string=prop, instance=instance) if instance else None
|
||||
|
||||
|
||||
def process_chart_variables(*, variables: Dict[str, Any]) -> Dict[str, Any]:
|
||||
charts = variables.get("charts")
|
||||
|
||||
if not isinstance(charts, dict):
|
||||
return variables
|
||||
|
||||
for key, chart in charts.items():
|
||||
options = chart.get("options")
|
||||
if not isinstance(options, dict):
|
||||
continue
|
||||
|
||||
data_frame = options.get("data_frame")
|
||||
if isinstance(data_frame, str):
|
||||
data_source = data_frame.split(".")
|
||||
data = variables
|
||||
|
||||
path_exists = True
|
||||
for path in data_source:
|
||||
data = data.get(path)
|
||||
if data is None:
|
||||
path_exists = False
|
||||
break
|
||||
|
||||
if not path_exists:
|
||||
continue
|
||||
|
||||
options["data_frame"] = data
|
||||
|
||||
traces = chart.get("traces")
|
||||
layout = chart.get("layout")
|
||||
charts[key] = generate_chart(
|
||||
type=chart["chartType"],
|
||||
format=chart["outputType"],
|
||||
options=options,
|
||||
traces=traces,
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def generate_chart(
|
||||
*,
|
||||
type: Literal["pie", "bar", "line"],
|
||||
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from django.conf import settings as djangosettings
|
||||
from django.core.exceptions import (
|
||||
@@ -16,6 +16,7 @@ from django.core.exceptions import (
|
||||
PermissionDenied,
|
||||
SuspiciousFileOperation,
|
||||
)
|
||||
from django.db import transaction
|
||||
from django.core.files.base import ContentFile
|
||||
from django.http import FileResponse, HttpResponse, JsonResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
@@ -26,6 +27,10 @@ from rest_framework.response import Response
|
||||
from rest_framework.serializers import (
|
||||
CharField,
|
||||
ListField,
|
||||
IntegerField,
|
||||
JSONField,
|
||||
ChoiceField,
|
||||
BooleanField,
|
||||
ModelSerializer,
|
||||
Serializer,
|
||||
ValidationError,
|
||||
@@ -35,7 +40,6 @@ from rest_framework.views import APIView
|
||||
from tacticalrmm.utils import notify_error
|
||||
|
||||
from .models import ReportAsset, ReportDataQuery, ReportHTMLTemplate, ReportTemplate
|
||||
from .settings import settings
|
||||
from .storage import report_assets_fs
|
||||
from .utils import (
|
||||
base64_encode_assets,
|
||||
@@ -112,6 +116,9 @@ class GenerateReport(APIView):
|
||||
|
||||
format = request.data["format"]
|
||||
|
||||
if format not in ["pdf", "html"]:
|
||||
return notify_error("Report format is incorrect.")
|
||||
|
||||
try:
|
||||
html_report, _ = generate_html(
|
||||
template=template.template_md,
|
||||
@@ -128,7 +135,7 @@ class GenerateReport(APIView):
|
||||
|
||||
if format == "html":
|
||||
return Response(html_report)
|
||||
elif format == "pdf":
|
||||
else:
|
||||
pdf_bytes = generate_pdf(html=html_report)
|
||||
|
||||
return FileResponse(
|
||||
@@ -136,8 +143,7 @@ class GenerateReport(APIView):
|
||||
content_type="application/pdf",
|
||||
filename=f"{template.name}.pdf",
|
||||
)
|
||||
else:
|
||||
return notify_error("Report format is incorrect.")
|
||||
|
||||
except TemplateError as error:
|
||||
if hasattr(error, "lineno"):
|
||||
return notify_error(f"Line {error.lineno}: {error.message}")
|
||||
@@ -148,75 +154,94 @@ class GenerateReport(APIView):
|
||||
|
||||
|
||||
class GenerateReportPreview(APIView):
|
||||
class InputRequest:
|
||||
template_md: str
|
||||
type: Literal["markdown", "html"]
|
||||
template_css: str
|
||||
template_html: int
|
||||
template_variables: Dict[str, Any]
|
||||
dependencies: Dict[str, Any]
|
||||
format: Literal["html", "pdf"]
|
||||
debug: bool
|
||||
|
||||
class InputSerializer(Serializer[InputRequest]):
|
||||
template_md = CharField()
|
||||
type = CharField()
|
||||
template_css = CharField(required=False)
|
||||
template_html = IntegerField(required=False)
|
||||
template_variables = JSONField()
|
||||
dependencies = JSONField()
|
||||
format = ChoiceField(choices=["html", "pdf"])
|
||||
debug = BooleanField(default=False)
|
||||
|
||||
def post(self, request: Request) -> Union[FileResponse, Response]:
|
||||
try:
|
||||
debug = request.data["debug"]
|
||||
|
||||
report_data = self._parse_and_validate_request_data(request.data)
|
||||
html_report, variables = generate_html(
|
||||
template=request.data["template_md"],
|
||||
template_type=request.data["type"],
|
||||
css=request.data["template_css"],
|
||||
html_template=(
|
||||
request.data["template_html"]
|
||||
if "template_html" in request.data.keys()
|
||||
else None
|
||||
),
|
||||
variables=request.data["template_variables"],
|
||||
dependencies=request.data["dependencies"],
|
||||
template=report_data["template_md"],
|
||||
template_type=report_data["type"],
|
||||
css=report_data.get("template_css", ""),
|
||||
html_template=report_data.get("template_html"),
|
||||
variables=report_data["template_variables"],
|
||||
dependencies=report_data["dependencies"],
|
||||
)
|
||||
|
||||
response_format = request.data["format"]
|
||||
debug = request.data["debug"]
|
||||
|
||||
html_report = normalize_asset_url(html_report, response_format)
|
||||
|
||||
if debug:
|
||||
# need to serialize the models if an agent, site, or client is specified
|
||||
if variables:
|
||||
from django.forms.models import model_to_dict
|
||||
|
||||
if "agent" in variables.keys():
|
||||
variables["agent"] = model_to_dict(
|
||||
variables["agent"],
|
||||
fields=[
|
||||
field.name for field in variables["agent"]._meta.fields
|
||||
],
|
||||
)
|
||||
if "site" in variables.keys():
|
||||
variables["site"] = model_to_dict(
|
||||
variables["site"],
|
||||
fields=[
|
||||
field.name for field in variables["site"]._meta.fields
|
||||
],
|
||||
)
|
||||
if "client" in variables.keys():
|
||||
variables["client"] = model_to_dict(
|
||||
variables["client"],
|
||||
fields=[
|
||||
field.name for field in variables["client"]._meta.fields
|
||||
],
|
||||
)
|
||||
|
||||
return Response({"template": html_report, "variables": variables})
|
||||
|
||||
elif response_format == "html":
|
||||
return Response(html_report)
|
||||
else:
|
||||
pdf_bytes = generate_pdf(html=html_report)
|
||||
|
||||
return FileResponse(
|
||||
ContentFile(pdf_bytes),
|
||||
content_type="application/pdf",
|
||||
filename=f"preview.pdf",
|
||||
)
|
||||
if report_data["debug"]:
|
||||
return self._process_debug_response(html_report, variables)
|
||||
return self._generate_response_based_on_format(
|
||||
html_report, report_data["format"]
|
||||
)
|
||||
except TemplateError as error:
|
||||
if hasattr(error, "lineno"):
|
||||
return notify_error(f"Line {error.lineno}: {error.message}")
|
||||
else:
|
||||
return notify_error(str(error))
|
||||
return self._handle_template_error(error)
|
||||
except Exception as error:
|
||||
return notify_error(str(error))
|
||||
|
||||
def _parse_and_validate_request_data(self, data: Dict[str, Any]) -> Any:
|
||||
serializer = self.InputSerializer(data=data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
return serializer.validated_data
|
||||
|
||||
def _process_debug_response(
|
||||
self, html_report: str, variables: Dict[str, Any]
|
||||
) -> Response:
|
||||
if variables:
|
||||
from django.forms.models import model_to_dict
|
||||
|
||||
# serialize any model instances provided
|
||||
for model_name in ["agent", "site", "client"]:
|
||||
if model_name in variables:
|
||||
model_instance = variables[model_name]
|
||||
serialized_model = model_to_dict(
|
||||
model_instance,
|
||||
fields=[field.name for field in model_instance._meta.fields],
|
||||
)
|
||||
variables[model_name] = serialized_model
|
||||
|
||||
return Response({"template": html_report, "variables": variables})
|
||||
|
||||
def _generate_response_based_on_format(
|
||||
self, html_report: str, format: Literal["html", "pdf"]
|
||||
) -> Union[Response, FileResponse]:
|
||||
html_report = normalize_asset_url(html_report, format)
|
||||
|
||||
if format == "html":
|
||||
return Response(html_report)
|
||||
else:
|
||||
pdf_bytes = generate_pdf(html=html_report)
|
||||
return FileResponse(
|
||||
ContentFile(pdf_bytes),
|
||||
content_type="application/pdf",
|
||||
filename="preview.pdf",
|
||||
)
|
||||
|
||||
def _handle_template_error(self, error: TemplateError) -> Response:
|
||||
if hasattr(error, "lineno"):
|
||||
error_message = f"Line {error.lineno}: {error.message}"
|
||||
else:
|
||||
error_message = str(error)
|
||||
|
||||
return notify_error(error_message)
|
||||
|
||||
|
||||
class ExportReportTemplate(APIView):
|
||||
def post(self, request: Request, pk: int) -> Response:
|
||||
@@ -253,121 +278,156 @@ class ExportReportTemplate(APIView):
|
||||
|
||||
|
||||
class ImportReportTemplate(APIView):
|
||||
@transaction.atomic
|
||||
def post(self, request: Request) -> Response:
|
||||
import random
|
||||
import string
|
||||
|
||||
base_template = None
|
||||
report_template = None
|
||||
try:
|
||||
template_obj = json.loads(request.data["template"])
|
||||
|
||||
if "template" not in template_obj.keys():
|
||||
return notify_error("Missing template information")
|
||||
# import base template if exists
|
||||
base_template_id = self._import_base_template(
|
||||
template_obj.get("base_template")
|
||||
)
|
||||
|
||||
# create base template
|
||||
if "base_template" in template_obj.keys() and template_obj["base_template"]:
|
||||
# check if there is a name conflict and append some characters to the name if so
|
||||
if (
|
||||
"name" in template_obj["base_template"].keys()
|
||||
and ReportHTMLTemplate.objects.filter(
|
||||
name=template_obj["base_template"]["name"]
|
||||
).exists()
|
||||
):
|
||||
template_obj["base_template"]["name"] += "".join(
|
||||
random.choice(string.ascii_lowercase) for i in range(6)
|
||||
)
|
||||
base_template = ReportHTMLTemplate.objects.create(
|
||||
**template_obj["base_template"]
|
||||
)
|
||||
base_template.refresh_from_db()
|
||||
# import base template if exists
|
||||
report_template = self._import_report_template(
|
||||
template_obj.get("template"), base_template_id
|
||||
)
|
||||
|
||||
# create template
|
||||
if "template" in template_obj.keys() and template_obj["template"]:
|
||||
# check if there is a name conflict and append some characters to the name if so
|
||||
if (
|
||||
"name" in template_obj["template"].keys()
|
||||
and ReportTemplate.objects.filter(
|
||||
name=template_obj["template"]["name"]
|
||||
).exists()
|
||||
):
|
||||
template_obj["template"]["name"] += "".join(
|
||||
random.choice(string.ascii_lowercase) for i in range(6)
|
||||
)
|
||||
report_template = ReportTemplate.objects.create(
|
||||
**template_obj["template"],
|
||||
template_html=base_template if base_template else None,
|
||||
)
|
||||
|
||||
# import assets
|
||||
if "assets" in template_obj.keys() and isinstance(
|
||||
template_obj["assets"], list
|
||||
):
|
||||
for asset in template_obj["assets"]:
|
||||
# asset should have id, name, and file fields
|
||||
try:
|
||||
asset = ReportAsset(
|
||||
id=asset["id"], file=decode_base64_asset(asset["file"])
|
||||
)
|
||||
asset.file.name = os.path.join(
|
||||
settings.REPORTING_ASSETS_BASE_PATH, asset["name"]
|
||||
)
|
||||
asset.save()
|
||||
except:
|
||||
pass
|
||||
# import assets if exists
|
||||
self._import_assets(template_obj.get("assets"))
|
||||
|
||||
return Response(ReportTemplateSerializer(report_template).data)
|
||||
except:
|
||||
base_template.delete() if base_template else None
|
||||
report_template.delete() if report_template else None
|
||||
return notify_error("There was an error with the request")
|
||||
|
||||
except Exception as e:
|
||||
# rollback db transaction if any exception occurs
|
||||
transaction.set_rollback(True)
|
||||
return notify_error(str(e))
|
||||
|
||||
def _import_base_template(
|
||||
self, base_template_data: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[int]:
|
||||
if base_template_data:
|
||||
# Check name conflict and modify name if necessary
|
||||
name = base_template_data.get("name")
|
||||
html = base_template_data.get("html")
|
||||
|
||||
if not name:
|
||||
raise ValidationError("base_template is missing 'name' key")
|
||||
if not html:
|
||||
raise ValidationError("base_template is missing 'html' field")
|
||||
|
||||
if ReportHTMLTemplate.objects.filter(name=name).exists():
|
||||
name += self._generate_random_string()
|
||||
base_template = ReportHTMLTemplate.objects.create(name=name, html=html)
|
||||
base_template.refresh_from_db()
|
||||
return base_template.id
|
||||
return None
|
||||
|
||||
def _import_report_template(
|
||||
self,
|
||||
report_template_data: Dict[str, Any],
|
||||
base_template_id: Optional[int] = None,
|
||||
) -> "ReportTemplate":
|
||||
if report_template_data:
|
||||
name = report_template_data.pop("name", None)
|
||||
template_md = report_template_data.get("template_md")
|
||||
if not name:
|
||||
raise ValidationError("template requires a 'name' key")
|
||||
if not template_md:
|
||||
raise ValidationError("template requires a 'template_md' field")
|
||||
|
||||
if ReportTemplate.objects.filter(name=name).exists():
|
||||
name += self._generate_random_string()
|
||||
report_template = ReportTemplate.objects.create(
|
||||
name=name, template_html_id=base_template_id, **report_template_data
|
||||
)
|
||||
return report_template
|
||||
else:
|
||||
raise ValidationError("'template' key is required in input")
|
||||
|
||||
def _import_assets(self, assets: List[Dict[str, Any]]) -> None:
|
||||
from django.core.files import File
|
||||
import io
|
||||
from .storage import report_assets_fs
|
||||
|
||||
if isinstance(assets, list):
|
||||
for asset in assets:
|
||||
parent_folder = report_assets_fs.getreldir(path=asset["name"])
|
||||
path = report_assets_fs.get_available_name(
|
||||
os.path.join(parent_folder, asset["name"])
|
||||
)
|
||||
asset_obj = ReportAsset(
|
||||
id=asset["id"],
|
||||
file=File(
|
||||
io.BytesIO(decode_base64_asset(asset["file"])),
|
||||
name=path,
|
||||
),
|
||||
)
|
||||
asset_obj.save()
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_string(length: int = 6) -> str:
|
||||
import random
|
||||
import string
|
||||
|
||||
return "".join(random.choice(string.ascii_lowercase) for i in range(length))
|
||||
|
||||
|
||||
class GetAllowedValues(APIView):
|
||||
def post(self, request: Request) -> Response:
|
||||
# pass in blank template. We are just interested in variables
|
||||
variables = request.data.get("variables", None)
|
||||
if variables is None:
|
||||
return notify_error("'variables' is required")
|
||||
|
||||
dependencies = request.data.get("dependencies", None)
|
||||
|
||||
# process variables and dependencies
|
||||
variables = prep_variables_for_template(
|
||||
variables=request.data["variables"],
|
||||
dependencies=request.data["dependencies"],
|
||||
variables=variables,
|
||||
dependencies=dependencies,
|
||||
limit_query_results=1, # only get first item for querysets
|
||||
)
|
||||
|
||||
# recursive function to get properties on any embedded objects
|
||||
def get_dot_notation(
|
||||
d: Dict[str, Any], parent_key: str = "", path: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
items = {}
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}.{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items[new_key] = "Object"
|
||||
items.update(get_dot_notation(v, new_key, path=path))
|
||||
elif isinstance(v, list) or type(v).__name__ == "PermissionQuerySet":
|
||||
items[new_key] = f"Array ({len(v)} Results)"
|
||||
if v: # Ensure the list is not empty
|
||||
item = v[0]
|
||||
if isinstance(item, dict):
|
||||
items.update(
|
||||
get_dot_notation(item, f"{new_key}[0]", path=path)
|
||||
)
|
||||
else:
|
||||
items[f"{new_key}[0]"] = type(item).__name__
|
||||
|
||||
else:
|
||||
items[new_key] = type(v).__name__
|
||||
return items
|
||||
|
||||
if variables:
|
||||
return Response(get_dot_notation(variables))
|
||||
return Response(self._get_dot_notation(variables))
|
||||
else:
|
||||
return Response()
|
||||
|
||||
# recursive function to get properties on any embedded objects
|
||||
def _get_dot_notation(
|
||||
self, d: Dict[str, Any], parent_key: str = "", path: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
items = {}
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}.{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items[new_key] = "Object"
|
||||
items.update(self._get_dot_notation(v, new_key, path=path))
|
||||
elif isinstance(v, list) or type(v).__name__ == "PermissionQuerySet":
|
||||
items[new_key] = f"Array ({len(v)} Results)"
|
||||
if v: # Ensure the list is not empty
|
||||
item = v[0]
|
||||
if isinstance(item, dict):
|
||||
items.update(
|
||||
self._get_dot_notation(item, f"{new_key}[0]", path=path)
|
||||
)
|
||||
else:
|
||||
items[f"{new_key}[0]"] = type(item).__name__
|
||||
|
||||
else:
|
||||
items[new_key] = type(v).__name__
|
||||
return items
|
||||
|
||||
|
||||
class GetReportAssets(APIView):
|
||||
def get(self, request: Request) -> Response:
|
||||
path = request.query_params.get("path", "").lstrip("/")
|
||||
|
||||
directories, files = report_assets_fs.listdir(path)
|
||||
try:
|
||||
directories, files = report_assets_fs.listdir(path)
|
||||
except FileNotFoundError:
|
||||
return notify_error("The path is invalid")
|
||||
|
||||
response = list()
|
||||
|
||||
# parse directories
|
||||
@@ -401,62 +461,61 @@ class GetReportAssets(APIView):
|
||||
|
||||
class GetAllAssets(APIView):
|
||||
def get(self, request: Request) -> Response:
|
||||
only_folders = request.query_params.get("OnlyFolders", None)
|
||||
only_folders = request.query_params.get("onlyFolders", None)
|
||||
only_folders = True if only_folders and only_folders == "true" else False
|
||||
|
||||
# pull report assets from the database so we can pair with the file system assets
|
||||
assets = ReportAsset.objects.all()
|
||||
|
||||
# TODO: define a Type for file node
|
||||
def walk_folder_and_return_node(path: str):
|
||||
for current_dir, subdirs, files in os.walk(path):
|
||||
print(current_dir, subdirs, files)
|
||||
current_dir = "Report Assets" if current_dir == "." else current_dir
|
||||
node = {
|
||||
"type": "folder",
|
||||
"name": current_dir.replace("./", ""),
|
||||
"path": path.replace("./", ""),
|
||||
"children": list(),
|
||||
"selectable": False,
|
||||
"icon": "folder",
|
||||
"iconColor": "yellow-9",
|
||||
}
|
||||
for dirname in subdirs:
|
||||
dirpath = f"{path}/{dirname}"
|
||||
node["children"].append(
|
||||
walk_folder_and_return_node(dirpath) # recursively call
|
||||
)
|
||||
|
||||
if not only_folders:
|
||||
for filename in files:
|
||||
print(current_dir, filename)
|
||||
path = f"{current_dir}/{filename}".replace("./", "").replace(
|
||||
"Report Assets/", ""
|
||||
)
|
||||
try:
|
||||
# need to remove the relative path
|
||||
id = assets.get(file=path).id
|
||||
node["children"].append(
|
||||
{
|
||||
"id": id,
|
||||
"type": "file",
|
||||
"name": filename,
|
||||
"path": path,
|
||||
"icon": "description",
|
||||
}
|
||||
)
|
||||
except ReportAsset.DoesNotExist:
|
||||
pass
|
||||
|
||||
return node
|
||||
|
||||
try:
|
||||
os.chdir(report_assets_fs.base_location)
|
||||
response = [walk_folder_and_return_node(".")]
|
||||
response = [self._walk_folder_and_return_node(".", only_folders)]
|
||||
return Response(response)
|
||||
except FileNotFoundError:
|
||||
return notify_error("Unable to process request")
|
||||
|
||||
# TODO: define a Type for file node
|
||||
def _walk_folder_and_return_node(self, path: str, only_folders: bool = False):
|
||||
# pull report assets from the database so we can pair with the file system assets
|
||||
assets = ReportAsset.objects.all()
|
||||
|
||||
for current_dir, subdirs, files in os.walk(path):
|
||||
current_dir = "Report Assets" if current_dir == "." else current_dir
|
||||
node = {
|
||||
"type": "folder",
|
||||
"name": current_dir.replace("./", ""),
|
||||
"path": path.replace("./", ""),
|
||||
"children": list(),
|
||||
"selectable": False,
|
||||
"icon": "folder",
|
||||
"iconColor": "yellow-9",
|
||||
}
|
||||
for dirname in subdirs:
|
||||
dirpath = f"{path}/{dirname}"
|
||||
node["children"].append(
|
||||
# recursively call
|
||||
self._walk_folder_and_return_node(dirpath, only_folders)
|
||||
)
|
||||
|
||||
if not only_folders:
|
||||
for filename in files:
|
||||
path = f"{current_dir}/{filename}".replace("./", "").replace(
|
||||
"Report Assets/", ""
|
||||
)
|
||||
try:
|
||||
# need to remove the relative path
|
||||
id = assets.get(file=path).id
|
||||
node["children"].append(
|
||||
{
|
||||
"id": id,
|
||||
"type": "file",
|
||||
"name": filename,
|
||||
"path": path,
|
||||
"icon": "description",
|
||||
}
|
||||
)
|
||||
except ReportAsset.DoesNotExist:
|
||||
pass
|
||||
|
||||
return node
|
||||
|
||||
|
||||
class RenameReportAsset(APIView):
|
||||
class InputRequest:
|
||||
@@ -496,6 +555,8 @@ class CreateAssetFolder(APIView):
|
||||
def post(self, request: Request) -> Response:
|
||||
path = request.data["path"].lstrip("/") if "path" in request.data else ""
|
||||
|
||||
if not path:
|
||||
return notify_error("'path' in required.")
|
||||
try:
|
||||
new_path = report_assets_fs.createfolder(path=path)
|
||||
return Response(new_path)
|
||||
@@ -752,10 +813,13 @@ class QuerySchema(APIView):
|
||||
schema_path = "static/reporting/schemas/query_schema.json"
|
||||
|
||||
if djangosettings.DEBUG:
|
||||
with open(djangosettings.BASE_DIR / schema_path, "r") as f:
|
||||
data = json.load(f)
|
||||
try:
|
||||
with open(djangosettings.BASE_DIR / schema_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
return JsonResponse(data)
|
||||
return JsonResponse(data)
|
||||
except FileNotFoundError:
|
||||
return notify_error("There was an error getting the file")
|
||||
else:
|
||||
response = HttpResponse()
|
||||
response["X-Accel-Redirect"] = f"/{schema_path}"
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Literal, TYPE_CHECKING
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import requests
|
||||
@@ -41,6 +41,11 @@ from tacticalrmm.helpers import (
|
||||
notify_error,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agents.models import Agent
|
||||
from clients.models import Site, Client
|
||||
|
||||
|
||||
def generate_winagent_exe(
|
||||
*,
|
||||
client: int,
|
||||
@@ -286,35 +291,75 @@ def get_latest_trmm_ver() -> str:
|
||||
return "error"
|
||||
|
||||
|
||||
# Receives something like {{ client.name }} and a Model instance of Client, Site, or Agent. If an
|
||||
# Receives something like {{ client.name }} and a Model instance of Client, Site, or Agent. If an
|
||||
# agent instance is passed it will resolve the value of agent.client.name and return the agent's client name.
|
||||
#
|
||||
#
|
||||
# You can query custom fields by using their name. {{ site.Custom Field Name }}
|
||||
#
|
||||
# This will recursively lookup values for relations. {{ client.site.id }}
|
||||
#
|
||||
#
|
||||
# You can also use {{ global.value }} without an obj instance to use the global key store
|
||||
def get_db_value(*, string: str, instance=None) -> Union[str, List, True, False, None]:
|
||||
def get_db_value(
|
||||
*, string: str, instance: Optional[Union["Agent", "Client", "Site"]] = None
|
||||
) -> Union[str, List[str], Literal[True], Literal[False], None]:
|
||||
from core.models import CustomField, GlobalKVStore
|
||||
|
||||
# get properties into an array
|
||||
props = string.strip().split(".")
|
||||
|
||||
model = props[0]
|
||||
|
||||
# value is in the global keystore and replace value
|
||||
if props[0] == "global" and len(props) == 2:
|
||||
try:
|
||||
try:
|
||||
return GlobalKVStore.objects.get(name=props[1]).value
|
||||
except GlobalKVStore.DoesNotExist:
|
||||
DebugLog.error(
|
||||
log_type=DebugLogType.SCRIPTING,
|
||||
message=f"Couldn't lookup value for: {string}. Make sure it exists in CoreSettings > Key Store", # type:ignore
|
||||
message=f"Couldn't lookup value for: {string}. Make sure it exists in CoreSettings > Key Store",
|
||||
)
|
||||
return None
|
||||
|
||||
if not instance:
|
||||
# instance must be set if not global property
|
||||
return None
|
||||
|
||||
|
||||
# custom field lookup
|
||||
try:
|
||||
# looking up custom field directly on this instance
|
||||
if len(props) == 2:
|
||||
field = CustomField.objects.get(model=props[0], name=props[1])
|
||||
model_fields = getattr(field, f"{props[0]}_fields")
|
||||
|
||||
try:
|
||||
# resolve the correct model id
|
||||
if props[0] != instance.__class__.__name__.lower():
|
||||
value = model_fields.get(
|
||||
**{props[0]: getattr(instance, props[0])}
|
||||
).value
|
||||
else:
|
||||
value = model_fields.get(**{f"{props[0]}_id": instance.id}).value
|
||||
|
||||
if field.type != CustomFieldType.CHECKBOX:
|
||||
if value:
|
||||
return value
|
||||
else:
|
||||
return field.default_value
|
||||
else:
|
||||
return bool(value)
|
||||
except:
|
||||
return (
|
||||
field.default_value
|
||||
if field.type != CustomFieldType.CHECKBOX
|
||||
else bool(field.default_value)
|
||||
)
|
||||
except CustomField.DoesNotExist:
|
||||
pass
|
||||
|
||||
# if the instance is the same as the first prop. We remove it.
|
||||
if props[0] == instance.__class__.__name__.lower():
|
||||
del props[0]
|
||||
|
||||
instance_value = instance
|
||||
|
||||
# look through all properties and return the value
|
||||
@@ -324,27 +369,16 @@ def get_db_value(*, string: str, instance=None) -> Union[str, List, True, False,
|
||||
if callable(value):
|
||||
return None
|
||||
instance_value = value
|
||||
else:
|
||||
try:
|
||||
field = CustomField.objects.get(model=props[0], name=prop)
|
||||
model_fields = getattr(field, f"{props[0]}_fields")
|
||||
try:
|
||||
value = model_fields.get(**{props[0]: instance}).value
|
||||
return value if field.type != CustomFieldType.CHECKBOX else bool(field.default_value)
|
||||
except:
|
||||
return field.default_value if field.type != CustomFieldType.CHECKBOX else bool(field.default_value)
|
||||
except CustomField.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
if not instance_value:
|
||||
return None
|
||||
|
||||
|
||||
return instance_value
|
||||
|
||||
|
||||
|
||||
def replace_arg_db_values(
|
||||
string: str, instance=None, shell: str = None, quotes=True # type:ignore
|
||||
) -> Union[str, None]:
|
||||
|
||||
# resolve the value
|
||||
value = get_db_value(string=string, instance=instance)
|
||||
|
||||
@@ -365,12 +399,8 @@ def replace_arg_db_values(
|
||||
|
||||
# format args for list
|
||||
elif isinstance(value, list):
|
||||
return (
|
||||
f"'{format_shell_array(value)}'"
|
||||
if quotes
|
||||
else format_shell_array(value)
|
||||
)
|
||||
|
||||
return f"'{format_shell_array(value)}'" if quotes else format_shell_array(value)
|
||||
|
||||
# format args for bool
|
||||
elif value is True or value is False:
|
||||
return format_shell_bool(value, shell)
|
||||
|
||||
Reference in New Issue
Block a user