1
0
Fork 0
mirror of https://gitlab.com/bramw/baserow.git synced 2025-04-03 04:35:31 +00:00

Add import/export for chart widget and its service

This commit is contained in:
Petr Stribny 2025-03-05 14:13:17 +00:00
parent 3f8354d946
commit 3d504a20c7
4 changed files with 526 additions and 22 deletions
enterprise/backend
src/baserow_enterprise
integrations/local_baserow
services
tests/baserow_enterprise_tests

View file

@ -1,3 +1,5 @@
import re
from django.conf import settings
from django.db.models import F
@ -204,6 +206,9 @@ class LocalBaserowGroupedAggregateRowsUserServiceType(
code="invalid_field",
)
if group_by["field_id"] is None:
return True
field = next(
(
field
@ -366,30 +371,52 @@ class LocalBaserowGroupedAggregateRowsUserServiceType(
if prop_name == "filters":
return self.serialize_filters(service)
# FIXME: aggregation_series, aggregation_group_bys
if prop_name == "service_aggregation_series":
return [
{
"field_id": series.field_id,
"aggregation_type": series.aggregation_type,
}
for series in service.service_aggregation_series.all()
]
if prop_name == "service_aggregation_group_bys":
return [
{
"field_id": group_by.field_id,
}
for group_by in service.service_aggregation_group_bys.all()
]
if prop_name == "service_aggregation_sorts":
return [
{
"sort_on": sort.sort_on,
"reference": sort.reference,
"direction": sort.direction,
}
for sort in service.service_aggregation_sorts.all()
]
return super().serialize_property(
service, prop_name, files_zip=files_zip, storage=storage, cache=cache
)
def deserialize_property(
def create_instance_from_serialized(
self,
prop_name: str,
value: any,
id_mapping: dict[str, any],
serialized_values,
id_mapping,
files_zip=None,
storage=None,
cache=None,
**kwargs,
):
if prop_name == "filters":
return self.deserialize_filters(value, id_mapping)
) -> "LocalBaserowGroupedAggregateRowsUserServiceType":
series = serialized_values.pop("service_aggregation_series", [])
group_bys = serialized_values.pop("service_aggregation_group_bys", [])
sorts = serialized_values.pop("service_aggregation_sorts", [])
# FIXME: aggregation_series, aggregation_group_bys
return super().deserialize_property(
prop_name,
value,
service = super().create_instance_from_serialized(
serialized_values,
id_mapping,
files_zip=files_zip,
storage=storage,
@ -397,6 +424,38 @@ class LocalBaserowGroupedAggregateRowsUserServiceType(
**kwargs,
)
if "database_fields" in id_mapping:
for current_series in series:
if current_series["field_id"] is not None:
current_series["field_id"] = id_mapping["database_fields"].get(
current_series["field_id"], None
)
for group_by in group_bys:
if group_by["field_id"] is not None:
group_by["field_id"] = id_mapping["database_fields"].get(
group_by["field_id"], None
)
for sort in sorts:
match = re.search(r"\d+", sort["reference"])
sort_field_id = match.group()
remapped_id = id_mapping["database_fields"].get(
int(sort_field_id), None
)
if remapped_id is not None:
sort["reference"] = sort["reference"].replace(
sort_field_id, str(remapped_id)
)
else:
sort["reference"] = None
self._update_service_aggregation_series(service, series)
self._update_service_aggregation_group_bys(service, group_bys)
self._update_service_sorts(
service, [sort for sort in sorts if sort["reference"] is not None]
)
return service
def dispatch_data(
self,
service: LocalBaserowGroupedAggregateRows,

View file

@ -2,12 +2,12 @@ from typing import TypedDict
class ServiceAggregationSeriesDict(TypedDict):
field_id: int
field_id: int | None
aggregation_type: str
class ServiceAggregationGroupByDict(TypedDict):
field_id: int
field_id: int | None
class ServiceAggregationSortByDict(TypedDict):

View file

@ -0,0 +1,271 @@
import json
from decimal import Decimal
from typing import cast
from django.contrib.contenttypes.models import ContentType
from django.test.utils import override_settings
import pytest
from baserow.contrib.dashboard.application_types import DashboardApplicationType
from baserow.contrib.dashboard.data_sources.models import DashboardDataSource
from baserow.contrib.dashboard.models import Dashboard
from baserow.contrib.dashboard.widgets.models import Widget
from baserow.contrib.dashboard.widgets.service import WidgetService
from baserow.contrib.integrations.local_baserow.models import LocalBaserowIntegration
from baserow.core.handler import CoreHandler
from baserow.core.integrations.models import Integration
from baserow.core.registries import ImportExportConfig
from baserow.core.utils import ChildProgressBuilder, Progress
from baserow_enterprise.dashboard.widgets.models import ChartWidget
from baserow_enterprise.integrations.local_baserow.models import (
LocalBaserowGroupedAggregateRows,
LocalBaserowTableServiceAggregationGroupBy,
LocalBaserowTableServiceAggregationSeries,
LocalBaserowTableServiceAggregationSortBy,
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_dashboard_export_serialized_with_chart_widget(enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
user = enterprise_data_fixture.create_user()
workspace = enterprise_data_fixture.create_workspace(user=user)
database = enterprise_data_fixture.create_database_application(
user=user, workspace=workspace
)
table = enterprise_data_fixture.create_database_table(database=database)
field = enterprise_data_fixture.create_number_field(table=table)
dashboard = cast(
Dashboard,
CoreHandler().create_application(
user,
workspace,
type_name="dashboard",
description="Dashboard description",
init_with_data=True,
),
)
integration = Integration.objects.filter(application=dashboard).first()
dashboard_widget = WidgetService().create_widget(
user, "chart", dashboard.id, title="Widget 1", description="Description 1"
)
service = dashboard_widget.data_source.service
service.table = table
service.save()
LocalBaserowTableServiceAggregationSeries.objects.create(
service=service, field=field, aggregation_type="sum", order=1
)
LocalBaserowTableServiceAggregationGroupBy.objects.create(
service=service, field=None, order=1
)
LocalBaserowTableServiceAggregationSortBy.objects.create(
service=service,
sort_on="SERIES",
reference=f"field_{field.id}_sum",
order=1,
direction="ASC",
)
serialized = DashboardApplicationType().export_serialized(
dashboard, ImportExportConfig(include_permission_data=True)
)
serialized = json.loads(json.dumps(serialized))
assert serialized == {
"id": dashboard.id,
"name": dashboard.name,
"description": "Dashboard description",
"order": dashboard.order,
"type": "dashboard",
"integrations": [
{
"authorized_user": user.email,
"id": integration.id,
"name": "",
"order": "1.00000000000000000000",
"type": "local_baserow",
},
],
"data_sources": [
{
"id": dashboard_widget.data_source.id,
"name": dashboard_widget.data_source.name,
"order": "1.00000000000000000000",
"service": {
"filter_type": "AND",
"filters": [],
"id": service.id,
"integration_id": service.integration.id,
"service_aggregation_group_bys": [
{"field_id": None},
],
"service_aggregation_series": [
{"aggregation_type": "sum", "field_id": field.id},
],
"service_aggregation_sorts": [
{
"direction": "ASC",
"reference": f"field_{field.id}_sum",
"sort_on": "SERIES",
},
],
"table_id": table.id,
"type": "local_baserow_grouped_aggregate_rows",
"view_id": None,
},
},
],
"widgets": [
{
"data_source_id": dashboard_widget.data_source.id,
"description": "Description 1",
"id": dashboard_widget.id,
"order": "1.00000000000000000000",
"title": "Widget 1",
"type": "chart",
},
],
"role_assignments": [],
}
@pytest.mark.django_db()
@override_settings(DEBUG=True)
def test_dashboard_import_serialized_with_widgets(enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
user = enterprise_data_fixture.create_user()
workspace = enterprise_data_fixture.create_workspace(user=user)
database = enterprise_data_fixture.create_database_application(
user=user, workspace=workspace
)
table = enterprise_data_fixture.create_database_table(database=database)
field = enterprise_data_fixture.create_number_field(table=table)
field_2 = enterprise_data_fixture.create_number_field(table=table, primary=True)
id_mapping = {
"database_tables": {1: table.id},
"database_fields": {1: field.id},
}
serialized = {
"id": "999",
"name": "Dashboard 1",
"description": "Description 1",
"order": 99,
"type": "dashboard",
"integrations": [
{
"authorized_user": user.email,
"id": 1,
"name": "IntegrationName",
"order": "1.00000000000000000000",
"type": "local_baserow",
},
],
"data_sources": [
{
"id": 1,
"name": "DataSource1",
"order": "1.00000000000000000000",
"service": {
"filter_type": "AND",
"filters": [],
"id": 1,
"integration_id": 1,
"service_aggregation_group_bys": [
{"field_id": None},
],
"service_aggregation_series": [
{"aggregation_type": "sum", "field_id": 1},
],
"service_aggregation_sorts": [
{
"direction": "ASC",
"reference": f"field_1_sum",
"sort_on": "SERIES",
},
],
"table_id": 1,
"type": "local_baserow_grouped_aggregate_rows",
"view_id": None,
},
},
],
"widgets": [
{
"data_source_id": 1,
"description": "Description 1",
"id": 45,
"order": "1.00000000000000000000",
"title": "Widget 1",
"type": "chart",
},
],
}
progress = Progress(100)
progress_builder = ChildProgressBuilder(parent=progress, represents_progress=100)
assert progress.progress == 0
dashboard = DashboardApplicationType().import_serialized(
workspace,
serialized,
ImportExportConfig(include_permission_data=True),
id_mapping,
progress_builder=progress_builder,
)
assert dashboard.name == "Dashboard 1"
assert dashboard.description == "Description 1"
assert dashboard.order == 99
integrations = Integration.objects.filter(application=dashboard)
integration = integrations[0].specific
assert integrations.count() == 1
assert integration.content_type == ContentType.objects.get_for_model(
LocalBaserowIntegration
)
assert integration.authorized_user.id == user.id
assert integration.name == "IntegrationName"
assert integration.order == Decimal("1.0")
data_sources = DashboardDataSource.objects.filter(dashboard=dashboard)
assert data_sources.count() == 1
ds1 = data_sources[0]
ds1.name = "DataSource1"
ds1.order = Decimal("1.0")
service = ds1.service.specific
assert service.content_type == ContentType.objects.get_for_model(
LocalBaserowGroupedAggregateRows
)
assert service.integration_id == integration.id
assert service.filter_type == "AND"
series = service.service_aggregation_series.all()
assert series.count() == 1
assert series[0].aggregation_type == "sum"
assert series[0].field_id == field.id
group_bys = service.service_aggregation_group_bys.all()
assert group_bys.count() == 1
assert group_bys[0].field_id is None
sorts = service.service_aggregation_sorts.all()
assert sorts.count() == 1
assert sorts[0].direction == "ASC"
assert sorts[0].sort_on == "SERIES"
assert sorts[0].reference == f"field_{field.id}_sum"
widgets = Widget.objects.filter(dashboard=dashboard)
assert widgets.count() == 1
widget1 = widgets[0].specific
assert widget1.content_type == ContentType.objects.get_for_model(ChartWidget)
assert widget1.title == "Widget 1"
assert widget1.description == "Description 1"
assert widget1.order == Decimal("1.0")
assert widget1.data_source.id == ds1.id
assert progress.progress == 100

View file

@ -7,9 +7,6 @@ import pytest
from rest_framework.exceptions import ValidationError
from baserow.contrib.database.rows.handler import RowHandler
from baserow.contrib.integrations.local_baserow.models import (
LocalBaserowTableServiceSort,
)
from baserow.core.services.exceptions import ServiceImproperlyConfigured
from baserow.core.services.handler import ServiceHandler
from baserow.core.services.registries import service_type_registry
@ -20,6 +17,9 @@ from baserow_enterprise.integrations.local_baserow.models import (
LocalBaserowTableServiceAggregationSeries,
LocalBaserowTableServiceAggregationSortBy,
)
from baserow_enterprise.integrations.local_baserow.service_types import (
LocalBaserowGroupedAggregateRowsUserServiceType,
)
def test_grouped_aggregate_rows_service_get_schema_name():
@ -1947,11 +1947,19 @@ def test_grouped_aggregate_rows_service_dispatch_sort_by_series_without_group_by
LocalBaserowTableServiceAggregationSeries.objects.create(
service=service, field=field_3, aggregation_type="sum", order=3
)
LocalBaserowTableServiceSort.objects.create(
service=service, field=field_3, order=1, order_by="ASC"
LocalBaserowTableServiceAggregationSortBy.objects.create(
service=service,
sort_on="SERIES",
reference=f"field_{field.id}_sum",
order=1,
direction="ASC",
)
LocalBaserowTableServiceSort.objects.create(
service=service, field=field_2, order=2, order_by="DESC"
LocalBaserowTableServiceAggregationSortBy.objects.create(
service=service,
sort_on="SERIES",
reference=f"field_{field_2.id}_sum",
order=2,
direction="DESC",
)
RowHandler().create_rows(
@ -2759,3 +2767,169 @@ def test_grouped_aggregate_rows_service_dispatch_max_buckets_sort_on_primary_fie
},
],
}
@pytest.mark.django_db
def test_grouped_aggregate_rows_service_export_serialized(
data_fixture,
):
user = data_fixture.create_user()
dashboard = data_fixture.create_dashboard_application(user=user)
table = data_fixture.create_database_table(user=user)
field = data_fixture.create_number_field(table=table)
field_2 = data_fixture.create_number_field(table=table)
field_3 = data_fixture.create_number_field(table=table)
view = data_fixture.create_grid_view(user=user, table=table)
integration = data_fixture.create_local_baserow_integration(
application=dashboard, user=user
)
service = data_fixture.create_service(
LocalBaserowGroupedAggregateRows,
integration=integration,
table=table,
view=view,
)
LocalBaserowTableServiceAggregationSeries.objects.create(
service=service, field=field, aggregation_type="sum", order=1
)
LocalBaserowTableServiceAggregationSeries.objects.create(
service=service, field=field_2, aggregation_type="min", order=2
)
LocalBaserowTableServiceAggregationSeries.objects.create(
service=service, field=field_3, aggregation_type="max", order=3
)
LocalBaserowTableServiceAggregationGroupBy.objects.create(
service=service, field=field_3, order=1
)
LocalBaserowTableServiceAggregationSortBy.objects.create(
service=service,
sort_on="SERIES",
reference=f"field_{field.id}_sum",
order=1,
direction="ASC",
)
LocalBaserowTableServiceAggregationSortBy.objects.create(
service=service,
sort_on="SERIES",
reference=f"field_{field_2.id}_min",
order=1,
direction="ASC",
)
result = LocalBaserowGroupedAggregateRowsUserServiceType().export_serialized(
service, import_export_config=None, files_zip=None, storage=None, cache=None
)
assert result == {
"filter_type": "AND",
"filters": [],
"id": service.id,
"integration_id": service.integration.id,
"service_aggregation_group_bys": [
{"field_id": field_3.id},
],
"service_aggregation_series": [
{"aggregation_type": "sum", "field_id": field.id},
{"aggregation_type": "min", "field_id": field_2.id},
{"aggregation_type": "max", "field_id": field_3.id},
],
"service_aggregation_sorts": [
{
"direction": "ASC",
"reference": f"field_{field.id}_sum",
"sort_on": "SERIES",
},
{
"direction": "ASC",
"reference": f"field_{field_2.id}_min",
"sort_on": "SERIES",
},
],
"table_id": table.id,
"type": "local_baserow_grouped_aggregate_rows",
"view_id": view.id,
}
@pytest.mark.django_db
def test_grouped_aggregate_rows_service_import_serialized(data_fixture):
user = data_fixture.create_user()
dashboard = data_fixture.create_dashboard_application(user=user)
table = data_fixture.create_database_table(user=user)
field = data_fixture.create_number_field(table=table)
field_2 = data_fixture.create_number_field(table=table)
field_3 = data_fixture.create_number_field(table=table)
view = data_fixture.create_grid_view(user=user, table=table)
integration = data_fixture.create_local_baserow_integration(
application=dashboard, user=user
)
serialized_service = {
"filter_type": "AND",
"filters": [],
"id": 999,
"integration_id": integration.id,
"service_aggregation_group_bys": [
{"field_id": field_3.id},
],
"service_aggregation_series": [
{"aggregation_type": "sum", "field_id": field.id},
{"aggregation_type": "min", "field_id": field_2.id},
{"aggregation_type": "max", "field_id": field_3.id},
],
"service_aggregation_sorts": [
{
"direction": "ASC",
"reference": f"field_{field.id}_sum",
"sort_on": "SERIES",
},
{
"direction": "DESC",
"reference": f"field_{field_2.id}_min",
"sort_on": "SERIES",
},
],
"table_id": table.id,
"type": "local_baserow_grouped_aggregate_rows",
"view_id": view.id,
}
id_mapping = {}
instance = LocalBaserowGroupedAggregateRowsUserServiceType().import_serialized(
parent=integration,
serialized_values=serialized_service,
id_mapping=id_mapping,
import_formula=Mock(),
)
assert instance.content_type == ContentType.objects.get_for_model(
LocalBaserowGroupedAggregateRows
)
assert instance.filter_type == "AND"
assert instance.service_filters.count() == 0
assert instance.id != 999
assert instance.integration_id == integration.id
assert instance.table_id == table.id
assert instance.view_id == view.id
series = instance.service_aggregation_series.all()
assert series.count() == 3
assert series[0].aggregation_type == "sum"
assert series[0].field_id == field.id
assert series[1].aggregation_type == "min"
assert series[1].field_id == field_2.id
assert series[2].aggregation_type == "max"
assert series[2].field_id == field_3.id
group_bys = instance.service_aggregation_group_bys.all()
assert group_bys.count() == 1
assert group_bys[0].field_id == field_3.id
sorts = instance.service_aggregation_sorts.all()
assert sorts.count() == 2
assert sorts[0].direction == "ASC"
assert sorts[0].sort_on == "SERIES"
assert sorts[0].reference == f"field_{field.id}_sum"
assert sorts[1].direction == "DESC"
assert sorts[1].sort_on == "SERIES"
assert sorts[1].reference == f"field_{field_2.id}_min"