mirror of
https://gitlab.com/bramw/baserow.git
synced 2025-04-12 16:28:06 +00:00
[2/2] AI improvements: AI field choice output type
This commit is contained in:
parent
861a2729e5
commit
d01dccd0e1
47 changed files with 1533 additions and 146 deletions
backend
src/baserow
contrib/database
test_utils
tests/baserow/contrib
database
api
field
import_export
ws
integrations/local_baserow
changelog/entries/unreleased/feature
enterprise/backend
src/baserow_enterprise/data_sync
tests/baserow_enterprise_tests/data_sync
premium
backend
web-frontend/modules/baserow_premium
web-frontend/modules
core/assets/scss/components/views/grid
database
|
@ -254,8 +254,20 @@ def construct_all_possible_field_kwargs(
|
||||||
"name": "ai",
|
"name": "ai",
|
||||||
"ai_generative_ai_type": "test_generative_ai",
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
"ai_generative_ai_model": "test_1",
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_output_type": "text",
|
||||||
"ai_prompt": "'Who are you?'",
|
"ai_prompt": "'Who are you?'",
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"name": "ai_choice",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_prompt": "'What are you?'",
|
||||||
|
"ai_output_type": "choice",
|
||||||
|
"select_options": [
|
||||||
|
{"id": 5, "value": "Object", "color": "orange"},
|
||||||
|
{"id": 6, "value": "Else", "color": "yellow"},
|
||||||
|
],
|
||||||
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
# If you have added a new field please add an entry into the dict above with any
|
# If you have added a new field please add an entry into the dict above with any
|
||||||
|
|
|
@ -975,7 +975,7 @@ class FieldType(
|
||||||
field_id = serialized_copy.pop("id")
|
field_id = serialized_copy.pop("id")
|
||||||
serialized_copy.pop("type")
|
serialized_copy.pop("type")
|
||||||
select_options = (
|
select_options = (
|
||||||
serialized_copy.pop("select_options")
|
serialized_copy.pop("select_options", [])
|
||||||
if self.can_have_select_options
|
if self.can_have_select_options
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
|
@ -1976,6 +1976,8 @@ class ViewHandler(metaclass=baserow_trace_methods(tracer)):
|
||||||
:return: The created view sort instance.
|
:return: The created view sort instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
field = field.specific
|
||||||
|
|
||||||
workspace = view.table.database.workspace
|
workspace = view.table.database.workspace
|
||||||
CoreHandler().check_permissions(
|
CoreHandler().check_permissions(
|
||||||
user, ReadFieldOperationType.type, workspace=workspace, context=field
|
user, ReadFieldOperationType.type, workspace=workspace, context=field
|
||||||
|
|
|
@ -274,6 +274,7 @@ def public_rows_updated(
|
||||||
table_id=PUBLIC_PLACEHOLDER_ENTITY_ID,
|
table_id=PUBLIC_PLACEHOLDER_ENTITY_ID,
|
||||||
serialized_rows_before_update=visible_fields_only_old_rows,
|
serialized_rows_before_update=visible_fields_only_old_rows,
|
||||||
serialized_rows=visible_fields_only_updated_rows,
|
serialized_rows=visible_fields_only_updated_rows,
|
||||||
|
updated_field_ids=list(updated_field_ids),
|
||||||
metadata={},
|
metadata={},
|
||||||
),
|
),
|
||||||
slug=public_view.slug,
|
slug=public_view.slug,
|
||||||
|
|
|
@ -81,6 +81,9 @@ def rows_updated(
|
||||||
serialized_rows=get_row_serializer_class(
|
serialized_rows=get_row_serializer_class(
|
||||||
model, RowSerializer, is_response=True
|
model, RowSerializer, is_response=True
|
||||||
)(rows, many=True).data,
|
)(rows, many=True).data,
|
||||||
|
# Broadcast a list of updated fields so that the listener can take
|
||||||
|
# action even if the value didn't change.
|
||||||
|
updated_field_ids=list(updated_field_ids),
|
||||||
metadata=row_metadata_registry.generate_and_merge_metadata_for_rows(
|
metadata=row_metadata_registry.generate_and_merge_metadata_for_rows(
|
||||||
user, table, [row.id for row in rows]
|
user, table, [row.id for row in rows]
|
||||||
),
|
),
|
||||||
|
@ -213,6 +216,7 @@ class RealtimeRowMessages:
|
||||||
serialized_rows_before_update: List[Dict[str, Any]],
|
serialized_rows_before_update: List[Dict[str, Any]],
|
||||||
serialized_rows: List[Dict[str, Any]],
|
serialized_rows: List[Dict[str, Any]],
|
||||||
metadata: Dict[int, Dict[str, Any]],
|
metadata: Dict[int, Dict[str, Any]],
|
||||||
|
updated_field_ids: List[int],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "rows_updated",
|
"type": "rows_updated",
|
||||||
|
@ -223,6 +227,7 @@ class RealtimeRowMessages:
|
||||||
"rows_before_update": serialized_rows_before_update,
|
"rows_before_update": serialized_rows_before_update,
|
||||||
"rows": serialized_rows,
|
"rows": serialized_rows,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
|
"updated_field_ids": updated_field_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -250,6 +250,9 @@ def setup_interesting_test_table(
|
||||||
"phone_number": "+4412345678",
|
"phone_number": "+4412345678",
|
||||||
"password": "test",
|
"password": "test",
|
||||||
"ai": "I'm an AI.",
|
"ai": "I'm an AI.",
|
||||||
|
"ai_choice": SelectOption.objects.get(
|
||||||
|
value="Object", field_id=name_to_field_id["ai_choice"]
|
||||||
|
).id,
|
||||||
}
|
}
|
||||||
|
|
||||||
with freeze_time("2020-02-01 01:23"):
|
with freeze_time("2020-02-01 01:23"):
|
||||||
|
|
|
@ -378,6 +378,11 @@ def test_get_row_serializer_with_user_field_names(data_fixture):
|
||||||
"autonumber": 2,
|
"autonumber": 2,
|
||||||
"password": True,
|
"password": True,
|
||||||
"ai": "I'm an AI.",
|
"ai": "I'm an AI.",
|
||||||
|
"ai_choice": {
|
||||||
|
"color": "orange",
|
||||||
|
"id": SelectOption.objects.get(value="Object").id,
|
||||||
|
"value": "Object",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -84,6 +84,8 @@ def test_serialize_group_by_metadata_on_all_fields_in_interesting_table(data_fix
|
||||||
|
|
||||||
actual_result_per_field_name = {}
|
actual_result_per_field_name = {}
|
||||||
|
|
||||||
|
ai_choice_select_options = Field.objects.get(name="ai_choice").select_options.all()
|
||||||
|
|
||||||
for field in fields_to_group_by:
|
for field in fields_to_group_by:
|
||||||
counts = handler.get_group_by_metadata_in_rows([field], rows, queryset)
|
counts = handler.get_group_by_metadata_in_rows([field], rows, queryset)
|
||||||
serialized = serialize_group_by_metadata(counts)[field.db_column]
|
serialized = serialize_group_by_metadata(counts)[field.db_column]
|
||||||
|
@ -258,4 +260,12 @@ def test_serialize_group_by_metadata_on_all_fields_in_interesting_table(data_fix
|
||||||
{"count": 1, "field_duration_dhms": 90066.0},
|
{"count": 1, "field_duration_dhms": 90066.0},
|
||||||
{"count": 1, "field_duration_dhms": None},
|
{"count": 1, "field_duration_dhms": None},
|
||||||
],
|
],
|
||||||
|
"ai": [
|
||||||
|
{"count": 1, "field_ai": "I'm an AI."},
|
||||||
|
{"count": 1, "field_ai": None},
|
||||||
|
],
|
||||||
|
"ai_choice": [
|
||||||
|
{"count": 1, "field_ai_choice": ai_choice_select_options[0].id},
|
||||||
|
{"count": 1, "field_ai_choice": None},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
|
@ -303,7 +303,9 @@ def test_link_row_field_type_with_text_values(data_fixture):
|
||||||
for field_type in [
|
for field_type in [
|
||||||
f
|
f
|
||||||
for f in field_type_registry.get_all()
|
for f in field_type_registry.get_all()
|
||||||
if f.can_get_unique_values and not f.read_only
|
# The AI field is not compatible because it some field kwargs are required and
|
||||||
|
# not passed in.
|
||||||
|
if f.can_get_unique_values and not f.read_only and f.type != "ai"
|
||||||
]:
|
]:
|
||||||
field_type_name = field_type.type
|
field_type_name = field_type.type
|
||||||
field_name = f"Field {field_type_name}"
|
field_name = f"Field {field_type_name}"
|
||||||
|
|
|
@ -247,12 +247,12 @@ def test_can_export_every_interesting_different_field_to_csv(
|
||||||
"phone_number,formula_text,formula_int,formula_bool,formula_decimal,formula_dateinterval,"
|
"phone_number,formula_text,formula_int,formula_bool,formula_decimal,formula_dateinterval,"
|
||||||
"formula_date,formula_singleselect,formula_email,formula_link_with_label,"
|
"formula_date,formula_singleselect,formula_email,formula_link_with_label,"
|
||||||
"formula_link_url_only,formula_multipleselect,count,rollup,duration_rollup_sum,"
|
"formula_link_url_only,formula_multipleselect,count,rollup,duration_rollup_sum,"
|
||||||
"duration_rollup_avg,lookup,uuid,autonumber,password,ai\r\n"
|
"duration_rollup_avg,lookup,uuid,autonumber,password,ai,ai_choice\r\n"
|
||||||
"1,,,,,,,,,0,False,,,,,,,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
"1,,,,,,,,,0,False,,,,,,,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
||||||
"02/01/2021 13:00,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,02/01/2021 13:00,"
|
"02/01/2021 13:00,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,02/01/2021 13:00,"
|
||||||
"user@example.com,user@example.com,,,,,,,,,,,,,,,,,,,test FORMULA,1,True,33.3333333333,"
|
"user@example.com,user@example.com,,,,,,,,,,,,,,,,,,,test FORMULA,1,True,33.3333333333,"
|
||||||
"1d 0:00,2020-01-01,,,label (https://google.com),https://google.com,,0,0.000,"
|
"1d 0:00,2020-01-01,,,label (https://google.com),https://google.com,,0,0.000,"
|
||||||
"0:00,0:00,,00000000-0000-4000-8000-000000000001,1,,\r\n"
|
"0:00,0:00,,00000000-0000-4000-8000-000000000001,1,,,\r\n"
|
||||||
"2,text,long_text,https://www.google.com,test@example.com,-1,1,-1.2,1.2,3,True,"
|
"2,text,long_text,https://www.google.com,test@example.com,-1,1,-1.2,1.2,3,True,"
|
||||||
"02/01/2020 01:23,02/01/2020,01/02/2020 01:23,01/02/2020,01/02/2020 02:23,"
|
"02/01/2020 01:23,02/01/2020,01/02/2020 01:23,01/02/2020,01/02/2020 02:23,"
|
||||||
"01/02/2020 02:23,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
"01/02/2020 02:23,01/02/2021 12:00,01/02/2021,02/01/2021 12:00,02/01/2021,"
|
||||||
|
@ -267,7 +267,7 @@ def test_can_export_every_interesting_different_field_to_csv(
|
||||||
'"user2@example.com,user3@example.com",\'+4412345678,test FORMULA,1,True,33.3333333333,'
|
'"user2@example.com,user3@example.com",\'+4412345678,test FORMULA,1,True,33.3333333333,'
|
||||||
"1d 0:00,2020-01-01,A,test@example.com,label (https://google.com),https://google.com,"
|
"1d 0:00,2020-01-01,A,test@example.com,label (https://google.com),https://google.com,"
|
||||||
'"C,D,E",3,-122.222,0:04,0:02,"linked_row_1,linked_row_2,",'
|
'"C,D,E",3,-122.222,0:04,0:02,"linked_row_1,linked_row_2,",'
|
||||||
"00000000-0000-4000-8000-000000000002,2,True,I'm an AI.\r\n"
|
"00000000-0000-4000-8000-000000000002,2,True,I'm an AI.,Object\r\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert contents == expected
|
assert contents == expected
|
||||||
|
|
|
@ -1014,6 +1014,7 @@ def test_batch_update_rows_some_not_visible_in_public_view_to_be_visible_event_s
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"updated_field_ids": [hidden_field.id],
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
@ -1151,6 +1152,7 @@ def test_batch_update_rows_visible_in_public_view_to_some_not_be_visible_event_s
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"updated_field_ids": [hidden_field.id],
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
@ -1454,6 +1456,7 @@ def test_given_row_visible_in_public_view_when_updated_to_still_be_visible_event
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"updated_field_ids": [visible_field.id, hidden_field.id],
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
@ -1580,6 +1583,7 @@ def test_batch_update_rows_visible_in_public_view_still_be_visible_event_sent(
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"updated_field_ids": [visible_field.id, hidden_field.id],
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
@ -1661,6 +1665,7 @@ def test_batch_update_subset_rows_visible_in_public_view_no_filters(
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"updated_field_ids": [visible_field.id],
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
@ -2021,6 +2026,7 @@ def test_given_row_visible_in_public_view_when_moved_row_updated_sent(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"updated_field_ids": [],
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -315,6 +315,7 @@ def test_rows_history_updated(mock_broadcast_channel_group, data_fixture):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"updated_field_ids": [field.id],
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -864,6 +864,21 @@ def test_local_baserow_table_service_generate_schema_with_interesting_test_table
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
|
field_db_column_by_name["ai_choice"]: {
|
||||||
|
"title": "ai_choice",
|
||||||
|
"default": None,
|
||||||
|
"searchable": True,
|
||||||
|
"sortable": True,
|
||||||
|
"filterable": False,
|
||||||
|
"original_type": "ai",
|
||||||
|
"metadata": {},
|
||||||
|
"properties": {
|
||||||
|
"color": {"title": "color", "type": "string"},
|
||||||
|
"id": {"title": "id", "type": "number"},
|
||||||
|
"value": {"title": "value", "type": "string"},
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
"id": {
|
"id": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"title": "Id",
|
"title": "Id",
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
{
|
||||||
|
"type": "feature",
|
||||||
|
"message": "AI choice output type for classification purposes.",
|
||||||
|
"issue_number": 3143,
|
||||||
|
"bullet_points": [],
|
||||||
|
"created_at": "2024-11-10"
|
||||||
|
}
|
|
@ -5,6 +5,7 @@ from uuid import UUID
|
||||||
from django.db.models import Prefetch
|
from django.db.models import Prefetch
|
||||||
|
|
||||||
from baserow_premium.fields.field_types import AIFieldType
|
from baserow_premium.fields.field_types import AIFieldType
|
||||||
|
from baserow_premium.fields.registries import ai_field_output_registry
|
||||||
from baserow_premium.license.handler import LicenseHandler
|
from baserow_premium.license.handler import LicenseHandler
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
@ -33,7 +34,6 @@ from baserow.contrib.database.fields.field_types import (
|
||||||
from baserow.contrib.database.fields.models import (
|
from baserow.contrib.database.fields.models import (
|
||||||
DateField,
|
DateField,
|
||||||
Field,
|
Field,
|
||||||
LongTextField,
|
|
||||||
NumberField,
|
NumberField,
|
||||||
SelectOption,
|
SelectOption,
|
||||||
TextField,
|
TextField,
|
||||||
|
@ -51,16 +51,21 @@ from baserow_enterprise.features import DATA_SYNC
|
||||||
from .models import LocalBaserowTableDataSync
|
from .models import LocalBaserowTableDataSync
|
||||||
|
|
||||||
|
|
||||||
def prepare_single_select_value(value, enabled_property):
|
def prepare_single_select_value(value, field, metadata):
|
||||||
try:
|
try:
|
||||||
# The metadata contains a mapping of the select options where the key is the
|
# The metadata contains a mapping of the select options where the key is the
|
||||||
# old ID and the value is the new ID. For some reason the key is converted to
|
# old ID and the value is the new ID. For some reason the key is converted to
|
||||||
# a string when moved into the JSON field.
|
# a string when moved into the JSON field.
|
||||||
return enabled_property.metadata["select_options_mapping"][str(value)]
|
return metadata["select_options_mapping"][str(value)]
|
||||||
except (KeyError, TypeError):
|
except (KeyError, TypeError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_ai_value(value, field, metadata):
|
||||||
|
output_type = ai_field_output_registry.get(field.ai_output_type)
|
||||||
|
return output_type.prepare_data_sync_value(value, field, metadata)
|
||||||
|
|
||||||
|
|
||||||
# List of Baserow supported field types that can be included in the data sync.
|
# List of Baserow supported field types that can be included in the data sync.
|
||||||
supported_field_types = {
|
supported_field_types = {
|
||||||
TextFieldType.type: {},
|
TextFieldType.type: {},
|
||||||
|
@ -78,7 +83,7 @@ supported_field_types = {
|
||||||
LastModifiedFieldType.type: {},
|
LastModifiedFieldType.type: {},
|
||||||
UUIDFieldType.type: {},
|
UUIDFieldType.type: {},
|
||||||
AutonumberFieldType.type: {},
|
AutonumberFieldType.type: {},
|
||||||
AIFieldType.type: {},
|
AIFieldType.type: {"prepare_value": prepare_ai_value},
|
||||||
SingleSelectFieldType.type: {"prepare_value": prepare_single_select_value},
|
SingleSelectFieldType.type: {"prepare_value": prepare_single_select_value},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,7 +104,6 @@ class BaserowFieldDataSyncProperty(DataSyncProperty):
|
||||||
LastModifiedFieldType.type: DateField,
|
LastModifiedFieldType.type: DateField,
|
||||||
UUIDFieldType.type: TextField,
|
UUIDFieldType.type: TextField,
|
||||||
AutonumberFieldType.type: NumberField,
|
AutonumberFieldType.type: NumberField,
|
||||||
AIFieldType.type: LongTextField,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, field, immutable_properties, **kwargs):
|
def __init__(self, field, immutable_properties, **kwargs):
|
||||||
|
@ -332,7 +336,9 @@ class LocalBaserowTableDataSyncType(DataSyncType):
|
||||||
if "prepare_value" in supported_field:
|
if "prepare_value" in supported_field:
|
||||||
for row in rows_queryset:
|
for row in rows_queryset:
|
||||||
row[enabled_property.key] = supported_field["prepare_value"](
|
row[enabled_property.key] = supported_field["prepare_value"](
|
||||||
row[enabled_property.key], enabled_property
|
row[enabled_property.key],
|
||||||
|
enabled_property.field,
|
||||||
|
enabled_property.metadata,
|
||||||
)
|
)
|
||||||
progress.increment(by=2) # makes the total `10`
|
progress.increment(by=2) # makes the total `10`
|
||||||
|
|
||||||
|
|
|
@ -389,6 +389,7 @@ def test_sync_data_sync_table_with_interesting_table_as_source(enterprise_data_f
|
||||||
"last_modified_datetime_eu_tzone": "02/01/2021 13:00",
|
"last_modified_datetime_eu_tzone": "02/01/2021 13:00",
|
||||||
"autonumber": "1",
|
"autonumber": "1",
|
||||||
"ai": "",
|
"ai": "",
|
||||||
|
"ai_choice": "",
|
||||||
"uuid": "00000000-0000-4000-8000-000000000001",
|
"uuid": "00000000-0000-4000-8000-000000000001",
|
||||||
}
|
}
|
||||||
assert results == {
|
assert results == {
|
||||||
|
@ -432,6 +433,7 @@ def test_sync_data_sync_table_with_interesting_table_as_source(enterprise_data_f
|
||||||
"last_modified_datetime_eu_tzone": "02/01/2021 13:00",
|
"last_modified_datetime_eu_tzone": "02/01/2021 13:00",
|
||||||
"autonumber": "2",
|
"autonumber": "2",
|
||||||
"ai": "I'm an AI.",
|
"ai": "I'm an AI.",
|
||||||
|
"ai_choice": "Object",
|
||||||
"uuid": "00000000-0000-4000-8000-000000000002",
|
"uuid": "00000000-0000-4000-8000-000000000002",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,13 +23,21 @@ class BaserowPremiumConfig(AppConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
from .fields.actions import GenerateFormulaWithAIActionType
|
from .fields.actions import GenerateFormulaWithAIActionType
|
||||||
|
from .fields.ai_field_output_types import (
|
||||||
|
ChoiceAIFieldOutputType,
|
||||||
|
TextAIFieldOutputType,
|
||||||
|
)
|
||||||
from .fields.field_converters import AIFieldConverter
|
from .fields.field_converters import AIFieldConverter
|
||||||
from .fields.field_types import AIFieldType
|
from .fields.field_types import AIFieldType
|
||||||
|
from .fields.registries import ai_field_output_registry
|
||||||
|
|
||||||
field_type_registry.register(AIFieldType())
|
field_type_registry.register(AIFieldType())
|
||||||
|
|
||||||
field_converter_registry.register(AIFieldConverter())
|
field_converter_registry.register(AIFieldConverter())
|
||||||
|
|
||||||
|
ai_field_output_registry.register(TextAIFieldOutputType())
|
||||||
|
ai_field_output_registry.register(ChoiceAIFieldOutputType())
|
||||||
|
|
||||||
from baserow.contrib.database.rows.registries import row_metadata_registry
|
from baserow.contrib.database.rows.registries import row_metadata_registry
|
||||||
from baserow.contrib.database.views.registries import (
|
from baserow.contrib.database.views.registries import (
|
||||||
decorator_type_registry,
|
decorator_type_registry,
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
|
from difflib import get_close_matches
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.output_parsers.enum import EnumOutputParser
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
|
from baserow.contrib.database.fields.field_types import (
|
||||||
|
LongTextFieldType,
|
||||||
|
SingleSelectFieldType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .registries import AIFieldOutputType
|
||||||
|
|
||||||
|
|
||||||
|
class TextAIFieldOutputType(AIFieldOutputType):
|
||||||
|
type = "text"
|
||||||
|
baserow_field_type = LongTextFieldType
|
||||||
|
|
||||||
|
|
||||||
|
class StrictEnumOutputParser(EnumOutputParser):
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
json_array = json.dumps(self._valid_values)
|
||||||
|
return f"""Categorize the result following these requirements:
|
||||||
|
|
||||||
|
- Select only one option from the JSON array below.
|
||||||
|
- Don't use quotes or commas or partial values, just the option name.
|
||||||
|
- Choose the option that most closely matches the row values.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{json_array}
|
||||||
|
```""" # nosec this falsely marks as hardcoded sql expression, but it's not related
|
||||||
|
# to SQL at all.
|
||||||
|
|
||||||
|
def parse(self, response: str) -> Any:
|
||||||
|
response = response.strip()
|
||||||
|
# Sometimes the LLM responds with a quotes value or with part of the value if
|
||||||
|
# it contains a comma. Finding the close matches helps with selecting the
|
||||||
|
# right value.
|
||||||
|
closest_matches = get_close_matches(
|
||||||
|
response, self._valid_values, n=1, cutoff=0.0
|
||||||
|
)
|
||||||
|
return super().parse(closest_matches[0])
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceAIFieldOutputType(AIFieldOutputType):
|
||||||
|
type = "choice"
|
||||||
|
baserow_field_type = SingleSelectFieldType
|
||||||
|
|
||||||
|
def get_output_parser(self, ai_field):
|
||||||
|
choices = enum.Enum(
|
||||||
|
"Choices",
|
||||||
|
{
|
||||||
|
f"OPTION_{option.id}": option.value
|
||||||
|
for option in ai_field.select_options.all()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return StrictEnumOutputParser(enum=choices)
|
||||||
|
|
||||||
|
def format_prompt(self, prompt, ai_field):
|
||||||
|
output_parser = self.get_output_parser(ai_field)
|
||||||
|
format_instructions = output_parser.get_format_instructions()
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template=prompt + "Given this user query: \n\n{format_instructions}",
|
||||||
|
input_variables=[],
|
||||||
|
partial_variables={"format_instructions": format_instructions},
|
||||||
|
)
|
||||||
|
message = prompt.format()
|
||||||
|
return message
|
||||||
|
|
||||||
|
def parse_output(self, output, ai_field):
|
||||||
|
if not output:
|
||||||
|
return None
|
||||||
|
|
||||||
|
output_parser = self.get_output_parser(ai_field)
|
||||||
|
try:
|
||||||
|
parsed_output = output_parser.parse(output)
|
||||||
|
except OutputParserException:
|
||||||
|
return None
|
||||||
|
select_option_id = int(parsed_output.name.split("_")[1])
|
||||||
|
try:
|
||||||
|
return next(
|
||||||
|
o for o in ai_field.select_options.all() if o.id == select_option_id
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def prepare_data_sync_value(self, value, field, metadata):
|
||||||
|
try:
|
||||||
|
# The metadata contains a mapping of the select options where the key is the
|
||||||
|
# old ID and the value is the new ID. For some reason the key is converted
|
||||||
|
# to a string when moved into the JSON field.
|
||||||
|
return int(metadata["select_options_mapping"][str(value)])
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
return None
|
|
@ -1,5 +1,4 @@
|
||||||
from baserow.contrib.database.fields.field_converters import RecreateFieldConverter
|
from baserow.contrib.database.fields.field_converters import RecreateFieldConverter
|
||||||
from baserow.contrib.database.fields.models import LongTextField, TextField
|
|
||||||
|
|
||||||
from .models import AIField
|
from .models import AIField
|
||||||
|
|
||||||
|
@ -10,5 +9,5 @@ class AIFieldConverter(RecreateFieldConverter):
|
||||||
def is_applicable(self, from_model, from_field, to_field):
|
def is_applicable(self, from_model, from_field, to_field):
|
||||||
from_ai = isinstance(from_field, AIField)
|
from_ai = isinstance(from_field, AIField)
|
||||||
to_ai = isinstance(to_field, AIField)
|
to_ai = isinstance(to_field, AIField)
|
||||||
to_text_fields = isinstance(to_field, (TextField, LongTextField))
|
# If any field converts to the AI field, then we want to recreate the field
|
||||||
return from_ai and not (to_text_fields or to_ai) or not from_ai and to_ai
|
return not from_ai and to_ai
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.db import IntegrityError, models
|
from django.db import IntegrityError
|
||||||
from django.db.models import Value
|
from django.utils.functional import lazy
|
||||||
|
|
||||||
from baserow_premium.api.fields.exceptions import (
|
from baserow_premium.api.fields.exceptions import (
|
||||||
ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD,
|
ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD,
|
||||||
|
@ -15,15 +15,13 @@ from baserow.api.generative_ai.errors import (
|
||||||
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
|
ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
|
||||||
)
|
)
|
||||||
from baserow.contrib.database.api.fields.errors import ERROR_FIELD_DOES_NOT_EXIST
|
from baserow.contrib.database.api.fields.errors import ERROR_FIELD_DOES_NOT_EXIST
|
||||||
from baserow.contrib.database.fields.field_filters import (
|
from baserow.contrib.database.fields.field_types import (
|
||||||
contains_filter,
|
CollationSortMixin,
|
||||||
contains_word_filter,
|
SelectOptionBaseFieldType,
|
||||||
)
|
)
|
||||||
from baserow.contrib.database.fields.field_types import CollationSortMixin, TextField
|
|
||||||
from baserow.contrib.database.fields.models import Field
|
from baserow.contrib.database.fields.models import Field
|
||||||
from baserow.contrib.database.fields.registries import FieldType
|
from baserow.contrib.database.fields.registries import field_type_registry
|
||||||
from baserow.contrib.database.formula import BaserowFormulaTextType, BaserowFormulaType
|
from baserow.contrib.database.formula import BaserowFormulaType
|
||||||
from baserow.core.db import collate_expression
|
|
||||||
from baserow.core.formula.serializers import FormulaSerializerField
|
from baserow.core.formula.serializers import FormulaSerializerField
|
||||||
from baserow.core.generative_ai.exceptions import (
|
from baserow.core.generative_ai.exceptions import (
|
||||||
GenerativeAITypeDoesNotExist,
|
GenerativeAITypeDoesNotExist,
|
||||||
|
@ -35,6 +33,7 @@ from baserow.core.generative_ai.registries import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .models import AIField
|
from .models import AIField
|
||||||
|
from .registries import ai_field_output_registry
|
||||||
from .visitors import replace_field_id_references
|
from .visitors import replace_field_id_references
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
|
@ -43,7 +42,7 @@ if TYPE_CHECKING:
|
||||||
from baserow.contrib.database.table.models import GeneratedTableModel
|
from baserow.contrib.database.table.models import GeneratedTableModel
|
||||||
|
|
||||||
|
|
||||||
class AIFieldType(CollationSortMixin, FieldType):
|
class AIFieldType(CollationSortMixin, SelectOptionBaseFieldType):
|
||||||
"""
|
"""
|
||||||
The AI field can automatically query a generative AI model based on the provided
|
The AI field can automatically query a generative AI model based on the provided
|
||||||
prompt. It's possible to reference other fields to generate a unique output.
|
prompt. It's possible to reference other fields to generate a unique output.
|
||||||
|
@ -53,21 +52,29 @@ class AIFieldType(CollationSortMixin, FieldType):
|
||||||
model_class = AIField
|
model_class = AIField
|
||||||
can_be_in_form_view = False
|
can_be_in_form_view = False
|
||||||
keep_data_on_duplication = True
|
keep_data_on_duplication = True
|
||||||
allowed_fields = [
|
allowed_fields = SelectOptionBaseFieldType.allowed_fields + [
|
||||||
"ai_generative_ai_type",
|
"ai_generative_ai_type",
|
||||||
"ai_generative_ai_model",
|
"ai_generative_ai_model",
|
||||||
|
"ai_output_type",
|
||||||
"ai_temperature",
|
"ai_temperature",
|
||||||
"ai_prompt",
|
"ai_prompt",
|
||||||
"ai_file_field_id",
|
"ai_file_field_id",
|
||||||
]
|
]
|
||||||
serializer_field_names = [
|
serializer_field_names = SelectOptionBaseFieldType.allowed_fields + [
|
||||||
"ai_generative_ai_type",
|
"ai_generative_ai_type",
|
||||||
"ai_generative_ai_model",
|
"ai_generative_ai_model",
|
||||||
|
"ai_output_type",
|
||||||
"ai_temperature",
|
"ai_temperature",
|
||||||
"ai_prompt",
|
"ai_prompt",
|
||||||
"ai_file_field_id",
|
"ai_file_field_id",
|
||||||
]
|
]
|
||||||
serializer_field_overrides = {
|
serializer_field_overrides = {
|
||||||
|
"ai_output_type": serializers.ChoiceField(
|
||||||
|
required=False,
|
||||||
|
choices=lazy(ai_field_output_registry.get_types, list)(),
|
||||||
|
help_text="The desired output type of the field. It will try to force the "
|
||||||
|
"response of the prompt to match it.",
|
||||||
|
),
|
||||||
"ai_temperature": serializers.FloatField(
|
"ai_temperature": serializers.FloatField(
|
||||||
required=False,
|
required=False,
|
||||||
allow_null=True,
|
allow_null=True,
|
||||||
|
@ -89,6 +96,7 @@ class AIFieldType(CollationSortMixin, FieldType):
|
||||||
allow_null=True,
|
allow_null=True,
|
||||||
default=None,
|
default=None,
|
||||||
),
|
),
|
||||||
|
**SelectOptionBaseFieldType.serializer_field_overrides,
|
||||||
}
|
}
|
||||||
api_exceptions_map = {
|
api_exceptions_map = {
|
||||||
GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
|
GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
|
||||||
|
@ -96,52 +104,173 @@ class AIFieldType(CollationSortMixin, FieldType):
|
||||||
GenerativeAITypeDoesNotSupportFileField: ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD,
|
GenerativeAITypeDoesNotSupportFileField: ERROR_GENERATIVE_AI_DOES_NOT_SUPPORT_FILE_FIELD,
|
||||||
IntegrityError: ERROR_FIELD_DOES_NOT_EXIST,
|
IntegrityError: ERROR_FIELD_DOES_NOT_EXIST,
|
||||||
}
|
}
|
||||||
can_get_unique_values = False
|
can_get_unique_values = True
|
||||||
|
can_have_select_options = True
|
||||||
|
|
||||||
|
def get_baserow_field_type(self, instance):
|
||||||
|
output_type = ai_field_output_registry.get(instance.ai_output_type)
|
||||||
|
baserow_field_type = field_type_registry.get_by_type(
|
||||||
|
output_type.baserow_field_type
|
||||||
|
)
|
||||||
|
return baserow_field_type
|
||||||
|
|
||||||
def get_serializer_field(self, instance, **kwargs):
|
def get_serializer_field(self, instance, **kwargs):
|
||||||
required = kwargs.get("required", False)
|
kwargs["read_only"] = True
|
||||||
return serializers.CharField(
|
baserow_field_type = self.get_baserow_field_type(instance)
|
||||||
**{
|
return baserow_field_type.get_serializer_field(instance, **kwargs)
|
||||||
"required": required,
|
|
||||||
"allow_null": not required,
|
def get_response_serializer_field(self, instance, **kwargs):
|
||||||
"allow_blank": not required,
|
baserow_field_type = self.get_baserow_field_type(instance)
|
||||||
**kwargs,
|
return baserow_field_type.get_response_serializer_field(instance, **kwargs)
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_model_field(self, instance, **kwargs):
|
def get_model_field(self, instance, **kwargs):
|
||||||
return models.TextField(null=True, **kwargs)
|
baserow_field_type = self.get_baserow_field_type(instance)
|
||||||
|
return baserow_field_type.get_model_field(instance, **kwargs)
|
||||||
|
|
||||||
def get_serializer_help_text(self, instance):
|
def get_serializer_help_text(self, instance):
|
||||||
return (
|
return (
|
||||||
"Holds a text value that is generated by a generative AI model using a "
|
"Holds a value that is generated by a generative AI model using a "
|
||||||
"dynamic prompt."
|
"dynamic prompt."
|
||||||
)
|
)
|
||||||
|
|
||||||
def random_value(self, instance, fake, cache):
|
def random_value(self, instance, fake, cache):
|
||||||
return fake.name()
|
baserow_field_type = self.get_baserow_field_type(instance)
|
||||||
|
return baserow_field_type.random_value(instance, fake, cache)
|
||||||
|
|
||||||
def to_baserow_formula_type(self, field) -> BaserowFormulaType:
|
def to_baserow_formula_type(self, field) -> BaserowFormulaType:
|
||||||
return BaserowFormulaTextType(nullable=True)
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.to_baserow_formula_type(field)
|
||||||
def from_baserow_formula_type(
|
|
||||||
self, formula_type: BaserowFormulaTextType
|
|
||||||
) -> TextField:
|
|
||||||
return TextField()
|
|
||||||
|
|
||||||
def get_value_for_filter(self, row: "GeneratedTableModel", field: Field) -> any:
|
def get_value_for_filter(self, row: "GeneratedTableModel", field: Field) -> any:
|
||||||
value = getattr(row, field.db_column)
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
return collate_expression(Value(value))
|
return baserow_field_type.get_value_for_filter(row, field)
|
||||||
|
|
||||||
def contains_query(self, *args):
|
def get_alter_column_prepare_old_value(self, connection, from_field, to_field):
|
||||||
return contains_filter(*args)
|
baserow_field_type = self.get_baserow_field_type(from_field)
|
||||||
|
return baserow_field_type.get_alter_column_prepare_old_value(
|
||||||
|
connection, from_field, to_field
|
||||||
|
)
|
||||||
|
|
||||||
def contains_word_query(self, *args):
|
def get_alter_column_prepare_new_value(self, connection, from_field, to_field):
|
||||||
return contains_word_filter(*args)
|
baserow_field_type = self.get_baserow_field_type(to_field)
|
||||||
|
return baserow_field_type.get_alter_column_prepare_new_value(
|
||||||
|
connection, from_field, to_field
|
||||||
|
)
|
||||||
|
|
||||||
|
def contains_query(self, field_name, value, model_field, field):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.contains_query(field_name, value, model_field, field)
|
||||||
|
|
||||||
|
def contains_word_query(self, field_name, value, model_field, field):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.contains_word_query(
|
||||||
|
field_name, value, model_field, field
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_can_order_by(self, field: Field) -> bool:
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.check_can_order_by(field)
|
||||||
|
|
||||||
|
def check_can_group_by(self, field: Field) -> bool:
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.check_can_group_by(field)
|
||||||
|
|
||||||
|
def get_search_expression(self, field: Field, queryset):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.get_search_expression(field, queryset)
|
||||||
|
|
||||||
|
def is_searchable(self, field):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.is_searchable(field)
|
||||||
|
|
||||||
|
def enhance_queryset(self, queryset, field, name, **kwargs):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.enhance_queryset(queryset, field, name)
|
||||||
|
|
||||||
|
def get_order(self, field, field_name, order_direction):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.get_order(field, field_name, order_direction)
|
||||||
|
|
||||||
|
def serialize_to_input_value(self, field, value):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.serialize_to_input_value(field, value)
|
||||||
|
|
||||||
|
def valid_for_bulk_update(self, field):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.valid_for_bulk_update(field)
|
||||||
|
|
||||||
|
def prepare_value_for_db(self, instance, value):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(instance)
|
||||||
|
return baserow_field_type.prepare_value_for_db(instance, value)
|
||||||
|
|
||||||
|
def prepare_value_for_db_in_bulk(
|
||||||
|
self, instance, values_by_row, continue_on_error=False
|
||||||
|
):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(instance)
|
||||||
|
return baserow_field_type.prepare_value_for_db_in_bulk(
|
||||||
|
instance, values_by_row, continue_on_error
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_group_by_serializer_field(self, field, **kwargs):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.get_group_by_serializer_field(field, **kwargs)
|
||||||
|
|
||||||
|
def get_group_by_field_unique_value(self, field, field_name, value):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.get_group_by_field_unique_value(
|
||||||
|
field, field_name, value
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_group_by_field_filters_and_annotations(
|
||||||
|
self, field, field_name, base_queryset, value
|
||||||
|
):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field)
|
||||||
|
return baserow_field_type.get_group_by_field_filters_and_annotations(
|
||||||
|
field, field_name, base_queryset, value
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_export_serialized_value(
|
||||||
|
self,
|
||||||
|
row,
|
||||||
|
field_name,
|
||||||
|
cache,
|
||||||
|
files_zip=None,
|
||||||
|
storage=None,
|
||||||
|
):
|
||||||
|
field_object = row.get_field_object(field_name)
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||||
|
return baserow_field_type.get_export_serialized_value(
|
||||||
|
row, field_name, cache, files_zip, storage
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_import_serialized_value(
|
||||||
|
self,
|
||||||
|
row,
|
||||||
|
field_name,
|
||||||
|
value,
|
||||||
|
id_mapping,
|
||||||
|
cache,
|
||||||
|
files_zip=None,
|
||||||
|
storage=None,
|
||||||
|
):
|
||||||
|
field_object = row.get_field_object(field_name)
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||||
|
return baserow_field_type.set_import_serialized_value(
|
||||||
|
row, field_name, value, id_mapping, cache, files_zip, storage
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_export_value(self, value, field_object, rich_value=False):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||||
|
return baserow_field_type.get_export_value(value, field_object, rich_value)
|
||||||
|
|
||||||
|
def get_human_readable_value(self, value, field_object):
|
||||||
|
baserow_field_type = self.get_baserow_field_type(field_object["field"])
|
||||||
|
return baserow_field_type.get_human_readable_value(value, field_object)
|
||||||
|
|
||||||
def _validate_field_kwargs(
|
def _validate_field_kwargs(
|
||||||
self, ai_type, model_type, ai_file_field_id, workspace=None
|
self, ai_output_type, ai_type, model_type, ai_file_field_id, workspace=None
|
||||||
):
|
):
|
||||||
|
ai_field_output_registry.get(ai_output_type)
|
||||||
ai_type = generative_ai_model_type_registry.get(ai_type)
|
ai_type = generative_ai_model_type_registry.get(ai_type)
|
||||||
models = ai_type.get_enabled_models(workspace=workspace)
|
models = ai_type.get_enabled_models(workspace=workspace)
|
||||||
if model_type not in models:
|
if model_type not in models:
|
||||||
|
@ -154,12 +283,19 @@ class AIFieldType(CollationSortMixin, FieldType):
|
||||||
def before_create(
|
def before_create(
|
||||||
self, table, primary, allowed_field_values, order, user, field_kwargs
|
self, table, primary, allowed_field_values, order, user, field_kwargs
|
||||||
):
|
):
|
||||||
|
ai_output_type = field_kwargs.get(
|
||||||
|
"ai_output_type", AIField._meta.get_field("ai_output_type").default
|
||||||
|
)
|
||||||
ai_type = field_kwargs.get("ai_generative_ai_type", None)
|
ai_type = field_kwargs.get("ai_generative_ai_type", None)
|
||||||
model_type = field_kwargs.get("ai_generative_ai_model", None)
|
model_type = field_kwargs.get("ai_generative_ai_model", None)
|
||||||
ai_file_field_id = field_kwargs.get("ai_file_field_id", None)
|
ai_file_field_id = field_kwargs.get("ai_file_field_id", None)
|
||||||
workspace = table.database.workspace
|
workspace = table.database.workspace
|
||||||
self._validate_field_kwargs(
|
self._validate_field_kwargs(
|
||||||
ai_type, model_type, ai_file_field_id, workspace=workspace
|
ai_output_type, ai_type, model_type, ai_file_field_id, workspace=workspace
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().before_create(
|
||||||
|
table, primary, allowed_field_values, order, user, field_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def before_update(self, from_field, to_field_values, user, field_kwargs):
|
def before_update(self, from_field, to_field_values, user, field_kwargs):
|
||||||
|
@ -167,6 +303,11 @@ class AIFieldType(CollationSortMixin, FieldType):
|
||||||
if isinstance(from_field, AIField):
|
if isinstance(from_field, AIField):
|
||||||
update_field = from_field
|
update_field = from_field
|
||||||
|
|
||||||
|
ai_output_type = (
|
||||||
|
field_kwargs.get("ai_output_type", None)
|
||||||
|
or getattr(update_field, "ai_output_type", None)
|
||||||
|
or AIField._meta.get_field("ai_output_type").default
|
||||||
|
)
|
||||||
ai_type = field_kwargs.get("ai_generative_ai_type", None) or getattr(
|
ai_type = field_kwargs.get("ai_generative_ai_type", None) or getattr(
|
||||||
update_field, "ai_generative_ai_type", None
|
update_field, "ai_generative_ai_type", None
|
||||||
)
|
)
|
||||||
|
@ -179,9 +320,11 @@ class AIFieldType(CollationSortMixin, FieldType):
|
||||||
ai_file_field_id = getattr(update_field, "ai_file_field_id", None)
|
ai_file_field_id = getattr(update_field, "ai_file_field_id", None)
|
||||||
workspace = from_field.table.database.workspace
|
workspace = from_field.table.database.workspace
|
||||||
self._validate_field_kwargs(
|
self._validate_field_kwargs(
|
||||||
ai_type, model_type, ai_file_field_id, workspace=workspace
|
ai_output_type, ai_type, model_type, ai_file_field_id, workspace=workspace
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return super().before_update(from_field, to_field_values, user, field_kwargs)
|
||||||
|
|
||||||
def after_import_serialized(
|
def after_import_serialized(
|
||||||
self,
|
self,
|
||||||
field: AIField,
|
field: AIField,
|
||||||
|
@ -209,3 +352,17 @@ class AIFieldType(CollationSortMixin, FieldType):
|
||||||
|
|
||||||
if save:
|
if save:
|
||||||
field.save()
|
field.save()
|
||||||
|
|
||||||
|
def should_backup_field_data_for_same_type_update(
|
||||||
|
self, old_field, new_field_attrs
|
||||||
|
) -> bool:
|
||||||
|
backup = super().should_backup_field_data_for_same_type_update(
|
||||||
|
old_field, new_field_attrs
|
||||||
|
)
|
||||||
|
# Backup the field if the output type changes because
|
||||||
|
ai_output_changed = (
|
||||||
|
"ai_output_type" in new_field_attrs
|
||||||
|
and new_field_attrs["ai_output_type"]
|
||||||
|
and new_field_attrs["ai_output_type"] != old_field.ai_output_type
|
||||||
|
)
|
||||||
|
return backup or ai_output_changed
|
||||||
|
|
|
@ -3,12 +3,34 @@ from django.db import models
|
||||||
from baserow.contrib.database.fields.models import Field
|
from baserow.contrib.database.fields.models import Field
|
||||||
from baserow.core.formula.field import FormulaField as ModelFormulaField
|
from baserow.core.formula.field import FormulaField as ModelFormulaField
|
||||||
|
|
||||||
|
from .ai_field_output_types import TextAIFieldOutputType
|
||||||
|
from .registries import ai_field_output_registry
|
||||||
|
|
||||||
|
|
||||||
class AIField(Field):
|
class AIField(Field):
|
||||||
ai_generative_ai_type = models.CharField(max_length=32, null=True)
|
ai_generative_ai_type = models.CharField(max_length=32, null=True)
|
||||||
ai_generative_ai_model = models.CharField(max_length=32, null=True)
|
ai_generative_ai_model = models.CharField(max_length=32, null=True)
|
||||||
|
ai_output_type = models.CharField(
|
||||||
|
max_length=32,
|
||||||
|
db_default=TextAIFieldOutputType.type,
|
||||||
|
default=TextAIFieldOutputType.type,
|
||||||
|
)
|
||||||
ai_temperature = models.FloatField(null=True)
|
ai_temperature = models.FloatField(null=True)
|
||||||
ai_prompt = ModelFormulaField(default="")
|
ai_prompt = ModelFormulaField(default="")
|
||||||
ai_file_field = models.ForeignKey(
|
ai_file_field = models.ForeignKey(
|
||||||
Field, null=True, on_delete=models.SET_NULL, related_name="ai_field"
|
Field, null=True, on_delete=models.SET_NULL, related_name="ai_field"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
"""
|
||||||
|
When a property is called on the field object, it tries to return the default
|
||||||
|
value of the field object related to the `ai_output_type` `model_class`. This
|
||||||
|
will make it more compatible with the check functions like `check_can_group_by`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
ai_output_type = ai_field_output_registry.get(self.ai_output_type)
|
||||||
|
output_field = ai_output_type.baserow_field_type.model_class
|
||||||
|
return output_field._meta.get_field(name).default
|
||||||
|
except Exception:
|
||||||
|
super().__getattr__(name)
|
||||||
|
|
70
premium/backend/src/baserow_premium/fields/registries.py
Normal file
70
premium/backend/src/baserow_premium/fields/registries.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import abc
|
||||||
|
import typing
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from baserow.contrib.database.fields.models import Field
|
||||||
|
from baserow.core.registry import Instance, Registry
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from baserow_premium.fields.models import AIField
|
||||||
|
|
||||||
|
|
||||||
|
class AIFieldOutputType(abc.ABC, Instance):
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def baserow_field_type(self) -> str:
|
||||||
|
"""
|
||||||
|
The Baserow field type that corresponds to this AI output type and should be
|
||||||
|
used to do various Baserow operations like filtering, sorting, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def format_prompt(self, prompt: str, ai_field: "AIField"):
|
||||||
|
"""
|
||||||
|
Hook that can be used to change and format the provided prompt for this output
|
||||||
|
type. It accepts the original already resolved prompt and should return the
|
||||||
|
updated one.
|
||||||
|
|
||||||
|
It can be used to include the format instructions of an output parser, for
|
||||||
|
example.
|
||||||
|
|
||||||
|
:param prompt: The resolved prompt provided by the user. This already contains
|
||||||
|
the resolved variables.
|
||||||
|
:param ai_field: The AI field related to the output type.
|
||||||
|
:return: Should return the formatted prompt. This can include additional
|
||||||
|
information that can change the outcome of the prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def parse_output(self, output: str, ai_field: "AIField"):
|
||||||
|
"""
|
||||||
|
Hook that can be used to parse the output of the generative AI prompt. If an
|
||||||
|
output parser formatting instructions are added in `format_prompt`, then this
|
||||||
|
hook can be used to parse it.
|
||||||
|
|
||||||
|
:param output: The text output of the generative AI.
|
||||||
|
:param ai_field: The AI field related to the output type.
|
||||||
|
:return: Should return the parsed output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def prepare_data_sync_value(self, value: Any, field: Field, metadata: dict) -> Any:
|
||||||
|
"""
|
||||||
|
Hook that's called when preparing the value in the local Baserow data sync.
|
||||||
|
It's for example used to map the value of the single select option.
|
||||||
|
|
||||||
|
:param value: The original value.
|
||||||
|
:param field: The field that's synced.
|
||||||
|
:param metadata: The metadata related to the datasync property.
|
||||||
|
:return: The updated value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class AIFieldOutputRegistry(Registry):
|
||||||
|
name = "ai_field_output"
|
||||||
|
|
||||||
|
|
||||||
|
ai_field_output_registry: AIFieldOutputRegistry = AIFieldOutputRegistry()
|
|
@ -20,6 +20,7 @@ from baserow.core.handler import CoreHandler
|
||||||
from baserow.core.user.handler import User
|
from baserow.core.user.handler import User
|
||||||
|
|
||||||
from .models import AIField
|
from .models import AIField
|
||||||
|
from .registries import ai_field_output_registry
|
||||||
|
|
||||||
|
|
||||||
@app.task(bind=True, queue="export")
|
@app.task(bind=True, queue="export")
|
||||||
|
@ -28,9 +29,9 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
||||||
|
|
||||||
ai_field = FieldHandler().get_field(
|
ai_field = FieldHandler().get_field(
|
||||||
field_id,
|
field_id,
|
||||||
base_queryset=AIField.objects.all().select_related(
|
base_queryset=AIField.objects.all()
|
||||||
"table__database__workspace"
|
.select_related("table__database__workspace")
|
||||||
),
|
.prefetch_related("select_options"),
|
||||||
)
|
)
|
||||||
table = ai_field.table
|
table = ai_field.table
|
||||||
workspace = table.database.workspace
|
workspace = table.database.workspace
|
||||||
|
@ -71,6 +72,8 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
||||||
)
|
)
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
ai_output_type = ai_field_output_registry.get(ai_field.ai_output_type)
|
||||||
|
|
||||||
for i, row in enumerate(rows):
|
for i, row in enumerate(rows):
|
||||||
context = HumanReadableRowContext(row, exclude_field_ids=[ai_field.id])
|
context = HumanReadableRowContext(row, exclude_field_ids=[ai_field.id])
|
||||||
message = str(
|
message = str(
|
||||||
|
@ -79,6 +82,11 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The AI output type should be able to format the prompt because it can add
|
||||||
|
# additional instructions to it. The choice output type for example adds
|
||||||
|
# additional prompt trying to force the out, for example.
|
||||||
|
message = ai_output_type.format_prompt(message, ai_field)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ai_field.ai_file_field_id is not None and isinstance(
|
if ai_field.ai_file_field_id is not None and isinstance(
|
||||||
generative_ai_model_type, GenerativeAIWithFilesModelType
|
generative_ai_model_type, GenerativeAIWithFilesModelType
|
||||||
|
@ -105,6 +113,12 @@ def generate_ai_values_for_rows(self, user_id: int, field_id: int, row_ids: list
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
temperature=ai_field.ai_temperature,
|
temperature=ai_field.ai_temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Because the AI output type can change the prompt to try to force the
|
||||||
|
# output a certain way, then it should give the opportunity to parse the
|
||||||
|
# output when it's given. With the choice output type, it will try to match
|
||||||
|
# it to a `SelectOption`, for example.
|
||||||
|
value = ai_output_type.parse_output(value, ai_field)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# If the prompt fails once, we should not continue with the other rows.
|
# If the prompt fails once, we should not continue with the other rows.
|
||||||
rows_ai_values_generation_error.send(
|
rows_ai_values_generation_error.send(
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Generated by Django 5.0.9 on 2024-10-28 13:30
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("baserow_premium", "0022_aifield_ai_temperature"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="aifield",
|
||||||
|
name="ai_output_type",
|
||||||
|
field=models.CharField(db_default="text", default="text", max_length=32),
|
||||||
|
),
|
||||||
|
]
|
|
@ -111,7 +111,8 @@ def test_can_export_every_interesting_different_field_to_json(
|
||||||
"uuid": "00000000-0000-4000-8000-000000000001",
|
"uuid": "00000000-0000-4000-8000-000000000001",
|
||||||
"autonumber": 1,
|
"autonumber": 1,
|
||||||
"password": "",
|
"password": "",
|
||||||
"ai": ""
|
"ai": "",
|
||||||
|
"ai_choice": ""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2,
|
"id": 2,
|
||||||
|
@ -230,7 +231,8 @@ def test_can_export_every_interesting_different_field_to_json(
|
||||||
"uuid": "00000000-0000-4000-8000-000000000002",
|
"uuid": "00000000-0000-4000-8000-000000000002",
|
||||||
"autonumber": 2,
|
"autonumber": 2,
|
||||||
"password": true,
|
"password": true,
|
||||||
"ai": "I'm an AI."
|
"ai": "I'm an AI.",
|
||||||
|
"ai_choice": "Object"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
@ -393,6 +395,7 @@ def test_can_export_every_interesting_different_field_to_xml(
|
||||||
<autonumber>1</autonumber>
|
<autonumber>1</autonumber>
|
||||||
<password/>
|
<password/>
|
||||||
<ai/>
|
<ai/>
|
||||||
|
<ai-choice/>
|
||||||
</row>
|
</row>
|
||||||
<row>
|
<row>
|
||||||
<id>2</id>
|
<id>2</id>
|
||||||
|
@ -512,6 +515,7 @@ def test_can_export_every_interesting_different_field_to_xml(
|
||||||
<autonumber>2</autonumber>
|
<autonumber>2</autonumber>
|
||||||
<password>true</password>
|
<password>true</password>
|
||||||
<ai>I'm an AI.</ai>
|
<ai>I'm an AI.</ai>
|
||||||
|
<ai-choice>Object</ai-choice>
|
||||||
</row>
|
</row>
|
||||||
</rows>
|
</rows>
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_dynamic_get_attr(premium_data_fixture):
|
||||||
|
field = premium_data_fixture.create_ai_field(ai_output_type="text")
|
||||||
|
assert field.long_text_enable_rich_text is False
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
field.non_existing_property
|
|
@ -0,0 +1,109 @@
|
||||||
|
import enum
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from baserow_premium.fields.ai_field_output_types import StrictEnumOutputParser
|
||||||
|
from baserow_premium.fields.tasks import generate_ai_values_for_rows
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
|
from baserow.core.generative_ai.registries import (
|
||||||
|
GenerativeAIModelType,
|
||||||
|
generative_ai_model_type_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strict_enum_output_parser():
|
||||||
|
choices = enum.Enum(
|
||||||
|
"Choices",
|
||||||
|
{
|
||||||
|
"OPTION_1": "Object",
|
||||||
|
"OPTION_2": "Animal",
|
||||||
|
"OPTION_3": "Human",
|
||||||
|
"OPTION_4": "A,B,C",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
output_parser = StrictEnumOutputParser(enum=choices)
|
||||||
|
format_instructions = output_parser.get_format_instructions()
|
||||||
|
prompt = "What is a motorcycle?"
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template=prompt + "\n\n{format_instructions}",
|
||||||
|
input_variables=[],
|
||||||
|
partial_variables={"format_instructions": format_instructions},
|
||||||
|
)
|
||||||
|
message = prompt.format()
|
||||||
|
|
||||||
|
assert '["Object", "Animal", "Human", "A,B,C"]' in message
|
||||||
|
|
||||||
|
assert output_parser.parse("Object") == choices.OPTION_1
|
||||||
|
assert output_parser.parse("Animal") == choices.OPTION_2
|
||||||
|
assert output_parser.parse("Human") == choices.OPTION_3
|
||||||
|
assert output_parser.parse("A,B,C") == choices.OPTION_4
|
||||||
|
|
||||||
|
assert output_parser.parse(" Object ") == choices.OPTION_1
|
||||||
|
assert output_parser.parse("'Object'") == choices.OPTION_1
|
||||||
|
assert output_parser.parse("'A'") == choices.OPTION_4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_choice_output_type(premium_data_fixture, api_client):
|
||||||
|
class TestAIChoiceOutputTypeGenerativeAIModelType(GenerativeAIModelType):
|
||||||
|
type = "test_ai_choice_ouput_type"
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
def is_enabled(self, workspace=None):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_enabled_models(self, workspace=None):
|
||||||
|
return ["test_1"]
|
||||||
|
|
||||||
|
def prompt(self, model, prompt, workspace=None, temperature=None):
|
||||||
|
self.i += 1
|
||||||
|
if self.i == 1:
|
||||||
|
# Existing option should be matches based on the string.
|
||||||
|
return "Object"
|
||||||
|
else:
|
||||||
|
return "Else"
|
||||||
|
|
||||||
|
def get_settings_serializer(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
generative_ai_model_type_registry.register(
|
||||||
|
TestAIChoiceOutputTypeGenerativeAIModelType()
|
||||||
|
)
|
||||||
|
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
user = premium_data_fixture.create_user()
|
||||||
|
database = premium_data_fixture.create_database_application(
|
||||||
|
user=user, name="Placeholder"
|
||||||
|
)
|
||||||
|
table = premium_data_fixture.create_database_table(
|
||||||
|
name="Example", database=database
|
||||||
|
)
|
||||||
|
field = premium_data_fixture.create_ai_field(
|
||||||
|
table=table,
|
||||||
|
order=0,
|
||||||
|
name="ai",
|
||||||
|
ai_output_type="choice",
|
||||||
|
ai_generative_ai_type="test_ai_choice_ouput_type",
|
||||||
|
ai_generative_ai_model="test_1",
|
||||||
|
ai_prompt="'Option'",
|
||||||
|
)
|
||||||
|
option_1 = premium_data_fixture.create_select_option(
|
||||||
|
field=field, value="Object", color="red"
|
||||||
|
)
|
||||||
|
option_2 = premium_data_fixture.create_select_option(
|
||||||
|
field=field, value="Else", color="red"
|
||||||
|
)
|
||||||
|
premium_data_fixture.create_select_option(field=field, value="Animal", color="blue")
|
||||||
|
|
||||||
|
model = table.get_model()
|
||||||
|
row_1 = model.objects.create()
|
||||||
|
row_2 = model.objects.create()
|
||||||
|
|
||||||
|
generate_ai_values_for_rows(user.id, field.id, [row_1.id, row_2.id])
|
||||||
|
|
||||||
|
row_1.refresh_from_db()
|
||||||
|
row_2.refresh_from_db()
|
||||||
|
|
||||||
|
assert getattr(row_1, f"field_{field.id}").id == option_1.id
|
||||||
|
assert getattr(row_2, f"field_{field.id}").id == option_2.id
|
|
@ -1,10 +1,12 @@
|
||||||
from django.shortcuts import reverse
|
from django.shortcuts import reverse
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from baserow_premium.fields.field_types import AIFieldType
|
||||||
from baserow_premium.fields.models import AIField
|
from baserow_premium.fields.models import AIField
|
||||||
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
from baserow.contrib.database.fields.handler import FieldHandler
|
from baserow.contrib.database.fields.handler import FieldHandler
|
||||||
|
from baserow.contrib.database.fields.registries import field_type_registry
|
||||||
from baserow.contrib.database.table.handler import TableHandler
|
from baserow.contrib.database.table.handler import TableHandler
|
||||||
from baserow.core.db import specific_iterator
|
from baserow.core.db import specific_iterator
|
||||||
|
|
||||||
|
@ -28,6 +30,7 @@ def test_create_ai_field_type(premium_data_fixture):
|
||||||
ai_prompt="'Who are you?'",
|
ai_prompt="'Who are you?'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert ai_field.ai_output_type == "text" # default value
|
||||||
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
||||||
assert ai_field.ai_generative_ai_model == "test_1"
|
assert ai_field.ai_generative_ai_model == "test_1"
|
||||||
assert ai_field.ai_prompt == "'Who are you?'"
|
assert ai_field.ai_prompt == "'Who are you?'"
|
||||||
|
@ -51,6 +54,7 @@ def test_update_ai_field_type(premium_data_fixture):
|
||||||
ai_prompt="'Who are you?'",
|
ai_prompt="'Who are you?'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert ai_field.ai_output_type == "text" # default value
|
||||||
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
assert ai_field.ai_generative_ai_type == "test_generative_ai"
|
||||||
assert ai_field.ai_generative_ai_model == "test_1"
|
assert ai_field.ai_generative_ai_model == "test_1"
|
||||||
assert ai_field.ai_prompt == "'Who are you?'"
|
assert ai_field.ai_prompt == "'Who are you?'"
|
||||||
|
@ -98,6 +102,68 @@ def test_create_ai_field_type_via_api(premium_data_fixture, api_client):
|
||||||
)
|
)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
assert response.status_code == HTTP_200_OK
|
assert response.status_code == HTTP_200_OK
|
||||||
|
assert response_json["ai_output_type"] == "text"
|
||||||
|
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||||
|
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||||
|
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||||
|
assert response_json["ai_temperature"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_create_ai_field_type_via_api_with_non_existing_ai_output_type(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
user, token = premium_data_fixture.create_user_and_token()
|
||||||
|
table = premium_data_fixture.create_database_table(user=user)
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||||
|
|
||||||
|
response = api_client.post(
|
||||||
|
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||||
|
{
|
||||||
|
"name": "Test 1",
|
||||||
|
"type": "ai",
|
||||||
|
"ai_output_type": "DOES_NOT_EXIST",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_prompt": "'Who are you?'",
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||||
|
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
|
||||||
|
assert response_json["detail"]["ai_output_type"][0]["code"] == "invalid_choice"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_create_ai_field_type_via_api_with_ai_output_type(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
user, token = premium_data_fixture.create_user_and_token()
|
||||||
|
table = premium_data_fixture.create_database_table(user=user)
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||||
|
|
||||||
|
response = api_client.post(
|
||||||
|
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||||
|
{
|
||||||
|
"name": "Test 1",
|
||||||
|
"type": "ai",
|
||||||
|
"ai_output_type": "text",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_prompt": "'Who are you?'",
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
assert response.status_code == HTTP_200_OK
|
||||||
|
assert response_json["ai_output_type"] == "text"
|
||||||
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||||
assert response_json["ai_generative_ai_model"] == "test_1"
|
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||||
assert response_json["ai_prompt"] == "'Who are you?'"
|
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||||
|
@ -191,9 +257,6 @@ def test_update_ai_field_temperature_none_via_api(premium_data_fixture, api_clie
|
||||||
field = premium_data_fixture.create_ai_field(
|
field = premium_data_fixture.create_ai_field(
|
||||||
table=table, order=1, name="name", ai_temperature=0.7
|
table=table, order=1, name="name", ai_temperature=0.7
|
||||||
)
|
)
|
||||||
file_field = premium_data_fixture.create_file_field(
|
|
||||||
table=table, order=2, name="file"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = api_client.patch(
|
response = api_client.patch(
|
||||||
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||||
|
@ -209,6 +272,121 @@ def test_update_ai_field_temperature_none_via_api(premium_data_fixture, api_clie
|
||||||
assert response.json()["ai_temperature"] is None
|
assert response.json()["ai_temperature"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_update_ai_field_via_api_invalid_output_type(premium_data_fixture, api_client):
|
||||||
|
user, token = premium_data_fixture.create_user_and_token()
|
||||||
|
table = premium_data_fixture.create_database_table(user=user)
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
field = premium_data_fixture.create_ai_field(
|
||||||
|
table=table, order=1, name="name", ai_temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
response = api_client.patch(
|
||||||
|
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||||
|
{
|
||||||
|
"ai_output_type": "INVALID_CHOICE",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai_with_files",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_temperature": None,
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
assert response.status_code == HTTP_400_BAD_REQUEST
|
||||||
|
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
|
||||||
|
assert response_json["detail"]["ai_output_type"][0]["code"] == "invalid_choice"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_update_ai_field_via_api_valid_output_type(premium_data_fixture, api_client):
|
||||||
|
user, token = premium_data_fixture.create_user_and_token()
|
||||||
|
table = premium_data_fixture.create_database_table(user=user)
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
field = premium_data_fixture.create_ai_field(
|
||||||
|
table=table, order=1, name="name", ai_temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
response = api_client.patch(
|
||||||
|
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||||
|
{
|
||||||
|
"ai_output_type": "text",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_temperature": None,
|
||||||
|
"ai_prompt": "'Who are you?'",
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
assert response.status_code == HTTP_200_OK
|
||||||
|
assert response_json["ai_output_type"] == "text"
|
||||||
|
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||||
|
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||||
|
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||||
|
assert response_json["ai_temperature"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_update_to_ai_field_with_all_parameters(premium_data_fixture, api_client):
|
||||||
|
user, token = premium_data_fixture.create_user_and_token()
|
||||||
|
table = premium_data_fixture.create_database_table(user=user)
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
field = premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||||
|
|
||||||
|
response = api_client.patch(
|
||||||
|
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"ai_output_type": "text",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_temperature": None,
|
||||||
|
"ai_prompt": "'Who are you?'",
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
assert response.status_code == HTTP_200_OK
|
||||||
|
assert response_json["ai_output_type"] == "text"
|
||||||
|
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||||
|
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||||
|
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||||
|
assert response_json["ai_temperature"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_update_to_ai_field_without_parameters(premium_data_fixture, api_client):
|
||||||
|
user, token = premium_data_fixture.create_user_and_token()
|
||||||
|
table = premium_data_fixture.create_database_table(user=user)
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
field = premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||||
|
|
||||||
|
response = api_client.patch(
|
||||||
|
reverse("api:database:fields:item", kwargs={"field_id": field.id}),
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
assert response.status_code == HTTP_200_OK
|
||||||
|
assert response_json["ai_output_type"] == "text"
|
||||||
|
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||||
|
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||||
|
assert response_json["ai_prompt"] == ""
|
||||||
|
assert response_json["ai_temperature"] is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@pytest.mark.field_ai
|
@pytest.mark.field_ai
|
||||||
def test_create_ai_field_type_via_api_invalid_formula(premium_data_fixture, api_client):
|
def test_create_ai_field_type_via_api_invalid_formula(premium_data_fixture, api_client):
|
||||||
|
@ -606,3 +784,266 @@ def test_duplicate_table_with_ai_field_broken_references(premium_data_fixture):
|
||||||
duplicated_ai_field = duplicated_fields[2]
|
duplicated_ai_field = duplicated_fields[2]
|
||||||
|
|
||||||
assert duplicated_ai_field.ai_prompt == f"concat('test:',get('fields.field_0'))"
|
assert duplicated_ai_field.ai_prompt == f"concat('test:',get('fields.field_0'))"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_can_set_select_options_to_choice_ai_output_type(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
user, token = premium_data_fixture.create_user_and_token()
|
||||||
|
table = premium_data_fixture.create_database_table(user=user)
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
premium_data_fixture.create_text_field(table=table, order=1, name="name")
|
||||||
|
|
||||||
|
response = api_client.post(
|
||||||
|
reverse("api:database:fields:list", kwargs={"table_id": table.id}),
|
||||||
|
{
|
||||||
|
"name": "Test 1",
|
||||||
|
"type": "ai",
|
||||||
|
"ai_output_type": "choice",
|
||||||
|
"ai_generative_ai_type": "test_generative_ai",
|
||||||
|
"ai_generative_ai_model": "test_1",
|
||||||
|
"ai_prompt": "'Who are you?'",
|
||||||
|
"select_options": [
|
||||||
|
{"value": "Small", "color": "red"},
|
||||||
|
{"value": "Medium", "color": "blue"},
|
||||||
|
{"value": "Large", "color": "green"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
HTTP_AUTHORIZATION=f"JWT {token}",
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
assert response.status_code == HTTP_200_OK
|
||||||
|
assert response_json["ai_output_type"] == "choice"
|
||||||
|
assert response_json["ai_generative_ai_type"] == "test_generative_ai"
|
||||||
|
assert response_json["ai_generative_ai_model"] == "test_1"
|
||||||
|
assert response_json["ai_prompt"] == "'Who are you?'"
|
||||||
|
assert len(response_json["select_options"]) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_should_backup(premium_data_fixture, api_client):
|
||||||
|
ai_field_type = field_type_registry.get(AIFieldType.type)
|
||||||
|
ai_field = premium_data_fixture.create_ai_field()
|
||||||
|
file_field = premium_data_fixture.create_file_field(table=ai_field.table)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ai_field_type.should_backup_field_data_for_same_type_update(
|
||||||
|
ai_field,
|
||||||
|
{
|
||||||
|
"ai_generative_ai_type": "test_generative_ai_2",
|
||||||
|
"ai_generative_ai_model": "test_model_2",
|
||||||
|
"ai_prompt": "'New AI prompt'",
|
||||||
|
"ai_output_type": "text", # same as before
|
||||||
|
"ai_temperature": 1,
|
||||||
|
"ai_file_field": file_field,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ai_field_type.should_backup_field_data_for_same_type_update(
|
||||||
|
ai_field,
|
||||||
|
{
|
||||||
|
"ai_generative_ai_type": "test_generative_ai_2",
|
||||||
|
"ai_generative_ai_model": "test_model_2",
|
||||||
|
"ai_prompt": "'New AI prompt'",
|
||||||
|
"ai_output_type": "choice", # new one
|
||||||
|
"ai_temperature": 1,
|
||||||
|
"ai_file_field": file_field,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
) # Expect to make a backup when output type changes.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_can_convert_from_text_output_type_to_choice_output_type(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
user = premium_data_fixture.create_user()
|
||||||
|
database = premium_data_fixture.create_database_application(
|
||||||
|
user=user, name="Placeholder"
|
||||||
|
)
|
||||||
|
table = premium_data_fixture.create_database_table(
|
||||||
|
name="Example", database=database
|
||||||
|
)
|
||||||
|
field = premium_data_fixture.create_ai_field(
|
||||||
|
table=table, order=0, name="ai", ai_output_type="text"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = table.get_model()
|
||||||
|
model.objects.create(**{f"field_{field.id}": "Option 1"})
|
||||||
|
model.objects.create(**{f"field_{field.id}": "Something else"})
|
||||||
|
|
||||||
|
field = FieldHandler().update_field(
|
||||||
|
user=user,
|
||||||
|
field=field,
|
||||||
|
ai_output_type="choice",
|
||||||
|
select_options=[
|
||||||
|
{"value": "Option 1", "color": "red"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
table.refresh_from_db()
|
||||||
|
model = table.get_model()
|
||||||
|
rows = list(model.objects.all())
|
||||||
|
|
||||||
|
# Converting text ai field to choice field should try to convert the text values to
|
||||||
|
# the new choices.
|
||||||
|
assert getattr(rows[0], f"field_{field.id}").value == "Option 1"
|
||||||
|
assert getattr(rows[1], f"field_{field.id}") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_can_convert_from_choice_output_type_to_text_output_type(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
user = premium_data_fixture.create_user()
|
||||||
|
database = premium_data_fixture.create_database_application(
|
||||||
|
user=user, name="Placeholder"
|
||||||
|
)
|
||||||
|
table = premium_data_fixture.create_database_table(
|
||||||
|
name="Example", database=database
|
||||||
|
)
|
||||||
|
field = premium_data_fixture.create_ai_field(
|
||||||
|
table=table, order=0, name="ai", ai_output_type="choice"
|
||||||
|
)
|
||||||
|
select_option = premium_data_fixture.create_select_option(
|
||||||
|
field=field, value="Option 1", color="blue", order=0
|
||||||
|
)
|
||||||
|
|
||||||
|
model = table.get_model()
|
||||||
|
model.objects.create(**{f"field_{field.id}_id": select_option.id})
|
||||||
|
model.objects.create(**{f"field_{field.id}": None})
|
||||||
|
|
||||||
|
field = FieldHandler().update_field(
|
||||||
|
user=user,
|
||||||
|
field=field,
|
||||||
|
ai_output_type="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
table.refresh_from_db()
|
||||||
|
model = table.get_model()
|
||||||
|
rows = list(model.objects.all())
|
||||||
|
|
||||||
|
# Converting choice ai field to text ai field should try to convert the choices to
|
||||||
|
# text values.
|
||||||
|
assert getattr(rows[0], f"field_{field.id}") == "Option 1"
|
||||||
|
assert getattr(rows[1], f"field_{field.id}") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_can_convert_from_text_field_to_text_output_type(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
user = premium_data_fixture.create_user()
|
||||||
|
database = premium_data_fixture.create_database_application(
|
||||||
|
user=user, name="Placeholder"
|
||||||
|
)
|
||||||
|
table = premium_data_fixture.create_database_table(
|
||||||
|
name="Example", database=database
|
||||||
|
)
|
||||||
|
field = premium_data_fixture.create_text_field(table=table, order=0, name="text")
|
||||||
|
|
||||||
|
model = table.get_model()
|
||||||
|
model.objects.create(**{f"field_{field.id}": "Test"})
|
||||||
|
|
||||||
|
field = FieldHandler().update_field(
|
||||||
|
user=user,
|
||||||
|
field=field,
|
||||||
|
new_type_name="ai",
|
||||||
|
ai_generative_ai_type="test_generative_ai",
|
||||||
|
ai_generative_ai_model="test_1",
|
||||||
|
ai_prompt="'test'",
|
||||||
|
ai_output_type="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
table.refresh_from_db()
|
||||||
|
model = table.get_model()
|
||||||
|
rows = list(model.objects.all())
|
||||||
|
|
||||||
|
# Expect the value to be reset because we don't want to keep the existing cell
|
||||||
|
# value when converting from any other field.
|
||||||
|
assert getattr(rows[0], f"field_{field.id}") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_can_convert_from_text_field_to_choice_output_type(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
user = premium_data_fixture.create_user()
|
||||||
|
database = premium_data_fixture.create_database_application(
|
||||||
|
user=user, name="Placeholder"
|
||||||
|
)
|
||||||
|
table = premium_data_fixture.create_database_table(
|
||||||
|
name="Example", database=database
|
||||||
|
)
|
||||||
|
field = premium_data_fixture.create_text_field(table=table, order=0, name="text")
|
||||||
|
|
||||||
|
model = table.get_model()
|
||||||
|
model.objects.create(**{f"field_{field.id}": "Test"})
|
||||||
|
|
||||||
|
field = FieldHandler().update_field(
|
||||||
|
user=user,
|
||||||
|
field=field,
|
||||||
|
new_type_name="ai",
|
||||||
|
ai_generative_ai_type="test_generative_ai",
|
||||||
|
ai_generative_ai_model="test_1",
|
||||||
|
ai_prompt="'test'",
|
||||||
|
ai_output_type="choice",
|
||||||
|
)
|
||||||
|
|
||||||
|
table.refresh_from_db()
|
||||||
|
model = table.get_model()
|
||||||
|
rows = list(model.objects.all())
|
||||||
|
|
||||||
|
# Expect the value to be reset because we don't want to keep the existing cell
|
||||||
|
# value when converting from any other field.
|
||||||
|
assert getattr(rows[0], f"field_{field.id}") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.field_ai
|
||||||
|
def test_can_convert_from_text_output_type_to_text_field(
|
||||||
|
premium_data_fixture, api_client
|
||||||
|
):
|
||||||
|
premium_data_fixture.register_fake_generate_ai_type()
|
||||||
|
user = premium_data_fixture.create_user()
|
||||||
|
database = premium_data_fixture.create_database_application(
|
||||||
|
user=user, name="Placeholder"
|
||||||
|
)
|
||||||
|
table = premium_data_fixture.create_database_table(
|
||||||
|
name="Example", database=database
|
||||||
|
)
|
||||||
|
field = premium_data_fixture.create_ai_field(table=table, order=0, name="ai")
|
||||||
|
|
||||||
|
model = table.get_model()
|
||||||
|
model.objects.create(**{f"field_{field.id}": "Test"})
|
||||||
|
|
||||||
|
field = FieldHandler().update_field(
|
||||||
|
user=user,
|
||||||
|
field=field,
|
||||||
|
new_type_name="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
table.refresh_from_db()
|
||||||
|
model = table.get_model()
|
||||||
|
rows = list(model.objects.all())
|
||||||
|
|
||||||
|
# Converting text ai field to text field should keep the values because the text
|
||||||
|
# field conversion is automatically used.
|
||||||
|
assert getattr(rows[0], f"field_{field.id}") == "Test"
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
import { Registerable } from '@baserow/modules/core/registry'
|
||||||
|
|
||||||
|
import {
|
||||||
|
LongTextFieldType,
|
||||||
|
SingleSelectFieldType,
|
||||||
|
} from '@baserow/modules/database/fieldTypes'
|
||||||
|
import FieldSelectOptionsSubForm from '@baserow/modules/database/components/field/FieldSelectOptionsSubForm.vue'
|
||||||
|
|
||||||
|
export class AIFieldOutputType extends Registerable {
|
||||||
|
/**
|
||||||
|
* A human readable name of the AI output type. This will be shown in in the dropdown
|
||||||
|
* where the user chosen the type.
|
||||||
|
*/
|
||||||
|
getName() {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A human-readable description of the AI output type.
|
||||||
|
*/
|
||||||
|
getDescription() {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor(...args) {
|
||||||
|
super(...args)
|
||||||
|
this.type = this.getType()
|
||||||
|
|
||||||
|
if (this.type === null) {
|
||||||
|
throw new Error('The type name of an admin type must be set.')
|
||||||
|
}
|
||||||
|
if (this.name === null) {
|
||||||
|
throw new Error('The name of an admin type must be set.')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getBaserowFieldType() {
|
||||||
|
throw new Error('The Baserow field type must be set on the AI output type.')
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Can optionally return a form component that will be added to the `FieldForm` is
|
||||||
|
* the output type is chosen.
|
||||||
|
*/
|
||||||
|
getFormComponent() {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class TextAIFieldOutputType extends AIFieldOutputType {
|
||||||
|
static getType() {
|
||||||
|
return 'text'
|
||||||
|
}
|
||||||
|
|
||||||
|
getName() {
|
||||||
|
const { i18n } = this.app
|
||||||
|
return i18n.t('aiOutputType.text')
|
||||||
|
}
|
||||||
|
|
||||||
|
getDescription() {
|
||||||
|
const { i18n } = this.app
|
||||||
|
return i18n.t('aiOutputType.textDescription')
|
||||||
|
}
|
||||||
|
|
||||||
|
getBaserowFieldType() {
|
||||||
|
return this.app.$registry.get('field', LongTextFieldType.getType())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ChoiceAIFieldOutputType extends AIFieldOutputType {
|
||||||
|
static getType() {
|
||||||
|
return 'choice'
|
||||||
|
}
|
||||||
|
|
||||||
|
getName() {
|
||||||
|
const { i18n } = this.app
|
||||||
|
return i18n.t('aiOutputType.choice')
|
||||||
|
}
|
||||||
|
|
||||||
|
getDescription() {
|
||||||
|
const { i18n } = this.app
|
||||||
|
return i18n.t('aiOutputType.choiceDescription')
|
||||||
|
}
|
||||||
|
|
||||||
|
getBaserowFieldType() {
|
||||||
|
return this.app.$registry.get('field', SingleSelectFieldType.getType())
|
||||||
|
}
|
||||||
|
|
||||||
|
getFormComponent() {
|
||||||
|
return FieldSelectOptionsSubForm
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,6 +35,31 @@
|
||||||
/>
|
/>
|
||||||
</Dropdown>
|
</Dropdown>
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
|
|
||||||
|
<FormGroup
|
||||||
|
required
|
||||||
|
small-label
|
||||||
|
:label="$t('fieldAISubForm.outputType')"
|
||||||
|
:help-icon-tooltip="$t('fieldAISubForm.outputTypeTooltip')"
|
||||||
|
>
|
||||||
|
<Dropdown
|
||||||
|
v-model="values.ai_output_type"
|
||||||
|
class="dropdown--floating"
|
||||||
|
:fixed-items="true"
|
||||||
|
>
|
||||||
|
<DropdownItem
|
||||||
|
v-for="outputType in outputTypes"
|
||||||
|
:key="outputType.getType()"
|
||||||
|
:name="outputType.getName()"
|
||||||
|
:value="outputType.getType()"
|
||||||
|
:description="outputType.getDescription()"
|
||||||
|
/>
|
||||||
|
</Dropdown>
|
||||||
|
<template v-if="changedOutputType" #warning>
|
||||||
|
{{ $t('fieldAISubForm.outputTypeChangedWarning') }}
|
||||||
|
</template>
|
||||||
|
</FormGroup>
|
||||||
|
|
||||||
<FormGroup
|
<FormGroup
|
||||||
small-label
|
small-label
|
||||||
:label="$t('fieldAISubForm.prompt')"
|
:label="$t('fieldAISubForm.prompt')"
|
||||||
|
@ -52,6 +77,12 @@
|
||||||
</div>
|
</div>
|
||||||
<template #error> {{ $t('error.requiredField') }}</template>
|
<template #error> {{ $t('error.requiredField') }}</template>
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
|
|
||||||
|
<component
|
||||||
|
:is="outputType.getFormComponent()"
|
||||||
|
ref="childForm"
|
||||||
|
v-bind="$props"
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div v-else>
|
<div v-else>
|
||||||
<p>
|
<p>
|
||||||
|
@ -67,6 +98,7 @@ import form from '@baserow/modules/core/mixins/form'
|
||||||
import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm'
|
import fieldSubForm from '@baserow/modules/database/mixins/fieldSubForm'
|
||||||
import FormulaInputField from '@baserow/modules/core/components/formula/FormulaInputField'
|
import FormulaInputField from '@baserow/modules/core/components/formula/FormulaInputField'
|
||||||
import SelectAIModelForm from '@baserow/modules/core/components/ai/SelectAIModelForm'
|
import SelectAIModelForm from '@baserow/modules/core/components/ai/SelectAIModelForm'
|
||||||
|
import { TextAIFieldOutputType } from '@baserow_premium/aiFieldOutputTypes'
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: 'FieldAISubForm',
|
name: 'FieldAISubForm',
|
||||||
|
@ -74,9 +106,10 @@ export default {
|
||||||
mixins: [form, fieldSubForm],
|
mixins: [form, fieldSubForm],
|
||||||
data() {
|
data() {
|
||||||
return {
|
return {
|
||||||
allowedValues: ['ai_prompt', 'ai_file_field_id'],
|
allowedValues: ['ai_prompt', 'ai_file_field_id', 'ai_output_type'],
|
||||||
values: {
|
values: {
|
||||||
ai_prompt: '',
|
ai_prompt: '',
|
||||||
|
ai_output_type: TextAIFieldOutputType.getType(),
|
||||||
ai_file_field_id: null,
|
ai_file_field_id: null,
|
||||||
},
|
},
|
||||||
fileFieldSupported: false,
|
fileFieldSupported: false,
|
||||||
|
@ -113,6 +146,19 @@ export default {
|
||||||
return t.canRepresentFiles(field)
|
return t.canRepresentFiles(field)
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
outputTypes() {
|
||||||
|
return Object.values(this.$registry.getAll('aiFieldOutputType'))
|
||||||
|
},
|
||||||
|
outputType() {
|
||||||
|
return this.$registry.get('aiFieldOutputType', this.values.ai_output_type)
|
||||||
|
},
|
||||||
|
changedOutputType() {
|
||||||
|
return (
|
||||||
|
this.defaultValues.id &&
|
||||||
|
this.defaultValues.type === this.values.type &&
|
||||||
|
this.defaultValues.ai_output_type !== this.values.ai_output_type
|
||||||
|
)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
setFileFieldSupported(generativeAIType) {
|
setFileFieldSupported(generativeAIType) {
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
<template>
|
<template>
|
||||||
<div class="control__elements">
|
<div>
|
||||||
<FormTextarea
|
<component
|
||||||
ref="input"
|
:is="outputRowEditFieldComponent"
|
||||||
v-model="value"
|
ref="field"
|
||||||
type="text"
|
v-bind="$props"
|
||||||
class="margin-bottom-2"
|
:read-only="true"
|
||||||
:rows="6"
|
></component>
|
||||||
:disabled="true"
|
<div v-if="!readOnly" class="margin-top-2">
|
||||||
/>
|
|
||||||
|
|
||||||
<template v-if="!readOnly">
|
|
||||||
<Button
|
<Button
|
||||||
v-if="isDeactivated && rowIsCreated"
|
v-if="isDeactivated && rowIsCreated"
|
||||||
type="secondary"
|
type="secondary"
|
||||||
|
@ -19,7 +16,7 @@
|
||||||
{{ $t('rowEditFieldAI.generate') }}
|
{{ $t('rowEditFieldAI.generate') }}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
v-if="rowIsCreated"
|
v-else-if="rowIsCreated"
|
||||||
type="secondary"
|
type="secondary"
|
||||||
:loading="generating"
|
:loading="generating"
|
||||||
@click="generate()"
|
@click="generate()"
|
||||||
|
@ -33,7 +30,7 @@
|
||||||
:workspace="workspace"
|
:workspace="workspace"
|
||||||
:name="fieldName"
|
:name="fieldName"
|
||||||
></component>
|
></component>
|
||||||
</template>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
|
@ -48,6 +45,12 @@ export default {
|
||||||
fieldName() {
|
fieldName() {
|
||||||
return this.$registry.get('field', this.field.type).getName()
|
return this.$registry.get('field', this.field.type).getName()
|
||||||
},
|
},
|
||||||
|
outputRowEditFieldComponent() {
|
||||||
|
return this.$registry
|
||||||
|
.get('aiFieldOutputType', this.field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getRowEditFieldComponent(this.field)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
|
@ -23,9 +23,15 @@
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div v-else class="grid-view__cell grid-field-long-text__cell">
|
<component
|
||||||
<div class="grid-field-long-text">{{ props.value }}</div>
|
:is="$options.methods.getFunctionalOutputFieldComponent(parent, props)"
|
||||||
</div>
|
v-else
|
||||||
|
:workspace-id="props.workspaceId"
|
||||||
|
:field="props.field"
|
||||||
|
:value="props.value"
|
||||||
|
:state="props.state"
|
||||||
|
:read-only="props.readOnly"
|
||||||
|
/>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
|
@ -41,6 +47,12 @@ export default {
|
||||||
.get('field', AIFieldType.getType())
|
.get('field', AIFieldType.getType())
|
||||||
.isDeactivated(props.workspaceId)
|
.isDeactivated(props.workspaceId)
|
||||||
},
|
},
|
||||||
|
getFunctionalOutputFieldComponent(parent, props) {
|
||||||
|
return parent.$registry
|
||||||
|
.get('aiFieldOutputType', props.field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getFunctionalGridViewFieldComponent(props.field)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
|
@ -18,40 +18,34 @@
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<component
|
||||||
|
:is="outputGridViewFieldComponent"
|
||||||
v-else
|
v-else
|
||||||
ref="cell"
|
ref="cell"
|
||||||
class="grid-view__cell grid-field-long-text__cell active"
|
v-bind="$props"
|
||||||
:class="{ editing: opened }"
|
:read-only="true"
|
||||||
@keyup.enter="opened = true"
|
|
||||||
>
|
>
|
||||||
<div v-if="!opened" class="grid-field-long-text">{{ value }}</div>
|
<template v-if="!readOnly && editing" #default="{ editing }">
|
||||||
<template v-else>
|
|
||||||
<div class="grid-field-long-text__textarea">
|
|
||||||
{{ value }}
|
|
||||||
</div>
|
|
||||||
<div style="background-color: #fff; padding: 8px">
|
<div style="background-color: #fff; padding: 8px">
|
||||||
<template v-if="!readOnly">
|
<ButtonText
|
||||||
<ButtonText
|
v-if="!isDeactivated"
|
||||||
v-if="!isDeactivated"
|
icon="iconoir-magic-wand"
|
||||||
icon="iconoir-magic-wand"
|
:disabled="!modelAvailable || generating"
|
||||||
:disabled="!modelAvailable || generating"
|
:loading="generating"
|
||||||
:loading="generating"
|
@click.prevent.stop="generate()"
|
||||||
@click.prevent.stop="generate()"
|
>
|
||||||
>
|
{{ $t('gridViewFieldAI.regenerate') }}
|
||||||
{{ $t('gridViewFieldAI.regenerate') }}
|
</ButtonText>
|
||||||
</ButtonText>
|
<ButtonText
|
||||||
<ButtonText
|
v-else
|
||||||
v-else
|
icon="iconoir-lock"
|
||||||
icon="iconoir-lock"
|
@click.prevent.stop="$refs.clickModal.show()"
|
||||||
@click.prevent.stop="$refs.clickModal.show()"
|
>
|
||||||
>
|
{{ $t('gridViewFieldAI.regenerate') }}
|
||||||
{{ $t('gridViewFieldAI.regenerate') }}
|
</ButtonText>
|
||||||
</ButtonText>
|
|
||||||
</template>
|
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
</div>
|
</component>
|
||||||
<component
|
<component
|
||||||
:is="deactivatedClickComponent"
|
:is="deactivatedClickComponent"
|
||||||
v-if="isDeactivated && workspace"
|
v-if="isDeactivated && workspace"
|
||||||
|
@ -75,6 +69,12 @@ export default {
|
||||||
fieldName() {
|
fieldName() {
|
||||||
return this.$registry.get('field', this.field.type).getName()
|
return this.$registry.get('field', this.field.type).getName()
|
||||||
},
|
},
|
||||||
|
outputGridViewFieldComponent() {
|
||||||
|
return this.$registry
|
||||||
|
.get('aiFieldOutputType', this.field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getGridViewFieldComponent(this.field)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
save() {
|
save() {
|
||||||
|
|
|
@ -2,13 +2,6 @@ import {
|
||||||
FieldType,
|
FieldType,
|
||||||
FormulaFieldType,
|
FormulaFieldType,
|
||||||
} from '@baserow/modules/database/fieldTypes'
|
} from '@baserow/modules/database/fieldTypes'
|
||||||
import RowHistoryFieldText from '@baserow/modules/database/components/row/RowHistoryFieldText'
|
|
||||||
import RowCardFieldText from '@baserow/modules/database/components/card/RowCardFieldText'
|
|
||||||
import { collatedStringCompare } from '@baserow/modules/core/utils/string'
|
|
||||||
import {
|
|
||||||
genericContainsFilter,
|
|
||||||
genericContainsWordFilter,
|
|
||||||
} from '@baserow/modules/database/utils/fieldFilters'
|
|
||||||
|
|
||||||
import GridViewFieldAI from '@baserow_premium/components/views/grid/fields/GridViewFieldAI'
|
import GridViewFieldAI from '@baserow_premium/components/views/grid/fields/GridViewFieldAI'
|
||||||
import FunctionalGridViewFieldAI from '@baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI'
|
import FunctionalGridViewFieldAI from '@baserow_premium/components/views/grid/fields/FunctionalGridViewFieldAI'
|
||||||
|
@ -33,6 +26,10 @@ export class AIFieldType extends FieldType {
|
||||||
return i18n.t('premiumFieldType.ai')
|
return i18n.t('premiumFieldType.ai')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getIsReadOnly() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
getGridViewFieldComponent() {
|
getGridViewFieldComponent() {
|
||||||
return GridViewFieldAI
|
return GridViewFieldAI
|
||||||
}
|
}
|
||||||
|
@ -45,18 +42,21 @@ export class AIFieldType extends FieldType {
|
||||||
return RowEditFieldAI
|
return RowEditFieldAI
|
||||||
}
|
}
|
||||||
|
|
||||||
getCardComponent() {
|
|
||||||
return RowCardFieldText
|
|
||||||
}
|
|
||||||
|
|
||||||
getRowHistoryEntryComponent() {
|
|
||||||
return RowHistoryFieldText
|
|
||||||
}
|
|
||||||
|
|
||||||
getFormComponent() {
|
getFormComponent() {
|
||||||
return FieldAISubForm
|
return FieldAISubForm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getCardComponent(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getCardComponent(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getRowHistoryEntryComponent(field) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
getFormViewFieldComponents(field) {
|
getFormViewFieldComponents(field) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
@ -65,17 +65,32 @@ export class AIFieldType extends FieldType {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
getSort(name, order) {
|
getCardValueHeight(field) {
|
||||||
return (a, b) => {
|
return this.app.$registry
|
||||||
const stringA = a[name] === null ? '' : '' + a[name]
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
const stringB = b[name] === null ? '' : '' + b[name]
|
.getBaserowFieldType()
|
||||||
|
.getCardValueHeight(field)
|
||||||
|
}
|
||||||
|
|
||||||
return collatedStringCompare(stringA, stringB, order)
|
getSort(name, order, field) {
|
||||||
}
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getSort(name, order, field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getCanSortInView(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getCanSortInView(field)
|
||||||
}
|
}
|
||||||
|
|
||||||
getDocsDataType(field) {
|
getDocsDataType(field) {
|
||||||
return 'string'
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getDocsDataType(field)
|
||||||
}
|
}
|
||||||
|
|
||||||
getDocsDescription(field) {
|
getDocsDescription(field) {
|
||||||
|
@ -84,15 +99,147 @@ export class AIFieldType extends FieldType {
|
||||||
}
|
}
|
||||||
|
|
||||||
getDocsRequestExample(field) {
|
getDocsRequestExample(field) {
|
||||||
return 'string'
|
return 'read only'
|
||||||
}
|
}
|
||||||
|
|
||||||
getContainsFilterFunction() {
|
getDocsResponseExample(field) {
|
||||||
return genericContainsFilter
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getDocsResponseExample(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
prepareValueForCopy(field, value) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.prepareValueForCopy(field, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
getContainsFilterFunction(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getContainsFilterFunction(field)
|
||||||
}
|
}
|
||||||
|
|
||||||
getContainsWordFilterFunction(field) {
|
getContainsWordFilterFunction(field) {
|
||||||
return genericContainsWordFilter
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getContainsWordFilterFunction(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
toHumanReadableString(field, value) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.toHumanReadableString(field, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
getSortIndicator(field, registry) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getSortIndicator(field, registry)
|
||||||
|
}
|
||||||
|
|
||||||
|
canRepresentDate(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.canRepresentDate(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getCanGroupByInView(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getCanGroupByInView(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
parseInputValue(field, value) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.parseInputValue(field, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
canRepresentFiles(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.canRepresentFiles(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getHasEmptyValueFilterFunction(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getHasEmptyValueFilterFunction(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getHasValueContainsFilterFunction(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getHasValueContainsFilterFunction(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getHasValueContainsWordFilterFunction(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getHasValueContainsWordFilterFunction(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getHasValueLengthIsLowerThanFilterFunction(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getHasValueLengthIsLowerThanFilterFunction(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getGroupByComponent(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getGroupByComponent(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
getGroupByIndicator(field, registry) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getGroupByIndicator(field, registry)
|
||||||
|
}
|
||||||
|
|
||||||
|
getRowValueFromGroupValue(field, value) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getRowValueFromGroupValue(field, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
getGroupValueFromRowValue(field, value) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.getGroupValueFromRowValue(field, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
isEqual(field, value1, value2) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.isEqual(field, value1, value2)
|
||||||
|
}
|
||||||
|
|
||||||
|
canBeReferencedByFormulaField(field) {
|
||||||
|
return this.app.$registry
|
||||||
|
.get('aiFieldOutputType', field.ai_output_type)
|
||||||
|
.getBaserowFieldType()
|
||||||
|
.canBeReferencedByFormulaField(field)
|
||||||
}
|
}
|
||||||
|
|
||||||
getGridViewContextItemsOnCellsSelection(field) {
|
getGridViewContextItemsOnCellsSelection(field) {
|
||||||
|
|
|
@ -409,7 +409,10 @@
|
||||||
"promptPlaceholder": "What is Baserow?",
|
"promptPlaceholder": "What is Baserow?",
|
||||||
"premiumFeature": "The AI field is a premium feature",
|
"premiumFeature": "The AI field is a premium feature",
|
||||||
"emptyFileField": "None",
|
"emptyFileField": "None",
|
||||||
"fileFieldHelp": "The first compatible file in the field will be used as the knowledge base for the prompt. The file has to be a text file with the supported file extension like .txt, .md, .pdf, .docx."
|
"fileFieldHelp": "The first compatible file in the field will be used as the knowledge base for the prompt. The file has to be a text file with the supported file extension like .txt, .md, .pdf, .docx.",
|
||||||
|
"outputType": "Output type",
|
||||||
|
"outputTypeTooltip": "Select your desired output format to guide the LLM in generating responses that match your desired options.",
|
||||||
|
"outputTypeChangedWarning": "Clears the generated cell values."
|
||||||
},
|
},
|
||||||
"rowEditFieldAI": {
|
"rowEditFieldAI": {
|
||||||
"generate": "Generate",
|
"generate": "Generate",
|
||||||
|
@ -426,5 +429,11 @@
|
||||||
"formulaFieldAI": {
|
"formulaFieldAI": {
|
||||||
"generateWithAI": "Generate using AI",
|
"generateWithAI": "Generate using AI",
|
||||||
"featureName": "Generate formula using AI"
|
"featureName": "Generate formula using AI"
|
||||||
|
},
|
||||||
|
"aiOutputType": {
|
||||||
|
"text": "Text",
|
||||||
|
"textDescription": "Generates free text based on the prompt.",
|
||||||
|
"choice": "Choice",
|
||||||
|
"choiceDescription": "Chooses only one of the field options."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,6 +52,10 @@ import {
|
||||||
AIFieldType,
|
AIFieldType,
|
||||||
PremiumFormulaFieldType,
|
PremiumFormulaFieldType,
|
||||||
} from '@baserow_premium/fieldTypes'
|
} from '@baserow_premium/fieldTypes'
|
||||||
|
import {
|
||||||
|
ChoiceAIFieldOutputType,
|
||||||
|
TextAIFieldOutputType,
|
||||||
|
} from '@baserow_premium/aiFieldOutputTypes'
|
||||||
|
|
||||||
export default (context) => {
|
export default (context) => {
|
||||||
const { store, app, isDev } = context
|
const { store, app, isDev } = context
|
||||||
|
@ -94,6 +98,8 @@ export default (context) => {
|
||||||
store.registerModule('template/view/timeline', timelineStore)
|
store.registerModule('template/view/timeline', timelineStore)
|
||||||
store.registerModule('impersonating', impersonatingStore)
|
store.registerModule('impersonating', impersonatingStore)
|
||||||
|
|
||||||
|
app.$registry.registerNamespace('aiFieldOutputType')
|
||||||
|
|
||||||
app.$registry.register('plugin', new PremiumPlugin(context))
|
app.$registry.register('plugin', new PremiumPlugin(context))
|
||||||
app.$registry.register('admin', new DashboardType(context))
|
app.$registry.register('admin', new DashboardType(context))
|
||||||
app.$registry.register('admin', new UsersAdminType(context))
|
app.$registry.register('admin', new UsersAdminType(context))
|
||||||
|
@ -160,4 +166,13 @@ export default (context) => {
|
||||||
'rowModalSidebar',
|
'rowModalSidebar',
|
||||||
new CommentsRowModalSidebarType(context)
|
new CommentsRowModalSidebarType(context)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.$registry.register(
|
||||||
|
'aiFieldOutputType',
|
||||||
|
new TextAIFieldOutputType(context)
|
||||||
|
)
|
||||||
|
app.$registry.register(
|
||||||
|
'aiFieldOutputType',
|
||||||
|
new ChoiceAIFieldOutputType(context)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,6 +180,7 @@ export class KanbanViewType extends PremiumViewType {
|
||||||
row,
|
row,
|
||||||
values,
|
values,
|
||||||
metadata,
|
metadata,
|
||||||
|
updatedFieldIds,
|
||||||
storePrefix = ''
|
storePrefix = ''
|
||||||
) {
|
) {
|
||||||
if (this.isCurrentView(store, tableId)) {
|
if (this.isCurrentView(store, tableId)) {
|
||||||
|
@ -386,6 +387,7 @@ export class CalendarViewType extends PremiumViewType {
|
||||||
row,
|
row,
|
||||||
values,
|
values,
|
||||||
metadata,
|
metadata,
|
||||||
|
updatedFieldIds,
|
||||||
storePrefix = ''
|
storePrefix = ''
|
||||||
) {
|
) {
|
||||||
if (this.isCurrentView(store, tableId)) {
|
if (this.isCurrentView(store, tableId)) {
|
||||||
|
|
|
@ -1,3 +1,13 @@
|
||||||
|
.grid-field-single-select__cell {
|
||||||
|
&.active {
|
||||||
|
bottom: auto;
|
||||||
|
right: auto;
|
||||||
|
height: auto;
|
||||||
|
min-height: calc(100% + 4px);
|
||||||
|
min-width: calc(100% + 4px);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
.grid-field-single-select {
|
.grid-field-single-select {
|
||||||
position: relative;
|
position: relative;
|
||||||
display: block;
|
display: block;
|
||||||
|
|
|
@ -84,7 +84,7 @@ export default {
|
||||||
const isNotThisField = f.id !== this.defaultValues.id
|
const isNotThisField = f.id !== this.defaultValues.id
|
||||||
const canBeReferencedByFormulaField = this.$registry
|
const canBeReferencedByFormulaField = this.$registry
|
||||||
.get('field', f.type)
|
.get('field', f.type)
|
||||||
.canBeReferencedByFormulaField()
|
.canBeReferencedByFormulaField(f)
|
||||||
return isNotThisField && canBeReferencedByFormulaField
|
return isNotThisField && canBeReferencedByFormulaField
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
|
|
@ -114,7 +114,7 @@ export default {
|
||||||
.filter((f) => {
|
.filter((f) => {
|
||||||
return this.$registry
|
return this.$registry
|
||||||
.get('field', f.type)
|
.get('field', f.type)
|
||||||
.canBeReferencedByFormulaField()
|
.canBeReferencedByFormulaField(f)
|
||||||
})
|
})
|
||||||
.filter((f) => {
|
.filter((f) => {
|
||||||
return this.$hasPermission(
|
return this.$hasPermission(
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
<template functional>
|
<template functional>
|
||||||
<div ref="cell" class="grid-view__cell" :class="data.staticClass || ''">
|
<div
|
||||||
|
ref="cell"
|
||||||
|
class="grid-view__cell grid-field-single-select__cell"
|
||||||
|
:class="data.staticClass || ''"
|
||||||
|
>
|
||||||
<div class="grid-field-single-select">
|
<div class="grid-field-single-select">
|
||||||
<div
|
<div
|
||||||
v-if="props.value"
|
v-if="props.value"
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
class="grid-field-long-text__textarea"
|
class="grid-field-long-text__textarea"
|
||||||
/>
|
/>
|
||||||
<div v-else class="grid-field-long-text__textarea">{{ value }}</div>
|
<div v-else class="grid-field-long-text__textarea">{{ value }}</div>
|
||||||
|
<slot name="default" :slot-props="{ editing, opened }"></slot>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
<template>
|
<template>
|
||||||
<div ref="cell" class="grid-view__cell active">
|
<div ref="cell" class="grid-view__cell grid-field-single-select__cell active">
|
||||||
<div
|
<div
|
||||||
ref="dropdownLink"
|
ref="dropdownLink"
|
||||||
class="grid-field-single-select grid-field-single-select--selected"
|
class="grid-field-single-select grid-field-single-select--selected"
|
||||||
|
@ -18,6 +18,7 @@
|
||||||
class="iconoir-nav-arrow-down grid-field-single-select__icon"
|
class="iconoir-nav-arrow-down grid-field-single-select__icon"
|
||||||
></i>
|
></i>
|
||||||
</div>
|
</div>
|
||||||
|
<slot name="default" :slot-props="{ editing, opened: true }"></slot>
|
||||||
<FieldSelectOptionsDropdown
|
<FieldSelectOptionsDropdown
|
||||||
v-if="!readOnly"
|
v-if="!readOnly"
|
||||||
ref="dropdown"
|
ref="dropdown"
|
||||||
|
|
|
@ -727,7 +727,7 @@ export class FieldType extends Registerable {
|
||||||
* Override and return true if the field type can be referenced by a formula field.
|
* Override and return true if the field type can be referenced by a formula field.
|
||||||
* @return {boolean}
|
* @return {boolean}
|
||||||
*/
|
*/
|
||||||
canBeReferencedByFormulaField() {
|
canBeReferencedByFormulaField(field) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -202,6 +202,7 @@ export const registerRealtimeEvents = (realtime) => {
|
||||||
rowBeforeUpdate,
|
rowBeforeUpdate,
|
||||||
row,
|
row,
|
||||||
data.metadata[row.id],
|
data.metadata[row.id],
|
||||||
|
data.updated_field_ids,
|
||||||
'page/'
|
'page/'
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2506,7 +2506,7 @@ export const actions = {
|
||||||
*/
|
*/
|
||||||
async updatedExistingRow(
|
async updatedExistingRow(
|
||||||
{ commit, getters, dispatch },
|
{ commit, getters, dispatch },
|
||||||
{ view, fields, row, values, metadata }
|
{ view, fields, row, values, metadata, updatedFieldIds = [] }
|
||||||
) {
|
) {
|
||||||
const oldRow = clone(row)
|
const oldRow = clone(row)
|
||||||
const newRow = Object.assign(clone(row), values)
|
const newRow = Object.assign(clone(row), values)
|
||||||
|
@ -2646,15 +2646,21 @@ export const actions = {
|
||||||
// sure the loading state will stop if the value is updated. This is done even
|
// sure the loading state will stop if the value is updated. This is done even
|
||||||
// if the row is not found in the buffer because it could have been removed from
|
// if the row is not found in the buffer because it could have been removed from
|
||||||
// the buffer when scrolling outside the buffer range.
|
// the buffer when scrolling outside the buffer range.
|
||||||
const updatedFieldIds = Object.entries(values)
|
const getFieldId = (key) => parseInt(key.split('_')[1])
|
||||||
|
const fieldIdsToClearPendingOperationsFor = Object.entries(values)
|
||||||
.filter(
|
.filter(
|
||||||
([key, value]) =>
|
([key, value]) =>
|
||||||
key.startsWith('field_') && !_.isEqual(value, oldRow[key])
|
key.startsWith('field_') &&
|
||||||
|
// Either the value has changed.
|
||||||
|
(_.isEqual(value, oldRow[key]) ||
|
||||||
|
// Or the backend has just recalculated the value, even if it hasn't
|
||||||
|
// actually changed.
|
||||||
|
updatedFieldIds.includes(getFieldId(key)))
|
||||||
)
|
)
|
||||||
.map(([key, value]) => parseInt(key.split('_')[1]))
|
.map(([key, value]) => getFieldId(key))
|
||||||
|
|
||||||
commit('CLEAR_PENDING_FIELD_OPERATIONS', {
|
commit('CLEAR_PENDING_FIELD_OPERATIONS', {
|
||||||
fieldIds: updatedFieldIds,
|
fieldIds: fieldIdsToClearPendingOperationsFor,
|
||||||
rowId: row.id,
|
rowId: row.id,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -267,7 +267,16 @@ export class ViewType extends Registerable {
|
||||||
* via a real time event by another user. It can be used to check if data in an store
|
* via a real time event by another user. It can be used to check if data in an store
|
||||||
* needs to be updated.
|
* needs to be updated.
|
||||||
*/
|
*/
|
||||||
rowUpdated(context, tableId, fields, row, values, metadata, storePrefix) {}
|
rowUpdated(
|
||||||
|
context,
|
||||||
|
tableId,
|
||||||
|
fields,
|
||||||
|
row,
|
||||||
|
values,
|
||||||
|
metadata,
|
||||||
|
updatedFieldIds,
|
||||||
|
storePrefix
|
||||||
|
) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Event that is called when something went wrong while generating AI values
|
* Event that is called when something went wrong while generating AI values
|
||||||
|
@ -722,6 +731,7 @@ export class GridViewType extends ViewType {
|
||||||
row,
|
row,
|
||||||
values,
|
values,
|
||||||
metadata,
|
metadata,
|
||||||
|
updatedFieldIds,
|
||||||
storePrefix = ''
|
storePrefix = ''
|
||||||
) {
|
) {
|
||||||
if (this.isCurrentView(store, tableId)) {
|
if (this.isCurrentView(store, tableId)) {
|
||||||
|
@ -731,6 +741,7 @@ export class GridViewType extends ViewType {
|
||||||
row,
|
row,
|
||||||
values,
|
values,
|
||||||
metadata,
|
metadata,
|
||||||
|
updatedFieldIds,
|
||||||
})
|
})
|
||||||
await store.dispatch(storePrefix + 'view/grid/fetchByScrollTopDelayed', {
|
await store.dispatch(storePrefix + 'view/grid/fetchByScrollTopDelayed', {
|
||||||
scrollTop: store.getters[storePrefix + 'view/grid/getScrollTop'],
|
scrollTop: store.getters[storePrefix + 'view/grid/getScrollTop'],
|
||||||
|
@ -971,6 +982,7 @@ export const BaseBufferedRowViewTypeMixin = (Base) =>
|
||||||
row,
|
row,
|
||||||
values,
|
values,
|
||||||
metadata,
|
metadata,
|
||||||
|
updatedFieldIds,
|
||||||
storePrefix = ''
|
storePrefix = ''
|
||||||
) {
|
) {
|
||||||
if (this.isCurrentView(store, tableId)) {
|
if (this.isCurrentView(store, tableId)) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue